Browse Source

topk comment

Sebastian Raschka 1 year ago
parent
commit
c2117ca073
1 changed files with 22 additions and 3 deletions
  1. 22 3
      ch05/01_main-chapter-code/ch05.ipynb

+ 22 - 3
ch05/01_main-chapter-code/ch05.ipynb

@@ -1844,13 +1844,32 @@
    "source": [
    "source": [
     "new_logits = torch.where(\n",
     "new_logits = torch.where(\n",
     "    condition=next_token_logits < top_logits[-1],\n",
     "    condition=next_token_logits < top_logits[-1],\n",
-    "    input=torch.tensor(float('-inf')), \n",
+    "    input=torch.tensor(float(\"-inf\")), \n",
     "    other=next_token_logits\n",
     "    other=next_token_logits\n",
     ")\n",
     ")\n",
     "\n",
     "\n",
     "print(new_logits)"
     "print(new_logits)"
    ]
    ]
   },
   },
+  {
+   "cell_type": "markdown",
+   "id": "dfa6fa49-6e99-459d-a517-d7d0f51c4f00",
+   "metadata": {},
+   "source": [
+    "> NOTE:  \n",
+    ">\n",
+    ">  An alternative, slightly more efficient implementation of the previous code cell is the following:\n",
+    ">\n",
+    "> ```python\n",
+    "> new_logits = torch.full_like( # create tensor containing -inf values\n",
+    ">    next_token_logits, -torch.inf\n",
+    ">)   \n",
+    "> new_logits[top_pos] = next_token_logits[top_pos] # copy top k values into the -inf tensor\n",
+    "> ```\n",
+    "> <br>\n",
+    "> For more details, see https://github.com/rasbt/LLMs-from-scratch/discussions/326\n"
+   ]
+  },
   {
   {
    "cell_type": "code",
    "cell_type": "code",
    "execution_count": 39,
    "execution_count": 39,
@@ -1908,7 +1927,7 @@
     "            # Keep only top_k values\n",
     "            # Keep only top_k values\n",
     "            top_logits, _ = torch.topk(logits, top_k)\n",
     "            top_logits, _ = torch.topk(logits, top_k)\n",
     "            min_val = top_logits[:, -1]\n",
     "            min_val = top_logits[:, -1]\n",
-    "            logits = torch.where(logits < min_val, torch.tensor(float('-inf')).to(logits.device), logits)\n",
+    "            logits = torch.where(logits < min_val, torch.tensor(float(\"-inf\")).to(logits.device), logits)\n",
     "\n",
     "\n",
     "        # New: Apply temperature scaling\n",
     "        # New: Apply temperature scaling\n",
     "        if temperature > 0.0:\n",
     "        if temperature > 0.0:\n",
@@ -2485,7 +2504,7 @@
    "name": "python",
    "name": "python",
    "nbconvert_exporter": "python",
    "nbconvert_exporter": "python",
    "pygments_lexer": "ipython3",
    "pygments_lexer": "ipython3",
-   "version": "3.11.4"
+   "version": "3.12.2"
   }
   }
  },
  },
  "nbformat": 4,
  "nbformat": 4,