|
@@ -1844,13 +1844,32 @@
|
|
|
"source": [
|
|
"source": [
|
|
|
"new_logits = torch.where(\n",
|
|
"new_logits = torch.where(\n",
|
|
|
" condition=next_token_logits < top_logits[-1],\n",
|
|
" condition=next_token_logits < top_logits[-1],\n",
|
|
|
- " input=torch.tensor(float('-inf')), \n",
|
|
|
|
|
|
|
+ " input=torch.tensor(float(\"-inf\")), \n",
|
|
|
" other=next_token_logits\n",
|
|
" other=next_token_logits\n",
|
|
|
")\n",
|
|
")\n",
|
|
|
"\n",
|
|
"\n",
|
|
|
"print(new_logits)"
|
|
"print(new_logits)"
|
|
|
]
|
|
]
|
|
|
},
|
|
},
|
|
|
|
|
+ {
|
|
|
|
|
+ "cell_type": "markdown",
|
|
|
|
|
+ "id": "dfa6fa49-6e99-459d-a517-d7d0f51c4f00",
|
|
|
|
|
+ "metadata": {},
|
|
|
|
|
+ "source": [
|
|
|
|
|
+ "> NOTE: \n",
|
|
|
|
|
+ ">\n",
|
|
|
|
|
+ "> An alternative, slightly more efficient implementation of the previous code cell is the following:\n",
|
|
|
|
|
+ ">\n",
|
|
|
|
|
+ "> ```python\n",
|
|
|
|
|
+ "> new_logits = torch.full_like( # create tensor containing -inf values\n",
|
|
|
|
|
+ "> next_token_logits, -torch.inf\n",
|
|
|
|
|
+ ">) \n",
|
|
|
|
|
+ "> new_logits[top_pos] = next_token_logits[top_pos] # copy top k values into the -inf tensor\n",
|
|
|
|
|
+ "> ```\n",
|
|
|
|
|
+ "> <br>\n",
|
|
|
|
|
+ "> For more details, see https://github.com/rasbt/LLMs-from-scratch/discussions/326\n"
|
|
|
|
|
+ ]
|
|
|
|
|
+ },
|
|
|
{
|
|
{
|
|
|
"cell_type": "code",
|
|
"cell_type": "code",
|
|
|
"execution_count": 39,
|
|
"execution_count": 39,
|
|
@@ -1908,7 +1927,7 @@
|
|
|
" # Keep only top_k values\n",
|
|
" # Keep only top_k values\n",
|
|
|
" top_logits, _ = torch.topk(logits, top_k)\n",
|
|
" top_logits, _ = torch.topk(logits, top_k)\n",
|
|
|
" min_val = top_logits[:, -1]\n",
|
|
" min_val = top_logits[:, -1]\n",
|
|
|
- " logits = torch.where(logits < min_val, torch.tensor(float('-inf')).to(logits.device), logits)\n",
|
|
|
|
|
|
|
+ " logits = torch.where(logits < min_val, torch.tensor(float(\"-inf\")).to(logits.device), logits)\n",
|
|
|
"\n",
|
|
"\n",
|
|
|
" # New: Apply temperature scaling\n",
|
|
" # New: Apply temperature scaling\n",
|
|
|
" if temperature > 0.0:\n",
|
|
" if temperature > 0.0:\n",
|
|
@@ -2485,7 +2504,7 @@
|
|
|
"name": "python",
|
|
"name": "python",
|
|
|
"nbconvert_exporter": "python",
|
|
"nbconvert_exporter": "python",
|
|
|
"pygments_lexer": "ipython3",
|
|
"pygments_lexer": "ipython3",
|
|
|
- "version": "3.11.4"
|
|
|
|
|
|
|
+ "version": "3.12.2"
|
|
|
}
|
|
}
|
|
|
},
|
|
},
|
|
|
"nbformat": 4,
|
|
"nbformat": 4,
|