浏览代码

add check for small validation sets

rasbt 1 年之前
父节点
当前提交
861a2788f3
共有 1 个文件被更改,包括 50 次插入25 次删除
  1. 50 25
      ch05/01_main-chapter-code/ch05.ipynb

+ 50 - 25
ch05/01_main-chapter-code/ch05_draft.ipynb → ch05/01_main-chapter-code/ch05.ipynb

@@ -20,8 +20,9 @@
      "name": "stdout",
      "output_type": "stream",
      "text": [
-      "numpy version: 1.26.0\n",
-      "matplotlib version: 3.8.2\n",
+      "numpy version: 1.25.2\n",
+      "matplotlib version: 3.7.2\n",
+      "numpy version: 1.25.2\n",
       "tiktoken version: 0.5.1\n",
       "torch version: 2.2.1\n"
      ]
@@ -916,8 +917,11 @@
     }
    ],
    "source": [
-    "print(\"Characters:\", len(text_data))\n",
-    "print(\"Tokens:\", len(tokenizer.encode(text_data)))"
+    "total_char = len(text_data)\n",
+    "total_tokens = len(tokenizer.encode(text_data))\n",
+    "\n",
+    "print(\"Characters:\", total_char)\n",
+    "print(\"Tokens:\", total_tokens)"
    ]
   },
   {
@@ -962,7 +966,6 @@
     "train_ratio = 0.90\n",
     "split_idx = int(train_ratio * len(text_data))\n",
     "\n",
-    "\n",
     "torch.manual_seed(123)\n",
     "\n",
     "train_loader = create_dataloader_v1(\n",
@@ -984,6 +987,26 @@
     ")"
    ]
   },
+  {
+   "cell_type": "code",
+   "execution_count": 23,
+   "id": "f37b3eb0-854e-4895-9898-fa7d1e67566e",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "# Sanity check\n",
+    "\n",
+    "if total_tokens * (train_ratio) < GPT_CONFIG_124M[\"ctx_len\"]:\n",
+    "    print(\"Not enough tokens for the training loader. \"\n",
+    "          \"Try to lower the `GPT_CONFIG_124M['ctx_len']` or \"\n",
+    "          \"increase the `training_ratio`\")\n",
+    "\n",
+    "if total_tokens * (1-train_ratio) < GPT_CONFIG_124M[\"ctx_len\"]:\n",
+    "    print(\"Not enough tokens for the validation loader. \"\n",
+    "          \"Try to lower the `GPT_CONFIG_124M['ctx_len']` or \"\n",
+    "          \"decrease the `training_ratio`\")"
+   ]
+  },
   {
    "cell_type": "markdown",
    "id": "e7ac3296-a4d1-4303-9ac5-376518960c33",
@@ -1003,7 +1026,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 23,
+   "execution_count": 24,
    "id": "ca0116d0-d229-472c-9fbf-ebc229331c3e",
    "metadata": {},
    "outputs": [
@@ -1047,7 +1070,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 24,
+   "execution_count": 25,
    "id": "eb860488-5453-41d7-9870-23b723f742a0",
    "metadata": {
     "colab": {
@@ -1092,7 +1115,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 25,
+   "execution_count": 26,
    "id": "7b9de31e-4096-47b3-976d-b6d2fdce04bc",
    "metadata": {
     "id": "7b9de31e-4096-47b3-976d-b6d2fdce04bc"
@@ -1133,7 +1156,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 26,
+   "execution_count": 27,
    "id": "56f5b0c9-1065-4d67-98b9-010e42fc1e2a",
    "metadata": {},
    "outputs": [
@@ -1195,7 +1218,7 @@
    },
    "outputs": [],
    "source": [
-    "def train_model_simple(model, train_loader, val_loader, optimizer, device, n_epochs,\n",
+    "def train_model_simple(model, train_loader, val_loader, optimizer, device, num_epochs,\n",
     "                       eval_freq, eval_iter, start_context):\n",
     "    # Initialize lists to track losses and tokens seen\n",
     "    train_losses, val_losses, track_tokens_seen = [], [], []\n",
@@ -1203,7 +1226,7 @@
     "    global_step = -1\n",
     "\n",
     "    # Main training loop\n",
-    "    for epoch in range(n_epochs):\n",
+    "    for epoch in range(num_epochs):\n",
     "        model.train()  # Set model to training mode\n",
     "        \n",
     "        for input_batch, target_batch in train_loader:\n",
@@ -1246,8 +1269,10 @@
     "    context_size = model.pos_emb.weight.shape[0]\n",
     "    encoded = text_to_token_ids(start_context, tokenizer).to(device)\n",
     "    with torch.no_grad():\n",
-    "        token_ids = generate_text_simple(model=model, idx=encoded,\n",
-    "                                   max_new_tokens=50, context_size=context_size)\n",
+    "        token_ids = generate_text_simple(\n",
+    "            model=model, idx=encoded,\n",
+    "            max_new_tokens=50, context_size=context_size\n",
+    "        )\n",
     "        decoded_text = token_ids_to_text(token_ids, tokenizer)\n",
     "        print(decoded_text.replace(\"\\n\", \" \"))  # Compact print format\n",
     "    model.train()"
@@ -1314,10 +1339,10 @@
     "model.to(device)\n",
     "optimizer = torch.optim.AdamW(model.parameters(), lr=5e-4, weight_decay=0.1)\n",
     "\n",
-    "n_epochs = 10\n",
+    "num_epochs = 10\n",
     "train_losses, val_losses, tokens_seen = train_model_simple(\n",
     "    model, train_loader, val_loader, optimizer, device,\n",
-    "    n_epochs=n_epochs, eval_freq=5, eval_iter=1,\n",
+    "    num_epochs=num_epochs, eval_freq=5, eval_iter=1,\n",
     "    start_context=\"Every effort moves you\",\n",
     ")"
    ]
@@ -1368,7 +1393,7 @@
     "    plt.show()\n",
     "\n",
     "\n",
-    "epochs_tensor = torch.linspace(0, n_epochs, len(train_losses))\n",
+    "epochs_tensor = torch.linspace(0, num_epochs, len(train_losses))\n",
     "plot_losses(epochs_tensor, tokens_seen, train_losses, val_losses)"
    ]
   },
@@ -1959,7 +1984,7 @@
    "metadata": {},
    "outputs": [],
    "source": [
-    "torch.save(model.state_dict(), 'model.pth')"
+    "torch.save(model.state_dict(), \"model.pth\")"
    ]
   },
   {
@@ -1978,7 +2003,7 @@
    "outputs": [],
    "source": [
     "model = GPTModel(GPT_CONFIG_124M)\n",
-    "model.load_state_dict(torch.load('model.pth'))\n",
+    "model.load_state_dict(torch.load(\"model.pth\"))\n",
     "model.eval();"
    ]
   },
@@ -1999,10 +2024,10 @@
    "outputs": [],
    "source": [
     "torch.save({\n",
-    "    'model_state_dict': model.state_dict(),\n",
-    "    'optimizer_state_dict': optimizer.state_dict(),\n",
+    "    \"model_state_dict\": model.state_dict(),\n",
+    "    \"optimizer_state_dict\": optimizer.state_dict(),\n",
     "    }, \n",
-    "    'model_and_optimizer.pth'\n",
+    "    \"model_and_optimizer.pth\"\n",
     ")"
    ]
   },
@@ -2013,9 +2038,9 @@
    "metadata": {},
    "outputs": [],
    "source": [
-    "checkpoint = torch.load('model_and_optimizer.pth')\n",
-    "model.load_state_dict(checkpoint['model_state_dict'])\n",
-    "optimizer.load_state_dict(checkpoint['optimizer_state_dict'])\n",
+    "checkpoint = torch.load(\"model_and_optimizer.pth\")\n",
+    "model.load_state_dict(checkpoint[\"model_state_dict\"])\n",
+    "optimizer.load_state_dict(checkpoint[\"optimizer_state_dict\"])\n",
     "model.train();"
    ]
   },
@@ -2474,7 +2499,7 @@
    "name": "python",
    "nbconvert_exporter": "python",
    "pygments_lexer": "ipython3",
-   "version": "3.10.12"
+   "version": "3.11.4"
   }
  },
  "nbformat": 4,