|
|
@@ -1554,11 +1554,11 @@
|
|
|
"if not finetuned_model_path.exists():\n",
|
|
|
"\n",
|
|
|
" # Try finding the model checkpoint locally:\n",
|
|
|
- " relative_path = Path(\"..\") / \"ch07\" / finetuned_model_path\n",
|
|
|
+ " relative_path = Path(\"..\") / \"01_main-chapter-code\" / finetuned_model_path\n",
|
|
|
" if relative_path.exists():\n",
|
|
|
" shutil.copy(relative_path, \".\")\n",
|
|
|
"\n",
|
|
|
- " # If this notebook is run on Google Colab, get it from a Googe Drive folder\n",
|
|
|
+ " # If this notebook is run on Google Colab, get it from a Google Drive folder\n",
|
|
|
" elif \"COLAB_GPU\" in os.environ or \"COLAB_TPU_ADDR\" in os.environ:\n",
|
|
|
" from google.colab import drive\n",
|
|
|
" drive.mount(\"/content/drive\")\n",
|
|
|
@@ -1875,10 +1875,10 @@
|
|
|
"- Keeping this in mind, let's go through some of the steps (we will calculate the `logprobs` using a separate function later)\n",
|
|
|
"- Let's start with the lines\n",
|
|
|
"\n",
|
|
|
- "```python\n",
|
|
|
- "model_logratios = model_chosen_logprobs - model_rejected_logprobs\n",
|
|
|
- "reference_logratios = reference_chosen_logprobs - reference_rejected_logprobs\n",
|
|
|
- "```\n",
|
|
|
+ " ```python\n",
|
|
|
+ " model_logratios = model_chosen_logprobs - model_rejected_logprobs\n",
|
|
|
+ " reference_logratios = reference_chosen_logprobs - reference_rejected_logprobs\n",
|
|
|
+ " ```\n",
|
|
|
"\n",
|
|
|
"- These lines above calculate the difference in log probabilities (logits) for the chosen and rejected samples for both the policy model and the reference model (this is due to $\\log\\left(\\frac{a}{b}\\right) = \\log a - \\log b$):\n",
|
|
|
"\n",
|
|
|
@@ -1936,7 +1936,7 @@
|
|
|
"\n",
|
|
|
" Args:\n",
|
|
|
" logits: Tensor of shape (batch_size, num_tokens, vocab_size)\n",
|
|
|
- " labels: Tensor of shape (batch_size, snum_tokens)\n",
|
|
|
+ " labels: Tensor of shape (batch_size, num_tokens)\n",
|
|
|
" selection_mask: Tensor for shape (batch_size, num_tokens)\n",
|
|
|
"\n",
|
|
|
" Returns:\n",
|
|
|
@@ -1981,7 +1981,7 @@
|
|
|
"id": "cf6a71ac-3fcc-44a4-befc-1c56bbd378d7"
|
|
|
},
|
|
|
"source": [
|
|
|
- "- Note that this function above might look a bit intimidating at first due to the `torch.gather` function, but it's pretty similar to what happens under the hood in PyTorch's `cross_entropy` function\n",
|
|
|
+ "- Note that this function above might look a bit intimidating at first due to the `torch.gather` function, but it's pretty similar to what happens under the hood in PyTorch's `cross_entropy` function\n",
|
|
|
"- For example, consider the following example:"
|
|
|
]
|
|
|
},
|
|
|
@@ -2264,7 +2264,7 @@
|
|
|
"id": "852e4c09-d285-44d5-be12-d29769950cb6"
|
|
|
},
|
|
|
"source": [
|
|
|
- "- Why a specified `num_batches`? That's purely for efficiency reasons (because calculating the loss on the whole dataset each time would slow down the training significantly"
|
|
|
+ "- Why a specified `num_batches`? That's purely for efficiency reasons (because calculating the loss on the whole dataset each time would slow down the training significantly)"
|
|
|
]
|
|
|
},
|
|
|
{
|
|
|
@@ -2354,7 +2354,7 @@
|
|
|
"source": [
|
|
|
"- After setting up the DPO loss functions in the previous section, we can now finally train the model\n",
|
|
|
"- Note that this training function is the same one we used for pretraining and instruction finetuning, with minor differences:\n",
|
|
|
- " - we swap the cross entropy loss with our new DPO loss function\n",
|
|
|
+ " - we swap the cross-entropy loss with our new DPO loss function\n",
|
|
|
" - we also track the rewards and reward margins, which are commonly used in RLHF and DPO contexts to track the training progress\n"
|
|
|
]
|
|
|
},
|
|
|
@@ -2394,7 +2394,7 @@
|
|
|
"\n",
|
|
|
" for batch_idx, batch in enumerate(train_loader):\n",
|
|
|
"\n",
|
|
|
- " optimizer.zero_grad() # Reset loss gradients from previous epoch\n",
|
|
|
+ " optimizer.zero_grad() # Reset loss gradients from previous batch iteration\n",
|
|
|
"\n",
|
|
|
" loss, chosen_rewards, rejected_rewards = compute_dpo_loss_batch(\n",
|
|
|
" batch=batch,\n",
|
|
|
@@ -3088,7 +3088,7 @@
|
|
|
"name": "python",
|
|
|
"nbconvert_exporter": "python",
|
|
|
"pygments_lexer": "ipython3",
|
|
|
- "version": "3.11.4"
|
|
|
+ "version": "3.11.9"
|
|
|
}
|
|
|
},
|
|
|
"nbformat": 4,
|