Эх сурвалжийг харах

Minor DPO fixes (#617)

* minor dpo fixes

* Update dpo-from-scratch.ipynb

metadata diff
casinca 7 сар өмнө
parent
commit
1b242d01a5

+ 4 - 5
ch07/04_preference-tuning-with-dpo/dpo-from-scratch.ipynb

@@ -1876,7 +1876,6 @@
     "        reference_chosen_logprobs: Log probabilities of the reference model for the chosen responses. Shape: (batch_size,)\n",
     "        reference_rejected_logprobs: Log probabilities of the reference model for the rejected responses. Shape: (batch_size,)\n",
     "        beta: Temperature parameter for the DPO loss; typically something in the range of 0.1 to 0.5. We ignore the reference model as beta -> 0.\n",
-    "        label_smoothing: conservativeness for DPO loss.\n",
     "\n",
     "    Returns:\n",
     "        A tuple of three tensors: (loss, chosen_rewards, rejected_rewards).\n",
@@ -1998,7 +1997,7 @@
     "        selected_log_probs = selected_log_probs * mask\n",
     "\n",
     "        # Calculate the average log probability excluding padding tokens\n",
-    "        # This averages over the tokens, so the shape is (batch_size, num_tokens)\n",
+    "        # This averages over the tokens, so the shape is (batch_size,)\n",
     "        avg_log_prob = selected_log_probs.sum(-1) / mask.sum(-1)\n",
     "\n",
     "        return avg_log_prob\n",
@@ -2439,7 +2438,7 @@
     "    for epoch in range(num_epochs):\n",
     "        policy_model.train()  # Set model to training mode\n",
     "\n",
-    "        for batch_idx, batch in enumerate(train_loader):\n",
+    "        for batch in train_loader:\n",
     "\n",
     "            optimizer.zero_grad()  # Reset loss gradients from previous batch iteration\n",
     "\n",
@@ -3113,7 +3112,7 @@
    "provenance": []
   },
   "kernelspec": {
-   "display_name": "Python 3 (ipykernel)",
+   "display_name": ".venv",
    "language": "python",
    "name": "python3"
   },
@@ -3127,7 +3126,7 @@
    "name": "python",
    "nbconvert_exporter": "python",
    "pygments_lexer": "ipython3",
-   "version": "3.10.16"
+   "version": "3.12.6"
   }
  },
  "nbformat": 4,