瀏覽代碼

add toggle for qkv_bias

rasbt 1 年之前
父節點
當前提交
92896d817c

+ 82 - 69
ch02/01_main-chapter-code/ch02.ipynb

@@ -26,7 +26,7 @@
      "name": "stdout",
      "name": "stdout",
      "output_type": "stream",
      "output_type": "stream",
      "text": [
      "text": [
-      "torch version: 2.0.1\n",
+      "torch version: 2.1.0\n",
       "tiktoken version: 0.5.1\n"
       "tiktoken version: 0.5.1\n"
      ]
      ]
     }
     }
@@ -76,7 +76,7 @@
   },
   },
   {
   {
    "cell_type": "code",
    "cell_type": "code",
-   "execution_count": 1,
+   "execution_count": 2,
    "id": "8a769e87-470a-48b9-8bdb-12841b416198",
    "id": "8a769e87-470a-48b9-8bdb-12841b416198",
    "metadata": {},
    "metadata": {},
    "outputs": [
    "outputs": [
@@ -109,7 +109,7 @@
   },
   },
   {
   {
    "cell_type": "code",
    "cell_type": "code",
-   "execution_count": 2,
+   "execution_count": 3,
    "id": "737dd5b0-9dbb-4a97-9ae4-3482c8c04be7",
    "id": "737dd5b0-9dbb-4a97-9ae4-3482c8c04be7",
    "metadata": {},
    "metadata": {},
    "outputs": [
    "outputs": [
@@ -140,7 +140,7 @@
   },
   },
   {
   {
    "cell_type": "code",
    "cell_type": "code",
-   "execution_count": 3,
+   "execution_count": 4,
    "id": "ea02489d-01f9-4247-b7dd-a0d63f62ef07",
    "id": "ea02489d-01f9-4247-b7dd-a0d63f62ef07",
    "metadata": {},
    "metadata": {},
    "outputs": [
    "outputs": [
@@ -168,7 +168,7 @@
   },
   },
   {
   {
    "cell_type": "code",
    "cell_type": "code",
-   "execution_count": 4,
+   "execution_count": 5,
    "id": "4d8a6fb7-2e62-4a12-ad06-ccb04f25fed7",
    "id": "4d8a6fb7-2e62-4a12-ad06-ccb04f25fed7",
    "metadata": {},
    "metadata": {},
    "outputs": [
    "outputs": [
@@ -196,7 +196,7 @@
   },
   },
   {
   {
    "cell_type": "code",
    "cell_type": "code",
-   "execution_count": 5,
+   "execution_count": 6,
    "id": "902f0d9c-9828-4c46-ba32-8fe810c3840a",
    "id": "902f0d9c-9828-4c46-ba32-8fe810c3840a",
    "metadata": {},
    "metadata": {},
    "outputs": [
    "outputs": [
@@ -226,7 +226,7 @@
   },
   },
   {
   {
    "cell_type": "code",
    "cell_type": "code",
-   "execution_count": 6,
+   "execution_count": 7,
    "id": "8c567caa-8ff5-49a8-a5cc-d365b0a78a99",
    "id": "8c567caa-8ff5-49a8-a5cc-d365b0a78a99",
    "metadata": {},
    "metadata": {},
    "outputs": [
    "outputs": [
@@ -254,7 +254,7 @@
   },
   },
   {
   {
    "cell_type": "code",
    "cell_type": "code",
-   "execution_count": 7,
+   "execution_count": 8,
    "id": "35db7b5e-510b-4c45-995f-f5ad64a8e19c",
    "id": "35db7b5e-510b-4c45-995f-f5ad64a8e19c",
    "metadata": {},
    "metadata": {},
    "outputs": [
    "outputs": [
@@ -288,7 +288,7 @@
   },
   },
   {
   {
    "cell_type": "code",
    "cell_type": "code",
-   "execution_count": 8,
+   "execution_count": 9,
    "id": "7fdf0533-5ab6-42a5-83fa-a3b045de6396",
    "id": "7fdf0533-5ab6-42a5-83fa-a3b045de6396",
    "metadata": {},
    "metadata": {},
    "outputs": [
    "outputs": [
@@ -309,7 +309,7 @@
   },
   },
   {
   {
    "cell_type": "code",
    "cell_type": "code",
-   "execution_count": 9,
+   "execution_count": 10,
    "id": "77d00d96-881f-4691-bb03-84fec2a75a26",
    "id": "77d00d96-881f-4691-bb03-84fec2a75a26",
    "metadata": {},
    "metadata": {},
    "outputs": [],
    "outputs": [],
@@ -327,7 +327,7 @@
   },
   },
   {
   {
    "cell_type": "code",
    "cell_type": "code",
-   "execution_count": 10,
+   "execution_count": 11,
    "id": "e1c5de4a-aa4e-4aec-b532-10bb364039d6",
    "id": "e1c5de4a-aa4e-4aec-b532-10bb364039d6",
    "metadata": {},
    "metadata": {},
    "outputs": [
    "outputs": [
@@ -406,7 +406,7 @@
   },
   },
   {
   {
    "cell_type": "code",
    "cell_type": "code",
-   "execution_count": 11,
+   "execution_count": 12,
    "id": "f531bf46-7c25-4ef8-bff8-0d27518676d5",
    "id": "f531bf46-7c25-4ef8-bff8-0d27518676d5",
    "metadata": {},
    "metadata": {},
    "outputs": [],
    "outputs": [],
@@ -440,7 +440,7 @@
   },
   },
   {
   {
    "cell_type": "code",
    "cell_type": "code",
-   "execution_count": 12,
+   "execution_count": 13,
    "id": "647364ec-7995-4654-9b4a-7607ccf5f1e4",
    "id": "647364ec-7995-4654-9b4a-7607ccf5f1e4",
    "metadata": {},
    "metadata": {},
    "outputs": [
    "outputs": [
@@ -470,7 +470,7 @@
   },
   },
   {
   {
    "cell_type": "code",
    "cell_type": "code",
-   "execution_count": 13,
+   "execution_count": 14,
    "id": "01d8c8fb-432d-4a49-b332-99f23b233746",
    "id": "01d8c8fb-432d-4a49-b332-99f23b233746",
    "metadata": {},
    "metadata": {},
    "outputs": [
    "outputs": [
@@ -480,7 +480,7 @@
        "'\" It\\' s the last he painted, you know,\" Mrs. Gisburn said with pardonable pride.'"
        "'\" It\\' s the last he painted, you know,\" Mrs. Gisburn said with pardonable pride.'"
       ]
       ]
      },
      },
-     "execution_count": 13,
+     "execution_count": 14,
      "metadata": {},
      "metadata": {},
      "output_type": "execute_result"
      "output_type": "execute_result"
     }
     }
@@ -490,9 +490,22 @@
    ]
    ]
   },
   },
   {
   {
-   "cell_type": "markdown",
-   "id": "75f21efe-c4d6-4323-839b-6061972810d2",
+   "cell_type": "code",
+   "execution_count": 15,
+   "id": "54f6aa8b-9827-412e-9035-e827296ab0fe",
    "metadata": {},
    "metadata": {},
+   "outputs": [
+    {
+     "data": {
+      "text/plain": [
+       "'\" It\\' s the last he painted, you know,\" Mrs. Gisburn said with pardonable pride.'"
+      ]
+     },
+     "execution_count": 15,
+     "metadata": {},
+     "output_type": "execute_result"
+    }
+   ],
    "source": [
    "source": [
     "tokenizer.decode(tokenizer.encode(text))"
     "tokenizer.decode(tokenizer.encode(text))"
    ]
    ]
@@ -534,7 +547,7 @@
   },
   },
   {
   {
    "cell_type": "code",
    "cell_type": "code",
-   "execution_count": 14,
+   "execution_count": 16,
    "id": "d5767eff-440c-4de1-9289-f789349d6b85",
    "id": "d5767eff-440c-4de1-9289-f789349d6b85",
    "metadata": {},
    "metadata": {},
    "outputs": [
    "outputs": [
@@ -545,9 +558,9 @@
      "traceback": [
      "traceback": [
       "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
       "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
       "\u001b[0;31mKeyError\u001b[0m                                  Traceback (most recent call last)",
       "\u001b[0;31mKeyError\u001b[0m                                  Traceback (most recent call last)",
-      "Cell \u001b[0;32mIn[14], line 5\u001b[0m\n\u001b[1;32m      1\u001b[0m tokenizer \u001b[38;5;241m=\u001b[39m SimpleTokenizerV1(vocab)\n\u001b[1;32m      3\u001b[0m text \u001b[38;5;241m=\u001b[39m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mHello, do you like tea. Is this-- a test?\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[0;32m----> 5\u001b[0m \u001b[43mtokenizer\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mencode\u001b[49m\u001b[43m(\u001b[49m\u001b[43mtext\u001b[49m\u001b[43m)\u001b[49m\n",
-      "Cell \u001b[0;32mIn[11], line 9\u001b[0m, in \u001b[0;36mSimpleTokenizerV1.encode\u001b[0;34m(self, text)\u001b[0m\n\u001b[1;32m      7\u001b[0m preprocessed \u001b[38;5;241m=\u001b[39m re\u001b[38;5;241m.\u001b[39msplit(\u001b[38;5;124mr\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124m([,.?_!\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m()\u001b[39m\u001b[38;5;130;01m\\'\u001b[39;00m\u001b[38;5;124m]|--|\u001b[39m\u001b[38;5;124m\\\u001b[39m\u001b[38;5;124ms)\u001b[39m\u001b[38;5;124m'\u001b[39m, text)\n\u001b[1;32m      8\u001b[0m preprocessed \u001b[38;5;241m=\u001b[39m [item\u001b[38;5;241m.\u001b[39mstrip() \u001b[38;5;28;01mfor\u001b[39;00m item \u001b[38;5;129;01min\u001b[39;00m preprocessed \u001b[38;5;28;01mif\u001b[39;00m item\u001b[38;5;241m.\u001b[39mstrip()]\n\u001b[0;32m----> 9\u001b[0m ids \u001b[38;5;241m=\u001b[39m \u001b[43m[\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mstr_to_int\u001b[49m\u001b[43m[\u001b[49m\u001b[43ms\u001b[49m\u001b[43m]\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43;01mfor\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[43ms\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;129;43;01min\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[43mpreprocessed\u001b[49m\u001b[43m]\u001b[49m\n\u001b[1;32m     10\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m ids\n",
-      "Cell \u001b[0;32mIn[11], line 9\u001b[0m, in \u001b[0;36m<listcomp>\u001b[0;34m(.0)\u001b[0m\n\u001b[1;32m      7\u001b[0m preprocessed \u001b[38;5;241m=\u001b[39m re\u001b[38;5;241m.\u001b[39msplit(\u001b[38;5;124mr\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124m([,.?_!\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m()\u001b[39m\u001b[38;5;130;01m\\'\u001b[39;00m\u001b[38;5;124m]|--|\u001b[39m\u001b[38;5;124m\\\u001b[39m\u001b[38;5;124ms)\u001b[39m\u001b[38;5;124m'\u001b[39m, text)\n\u001b[1;32m      8\u001b[0m preprocessed \u001b[38;5;241m=\u001b[39m [item\u001b[38;5;241m.\u001b[39mstrip() \u001b[38;5;28;01mfor\u001b[39;00m item \u001b[38;5;129;01min\u001b[39;00m preprocessed \u001b[38;5;28;01mif\u001b[39;00m item\u001b[38;5;241m.\u001b[39mstrip()]\n\u001b[0;32m----> 9\u001b[0m ids \u001b[38;5;241m=\u001b[39m [\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mstr_to_int\u001b[49m\u001b[43m[\u001b[49m\u001b[43ms\u001b[49m\u001b[43m]\u001b[49m \u001b[38;5;28;01mfor\u001b[39;00m s \u001b[38;5;129;01min\u001b[39;00m preprocessed]\n\u001b[1;32m     10\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m ids\n",
+      "Cell \u001b[0;32mIn[16], line 5\u001b[0m\n\u001b[1;32m      1\u001b[0m tokenizer \u001b[38;5;241m=\u001b[39m SimpleTokenizerV1(vocab)\n\u001b[1;32m      3\u001b[0m text \u001b[38;5;241m=\u001b[39m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mHello, do you like tea. Is this-- a test?\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[0;32m----> 5\u001b[0m \u001b[43mtokenizer\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mencode\u001b[49m\u001b[43m(\u001b[49m\u001b[43mtext\u001b[49m\u001b[43m)\u001b[49m\n",
+      "Cell \u001b[0;32mIn[12], line 9\u001b[0m, in \u001b[0;36mSimpleTokenizerV1.encode\u001b[0;34m(self, text)\u001b[0m\n\u001b[1;32m      7\u001b[0m preprocessed \u001b[38;5;241m=\u001b[39m re\u001b[38;5;241m.\u001b[39msplit(\u001b[38;5;124mr\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124m([,.?_!\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m()\u001b[39m\u001b[38;5;130;01m\\'\u001b[39;00m\u001b[38;5;124m]|--|\u001b[39m\u001b[38;5;124m\\\u001b[39m\u001b[38;5;124ms)\u001b[39m\u001b[38;5;124m'\u001b[39m, text)\n\u001b[1;32m      8\u001b[0m preprocessed \u001b[38;5;241m=\u001b[39m [item\u001b[38;5;241m.\u001b[39mstrip() \u001b[38;5;28;01mfor\u001b[39;00m item \u001b[38;5;129;01min\u001b[39;00m preprocessed \u001b[38;5;28;01mif\u001b[39;00m item\u001b[38;5;241m.\u001b[39mstrip()]\n\u001b[0;32m----> 9\u001b[0m ids \u001b[38;5;241m=\u001b[39m [\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mstr_to_int[s] \u001b[38;5;28;01mfor\u001b[39;00m s \u001b[38;5;129;01min\u001b[39;00m preprocessed]\n\u001b[1;32m     10\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m ids\n",
+      "Cell \u001b[0;32mIn[12], line 9\u001b[0m, in \u001b[0;36m<listcomp>\u001b[0;34m(.0)\u001b[0m\n\u001b[1;32m      7\u001b[0m preprocessed \u001b[38;5;241m=\u001b[39m re\u001b[38;5;241m.\u001b[39msplit(\u001b[38;5;124mr\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124m([,.?_!\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m()\u001b[39m\u001b[38;5;130;01m\\'\u001b[39;00m\u001b[38;5;124m]|--|\u001b[39m\u001b[38;5;124m\\\u001b[39m\u001b[38;5;124ms)\u001b[39m\u001b[38;5;124m'\u001b[39m, text)\n\u001b[1;32m      8\u001b[0m preprocessed \u001b[38;5;241m=\u001b[39m [item\u001b[38;5;241m.\u001b[39mstrip() \u001b[38;5;28;01mfor\u001b[39;00m item \u001b[38;5;129;01min\u001b[39;00m preprocessed \u001b[38;5;28;01mif\u001b[39;00m item\u001b[38;5;241m.\u001b[39mstrip()]\n\u001b[0;32m----> 9\u001b[0m ids \u001b[38;5;241m=\u001b[39m [\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mstr_to_int\u001b[49m\u001b[43m[\u001b[49m\u001b[43ms\u001b[49m\u001b[43m]\u001b[49m \u001b[38;5;28;01mfor\u001b[39;00m s \u001b[38;5;129;01min\u001b[39;00m preprocessed]\n\u001b[1;32m     10\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m ids\n",
       "\u001b[0;31mKeyError\u001b[0m: 'Hello'"
       "\u001b[0;31mKeyError\u001b[0m: 'Hello'"
      ]
      ]
     }
     }
@@ -572,7 +585,7 @@
   },
   },
   {
   {
    "cell_type": "code",
    "cell_type": "code",
-   "execution_count": 15,
+   "execution_count": 17,
    "id": "ce9df29c-6c5b-43f1-8c1a-c7f7b79db78f",
    "id": "ce9df29c-6c5b-43f1-8c1a-c7f7b79db78f",
    "metadata": {},
    "metadata": {},
    "outputs": [],
    "outputs": [],
@@ -588,7 +601,7 @@
   },
   },
   {
   {
    "cell_type": "code",
    "cell_type": "code",
-   "execution_count": 16,
+   "execution_count": 18,
    "id": "57c3143b-e860-4d3b-a22a-de22b547a6a9",
    "id": "57c3143b-e860-4d3b-a22a-de22b547a6a9",
    "metadata": {},
    "metadata": {},
    "outputs": [
    "outputs": [
@@ -598,7 +611,7 @@
        "1161"
        "1161"
       ]
       ]
      },
      },
-     "execution_count": 16,
+     "execution_count": 18,
      "metadata": {},
      "metadata": {},
      "output_type": "execute_result"
      "output_type": "execute_result"
     }
     }
@@ -609,7 +622,7 @@
   },
   },
   {
   {
    "cell_type": "code",
    "cell_type": "code",
-   "execution_count": 17,
+   "execution_count": 19,
    "id": "50e51bb1-ae05-4aa8-a9ff-455b65ed1959",
    "id": "50e51bb1-ae05-4aa8-a9ff-455b65ed1959",
    "metadata": {},
    "metadata": {},
    "outputs": [
    "outputs": [
@@ -640,7 +653,7 @@
   },
   },
   {
   {
    "cell_type": "code",
    "cell_type": "code",
-   "execution_count": 18,
+   "execution_count": 20,
    "id": "948861c5-3f30-4712-a234-725f20d26f68",
    "id": "948861c5-3f30-4712-a234-725f20d26f68",
    "metadata": {},
    "metadata": {},
    "outputs": [],
    "outputs": [],
@@ -676,7 +689,7 @@
   },
   },
   {
   {
    "cell_type": "code",
    "cell_type": "code",
-   "execution_count": 19,
+   "execution_count": 21,
    "id": "effcef79-e0a5-4f4a-a43a-31dd94b9250a",
    "id": "effcef79-e0a5-4f4a-a43a-31dd94b9250a",
    "metadata": {},
    "metadata": {},
    "outputs": [
    "outputs": [
@@ -701,7 +714,7 @@
   },
   },
   {
   {
    "cell_type": "code",
    "cell_type": "code",
-   "execution_count": 20,
+   "execution_count": 22,
    "id": "ddfe7346-398d-4bf8-99f1-5b071244ce95",
    "id": "ddfe7346-398d-4bf8-99f1-5b071244ce95",
    "metadata": {},
    "metadata": {},
    "outputs": [
    "outputs": [
@@ -726,7 +739,7 @@
        " 7]"
        " 7]"
       ]
       ]
      },
      },
-     "execution_count": 20,
+     "execution_count": 22,
      "metadata": {},
      "metadata": {},
      "output_type": "execute_result"
      "output_type": "execute_result"
     }
     }
@@ -737,7 +750,7 @@
   },
   },
   {
   {
    "cell_type": "code",
    "cell_type": "code",
-   "execution_count": 21,
+   "execution_count": 23,
    "id": "0c350ff6-2734-4e84-9ec7-d578baa4ae1b",
    "id": "0c350ff6-2734-4e84-9ec7-d578baa4ae1b",
    "metadata": {},
    "metadata": {},
    "outputs": [
    "outputs": [
@@ -747,7 +760,7 @@
        "'<|unk|>, do you like tea? <|endoftext|> In the sunlit terraces of the <|unk|>.'"
        "'<|unk|>, do you like tea? <|endoftext|> In the sunlit terraces of the <|unk|>.'"
       ]
       ]
      },
      },
-     "execution_count": 21,
+     "execution_count": 23,
      "metadata": {},
      "metadata": {},
      "output_type": "execute_result"
      "output_type": "execute_result"
     }
     }
@@ -779,7 +792,7 @@
   },
   },
   {
   {
    "cell_type": "code",
    "cell_type": "code",
-   "execution_count": 22,
+   "execution_count": 24,
    "id": "ede1d41f-934b-4bf4-8184-54394a257a94",
    "id": "ede1d41f-934b-4bf4-8184-54394a257a94",
    "metadata": {},
    "metadata": {},
    "outputs": [],
    "outputs": [],
@@ -789,7 +802,7 @@
   },
   },
   {
   {
    "cell_type": "code",
    "cell_type": "code",
-   "execution_count": 23,
+   "execution_count": 25,
    "id": "48967a77-7d17-42bf-9e92-fc619d63a59e",
    "id": "48967a77-7d17-42bf-9e92-fc619d63a59e",
    "metadata": {},
    "metadata": {},
    "outputs": [
    "outputs": [
@@ -810,7 +823,7 @@
   },
   },
   {
   {
    "cell_type": "code",
    "cell_type": "code",
-   "execution_count": 24,
+   "execution_count": 26,
    "id": "6ad3312f-a5f7-4efc-9d7d-8ea09d7b5128",
    "id": "6ad3312f-a5f7-4efc-9d7d-8ea09d7b5128",
    "metadata": {},
    "metadata": {},
    "outputs": [],
    "outputs": [],
@@ -820,7 +833,7 @@
   },
   },
   {
   {
    "cell_type": "code",
    "cell_type": "code",
-   "execution_count": 25,
+   "execution_count": 27,
    "id": "5ff2cd85-7cfb-4325-b390-219938589428",
    "id": "5ff2cd85-7cfb-4325-b390-219938589428",
    "metadata": {},
    "metadata": {},
    "outputs": [
    "outputs": [
@@ -842,7 +855,7 @@
   },
   },
   {
   {
    "cell_type": "code",
    "cell_type": "code",
-   "execution_count": 26,
+   "execution_count": 28,
    "id": "d26a48bb-f82e-41a8-a955-a1c9cf9d50ab",
    "id": "d26a48bb-f82e-41a8-a955-a1c9cf9d50ab",
    "metadata": {},
    "metadata": {},
    "outputs": [
    "outputs": [
@@ -870,7 +883,7 @@
   },
   },
   {
   {
    "cell_type": "code",
    "cell_type": "code",
-   "execution_count": 27,
+   "execution_count": 29,
    "id": "ce25cf25-a2bb-44d2-bac1-cb566f433f98",
    "id": "ce25cf25-a2bb-44d2-bac1-cb566f433f98",
    "metadata": {},
    "metadata": {},
    "outputs": [
    "outputs": [
@@ -889,7 +902,7 @@
   },
   },
   {
   {
    "cell_type": "code",
    "cell_type": "code",
-   "execution_count": 28,
+   "execution_count": 30,
    "id": "3e224f96-41d0-4074-ac6e-f7db2490f806",
    "id": "3e224f96-41d0-4074-ac6e-f7db2490f806",
    "metadata": {},
    "metadata": {},
    "outputs": [
    "outputs": [
@@ -913,7 +926,7 @@
   },
   },
   {
   {
    "cell_type": "code",
    "cell_type": "code",
-   "execution_count": 29,
+   "execution_count": 31,
    "id": "766bcf29-64bf-47ca-9b65-4ae8e607d580",
    "id": "766bcf29-64bf-47ca-9b65-4ae8e607d580",
    "metadata": {},
    "metadata": {},
    "outputs": [
    "outputs": [
@@ -940,7 +953,7 @@
   },
   },
   {
   {
    "cell_type": "code",
    "cell_type": "code",
-   "execution_count": 30,
+   "execution_count": 32,
    "id": "848d5ade-fd1f-46c3-9e31-1426e315c71b",
    "id": "848d5ade-fd1f-46c3-9e31-1426e315c71b",
    "metadata": {},
    "metadata": {},
    "outputs": [
    "outputs": [
@@ -971,7 +984,7 @@
   },
   },
   {
   {
    "cell_type": "code",
    "cell_type": "code",
-   "execution_count": 31,
+   "execution_count": 33,
    "id": "e84424a7-646d-45b6-99e3-80d15fb761f2",
    "id": "e84424a7-646d-45b6-99e3-80d15fb761f2",
    "metadata": {},
    "metadata": {},
    "outputs": [],
    "outputs": [],
@@ -981,7 +994,7 @@
   },
   },
   {
   {
    "cell_type": "code",
    "cell_type": "code",
-   "execution_count": 32,
+   "execution_count": 34,
    "id": "dfbff852-a92f-48c8-a46d-143a0f109f40",
    "id": "dfbff852-a92f-48c8-a46d-143a0f109f40",
    "metadata": {},
    "metadata": {},
    "outputs": [
    "outputs": [
@@ -1014,7 +1027,7 @@
   },
   },
   {
   {
    "cell_type": "code",
    "cell_type": "code",
-   "execution_count": 33,
+   "execution_count": 35,
    "id": "d97b031e-ed55-409d-95f2-aeb38c6fe366",
    "id": "d97b031e-ed55-409d-95f2-aeb38c6fe366",
    "metadata": {},
    "metadata": {},
    "outputs": [
    "outputs": [
@@ -1039,7 +1052,7 @@
   },
   },
   {
   {
    "cell_type": "code",
    "cell_type": "code",
-   "execution_count": 34,
+   "execution_count": 36,
    "id": "f57bd746-dcbf-4433-8e24-ee213a8c34a1",
    "id": "f57bd746-dcbf-4433-8e24-ee213a8c34a1",
    "metadata": {},
    "metadata": {},
    "outputs": [
    "outputs": [
@@ -1081,7 +1094,7 @@
   },
   },
   {
   {
    "cell_type": "code",
    "cell_type": "code",
-   "execution_count": 35,
+   "execution_count": 37,
    "id": "e1770134-e7f3-4725-a679-e04c3be48cac",
    "id": "e1770134-e7f3-4725-a679-e04c3be48cac",
    "metadata": {},
    "metadata": {},
    "outputs": [
    "outputs": [
@@ -1089,7 +1102,7 @@
      "name": "stdout",
      "name": "stdout",
      "output_type": "stream",
      "output_type": "stream",
      "text": [
      "text": [
-      "PyTorch version: 2.0.1\n"
+      "PyTorch version: 2.1.0\n"
      ]
      ]
     }
     }
    ],
    ],
@@ -1108,7 +1121,7 @@
   },
   },
   {
   {
    "cell_type": "code",
    "cell_type": "code",
-   "execution_count": 36,
+   "execution_count": 38,
    "id": "74b41073-4c9f-46e2-a1bd-d38e4122b375",
    "id": "74b41073-4c9f-46e2-a1bd-d38e4122b375",
    "metadata": {},
    "metadata": {},
    "outputs": [],
    "outputs": [],
@@ -1141,12 +1154,12 @@
   },
   },
   {
   {
    "cell_type": "code",
    "cell_type": "code",
-   "execution_count": 37,
+   "execution_count": 39,
    "id": "5eb30ebe-97b3-43c5-9ff1-a97d621b3c4e",
    "id": "5eb30ebe-97b3-43c5-9ff1-a97d621b3c4e",
    "metadata": {},
    "metadata": {},
    "outputs": [],
    "outputs": [],
    "source": [
    "source": [
-    "def create_dataloader(txt, batch_size=4, max_length=256, stride=128):\n",
+    "def create_dataloader(txt, batch_size=4, max_length=256, stride=128, shuffle=True):\n",
     "    # Initialize the tokenizer\n",
     "    # Initialize the tokenizer\n",
     "    tokenizer = tiktoken.get_encoding(\"gpt2\")\n",
     "    tokenizer = tiktoken.get_encoding(\"gpt2\")\n",
     "\n",
     "\n",
@@ -1154,7 +1167,7 @@
     "    dataset = GPTDatasetV1(txt, tokenizer, max_length, stride)\n",
     "    dataset = GPTDatasetV1(txt, tokenizer, max_length, stride)\n",
     "\n",
     "\n",
     "    # Create dataloader\n",
     "    # Create dataloader\n",
-    "    dataloader = DataLoader(dataset, batch_size=batch_size)\n",
+    "    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=shuffle)\n",
     "\n",
     "\n",
     "    return dataloader"
     "    return dataloader"
    ]
    ]
@@ -1169,7 +1182,7 @@
   },
   },
   {
   {
    "cell_type": "code",
    "cell_type": "code",
-   "execution_count": 38,
+   "execution_count": 40,
    "id": "df31d96c-6bfd-4564-a956-6192242d7579",
    "id": "df31d96c-6bfd-4564-a956-6192242d7579",
    "metadata": {},
    "metadata": {},
    "outputs": [],
    "outputs": [],
@@ -1180,7 +1193,7 @@
   },
   },
   {
   {
    "cell_type": "code",
    "cell_type": "code",
-   "execution_count": 39,
+   "execution_count": 41,
    "id": "9226d00c-ad9a-4949-a6e4-9afccfc7214f",
    "id": "9226d00c-ad9a-4949-a6e4-9afccfc7214f",
    "metadata": {},
    "metadata": {},
    "outputs": [
    "outputs": [
@@ -1193,7 +1206,7 @@
     }
     }
    ],
    ],
    "source": [
    "source": [
-    "dataloader = create_dataloader(raw_text, batch_size=1, max_length=4, stride=1)\n",
+    "dataloader = create_dataloader(raw_text, batch_size=1, max_length=4, stride=1, shuffle=False)\n",
     "\n",
     "\n",
     "data_iter = iter(dataloader)\n",
     "data_iter = iter(dataloader)\n",
     "first_batch = next(data_iter)\n",
     "first_batch = next(data_iter)\n",
@@ -1202,7 +1215,7 @@
   },
   },
   {
   {
    "cell_type": "code",
    "cell_type": "code",
-   "execution_count": 40,
+   "execution_count": 42,
    "id": "10deb4bc-4de1-4d20-921e-4b1c7a0e1a6d",
    "id": "10deb4bc-4de1-4d20-921e-4b1c7a0e1a6d",
    "metadata": {},
    "metadata": {},
    "outputs": [
    "outputs": [
@@ -1230,7 +1243,7 @@
   },
   },
   {
   {
    "cell_type": "code",
    "cell_type": "code",
-   "execution_count": 41,
+   "execution_count": 43,
    "id": "1916e7a6-f03d-4f09-91a6-d0bdbac5a58c",
    "id": "1916e7a6-f03d-4f09-91a6-d0bdbac5a58c",
    "metadata": {},
    "metadata": {},
    "outputs": [
    "outputs": [
@@ -1261,7 +1274,7 @@
     }
     }
    ],
    ],
    "source": [
    "source": [
-    "dataloader = create_dataloader(raw_text, batch_size=8, max_length=4, stride=5)\n",
+    "dataloader = create_dataloader(raw_text, batch_size=8, max_length=4, stride=5, shuffle=False)\n",
     "\n",
     "\n",
     "data_iter = iter(dataloader)\n",
     "data_iter = iter(dataloader)\n",
     "inputs, targets = next(data_iter)\n",
     "inputs, targets = next(data_iter)\n",
@@ -1297,7 +1310,7 @@
   },
   },
   {
   {
    "cell_type": "code",
    "cell_type": "code",
-   "execution_count": 42,
+   "execution_count": 44,
    "id": "15a6304c-9474-4470-b85d-3991a49fa653",
    "id": "15a6304c-9474-4470-b85d-3991a49fa653",
    "metadata": {},
    "metadata": {},
    "outputs": [],
    "outputs": [],
@@ -1315,7 +1328,7 @@
   },
   },
   {
   {
    "cell_type": "code",
    "cell_type": "code",
-   "execution_count": 43,
+   "execution_count": 45,
    "id": "93cb2cee-9aa6-4bb8-8977-c65661d16eda",
    "id": "93cb2cee-9aa6-4bb8-8977-c65661d16eda",
    "metadata": {},
    "metadata": {},
    "outputs": [],
    "outputs": [],
@@ -1337,7 +1350,7 @@
   },
   },
   {
   {
    "cell_type": "code",
    "cell_type": "code",
-   "execution_count": 44,
+   "execution_count": 46,
    "id": "a686eb61-e737-4351-8f1c-222913d47468",
    "id": "a686eb61-e737-4351-8f1c-222913d47468",
    "metadata": {},
    "metadata": {},
    "outputs": [
    "outputs": [
@@ -1378,7 +1391,7 @@
   },
   },
   {
   {
    "cell_type": "code",
    "cell_type": "code",
-   "execution_count": 45,
+   "execution_count": 47,
    "id": "e43600ba-f287-4746-8ddf-d0f71a9023ca",
    "id": "e43600ba-f287-4746-8ddf-d0f71a9023ca",
    "metadata": {},
    "metadata": {},
    "outputs": [
    "outputs": [
@@ -1405,7 +1418,7 @@
   },
   },
   {
   {
    "cell_type": "code",
    "cell_type": "code",
-   "execution_count": 46,
+   "execution_count": 48,
    "id": "50280ead-0363-44c8-8c35-bb885d92c8b7",
    "id": "50280ead-0363-44c8-8c35-bb885d92c8b7",
    "metadata": {},
    "metadata": {},
    "outputs": [
    "outputs": [
@@ -1443,7 +1456,7 @@
   },
   },
   {
   {
    "cell_type": "code",
    "cell_type": "code",
-   "execution_count": 47,
+   "execution_count": 49,
    "id": "0b9e344d-03a6-4f2c-b723-67b6a20c5041",
    "id": "0b9e344d-03a6-4f2c-b723-67b6a20c5041",
    "metadata": {},
    "metadata": {},
    "outputs": [],
    "outputs": [],
@@ -1465,20 +1478,20 @@
   },
   },
   {
   {
    "cell_type": "code",
    "cell_type": "code",
-   "execution_count": 48,
+   "execution_count": 50,
    "id": "ad56a263-3d2e-4d91-98bf-d0b68d3c7fc3",
    "id": "ad56a263-3d2e-4d91-98bf-d0b68d3c7fc3",
    "metadata": {},
    "metadata": {},
    "outputs": [],
    "outputs": [],
    "source": [
    "source": [
     "max_length = 4\n",
     "max_length = 4\n",
-    "dataloader = create_dataloader(raw_text, batch_size=8, max_length=max_length, stride=5)\n",
+    "dataloader = create_dataloader(raw_text, batch_size=8, max_length=max_length, stride=5, shuffle=False)\n",
     "data_iter = iter(dataloader)\n",
     "data_iter = iter(dataloader)\n",
     "inputs, targets = next(data_iter)"
     "inputs, targets = next(data_iter)"
    ]
    ]
   },
   },
   {
   {
    "cell_type": "code",
    "cell_type": "code",
-   "execution_count": 49,
+   "execution_count": 51,
    "id": "84416b60-3707-4370-bcbc-da0b62f2b64d",
    "id": "84416b60-3707-4370-bcbc-da0b62f2b64d",
    "metadata": {},
    "metadata": {},
    "outputs": [
    "outputs": [
@@ -1508,7 +1521,7 @@
   },
   },
   {
   {
    "cell_type": "code",
    "cell_type": "code",
-   "execution_count": 50,
+   "execution_count": 52,
    "id": "7766ec38-30d0-4128-8c31-f49f063c43d1",
    "id": "7766ec38-30d0-4128-8c31-f49f063c43d1",
    "metadata": {},
    "metadata": {},
    "outputs": [
    "outputs": [
@@ -1535,7 +1548,7 @@
   },
   },
   {
   {
    "cell_type": "code",
    "cell_type": "code",
-   "execution_count": 51,
+   "execution_count": 53,
    "id": "cc048e20-7ac8-417e-81f5-8fe6f9a4fe07",
    "id": "cc048e20-7ac8-417e-81f5-8fe6f9a4fe07",
    "metadata": {},
    "metadata": {},
    "outputs": [],
    "outputs": [],
@@ -1546,7 +1559,7 @@
   },
   },
   {
   {
    "cell_type": "code",
    "cell_type": "code",
-   "execution_count": 52,
+   "execution_count": 54,
    "id": "c369a1e7-d566-4b53-b398-d6adafb44105",
    "id": "c369a1e7-d566-4b53-b398-d6adafb44105",
    "metadata": {},
    "metadata": {},
    "outputs": [
    "outputs": [
@@ -1573,7 +1586,7 @@
   },
   },
   {
   {
    "cell_type": "code",
    "cell_type": "code",
-   "execution_count": 53,
+   "execution_count": 55,
    "id": "b22fab89-526e-43c8-9035-5b7018e34288",
    "id": "b22fab89-526e-43c8-9035-5b7018e34288",
    "metadata": {},
    "metadata": {},
    "outputs": [
    "outputs": [

+ 6 - 13
ch02/01_main-chapter-code/dataloader.ipynb

@@ -20,7 +20,7 @@
   },
   },
   {
   {
    "cell_type": "code",
    "cell_type": "code",
-   "execution_count": 2,
+   "execution_count": 1,
    "id": "93804da5-372b-45ff-9ef4-8398ba1dd78e",
    "id": "93804da5-372b-45ff-9ef4-8398ba1dd78e",
    "metadata": {},
    "metadata": {},
    "outputs": [
    "outputs": [
@@ -28,7 +28,7 @@
      "name": "stdout",
      "name": "stdout",
      "output_type": "stream",
      "output_type": "stream",
      "text": [
      "text": [
-      "torch version: 2.0.1\n",
+      "torch version: 2.1.0\n",
       "tiktoken version: 0.5.1\n"
       "tiktoken version: 0.5.1\n"
      ]
      ]
     }
     }
@@ -78,7 +78,7 @@
     "        return self.input_ids[idx], self.target_ids[idx]\n",
     "        return self.input_ids[idx], self.target_ids[idx]\n",
     "\n",
     "\n",
     "\n",
     "\n",
-    "def create_dataloader(txt, batch_size=4, max_length=256, stride=128):\n",
+    "def create_dataloader(txt, batch_size=4, max_length=256, stride=128, shuffle=True):\n",
     "    # Initialize the tokenizer\n",
     "    # Initialize the tokenizer\n",
     "    tokenizer = tiktoken.get_encoding(\"gpt2\")\n",
     "    tokenizer = tiktoken.get_encoding(\"gpt2\")\n",
     "\n",
     "\n",
@@ -86,11 +86,12 @@
     "    dataset = GPTDatasetV1(txt, tokenizer, max_length, stride)\n",
     "    dataset = GPTDatasetV1(txt, tokenizer, max_length, stride)\n",
     "\n",
     "\n",
     "    # Create dataloader\n",
     "    # Create dataloader\n",
-    "    dataloader = DataLoader(dataset, batch_size=batch_size)\n",
+    "    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=shuffle)\n",
     "\n",
     "\n",
     "    return dataloader\n",
     "    return dataloader\n",
     "\n",
     "\n",
     "\n",
     "\n",
+    "\n",
     "with open(\"the-verdict.txt\", \"r\", encoding=\"utf-8\") as f:\n",
     "with open(\"the-verdict.txt\", \"r\", encoding=\"utf-8\") as f:\n",
     "    raw_text = f.read()\n",
     "    raw_text = f.read()\n",
     "\n",
     "\n",
@@ -144,14 +145,6 @@
    "source": [
    "source": [
     "print(input_embeddings.shape)"
     "print(input_embeddings.shape)"
    ]
    ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": null,
-   "id": "2773c09d-c136-4372-a2be-04b58d292842",
-   "metadata": {},
-   "outputs": [],
-   "source": []
   }
   }
  ],
  ],
  "metadata": {
  "metadata": {
@@ -170,7 +163,7 @@
    "name": "python",
    "name": "python",
    "nbconvert_exporter": "python",
    "nbconvert_exporter": "python",
    "pygments_lexer": "ipython3",
    "pygments_lexer": "ipython3",
-   "version": "3.11.4"
+   "version": "3.10.12"
   }
   }
  },
  },
  "nbformat": 4,
  "nbformat": 4,

+ 14 - 14
ch03/01_main-chapter-code/ch03.ipynb

@@ -971,12 +971,12 @@
    "source": [
    "source": [
     "class SelfAttention_v2(nn.Module):\n",
     "class SelfAttention_v2(nn.Module):\n",
     "\n",
     "\n",
-    "    def __init__(self, d_in, d_out):\n",
+    "    def __init__(self, d_in, d_out, qkv_bias=False):\n",
     "        super().__init__()\n",
     "        super().__init__()\n",
     "        self.d_out = d_out\n",
     "        self.d_out = d_out\n",
-    "        self.W_query = nn.Linear(d_in, d_out, bias=False)\n",
-    "        self.W_key   = nn.Linear(d_in, d_out, bias=False)\n",
-    "        self.W_value = nn.Linear(d_in, d_out, bias=False)\n",
+    "        self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)\n",
+    "        self.W_key   = nn.Linear(d_in, d_out, bias=qkv_bias)\n",
+    "        self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)\n",
     "\n",
     "\n",
     "    def forward(self, x):\n",
     "    def forward(self, x):\n",
     "        keys = self.W_key(x)\n",
     "        keys = self.W_key(x)\n",
@@ -1397,12 +1397,12 @@
    "source": [
    "source": [
     "class CausalAttention(nn.Module):\n",
     "class CausalAttention(nn.Module):\n",
     "\n",
     "\n",
-    "    def __init__(self, d_in, d_out, block_size, dropout):\n",
+    "    def __init__(self, d_in, d_out, block_size, dropout, qkv_bias=False):\n",
     "        super().__init__()\n",
     "        super().__init__()\n",
     "        self.d_out = d_out\n",
     "        self.d_out = d_out\n",
-    "        self.W_query = nn.Linear(d_in, d_out, bias=False)\n",
-    "        self.W_key   = nn.Linear(d_in, d_out, bias=False)\n",
-    "        self.W_value = nn.Linear(d_in, d_out, bias=False)\n",
+    "        self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)\n",
+    "        self.W_key   = nn.Linear(d_in, d_out, bias=qkv_bias)\n",
+    "        self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)\n",
     "        self.dropout = nn.Dropout(dropout) # New\n",
     "        self.dropout = nn.Dropout(dropout) # New\n",
     "        self.register_buffer('mask', torch.triu(torch.ones(block_size, block_size), diagonal=1)) # New\n",
     "        self.register_buffer('mask', torch.triu(torch.ones(block_size, block_size), diagonal=1)) # New\n",
     "\n",
     "\n",
@@ -1504,10 +1504,10 @@
    "source": [
    "source": [
     "class MultiHeadAttentionWrapper(nn.Module):\n",
     "class MultiHeadAttentionWrapper(nn.Module):\n",
     "\n",
     "\n",
-    "    def __init__(self, d_in, d_out, block_size, dropout, num_heads):\n",
+    "    def __init__(self, d_in, d_out, block_size, dropout, num_heads, qkv_bias=False):\n",
     "        super().__init__()\n",
     "        super().__init__()\n",
     "        self.heads = nn.ModuleList(\n",
     "        self.heads = nn.ModuleList(\n",
-    "            [CausalAttention(d_in, d_out, block_size, dropout) \n",
+    "            [CausalAttention(d_in, d_out, block_size, dropout, qkv_bias) \n",
     "             for _ in range(num_heads)]\n",
     "             for _ in range(num_heads)]\n",
     "        )\n",
     "        )\n",
     "\n",
     "\n",
@@ -1623,7 +1623,7 @@
    ],
    ],
    "source": [
    "source": [
     "class MultiHeadAttention(nn.Module):\n",
     "class MultiHeadAttention(nn.Module):\n",
-    "    def __init__(self, d_in, d_out, block_size, dropout, num_heads):\n",
+    "    def __init__(self, d_in, d_out, block_size, dropout, num_heads, qkv_bias=False):\n",
     "        super().__init__()\n",
     "        super().__init__()\n",
     "        assert d_out % num_heads == 0, \"d_out must be divisible by n_heads\"\n",
     "        assert d_out % num_heads == 0, \"d_out must be divisible by n_heads\"\n",
     "\n",
     "\n",
@@ -1631,9 +1631,9 @@
     "        self.num_heads = num_heads\n",
     "        self.num_heads = num_heads\n",
     "        self.head_dim = d_out // num_heads # Reduce the projection dim to match desired output dim\n",
     "        self.head_dim = d_out // num_heads # Reduce the projection dim to match desired output dim\n",
     "\n",
     "\n",
-    "        self.W_query = nn.Linear(d_in, d_out, bias=False)\n",
-    "        self.W_key = nn.Linear(d_in, d_out, bias=False)\n",
-    "        self.W_value = nn.Linear(d_in, d_out, bias=False)\n",
+    "        self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)\n",
+    "        self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias)\n",
+    "        self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)\n",
     "        self.out_proj = nn.Linear(d_out, d_out)  # Linear layer to combine head outputs\n",
     "        self.out_proj = nn.Linear(d_out, d_out)  # Linear layer to combine head outputs\n",
     "        self.dropout = nn.Dropout(dropout)\n",
     "        self.dropout = nn.Dropout(dropout)\n",
     "        self.register_buffer('mask', torch.triu(torch.ones(block_size, block_size), diagonal=1))\n",
     "        self.register_buffer('mask', torch.triu(torch.ones(block_size, block_size), diagonal=1))\n",

+ 12 - 12
ch03/01_main-chapter-code/multihead-attention.ipynb

@@ -62,7 +62,7 @@
     "        return self.input_ids[idx], self.target_ids[idx]\n",
     "        return self.input_ids[idx], self.target_ids[idx]\n",
     "\n",
     "\n",
     "\n",
     "\n",
-    "def create_dataloader(txt, batch_size=4, max_length=256, stride=128):\n",
+    "def create_dataloader(txt, batch_size=4, max_length=256, stride=128, shuffle=True):\n",
     "    # Initialize the tokenizer\n",
     "    # Initialize the tokenizer\n",
     "    tokenizer = tiktoken.get_encoding(\"gpt2\")\n",
     "    tokenizer = tiktoken.get_encoding(\"gpt2\")\n",
     "\n",
     "\n",
@@ -70,7 +70,7 @@
     "    dataset = GPTDatasetV1(txt, tokenizer, max_length, stride)\n",
     "    dataset = GPTDatasetV1(txt, tokenizer, max_length, stride)\n",
     "\n",
     "\n",
     "    # Create dataloader\n",
     "    # Create dataloader\n",
-    "    dataloader = DataLoader(dataset, batch_size=batch_size)\n",
+    "    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=shuffle)\n",
     "\n",
     "\n",
     "    return dataloader\n",
     "    return dataloader\n",
     "\n",
     "\n",
@@ -155,12 +155,12 @@
    "source": [
    "source": [
     "class CausalSelfAttention(nn.Module):\n",
     "class CausalSelfAttention(nn.Module):\n",
     "\n",
     "\n",
-    "    def __init__(self, d_in, d_out, block_size, dropout):\n",
+    "    def __init__(self, d_in, d_out, block_size, dropout, qkv_bias=False):\n",
     "        super().__init__()\n",
     "        super().__init__()\n",
     "        self.d_out = d_out\n",
     "        self.d_out = d_out\n",
-    "        self.W_query = nn.Linear(d_in, d_out, bias=False)\n",
-    "        self.W_key   = nn.Linear(d_in, d_out, bias=False)\n",
-    "        self.W_value = nn.Linear(d_in, d_out, bias=False)\n",
+    "        self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)\n",
+    "        self.W_key   = nn.Linear(d_in, d_out, bias=qkv_bias)\n",
+    "        self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)\n",
     "        self.dropout = nn.Dropout(dropout) # New\n",
     "        self.dropout = nn.Dropout(dropout) # New\n",
     "        self.register_buffer('mask', torch.triu(torch.ones(block_size, block_size), diagonal=1)) # New\n",
     "        self.register_buffer('mask', torch.triu(torch.ones(block_size, block_size), diagonal=1)) # New\n",
     "\n",
     "\n",
@@ -181,10 +181,10 @@
     "\n",
     "\n",
     "\n",
     "\n",
     "class MultiHeadAttentionWrapper(nn.Module):\n",
     "class MultiHeadAttentionWrapper(nn.Module):\n",
-    "    def __init__(self, d_in, d_out, block_size, dropout, num_heads):\n",
+    "    def __init__(self, d_in, d_out, block_size, dropout, num_heads, qkv_bias=False):\n",
     "        super().__init__()\n",
     "        super().__init__()\n",
     "        self.heads = nn.ModuleList(\n",
     "        self.heads = nn.ModuleList(\n",
-    "            [CausalSelfAttention(d_in, d_out, block_size, dropout) \n",
+    "            [CausalSelfAttention(d_in, d_out, block_size, dropout, qkv_bias) \n",
     "             for _ in range(num_heads)]\n",
     "             for _ in range(num_heads)]\n",
     "        )\n",
     "        )\n",
     "        self.out_proj = nn.Linear(d_out*num_heads, d_out*num_heads)\n",
     "        self.out_proj = nn.Linear(d_out*num_heads, d_out*num_heads)\n",
@@ -241,7 +241,7 @@
    "outputs": [],
    "outputs": [],
    "source": [
    "source": [
     "class MultiHeadAttention(nn.Module):\n",
     "class MultiHeadAttention(nn.Module):\n",
-    "    def __init__(self, d_in, d_out, block_size, dropout, num_heads):\n",
+    "    def __init__(self, d_in, d_out, block_size, dropout, num_heads, qkv_bias=False):\n",
     "        super().__init__()\n",
     "        super().__init__()\n",
     "        assert d_out % num_heads == 0, \"d_out must be divisible by n_heads\"\n",
     "        assert d_out % num_heads == 0, \"d_out must be divisible by n_heads\"\n",
     "\n",
     "\n",
@@ -249,9 +249,9 @@
     "        self.num_heads = num_heads\n",
     "        self.num_heads = num_heads\n",
     "        self.head_dim = d_out // num_heads # Reduce the projection dim to match desired output dim\n",
     "        self.head_dim = d_out // num_heads # Reduce the projection dim to match desired output dim\n",
     "\n",
     "\n",
-    "        self.W_query = nn.Linear(d_in, d_out, bias=False)\n",
-    "        self.W_key = nn.Linear(d_in, d_out, bias=False)\n",
-    "        self.W_value = nn.Linear(d_in, d_out, bias=False)\n",
+    "        self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)\n",
+    "        self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias)\n",
+    "        self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)\n",
     "        self.out_proj = nn.Linear(d_out, d_out)  # Linear layer to combine head outputs\n",
     "        self.out_proj = nn.Linear(d_out, d_out)  # Linear layer to combine head outputs\n",
     "        self.dropout = nn.Dropout(dropout)\n",
     "        self.dropout = nn.Dropout(dropout)\n",
     "        self.register_buffer('mask', torch.triu(torch.ones(block_size, block_size), diagonal=1))\n",
     "        self.register_buffer('mask', torch.triu(torch.ones(block_size, block_size), diagonal=1))\n",