Răsfoiți Sursa

simplify calc_loss_loader

rasbt 1 an în urmă
părinte
comite
3cb5a52a1b

+ 2 - 3
appendix-D/01_main-chapter-code/previous_chapters.py

@@ -256,17 +256,16 @@ def calc_loss_batch(input_batch, target_batch, model, device):
 
 
 def calc_loss_loader(data_loader, model, device, num_batches=None):
-    total_loss, batches_seen = 0., 0.
+    total_loss = 0.
     if num_batches is None:
         num_batches = len(data_loader)
     for i, (input_batch, target_batch) in enumerate(data_loader):
         if i < num_batches:
             loss = calc_loss_batch(input_batch, target_batch, model, device)
             total_loss += loss.item()
-            batches_seen += 1
         else:
             break
-    return total_loss / batches_seen
+    return total_loss / num_batches
 
 
 def evaluate_model(model, train_loader, val_loader, device, eval_iter):

+ 21 - 55
ch05/01_main-chapter-code/ch05.ipynb

@@ -764,7 +764,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 16,
+   "execution_count": 23,
    "id": "654fde37-b2a9-4a20-a8d3-0206c056e2ff",
    "metadata": {},
    "outputs": [],
@@ -785,43 +785,6 @@
     "        text_data = file.read()"
    ]
   },
-  {
-   "cell_type": "code",
-   "execution_count": 18,
-   "id": "0959c855-f860-4358-8b98-bc654f047578",
-   "metadata": {},
-   "outputs": [],
-   "source": [
-    "from previous_chapters import create_dataloader_v1\n",
-    "\n",
-    "# Train/validation ratio\n",
-    "train_ratio = 0.90\n",
-    "split_idx = int(train_ratio * len(text_data))\n",
-    "train_data = text_data[:split_idx]\n",
-    "val_data = text_data[split_idx:]\n",
-    "\n",
-    "\n",
-    "torch.manual_seed(123)\n",
-    "\n",
-    "train_loader = create_dataloader_v1(\n",
-    "    train_data,\n",
-    "    batch_size=2,\n",
-    "    max_length=GPT_CONFIG_124M[\"ctx_len\"],\n",
-    "    stride=GPT_CONFIG_124M[\"ctx_len\"],\n",
-    "    drop_last=True,\n",
-    "    shuffle=True\n",
-    ")\n",
-    "\n",
-    "val_loader = create_dataloader_v1(\n",
-    "    val_data,\n",
-    "    batch_size=2,\n",
-    "    max_length=GPT_CONFIG_124M[\"ctx_len\"],\n",
-    "    stride=GPT_CONFIG_124M[\"ctx_len\"],\n",
-    "    drop_last=False,\n",
-    "    shuffle=False\n",
-    ")"
-   ]
-  },
   {
    "cell_type": "markdown",
    "id": "379330f1-80f4-4e34-8724-41d892b04cee",
@@ -832,7 +795,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 19,
+   "execution_count": 24,
    "id": "6kgJbe4ehI4q",
    "metadata": {
     "colab": {
@@ -858,7 +821,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 20,
+   "execution_count": 25,
    "id": "j2XPde_ThM_e",
    "metadata": {
     "colab": {
@@ -884,7 +847,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 21,
+   "execution_count": 27,
    "id": "6b46a952-d50a-4837-af09-4095698f7fd1",
    "metadata": {
     "colab": {
@@ -940,22 +903,24 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 22,
-   "id": "fd0e963b-d282-4c97-b004-8772f4b1bd8f",
+   "execution_count": 35,
+   "id": "0959c855-f860-4358-8b98-bc654f047578",
    "metadata": {},
    "outputs": [],
    "source": [
     "from previous_chapters import create_dataloader_v1\n",
     "\n",
-    "\n",
     "# Train/validation ratio\n",
     "train_ratio = 0.90\n",
     "split_idx = int(train_ratio * len(text_data))\n",
+    "train_data = text_data[:split_idx]\n",
+    "val_data = text_data[split_idx:]\n",
+    "\n",
     "\n",
     "torch.manual_seed(123)\n",
     "\n",
     "train_loader = create_dataloader_v1(\n",
-    "    text_data[:split_idx],\n",
+    "    train_data,\n",
     "    batch_size=2,\n",
     "    max_length=GPT_CONFIG_124M[\"ctx_len\"],\n",
     "    stride=GPT_CONFIG_124M[\"ctx_len\"],\n",
@@ -964,7 +929,7 @@
     ")\n",
     "\n",
     "val_loader = create_dataloader_v1(\n",
-    "    text_data[split_idx:],\n",
+    "    val_data,\n",
     "    batch_size=2,\n",
     "    max_length=GPT_CONFIG_124M[\"ctx_len\"],\n",
     "    stride=GPT_CONFIG_124M[\"ctx_len\"],\n",
@@ -975,7 +940,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 23,
+   "execution_count": 36,
    "id": "f37b3eb0-854e-4895-9898-fa7d1e67566e",
    "metadata": {},
    "outputs": [],
@@ -1012,7 +977,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 24,
+   "execution_count": 37,
    "id": "ca0116d0-d229-472c-9fbf-ebc229331c3e",
    "metadata": {},
    "outputs": [
@@ -1056,7 +1021,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 25,
+   "execution_count": 38,
    "id": "eb860488-5453-41d7-9870-23b723f742a0",
    "metadata": {
     "colab": {
@@ -1101,7 +1066,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 26,
+   "execution_count": 49,
    "id": "7b9de31e-4096-47b3-976d-b6d2fdce04bc",
    "metadata": {
     "id": "7b9de31e-4096-47b3-976d-b6d2fdce04bc"
@@ -1118,17 +1083,16 @@
     "\n",
     "\n",
     "def calc_loss_loader(data_loader, model, device, num_batches=None):\n",
-    "    total_loss, batches_seen = 0., 0.\n",
+    "    total_loss = 0.\n",
     "    if num_batches is None:\n",
     "        num_batches = len(data_loader)\n",
     "    for i, (input_batch, target_batch) in enumerate(data_loader):\n",
     "        if i < num_batches:\n",
     "            loss = calc_loss_batch(input_batch, target_batch, model, device)\n",
     "            total_loss += loss.item()\n",
-    "            batches_seen += 1\n",
     "        else:\n",
     "            break\n",
-    "    return total_loss / batches_seen"
+    "    return total_loss / num_batches"
    ]
   },
   {
@@ -1142,7 +1106,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 27,
+   "execution_count": 55,
    "id": "56f5b0c9-1065-4d67-98b9-010e42fc1e2a",
    "metadata": {},
    "outputs": [
@@ -1150,7 +1114,7 @@
      "name": "stdout",
      "output_type": "stream",
      "text": [
-      "Training loss: 11.002305030822754\n",
+      "Training loss: 10.98758347829183\n",
       "Validation loss: 10.98110580444336\n"
      ]
     }
@@ -1159,6 +1123,8 @@
     "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
     "model.to(device) # no assignment model = model.to(device) necessary for nn.Module classes\n",
     "\n",
+    "\n",
+    "torch.manual_seed(123) # For reproducibility due to the shuffling in the data loader\n",
     "train_loss = calc_loss_loader(train_loader, model, device, num_batches=1)\n",
     "val_loss = calc_loss_loader(val_loader, model, device, num_batches=1)\n",
     "\n",

+ 2 - 3
ch05/01_main-chapter-code/train.py

@@ -31,17 +31,16 @@ def calc_loss_batch(input_batch, target_batch, model, device):
 
 
 def calc_loss_loader(data_loader, model, device, num_batches=None):
-    total_loss, batches_seen = 0., 0.
+    total_loss = 0.
     if num_batches is None:
         num_batches = len(data_loader)
     for i, (input_batch, target_batch) in enumerate(data_loader):
         if i < num_batches:
             loss = calc_loss_batch(input_batch, target_batch, model, device)
             total_loss += loss.item()
-            batches_seen += 1
         else:
             break
-    return total_loss / batches_seen
+    return total_loss / num_batches
 
 
 def evaluate_model(model, train_loader, val_loader, device, eval_iter):

+ 2 - 3
ch05/03_bonus_pretraining_on_gutenberg/previous_chapters.py

@@ -250,17 +250,16 @@ def calc_loss_batch(input_batch, target_batch, model, device):
 
 
 def calc_loss_loader(data_loader, model, device, num_batches=None):
-    total_loss, batches_seen = 0., 0.
+    total_loss = 0.
     if num_batches is None:
         num_batches = len(data_loader)
     for i, (input_batch, target_batch) in enumerate(data_loader):
         if i < num_batches:
             loss = calc_loss_batch(input_batch, target_batch, model, device)
             total_loss += loss.item()
-            batches_seen += 1
         else:
             break
-    return total_loss / batches_seen
+    return total_loss / num_batches
 
 
 def evaluate_model(model, train_loader, val_loader, device, eval_iter):

+ 6 - 7
ch05/05_bonus_hparam_tuning/hparam_search.py

@@ -23,18 +23,17 @@ HPARAM_GRID = {
 }
 
 
-def calc_loss_loader(data_loader, model, device, num_iters=None):
-    total_loss, num_batches = 0., 0
-    if num_iters is None:
-        num_iters = len(data_loader)
+def calc_loss_loader(data_loader, model, device, num_batches=None):
+    total_loss = 0.
+    if num_batches is None:
+        num_batches = len(data_loader)
     for i, (input_batch, target_batch) in enumerate(data_loader):
-        if i < num_iters:
+        if i < num_batches:
             loss = calc_loss_batch(input_batch, target_batch, model, device)
             total_loss += loss.item()
-            num_batches += 1
         else:
             break
-    return total_loss
+    return total_loss / num_batches
 
 
 def calc_loss_batch(input_batch, target_batch, model, device):