Created
October 19, 2025 13:06
-
-
Save sdan/fece057b42e54a746859d72a96092158 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| """ | |
| Minimal DPO (Direct Preference Optimization) Implementation for Ranking | |
| Across a distribution of candidates, when a user picks one, we learn to rank that one more often | |
| Key components: | |
| 1. PolicyState: Input features for each candidate item | |
| 2. PolicyHead: Neural network that scores candidates | |
| 3. PreferenceSample: Records user preference (A preferred over B) | |
| 4. DPOTrainer: Implements the DPO loss to update the policy | |
| Based on: https://arxiv.org/abs/2305.18290 | |
| """ | |
| from __future__ import annotations | |
| from dataclasses import dataclass | |
| from typing import Dict, List, Optional | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| @dataclass | |
| class PolicyState: | |
| """Input features for ranking candidates. | |
| In this example: | |
| - z_img: image embedding (single vector) | |
| - z_locs: location embeddings for each candidate (N vectors) | |
| - sims: similarity scores between image and each location | |
| - base_probs: initial ranking scores from a base model | |
| """ | |
| z_img: torch.Tensor # shape: (embedding_dim,) | |
| z_locs: torch.Tensor # shape: (num_candidates, embedding_dim) | |
| sims: torch.Tensor # shape: (num_candidates,) | |
| base_probs: torch.Tensor # shape: (num_candidates,) | |
| class PolicyHead(nn.Module): | |
| """Neural network that scores each candidate item. | |
| Input: concatenation of [image_embedding, location_embedding, similarity, base_prob] | |
| Output: scalar score for each candidate | |
| """ | |
| def __init__( | |
| self, | |
| embedding_dim: int = 512, | |
| hidden_dim: int = 256, | |
| ) -> None: | |
| super().__init__() | |
| # Input = image_emb + location_emb + similarity + base_prob | |
| input_dim = embedding_dim * 2 + 2 | |
| self.rank_mlp = nn.Sequential( | |
| nn.Linear(input_dim, hidden_dim), | |
| nn.ReLU(), | |
| nn.Linear(hidden_dim, hidden_dim), | |
| nn.ReLU(), | |
| nn.Linear(hidden_dim, 1), # Output single score per candidate | |
| ) | |
| def forward(self, state: PolicyState) -> torch.Tensor: | |
| """Score all candidates given the policy state. | |
| Returns: shape (num_candidates,) - one score per candidate | |
| """ | |
| z_img = state.z_img | |
| if z_img.ndim == 1: | |
| z_img = z_img.unsqueeze(0) # (1, embedding_dim) | |
| z_img_expanded = z_img.expand_as(state.z_locs) # (num_candidates, embedding_dim) | |
| # Concatenate all features for each candidate | |
| rank_input = torch.cat( | |
| [ | |
| z_img_expanded, # image embedding repeated | |
| state.z_locs, # location embedding | |
| state.sims.unsqueeze(1), # similarity score | |
| state.base_probs.unsqueeze(1), # base model probability | |
| ], | |
| dim=1, | |
| ) # shape: (num_candidates, input_dim) | |
| return self.rank_mlp(rank_input).squeeze(-1) # (num_candidates,) | |
| @dataclass | |
| class PreferenceSample: | |
| """Records a human preference: some items were preferred over others. | |
| Example: User clicked on candidate 2 when shown [0,1,2,3,4] ranked by policy. | |
| This means candidate 2 is preferred over candidates 3 and 4 (ranked below it). | |
| pos_indices: indices of preferred candidates | |
| neg_indices: indices of dispreferred candidates | |
| reference_logits: scores from the frozen reference policy (for DPO loss) | |
| """ | |
| state: PolicyState | |
| pos_indices: torch.Tensor # shape: (num_pairs,) - preferred items | |
| neg_indices: torch.Tensor # shape: (num_pairs,) - dispreferred items | |
| reference_logits: torch.Tensor # shape: (num_candidates,) - frozen reference scores | |
| class DPOTrainer: | |
| """Direct Preference Optimization trainer for ranking policy. | |
| DPO Key Idea: | |
| - Maintain two models: policy (trainable) and reference (frozen) | |
| - When user prefers A over B: | |
| * Increase policy score gap: score(A) - score(B) | |
| * But penalize deviating too far from reference policy | |
| - This prevents overfitting to individual preferences | |
| Math: | |
| loss = -log_sigmoid(beta * [(policy(A) - policy(B)) - (ref(A) - ref(B))]) | |
| Args: | |
| beta: Controls how much we allow policy to deviate from reference | |
| Higher beta = stay closer to reference (more conservative) | |
| """ | |
| def __init__( | |
| self, | |
| policy_head: PolicyHead, | |
| reference_head: PolicyHead, | |
| optimizer: torch.optim.Optimizer, | |
| device: torch.device, | |
| beta: float = 1.0, | |
| grad_clip: Optional[float] = None, | |
| ) -> None: | |
| self.policy_head = policy_head | |
| self.reference_head = reference_head | |
| self.optimizer = optimizer | |
| self.device = device | |
| self.beta = beta | |
| self.grad_clip = grad_clip | |
| # Move models to device | |
| self.policy_head.to(device) | |
| self.reference_head.to(device) | |
| # Freeze reference policy - it should never be updated during DPO | |
| for param in self.reference_head.parameters(): | |
| param.requires_grad_(False) | |
| self.reference_head.eval() | |
| # Buffer stores preference samples until we do a training update | |
| self.buffer: List[PreferenceSample] = [] | |
| def store_preference(self, sample: PreferenceSample) -> None: | |
| """Add a preference sample to the buffer.""" | |
| self.buffer.append(sample) | |
| def update(self) -> Dict[str, float]: | |
| """Perform one DPO training step on all buffered preferences. | |
| For each preference pair (preferred_item, dispreferred_item): | |
| 1. Compute current policy scores for both items | |
| 2. Compute reference policy scores for both items (frozen) | |
| 3. Calculate how much policy changed relative to reference: | |
| delta = [policy(preferred) - policy(dispreferred)] - [ref(preferred) - ref(dispreferred)] | |
| 4. DPO loss encourages delta > 0 (prefer preferred item MORE than reference does) | |
| loss = -log_sigmoid(beta * delta) | |
| Returns: | |
| Dictionary with training metrics (loss, grad_norm, etc.) | |
| """ | |
| if not self.buffer: | |
| return {"buffer_size": 0, "pairs": 0, "loss": 0.0, "grad_norm": 0.0} | |
| buffer_size = len(self.buffer) | |
| self.policy_head.train() | |
| self.optimizer.zero_grad(set_to_none=True) | |
| losses: List[torch.Tensor] = [] | |
| total_pairs = 0 | |
| # Process each preference sample in the buffer | |
| for sample in self.buffer: | |
| pos_indices = sample.pos_indices.to(self.device) | |
| neg_indices = sample.neg_indices.to(self.device) | |
| if pos_indices.numel() == 0 or neg_indices.numel() == 0: | |
| continue | |
| # Get current policy scores | |
| state = PolicyState( | |
| z_img=sample.state.z_img.to(self.device), | |
| z_locs=sample.state.z_locs.to(self.device), | |
| sims=sample.state.sims.to(self.device), | |
| base_probs=sample.state.base_probs.to(self.device), | |
| ) | |
| current_logits = self.policy_head(state) # Current policy scores | |
| # Get reference policy scores (frozen) | |
| reference_logits = sample.reference_logits.to(self.device) | |
| # Extract scores for preferred and dispreferred items | |
| pos_scores = current_logits.gather(0, pos_indices) # Policy scores for preferred items | |
| neg_scores = current_logits.gather(0, neg_indices) # Policy scores for dispreferred items | |
| ref_pos = reference_logits.gather(0, pos_indices) # Reference scores for preferred items | |
| ref_neg = reference_logits.gather(0, neg_indices) # Reference scores for dispreferred items | |
| # DPO Loss: encourage policy to prefer preferred items MORE than reference does | |
| # delta > 0 means policy ranks preferred item higher than reference does | |
| delta = (pos_scores - neg_scores) - (ref_pos - ref_neg) | |
| loss = -F.logsigmoid(self.beta * delta) | |
| losses.append(loss) | |
| total_pairs += int(loss.numel()) | |
| if not losses: | |
| self.buffer.clear() | |
| return {"buffer_size": buffer_size, "pairs": 0, "loss": 0.0, "grad_norm": 0.0} | |
| # Backprop and optimize | |
| total_loss = torch.cat(losses).mean() | |
| total_loss.backward() | |
| # Optional gradient clipping | |
| if self.grad_clip is not None: | |
| grad_norm_tensor = torch.nn.utils.clip_grad_norm_( | |
| self.policy_head.parameters(), self.grad_clip | |
| ) | |
| grad_norm = float(grad_norm_tensor.detach().cpu().item()) | |
| else: | |
| # Compute gradient norm for logging | |
| grad_sq_sum = 0.0 | |
| for parameter in self.policy_head.parameters(): | |
| if parameter.grad is None: | |
| continue | |
| grad_sq_sum += parameter.grad.detach().pow(2).sum().item() | |
| grad_norm = grad_sq_sum**0.5 if grad_sq_sum > 0.0 else 0.0 | |
| self.optimizer.step() | |
| self.optimizer.zero_grad(set_to_none=True) | |
| self.policy_head.eval() | |
| self.buffer.clear() | |
| return { | |
| "buffer_size": buffer_size, | |
| "pairs": total_pairs, | |
| "loss": float(total_loss.detach().cpu().item()), | |
| "grad_norm": float(grad_norm), | |
| } | |
| def refresh_reference(self) -> None: | |
| """Periodically sync reference policy with current policy. | |
| In practice, you might do this every N updates to keep the reference | |
| from getting too stale. This prevents the policy from drifting too far | |
| from where it started. | |
| """ | |
| self.reference_head.load_state_dict(self.policy_head.state_dict()) | |
| self.reference_head.to(self.device) | |
| self.reference_head.eval() | |
| for param in self.reference_head.parameters(): | |
| param.requires_grad_(False) | |
| # ============================================================================ | |
| # Example Usage | |
| # ============================================================================ | |
| def example_usage(): | |
| """Demonstrate how to use the DPO trainer for ranking. | |
| Scenario: We have 10 candidate locations for an image. | |
| User clicks on candidate 2, which is ranked 5th by our current policy. | |
| This creates preference pairs: (2 > 5), (2 > 6), (2 > 7), (2 > 8), (2 > 9) | |
| meaning candidate 2 should be ranked higher than all items below it. | |
| """ | |
| # Initialize models | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| policy = PolicyHead(embedding_dim=512, hidden_dim=256) | |
| reference = PolicyHead(embedding_dim=512, hidden_dim=256) | |
| reference.load_state_dict(policy.state_dict()) # Start with same weights | |
| optimizer = torch.optim.Adam(policy.parameters(), lr=1e-4) | |
| trainer = DPOTrainer( | |
| policy_head=policy, | |
| reference_head=reference, | |
| optimizer=optimizer, | |
| device=device, | |
| beta=1.0, # Controls KL penalty | |
| grad_clip=1.0, | |
| ) | |
| # Create dummy data for one ranking scenario | |
| num_candidates = 10 | |
| embedding_dim = 512 | |
| state = PolicyState( | |
| z_img=torch.randn(embedding_dim), # Image embedding | |
| z_locs=torch.randn(num_candidates, embedding_dim), # Location embeddings | |
| sims=torch.randn(num_candidates), # Similarities | |
| base_probs=torch.softmax(torch.randn(num_candidates), dim=0), # Base probabilities | |
| ) | |
| # Get reference scores (frozen) | |
| with torch.no_grad(): | |
| state_on_device = PolicyState( | |
| z_img=state.z_img.to(device), | |
| z_locs=state.z_locs.to(device), | |
| sims=state.sims.to(device), | |
| base_probs=state.base_probs.to(device), | |
| ) | |
| reference_logits = reference(state_on_device).cpu() | |
| # Simulate user preference: clicked candidate 2 when it was ranked 5th | |
| # This means candidate 2 is preferred over candidates ranked 6-10 | |
| clicked_index = 2 | |
| current_ranking = torch.tensor([0, 1, 3, 4, 2, 5, 6, 7, 8, 9]) # Current rank order | |
| position = (current_ranking == clicked_index).nonzero(as_tuple=True)[0].item() | |
| # Create preference pairs: clicked item vs all items ranked below it | |
| dispreferred_indices = current_ranking[position + 1 :] # Items ranked worse | |
| preferred_indices = torch.full_like(dispreferred_indices, fill_value=clicked_index) | |
| sample = PreferenceSample( | |
| state=state, | |
| pos_indices=preferred_indices, # [2, 2, 2, 2, 2] | |
| neg_indices=dispreferred_indices, # [5, 6, 7, 8, 9] | |
| reference_logits=reference_logits, | |
| ) | |
| # Store preference and update | |
| trainer.store_preference(sample) | |
| metrics = trainer.update() | |
| print(f"Training metrics: {metrics}") | |
| print(f"Loss: {metrics['loss']:.4f}") | |
| print(f"Pairs trained: {metrics['pairs']}") | |
| # Optionally refresh reference policy after N updates | |
| # trainer.refresh_reference() | |
| if __name__ == "__main__": | |
| example_usage() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment