Browse Source

small readability updates

rasbt 1 year ago
parent
commit
a7b4880179

+ 14 - 13
ch03/01_main-chapter-code/ch03.ipynb

@@ -819,7 +819,7 @@
    "metadata": {},
    "metadata": {},
    "source": [
    "source": [
     "- Next, in **step 3**, we compute the attention weights (normalized attention scores that sum up to 1) using the softmax function we used earlier.\n",
     "- Next, in **step 3**, we compute the attention weights (normalized attention scores that sum up to 1) using the softmax function we used earlier.\n",
-    "- The difference to earlier is that we now scale the attention scores by dividing them by the square root of the embedding dimension, $\\sqrt{d}$ (i.e., `d_out**0.5`):"
+    "- The difference to earlier is that we now scale the attention scores by dividing them by the square root of the embedding dimension, $\\sqrt{d_k}$ (i.e., `d_k**0.5`):"
    ]
    ]
   },
   },
   {
   {
@@ -837,7 +837,8 @@
     }
     }
    ],
    ],
    "source": [
    "source": [
-    "attn_weights_2 = torch.softmax(attn_scores_2 / d_out**0.5, dim=-1)\n",
+    "d_k = keys.shape[1]\n",
+    "attn_weights_2 = torch.softmax(attn_scores_2 / d_k**0.5, dim=-1)\n",
     "print(attn_weights_2)"
     "print(attn_weights_2)"
    ]
    ]
   },
   },
@@ -929,7 +930,7 @@
     "        values = x @ self.W_value\n",
     "        values = x @ self.W_value\n",
     "        \n",
     "        \n",
     "        attn_scores = queries @ keys.T # omega\n",
     "        attn_scores = queries @ keys.T # omega\n",
-    "        attn_weights = torch.softmax(attn_scores / self.d_out**0.5, dim=-1)\n",
+    "        attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=-1)\n",
     "\n",
     "\n",
     "        context_vec = attn_weights @ values\n",
     "        context_vec = attn_weights @ values\n",
     "        return context_vec\n",
     "        return context_vec\n",
@@ -983,7 +984,7 @@
     "        values = self.W_value(x)\n",
     "        values = self.W_value(x)\n",
     "        \n",
     "        \n",
     "        attn_scores = queries @ keys.T\n",
     "        attn_scores = queries @ keys.T\n",
-    "        attn_weights = torch.softmax(attn_scores / self.d_out**0.5, dim=1)\n",
+    "        attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=1)\n",
     "\n",
     "\n",
     "        context_vec = attn_weights @ values\n",
     "        context_vec = attn_weights @ values\n",
     "        return context_vec\n",
     "        return context_vec\n",
@@ -1064,7 +1065,7 @@
     }
     }
    ],
    ],
    "source": [
    "source": [
-    "attn_weights = torch.softmax(attn_scores / d_out**0.5, dim=1)\n",
+    "attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=1)\n",
     "print(attn_weights)"
     "print(attn_weights)"
    ]
    ]
   },
   },
@@ -1236,7 +1237,7 @@
     }
     }
    ],
    ],
    "source": [
    "source": [
-    "attn_weights = torch.softmax(masked / d_out**0.5, dim=1)\n",
+    "attn_weights = torch.softmax(masked / keys.shape[-1]**0.5, dim=1)\n",
     "print(attn_weights)"
     "print(attn_weights)"
    ]
    ]
   },
   },
@@ -1406,15 +1407,15 @@
     "        self.register_buffer('mask', torch.triu(torch.ones(block_size, block_size), diagonal=1)) # New\n",
     "        self.register_buffer('mask', torch.triu(torch.ones(block_size, block_size), diagonal=1)) # New\n",
     "\n",
     "\n",
     "    def forward(self, x):\n",
     "    def forward(self, x):\n",
-    "        b, n_tokens, d_in = x.shape # New batch dimension b\n",
+    "        b, num_tokens, d_in = x.shape # New batch dimension b\n",
     "        keys = self.W_key(x)\n",
     "        keys = self.W_key(x)\n",
     "        queries = self.W_query(x)\n",
     "        queries = self.W_query(x)\n",
     "        values = self.W_value(x)\n",
     "        values = self.W_value(x)\n",
     "\n",
     "\n",
     "        attn_scores = queries @ keys.transpose(1, 2) # Changed transpose\n",
     "        attn_scores = queries @ keys.transpose(1, 2) # Changed transpose\n",
     "        attn_scores.masked_fill_(  # New, _ ops are in-place\n",
     "        attn_scores.masked_fill_(  # New, _ ops are in-place\n",
-    "            self.mask.bool()[:n_tokens, :n_tokens], -torch.inf) \n",
-    "        attn_weights = torch.softmax(attn_scores / self.d_out**0.5, dim=1)\n",
+    "            self.mask.bool()[:num_tokens, :num_tokens], -torch.inf) \n",
+    "        attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=1)\n",
     "        attn_weights = self.dropout(attn_weights) # New\n",
     "        attn_weights = self.dropout(attn_weights) # New\n",
     "\n",
     "\n",
     "        context_vec = attn_weights @ values\n",
     "        context_vec = attn_weights @ values\n",
@@ -1475,7 +1476,7 @@
   },
   },
   {
   {
    "cell_type": "code",
    "cell_type": "code",
-   "execution_count": 42,
+   "execution_count": 35,
    "id": "b9a66e11-7105-4bb4-be84-041f1a1f3bd2",
    "id": "b9a66e11-7105-4bb4-be84-041f1a1f3bd2",
    "metadata": {},
    "metadata": {},
    "outputs": [
    "outputs": [
@@ -1658,7 +1659,7 @@
     "        # Compute scaled dot-product attention (aka self-attention) with a causal mask\n",
     "        # Compute scaled dot-product attention (aka self-attention) with a causal mask\n",
     "        attn_scores = queries @ keys.transpose(2, 3)  # Dot product for each head\n",
     "        attn_scores = queries @ keys.transpose(2, 3)  # Dot product for each head\n",
     "        attn_scores.masked_fill_(self.mask.bool()[:num_tokens, :num_tokens].unsqueeze(0).unsqueeze(0), -torch.inf)\n",
     "        attn_scores.masked_fill_(self.mask.bool()[:num_tokens, :num_tokens].unsqueeze(0).unsqueeze(0), -torch.inf)\n",
-    "        attn_weights = torch.softmax(attn_scores / self.head_dim**0.5, dim=-1)\n",
+    "        attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=-1)\n",
     "        attn_weights = self.dropout(attn_weights)\n",
     "        attn_weights = self.dropout(attn_weights)\n",
     "\n",
     "\n",
     "        context_vec = (attn_weights @ values).transpose(1, 2) # Shape: (b, num_tokens, n_heads, head_dim)\n",
     "        context_vec = (attn_weights @ values).transpose(1, 2) # Shape: (b, num_tokens, n_heads, head_dim)\n",
@@ -1784,7 +1785,7 @@
   },
   },
   {
   {
    "cell_type": "code",
    "cell_type": "code",
-   "execution_count": 45,
+   "execution_count": 40,
    "id": "08c2a3fd-e674-4d69-9ef4-ea94b788e937",
    "id": "08c2a3fd-e674-4d69-9ef4-ea94b788e937",
    "metadata": {},
    "metadata": {},
    "outputs": [
    "outputs": [
@@ -1794,7 +1795,7 @@
        "2360064"
        "2360064"
       ]
       ]
      },
      },
-     "execution_count": 45,
+     "execution_count": 40,
      "metadata": {},
      "metadata": {},
      "output_type": "execute_result"
      "output_type": "execute_result"
     }
     }

+ 2 - 2
ch03/01_main-chapter-code/exercise-solutions.ipynb

@@ -61,7 +61,7 @@
     "        values = x @ self.W_value\n",
     "        values = x @ self.W_value\n",
     "        \n",
     "        \n",
     "        attn_scores = queries @ keys.T # omega\n",
     "        attn_scores = queries @ keys.T # omega\n",
-    "        attn_weights = torch.softmax(attn_scores / self.d_out**0.5, dim=-1)\n",
+    "        attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=-1)\n",
     "\n",
     "\n",
     "        context_vec = attn_weights @ values\n",
     "        context_vec = attn_weights @ values\n",
     "        return context_vec\n",
     "        return context_vec\n",
@@ -92,7 +92,7 @@
     "        values = self.W_value(x)\n",
     "        values = self.W_value(x)\n",
     "        \n",
     "        \n",
     "        attn_scores = queries @ keys.T\n",
     "        attn_scores = queries @ keys.T\n",
-    "        attn_weights = torch.softmax(attn_scores / self.d_out**0.5, dim=1)\n",
+    "        attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=1)\n",
     "\n",
     "\n",
     "        context_vec = attn_weights @ values\n",
     "        context_vec = attn_weights @ values\n",
     "        return context_vec\n",
     "        return context_vec\n",

+ 10 - 10
ch03/01_main-chapter-code/multihead-attention.ipynb

@@ -28,7 +28,7 @@
   },
   },
   {
   {
    "cell_type": "code",
    "cell_type": "code",
-   "execution_count": 2,
+   "execution_count": 1,
    "id": "0ed4b7db-3b47-4fd3-a4a6-5f4ed5dd166e",
    "id": "0ed4b7db-3b47-4fd3-a4a6-5f4ed5dd166e",
    "metadata": {},
    "metadata": {},
    "outputs": [],
    "outputs": [],
@@ -95,7 +95,7 @@
   },
   },
   {
   {
    "cell_type": "code",
    "cell_type": "code",
-   "execution_count": 3,
+   "execution_count": 2,
    "id": "664397bc-6daa-4b88-90aa-e8fc1fbd5846",
    "id": "664397bc-6daa-4b88-90aa-e8fc1fbd5846",
    "metadata": {},
    "metadata": {},
    "outputs": [],
    "outputs": [],
@@ -113,7 +113,7 @@
   },
   },
   {
   {
    "cell_type": "code",
    "cell_type": "code",
-   "execution_count": 4,
+   "execution_count": 3,
    "id": "d3664332-e6bb-447e-8b96-203aafde8b24",
    "id": "d3664332-e6bb-447e-8b96-203aafde8b24",
    "metadata": {},
    "metadata": {},
    "outputs": [
    "outputs": [
@@ -147,7 +147,7 @@
   },
   },
   {
   {
    "cell_type": "code",
    "cell_type": "code",
-   "execution_count": 5,
+   "execution_count": 4,
    "id": "a44e682d-1c3c-445d-85fa-b142f89f8503",
    "id": "a44e682d-1c3c-445d-85fa-b142f89f8503",
    "metadata": {},
    "metadata": {},
    "outputs": [],
    "outputs": [],
@@ -174,7 +174,7 @@
     "        attn_scores = queries @ keys.transpose(1, 2) # Changed transpose\n",
     "        attn_scores = queries @ keys.transpose(1, 2) # Changed transpose\n",
     "        attn_scores.masked_fill_(  # New, _ ops are in-place\n",
     "        attn_scores.masked_fill_(  # New, _ ops are in-place\n",
     "            self.mask.bool()[:n_tokens, :n_tokens], -torch.inf) \n",
     "            self.mask.bool()[:n_tokens, :n_tokens], -torch.inf) \n",
-    "        attn_weights = torch.softmax(attn_scores / self.d_out**0.5, dim=1)\n",
+    "        attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=1)\n",
     "        attn_weights = self.dropout(attn_weights) # New\n",
     "        attn_weights = self.dropout(attn_weights) # New\n",
     "\n",
     "\n",
     "        context_vec = attn_weights @ values\n",
     "        context_vec = attn_weights @ values\n",
@@ -197,7 +197,7 @@
   },
   },
   {
   {
    "cell_type": "code",
    "cell_type": "code",
-   "execution_count": 6,
+   "execution_count": 5,
    "id": "7898551e-f582-48ac-9f66-3632abe2a93f",
    "id": "7898551e-f582-48ac-9f66-3632abe2a93f",
    "metadata": {},
    "metadata": {},
    "outputs": [
    "outputs": [
@@ -236,7 +236,7 @@
   },
   },
   {
   {
    "cell_type": "code",
    "cell_type": "code",
-   "execution_count": 7,
+   "execution_count": 6,
    "id": "2773c09d-c136-4372-a2be-04b58d292842",
    "id": "2773c09d-c136-4372-a2be-04b58d292842",
    "metadata": {},
    "metadata": {},
    "outputs": [],
    "outputs": [],
@@ -271,7 +271,7 @@
     "        # Compute scaled dot-product attention for each head\n",
     "        # Compute scaled dot-product attention for each head\n",
     "        attn_scores = queries @ keys.transpose(2, 3)  # Dot product for each head\n",
     "        attn_scores = queries @ keys.transpose(2, 3)  # Dot product for each head\n",
     "        attn_scores.masked_fill_(self.mask.bool()[:n_tokens, :n_tokens].unsqueeze(0).unsqueeze(0), -torch.inf)\n",
     "        attn_scores.masked_fill_(self.mask.bool()[:n_tokens, :n_tokens].unsqueeze(0).unsqueeze(0), -torch.inf)\n",
-    "        attn_weights = torch.softmax(attn_scores / self.head_dim**0.5, dim=-1)\n",
+    "        attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=-1)\n",
     "        attn_weights = self.dropout(attn_weights)\n",
     "        attn_weights = self.dropout(attn_weights)\n",
     "        context_vec = (attn_weights @ values).transpose(1, 2) # Shape: (b, T, n_heads, head_dim)\n",
     "        context_vec = (attn_weights @ values).transpose(1, 2) # Shape: (b, T, n_heads, head_dim)\n",
     "        \n",
     "        \n",
@@ -284,7 +284,7 @@
   },
   },
   {
   {
    "cell_type": "code",
    "cell_type": "code",
-   "execution_count": 8,
+   "execution_count": 7,
    "id": "779fdd04-0152-4308-af08-840800a7f395",
    "id": "779fdd04-0152-4308-af08-840800a7f395",
    "metadata": {},
    "metadata": {},
    "outputs": [
    "outputs": [
@@ -328,7 +328,7 @@
    "name": "python",
    "name": "python",
    "nbconvert_exporter": "python",
    "nbconvert_exporter": "python",
    "pygments_lexer": "ipython3",
    "pygments_lexer": "ipython3",
-   "version": "3.11.4"
+   "version": "3.10.12"
   }
   }
  },
  },
  "nbformat": 4,
  "nbformat": 4,