Sfoglia il codice sorgente

Ch06 classifier function asserts (#703)

Sebastian Raschka 5 mesi fa
parent
commit
4014bdd520
1 ha cambiato i file con 11 aggiunte e 1 eliminazioni
  1. 11 1
      ch06/01_main-chapter-code/ch06.ipynb

+ 11 - 1
ch06/01_main-chapter-code/ch06.ipynb

@@ -2353,7 +2353,17 @@
     "\n",
     "    # Truncate sequences if they too long\n",
     "    input_ids = input_ids[:min(max_length, supported_context_length)]\n",
-    "\n",
+    "    assert max_length is not None, (\n",
+    "        \"max_length must be specified. If you want to use the full model context, \"\n",
+    "        \"pass max_length=model.pos_emb.weight.shape[0].\"\n",
+    "    )\n",
+    "    assert max_length <= supported_context_length, (\n",
+    "        f\"max_length ({max_length}) exceeds model's supported context length ({supported_context_length}).\"\n",
+    "    )    \n",
+    "    # Alternatively, a more robust version is the following one, which handles the max_length=None case better\n",
+    "    # max_len = min(max_length,supported_context_length) if max_length else supported_context_length\n",
+    "    # input_ids = input_ids[:max_len]\n",
+    "    \n",
     "    # Pad sequences to the longest sequence\n",
     "    input_ids += [pad_token_id] * (max_length - len(input_ids))\n",
     "    input_tensor = torch.tensor(input_ids, device=device).unsqueeze(0) # add batch dimension\n",