Skip to content

Instantly share code, notes, and snippets.

@sdan
Created October 19, 2025 13:06
Show Gist options
  • Select an option

  • Save sdan/fece057b42e54a746859d72a96092158 to your computer and use it in GitHub Desktop.

Select an option

Save sdan/fece057b42e54a746859d72a96092158 to your computer and use it in GitHub Desktop.
"""
Minimal DPO (Direct Preference Optimization) Implementation for Ranking
Across a distribution of candidates, when a user picks one, we learn to rank that one more often
Key components:
1. PolicyState: Input features for each candidate item
2. PolicyHead: Neural network that scores candidates
3. PreferenceSample: Records user preference (A preferred over B)
4. DPOTrainer: Implements the DPO loss to update the policy
Based on: https://arxiv.org/abs/2305.18290
"""
from __future__ import annotations
from dataclasses import dataclass
from typing import Dict, List, Optional
import torch
import torch.nn as nn
import torch.nn.functional as F
@dataclass
class PolicyState:
"""Input features for ranking candidates.
In this example:
- z_img: image embedding (single vector)
- z_locs: location embeddings for each candidate (N vectors)
- sims: similarity scores between image and each location
- base_probs: initial ranking scores from a base model
"""
z_img: torch.Tensor # shape: (embedding_dim,)
z_locs: torch.Tensor # shape: (num_candidates, embedding_dim)
sims: torch.Tensor # shape: (num_candidates,)
base_probs: torch.Tensor # shape: (num_candidates,)
class PolicyHead(nn.Module):
"""Neural network that scores each candidate item.
Input: concatenation of [image_embedding, location_embedding, similarity, base_prob]
Output: scalar score for each candidate
"""
def __init__(
self,
embedding_dim: int = 512,
hidden_dim: int = 256,
) -> None:
super().__init__()
# Input = image_emb + location_emb + similarity + base_prob
input_dim = embedding_dim * 2 + 2
self.rank_mlp = nn.Sequential(
nn.Linear(input_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, 1), # Output single score per candidate
)
def forward(self, state: PolicyState) -> torch.Tensor:
"""Score all candidates given the policy state.
Returns: shape (num_candidates,) - one score per candidate
"""
z_img = state.z_img
if z_img.ndim == 1:
z_img = z_img.unsqueeze(0) # (1, embedding_dim)
z_img_expanded = z_img.expand_as(state.z_locs) # (num_candidates, embedding_dim)
# Concatenate all features for each candidate
rank_input = torch.cat(
[
z_img_expanded, # image embedding repeated
state.z_locs, # location embedding
state.sims.unsqueeze(1), # similarity score
state.base_probs.unsqueeze(1), # base model probability
],
dim=1,
) # shape: (num_candidates, input_dim)
return self.rank_mlp(rank_input).squeeze(-1) # (num_candidates,)
@dataclass
class PreferenceSample:
"""Records a human preference: some items were preferred over others.
Example: User clicked on candidate 2 when shown [0,1,2,3,4] ranked by policy.
This means candidate 2 is preferred over candidates 3 and 4 (ranked below it).
pos_indices: indices of preferred candidates
neg_indices: indices of dispreferred candidates
reference_logits: scores from the frozen reference policy (for DPO loss)
"""
state: PolicyState
pos_indices: torch.Tensor # shape: (num_pairs,) - preferred items
neg_indices: torch.Tensor # shape: (num_pairs,) - dispreferred items
reference_logits: torch.Tensor # shape: (num_candidates,) - frozen reference scores
class DPOTrainer:
"""Direct Preference Optimization trainer for ranking policy.
DPO Key Idea:
- Maintain two models: policy (trainable) and reference (frozen)
- When user prefers A over B:
* Increase policy score gap: score(A) - score(B)
* But penalize deviating too far from reference policy
- This prevents overfitting to individual preferences
Math:
loss = -log_sigmoid(beta * [(policy(A) - policy(B)) - (ref(A) - ref(B))])
Args:
beta: Controls how much we allow policy to deviate from reference
Higher beta = stay closer to reference (more conservative)
"""
def __init__(
self,
policy_head: PolicyHead,
reference_head: PolicyHead,
optimizer: torch.optim.Optimizer,
device: torch.device,
beta: float = 1.0,
grad_clip: Optional[float] = None,
) -> None:
self.policy_head = policy_head
self.reference_head = reference_head
self.optimizer = optimizer
self.device = device
self.beta = beta
self.grad_clip = grad_clip
# Move models to device
self.policy_head.to(device)
self.reference_head.to(device)
# Freeze reference policy - it should never be updated during DPO
for param in self.reference_head.parameters():
param.requires_grad_(False)
self.reference_head.eval()
# Buffer stores preference samples until we do a training update
self.buffer: List[PreferenceSample] = []
def store_preference(self, sample: PreferenceSample) -> None:
"""Add a preference sample to the buffer."""
self.buffer.append(sample)
def update(self) -> Dict[str, float]:
"""Perform one DPO training step on all buffered preferences.
For each preference pair (preferred_item, dispreferred_item):
1. Compute current policy scores for both items
2. Compute reference policy scores for both items (frozen)
3. Calculate how much policy changed relative to reference:
delta = [policy(preferred) - policy(dispreferred)] - [ref(preferred) - ref(dispreferred)]
4. DPO loss encourages delta > 0 (prefer preferred item MORE than reference does)
loss = -log_sigmoid(beta * delta)
Returns:
Dictionary with training metrics (loss, grad_norm, etc.)
"""
if not self.buffer:
return {"buffer_size": 0, "pairs": 0, "loss": 0.0, "grad_norm": 0.0}
buffer_size = len(self.buffer)
self.policy_head.train()
self.optimizer.zero_grad(set_to_none=True)
losses: List[torch.Tensor] = []
total_pairs = 0
# Process each preference sample in the buffer
for sample in self.buffer:
pos_indices = sample.pos_indices.to(self.device)
neg_indices = sample.neg_indices.to(self.device)
if pos_indices.numel() == 0 or neg_indices.numel() == 0:
continue
# Get current policy scores
state = PolicyState(
z_img=sample.state.z_img.to(self.device),
z_locs=sample.state.z_locs.to(self.device),
sims=sample.state.sims.to(self.device),
base_probs=sample.state.base_probs.to(self.device),
)
current_logits = self.policy_head(state) # Current policy scores
# Get reference policy scores (frozen)
reference_logits = sample.reference_logits.to(self.device)
# Extract scores for preferred and dispreferred items
pos_scores = current_logits.gather(0, pos_indices) # Policy scores for preferred items
neg_scores = current_logits.gather(0, neg_indices) # Policy scores for dispreferred items
ref_pos = reference_logits.gather(0, pos_indices) # Reference scores for preferred items
ref_neg = reference_logits.gather(0, neg_indices) # Reference scores for dispreferred items
# DPO Loss: encourage policy to prefer preferred items MORE than reference does
# delta > 0 means policy ranks preferred item higher than reference does
delta = (pos_scores - neg_scores) - (ref_pos - ref_neg)
loss = -F.logsigmoid(self.beta * delta)
losses.append(loss)
total_pairs += int(loss.numel())
if not losses:
self.buffer.clear()
return {"buffer_size": buffer_size, "pairs": 0, "loss": 0.0, "grad_norm": 0.0}
# Backprop and optimize
total_loss = torch.cat(losses).mean()
total_loss.backward()
# Optional gradient clipping
if self.grad_clip is not None:
grad_norm_tensor = torch.nn.utils.clip_grad_norm_(
self.policy_head.parameters(), self.grad_clip
)
grad_norm = float(grad_norm_tensor.detach().cpu().item())
else:
# Compute gradient norm for logging
grad_sq_sum = 0.0
for parameter in self.policy_head.parameters():
if parameter.grad is None:
continue
grad_sq_sum += parameter.grad.detach().pow(2).sum().item()
grad_norm = grad_sq_sum**0.5 if grad_sq_sum > 0.0 else 0.0
self.optimizer.step()
self.optimizer.zero_grad(set_to_none=True)
self.policy_head.eval()
self.buffer.clear()
return {
"buffer_size": buffer_size,
"pairs": total_pairs,
"loss": float(total_loss.detach().cpu().item()),
"grad_norm": float(grad_norm),
}
def refresh_reference(self) -> None:
"""Periodically sync reference policy with current policy.
In practice, you might do this every N updates to keep the reference
from getting too stale. This prevents the policy from drifting too far
from where it started.
"""
self.reference_head.load_state_dict(self.policy_head.state_dict())
self.reference_head.to(self.device)
self.reference_head.eval()
for param in self.reference_head.parameters():
param.requires_grad_(False)
# ============================================================================
# Example Usage
# ============================================================================
def example_usage():
"""Demonstrate how to use the DPO trainer for ranking.
Scenario: We have 10 candidate locations for an image.
User clicks on candidate 2, which is ranked 5th by our current policy.
This creates preference pairs: (2 > 5), (2 > 6), (2 > 7), (2 > 8), (2 > 9)
meaning candidate 2 should be ranked higher than all items below it.
"""
# Initialize models
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
policy = PolicyHead(embedding_dim=512, hidden_dim=256)
reference = PolicyHead(embedding_dim=512, hidden_dim=256)
reference.load_state_dict(policy.state_dict()) # Start with same weights
optimizer = torch.optim.Adam(policy.parameters(), lr=1e-4)
trainer = DPOTrainer(
policy_head=policy,
reference_head=reference,
optimizer=optimizer,
device=device,
beta=1.0, # Controls KL penalty
grad_clip=1.0,
)
# Create dummy data for one ranking scenario
num_candidates = 10
embedding_dim = 512
state = PolicyState(
z_img=torch.randn(embedding_dim), # Image embedding
z_locs=torch.randn(num_candidates, embedding_dim), # Location embeddings
sims=torch.randn(num_candidates), # Similarities
base_probs=torch.softmax(torch.randn(num_candidates), dim=0), # Base probabilities
)
# Get reference scores (frozen)
with torch.no_grad():
state_on_device = PolicyState(
z_img=state.z_img.to(device),
z_locs=state.z_locs.to(device),
sims=state.sims.to(device),
base_probs=state.base_probs.to(device),
)
reference_logits = reference(state_on_device).cpu()
# Simulate user preference: clicked candidate 2 when it was ranked 5th
# This means candidate 2 is preferred over candidates ranked 6-10
clicked_index = 2
current_ranking = torch.tensor([0, 1, 3, 4, 2, 5, 6, 7, 8, 9]) # Current rank order
position = (current_ranking == clicked_index).nonzero(as_tuple=True)[0].item()
# Create preference pairs: clicked item vs all items ranked below it
dispreferred_indices = current_ranking[position + 1 :] # Items ranked worse
preferred_indices = torch.full_like(dispreferred_indices, fill_value=clicked_index)
sample = PreferenceSample(
state=state,
pos_indices=preferred_indices, # [2, 2, 2, 2, 2]
neg_indices=dispreferred_indices, # [5, 6, 7, 8, 9]
reference_logits=reference_logits,
)
# Store preference and update
trainer.store_preference(sample)
metrics = trainer.update()
print(f"Training metrics: {metrics}")
print(f"Loss: {metrics['loss']:.4f}")
print(f"Pairs trained: {metrics['pairs']}")
# Optionally refresh reference policy after N updates
# trainer.refresh_reference()
if __name__ == "__main__":
example_usage()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment