Browse Source

Fix an incorrect input dimension

Kostyantyn Borysenko 1 year ago
parent
commit
76cdf5e299
1 changed files with 2 additions and 2 deletions
  1. 2 2
      ch03/01_main-chapter-code/multihead-attention.ipynb

+ 2 - 2
ch03/01_main-chapter-code/multihead-attention.ipynb

@@ -228,7 +228,7 @@
     "            [CausalSelfAttention(d_in, d_out, context_length, dropout, qkv_bias) \n",
     "            [CausalSelfAttention(d_in, d_out, context_length, dropout, qkv_bias) \n",
     "             for _ in range(num_heads)]\n",
     "             for _ in range(num_heads)]\n",
     "        )\n",
     "        )\n",
-    "        self.out_proj = nn.Linear(d_out*num_heads, d_out*num_heads)\n",
+    "        self.out_proj = nn.Linear(d_in*num_heads, d_out*num_heads)\n",
     "\n",
     "\n",
     "    def forward(self, x):\n",
     "    def forward(self, x):\n",
     "        context_vec = torch.cat([head(x) for head in self.heads], dim=-1)\n",
     "        context_vec = torch.cat([head(x) for head in self.heads], dim=-1)\n",
@@ -383,7 +383,7 @@
    "name": "python",
    "name": "python",
    "nbconvert_exporter": "python",
    "nbconvert_exporter": "python",
    "pygments_lexer": "ipython3",
    "pygments_lexer": "ipython3",
-   "version": "3.11.4"
+   "version": "3.12.3"
   }
   }
  },
  },
  "nbformat": 4,
  "nbformat": 4,