Browse Source

also add simple wrapper

rasbt 1 year ago
parent
commit
b6fe1a37b3

+ 1 - 1
ch03/01_main-chapter-code/ch03.ipynb

@@ -1865,7 +1865,7 @@
    "name": "python",
    "nbconvert_exporter": "python",
    "pygments_lexer": "ipython3",
-   "version": "3.10.12"
+   "version": "3.10.6"
   }
  },
  "nbformat": 4,

+ 41 - 0
ch03/02_bonus_efficient-multihead-attention/ch03.py

@@ -2,6 +2,47 @@ import torch
 import torch.nn as nn
 
 
+class CausalAttention(nn.Module):
+
+    def __init__(self, d_in, d_out, block_size, dropout, qkv_bias=False):
+        super().__init__()
+        self.d_out = d_out
+        self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)
+        self.W_key   = nn.Linear(d_in, d_out, bias=qkv_bias)
+        self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)
+        self.dropout = nn.Dropout(dropout) # New
+        self.register_buffer('mask', torch.triu(torch.ones(block_size, block_size), diagonal=1)) # New
+
+    def forward(self, x):
+        b, num_tokens, d_in = x.shape # New batch dimension b
+        keys = self.W_key(x)
+        queries = self.W_query(x)
+        values = self.W_value(x)
+
+        attn_scores = queries @ keys.transpose(1, 2) # Changed transpose
+        attn_scores.masked_fill_(  # New, _ ops are in-place
+            self.mask.bool()[:num_tokens, :num_tokens], -torch.inf) 
+        attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=-1)
+        attn_weights = self.dropout(attn_weights) # New
+
+        context_vec = attn_weights @ values
+        return context_vec
+    
+    
+class MultiHeadAttentionWrapper(nn.Module):
+
+    def __init__(self, d_in, d_out, block_size, dropout, num_heads, qkv_bias=False):
+        super().__init__()
+        self.heads = nn.ModuleList(
+            [CausalAttention(d_in, d_out, block_size, dropout, qkv_bias) 
+             for _ in range(num_heads)]
+        )
+
+    def forward(self, x):
+        return torch.cat([head(x) for head in self.heads], dim=-1)
+
+
+
 class MultiHeadAttention(nn.Module):
     def __init__(self, d_in, d_out, block_size, dropout, num_heads, qkv_bias=False):
         super().__init__()

+ 62 - 14
ch03/02_bonus_efficient-multihead-attention/mha-implementations.ipynb

@@ -13,7 +13,7 @@
    "id": "2f9bb1b6-a1e5-4e0a-884d-0f31b374a8d6",
    "metadata": {},
    "source": [
-    "## Multi-head attention implementation from chapter 3"
+    "## Multi-head attention implementations from chapter 3"
    ]
   },
   {
@@ -36,6 +36,36 @@
   {
    "cell_type": "code",
    "execution_count": 2,
+   "id": "297c93ed-aec0-4896-bb89-42c4b294d3d1",
+   "metadata": {},
+   "outputs": [
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "torch.Size([8, 1024, 9216])\n"
+     ]
+    }
+   ],
+   "source": [
+    "from ch03 import MultiHeadAttentionWrapper as Ch03_MHA_1\n",
+    "\n",
+    "mha_ch03_1 = Ch03_MHA_1(\n",
+    "    d_in=embed_dim,\n",
+    "    d_out=embed_dim,\n",
+    "    block_size=context_len,\n",
+    "    dropout=0.0,\n",
+    "    num_heads=12,\n",
+    "    qkv_bias=False\n",
+    ")\n",
+    "\n",
+    "out = mha_ch03_1(embeddings)\n",
+    "print(out.shape)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 3,
    "id": "4ee6a61b-d25c-4a0c-8a59-f285544e3710",
    "metadata": {},
    "outputs": [
@@ -48,9 +78,9 @@
     }
    ],
    "source": [
-    "from ch03 import MultiHeadAttention as Ch03_MHA\n",
+    "from ch03 import MultiHeadAttention as Ch03_MHA_2\n",
     "\n",
-    "mha_ch03 = Ch03_MHA(\n",
+    "mha_ch03_2 = Ch03_MHA_2(\n",
     "    d_in=embed_dim,\n",
     "    d_out=embed_dim,\n",
     "    block_size=context_len,\n",
@@ -59,7 +89,7 @@
     "    qkv_bias=False\n",
     ")\n",
     "\n",
-    "out = mha_ch03(embeddings)\n",
+    "out = mha_ch03_2(embeddings)\n",
     "print(out.shape)"
    ]
   },
@@ -89,7 +119,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 3,
+   "execution_count": 4,
    "id": "9a6bd0a2-f27c-4602-afa0-c96cd295c1a6",
    "metadata": {},
    "outputs": [
@@ -192,7 +222,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 4,
+   "execution_count": 5,
    "id": "1b8e5a0d-1f65-4a03-bf6e-723f0cc428f5",
    "metadata": {},
    "outputs": [],
@@ -243,7 +273,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 5,
+   "execution_count": 6,
    "id": "fbc8ba92-3471-41cb-b1b2-4c0ef5be392b",
    "metadata": {},
    "outputs": [
@@ -279,7 +309,25 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 6,
+   "execution_count": 7,
+   "id": "a97c0b2e-6593-49d8-98bc-2267b3aa610f",
+   "metadata": {},
+   "outputs": [
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "879 ms ± 4.01 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)\n"
+     ]
+    }
+   ],
+   "source": [
+    "%timeit mha_ch03_1(embeddings)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 8,
    "id": "19db9c2c-8e75-431a-8eef-0b4d8284e6e6",
    "metadata": {},
    "outputs": [
@@ -287,17 +335,17 @@
      "name": "stdout",
      "output_type": "stream",
      "text": [
-      "253 ms ± 9.85 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)\n"
+      "259 ms ± 7.91 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)\n"
      ]
     }
    ],
    "source": [
-    "%timeit mha_ch03(embeddings)"
+    "%timeit mha_ch03_2(embeddings)"
    ]
   },
   {
    "cell_type": "code",
-   "execution_count": 7,
+   "execution_count": 9,
    "id": "aa526ee0-7a88-4f34-a49a-f8f97da83779",
    "metadata": {},
    "outputs": [
@@ -305,7 +353,7 @@
      "name": "stdout",
      "output_type": "stream",
      "text": [
-      "309 ms ± 26.7 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)\n"
+      "290 ms ± 2.58 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)\n"
      ]
     }
    ],
@@ -315,7 +363,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 8,
+   "execution_count": 10,
    "id": "cc2b4256-16d8-4c34-9fd0-d4b4af0e60fa",
    "metadata": {},
    "outputs": [
@@ -323,7 +371,7 @@
      "name": "stdout",
      "output_type": "stream",
      "text": [
-      "90.4 ms ± 719 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)\n"
+      "91.5 ms ± 1.04 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)\n"
      ]
     }
    ],