Переглянути джерело

show normalization explicitely

rasbt 1 рік тому
батько
коміт
e113075a16
1 змінених файлів з 65 додано та 31 видалено
  1. 65 31
      ch03/01_main-chapter-code/ch03.ipynb

+ 65 - 31
ch03/01_main-chapter-code/ch03.ipynb

@@ -159,7 +159,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 1,
+   "execution_count": 2,
    "id": "22b9556a-aaf8-4ab4-a5b4-973372b0b2c3",
    "metadata": {},
    "outputs": [],
@@ -187,7 +187,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 2,
+   "execution_count": 3,
    "id": "6fb5b2f8-dd2c-4a6d-94ef-a0e9ad163951",
    "metadata": {},
    "outputs": [
@@ -219,7 +219,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 3,
+   "execution_count": 4,
    "id": "9842f39b-1654-410e-88bf-d1b899bf0241",
    "metadata": {},
    "outputs": [
@@ -253,7 +253,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 4,
+   "execution_count": 5,
    "id": "e3ccc99c-33ce-4f11-b7f2-353cf1cbdaba",
    "metadata": {},
    "outputs": [
@@ -284,7 +284,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 5,
+   "execution_count": 6,
    "id": "07b2e58d-a6ed-49f0-a1cd-2463e8d53a20",
    "metadata": {},
    "outputs": [
@@ -318,7 +318,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 6,
+   "execution_count": 7,
    "id": "2d99cac4-45ea-46b3-b3c1-e000ad16e158",
    "metadata": {},
    "outputs": [
@@ -348,7 +348,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 7,
+   "execution_count": 8,
    "id": "8fcb96f0-14e5-4973-a50e-79ea7c6af99f",
    "metadata": {},
    "outputs": [
@@ -407,7 +407,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 8,
+   "execution_count": 9,
    "id": "04004be8-07a1-468b-ab33-32e16a551b45",
    "metadata": {},
    "outputs": [
@@ -444,7 +444,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 9,
+   "execution_count": 10,
    "id": "2cea69d0-9a47-45da-8d5a-47ceef2df673",
    "metadata": {},
    "outputs": [
@@ -476,7 +476,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 10,
+   "execution_count": 11,
    "id": "fa4ef062-de81-47ee-8415-bfe1708c81b8",
    "metadata": {},
    "outputs": [
@@ -508,7 +508,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 11,
+   "execution_count": 12,
    "id": "112b492c-fb6f-4e6d-8df5-518ae83363d5",
    "metadata": {},
    "outputs": [
@@ -538,7 +538,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 12,
+   "execution_count": 13,
    "id": "ba8eafcf-f7f7-4989-b8dc-61b50c4f81dc",
    "metadata": {},
    "outputs": [
@@ -570,7 +570,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 13,
+   "execution_count": 14,
    "id": "2570eb7d-aee1-457a-a61e-7544478219fa",
    "metadata": {},
    "outputs": [
@@ -649,7 +649,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 14,
+   "execution_count": 15,
    "id": "8250fdc6-6cd6-4c5b-b9c0-8c643aadb7db",
    "metadata": {},
    "outputs": [],
@@ -669,7 +669,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 15,
+   "execution_count": 16,
    "id": "bfd7259a-f26c-4cea-b8fc-282b5cae1e00",
    "metadata": {},
    "outputs": [],
@@ -691,7 +691,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 16,
+   "execution_count": 17,
    "id": "73cedd62-01e1-4196-a575-baecc6095601",
    "metadata": {},
    "outputs": [
@@ -721,7 +721,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 17,
+   "execution_count": 18,
    "id": "8c1c3949-fc08-4d19-a41e-1c235b4e631b",
    "metadata": {},
    "outputs": [
@@ -760,7 +760,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 18,
+   "execution_count": 19,
    "id": "64cbc253-a182-4490-a765-246979ea0a28",
    "metadata": {},
    "outputs": [
@@ -788,7 +788,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 19,
+   "execution_count": 20,
    "id": "b14e44b5-d170-40f9-8847-8990804af26d",
    "metadata": {},
    "outputs": [
@@ -824,7 +824,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 20,
+   "execution_count": 21,
    "id": "146f5587-c845-4e30-9894-c7ed3a248153",
    "metadata": {},
    "outputs": [
@@ -859,7 +859,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 21,
+   "execution_count": 22,
    "id": "e138f033-fa7e-4e3a-8764-b53a96b26397",
    "metadata": {},
    "outputs": [
@@ -894,7 +894,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 22,
+   "execution_count": 23,
    "id": "51590326-cdbe-4e62-93b1-17df71c11ee4",
    "metadata": {},
    "outputs": [
@@ -950,7 +950,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 23,
+   "execution_count": 24,
    "id": "73f411e3-e231-464a-89fe-0a9035e5f839",
    "metadata": {},
    "outputs": [
@@ -1046,7 +1046,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 24,
+   "execution_count": 25,
    "id": "1933940d-0fa5-4b17-a3ce-388e5314a1bb",
    "metadata": {},
    "outputs": [
@@ -1078,7 +1078,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 25,
+   "execution_count": 28,
    "id": "43f3d2e3-185b-4184-9f98-edde5e6df746",
    "metadata": {},
    "outputs": [
@@ -1097,8 +1097,8 @@
    ],
    "source": [
     "block_size = attn_scores.shape[0]\n",
-    "mask_naive = torch.tril(torch.ones(block_size, block_size))\n",
-    "print(mask_naive)"
+    "mask_simple = torch.tril(torch.ones(block_size, block_size))\n",
+    "print(mask_simple)"
    ]
   },
   {
@@ -1111,7 +1111,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 26,
+   "execution_count": 29,
    "id": "9f531e2e-f4d2-4fea-a87f-4c132e48b9e7",
    "metadata": {},
    "outputs": [
@@ -1129,8 +1129,8 @@
     }
    ],
    "source": [
-    "masked_naive = attn_weights*mask_naive\n",
-    "print(masked_naive)"
+    "masked_simple = attn_weights*mask_simple\n",
+    "print(masked_simple)"
    ]
   },
   {
@@ -1141,12 +1141,46 @@
     "- However, if the mask were applied after softmax, like above, it would disrupt the probability distribution created by softmax. Softmax ensures that all output values sum to 1. Masking after softmax would require re-normalizing the outputs to sum to 1 again, which complicates the process and might lead to unintended effects."
    ]
   },
+  {
+   "cell_type": "markdown",
+   "id": "94db92d7-c397-4e42-bd8a-6a2b3e237e0f",
+   "metadata": {},
+   "source": [
+    "- To make sure that the rows sum to 1, we can normalize the attention weights as follows:"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 30,
+   "id": "6d392083-fd81-4f70-9bdf-8db985e673d6",
+   "metadata": {},
+   "outputs": [
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],\n",
+      "        [0.4056, 0.5944, 0.0000, 0.0000, 0.0000, 0.0000],\n",
+      "        [0.2566, 0.3741, 0.3693, 0.0000, 0.0000, 0.0000],\n",
+      "        [0.2176, 0.2823, 0.2796, 0.2205, 0.0000, 0.0000],\n",
+      "        [0.1826, 0.2178, 0.2191, 0.1689, 0.2115, 0.0000],\n",
+      "        [0.1473, 0.2033, 0.1996, 0.1500, 0.1160, 0.1839]])\n"
+     ]
+    }
+   ],
+   "source": [
+    "row_sums = masked_simple.sum(dim=1, keepdim=True)\n",
+    "masked_simple_norm = masked_simple / row_sums\n",
+    "print(masked_simple_norm)"
+   ]
+  },
   {
    "cell_type": "markdown",
    "id": "512e7cf4-dc0e-4cec-948e-c7a3c4eb6877",
    "metadata": {},
    "source": [
-    "- So, instead, we take a different approach, masking elements with negative infinity before they enter the softmax function:"
+    "- While we are technically done with coding the causal attention mechanism now, let's briefly look at a more efficient approach to achieve the same as above.\n",
+    "- So, instead of zeroing out attention weights above the diagonal and renormalizing the results, we can mask the unnormalized attention scores above the diagonal with negative infinity before they enter the softmax function:"
    ]
   },
   {