Răsfoiți Sursa

make consistent with the latest production version

rasbt 1 an în urmă
părinte
comite
3b57b6d8c4
1 a modificat fișierele cu 18 adăugiri și 8 ștergeri
  1. 18 8
      ch03/01_main-chapter-code/ch03.ipynb

+ 18 - 8
ch03/01_main-chapter-code/ch03.ipynb

@@ -1066,7 +1066,6 @@
     "\n",
     "    def __init__(self, d_in, d_out):\n",
     "        super().__init__()\n",
-    "        self.d_out = d_out\n",
     "        self.W_query = nn.Parameter(torch.rand(d_in, d_out))\n",
     "        self.W_key   = nn.Parameter(torch.rand(d_in, d_out))\n",
     "        self.W_value = nn.Parameter(torch.rand(d_in, d_out))\n",
@@ -1077,7 +1076,9 @@
     "        values = x @ self.W_value\n",
     "        \n",
     "        attn_scores = queries @ keys.T # omega\n",
-    "        attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=-1)\n",
+    "        attn_weights = torch.softmax(\n",
+    "            attn_scores / keys.shape[-1]**0.5, dim=-1\n",
+    "        )\n",
     "\n",
     "        context_vec = attn_weights @ values\n",
     "        return context_vec\n",
@@ -1128,7 +1129,6 @@
     "\n",
     "    def __init__(self, d_in, d_out, qkv_bias=False):\n",
     "        super().__init__()\n",
-    "        self.d_out = d_out\n",
     "        self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)\n",
     "        self.W_key   = nn.Linear(d_in, d_out, bias=qkv_bias)\n",
     "        self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)\n",
@@ -1598,7 +1598,8 @@
    "source": [
     "class CausalAttention(nn.Module):\n",
     "\n",
-    "    def __init__(self, d_in, d_out, context_length, dropout, qkv_bias=False):\n",
+    "    def __init__(self, d_in, d_out, context_length,\n",
+    "                 dropout, qkv_bias=False):\n",
     "        super().__init__()\n",
     "        self.d_out = d_out\n",
     "        self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)\n",
@@ -1616,7 +1617,9 @@
     "        attn_scores = queries @ keys.transpose(1, 2) # Changed transpose\n",
     "        attn_scores.masked_fill_(  # New, _ ops are in-place\n",
     "            self.mask.bool()[:num_tokens, :num_tokens], -torch.inf) \n",
-    "        attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=-1)\n",
+    "        attn_weights = torch.softmax(\n",
+    "            attn_scores / keys.shape[-1]**0.5, dim=-1\n",
+    "        )\n",
     "        attn_weights = self.dropout(attn_weights) # New\n",
     "\n",
     "        context_vec = attn_weights @ values\n",
@@ -1728,7 +1731,9 @@
     "\n",
     "context_length = batch.shape[1] # This is the number of tokens\n",
     "d_in, d_out = 3, 2\n",
-    "mha = MultiHeadAttentionWrapper(d_in, d_out, context_length, 0.0, num_heads=2)\n",
+    "mha = MultiHeadAttentionWrapper(\n",
+    "    d_in, d_out, context_length, 0.0, num_heads=2\n",
+    ")\n",
     "\n",
     "context_vecs = mha(batch)\n",
     "\n",
@@ -1794,7 +1799,8 @@
     "class MultiHeadAttention(nn.Module):\n",
     "    def __init__(self, d_in, d_out, context_length, dropout, num_heads, qkv_bias=False):\n",
     "        super().__init__()\n",
-    "        assert d_out % num_heads == 0, \"d_out must be divisible by num_heads\"\n",
+    "        assert (d_out % num_heads == 0), \\\n",
+    "            \"d_out must be divisible by num_heads\"\n",
     "\n",
     "        self.d_out = d_out\n",
     "        self.num_heads = num_heads\n",
@@ -1805,7 +1811,11 @@
     "        self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)\n",
     "        self.out_proj = nn.Linear(d_out, d_out)  # Linear layer to combine head outputs\n",
     "        self.dropout = nn.Dropout(dropout)\n",
-    "        self.register_buffer('mask', torch.triu(torch.ones(context_length, context_length), diagonal=1))\n",
+    "        self.register_buffer(\n",
+    "            \"mask\",\n",
+    "            torch.triu(torch.ones(context_length, context_length),\n",
+    "                       diagonal=1)\n",
+    "        )\n",
     "\n",
     "    def forward(self, x):\n",
     "        b, num_tokens, d_in = x.shape\n",