|
@@ -819,7 +819,7 @@
|
|
|
"metadata": {},
|
|
"metadata": {},
|
|
|
"source": [
|
|
"source": [
|
|
|
"- Next, in **step 3**, we compute the attention weights (normalized attention scores that sum up to 1) using the softmax function we used earlier.\n",
|
|
"- Next, in **step 3**, we compute the attention weights (normalized attention scores that sum up to 1) using the softmax function we used earlier.\n",
|
|
|
- "- The difference to earlier is that we now scale the attention scores by dividing them by the square root of the embedding dimension, $\\sqrt{d}$ (i.e., `d_out**0.5`):"
|
|
|
|
|
|
|
+ "- The difference to earlier is that we now scale the attention scores by dividing them by the square root of the embedding dimension, $\\sqrt{d_k}$ (i.e., `d_k**0.5`):"
|
|
|
]
|
|
]
|
|
|
},
|
|
},
|
|
|
{
|
|
{
|
|
@@ -837,7 +837,8 @@
|
|
|
}
|
|
}
|
|
|
],
|
|
],
|
|
|
"source": [
|
|
"source": [
|
|
|
- "attn_weights_2 = torch.softmax(attn_scores_2 / d_out**0.5, dim=-1)\n",
|
|
|
|
|
|
|
+ "d_k = keys.shape[1]\n",
|
|
|
|
|
+ "attn_weights_2 = torch.softmax(attn_scores_2 / d_k**0.5, dim=-1)\n",
|
|
|
"print(attn_weights_2)"
|
|
"print(attn_weights_2)"
|
|
|
]
|
|
]
|
|
|
},
|
|
},
|
|
@@ -929,7 +930,7 @@
|
|
|
" values = x @ self.W_value\n",
|
|
" values = x @ self.W_value\n",
|
|
|
" \n",
|
|
" \n",
|
|
|
" attn_scores = queries @ keys.T # omega\n",
|
|
" attn_scores = queries @ keys.T # omega\n",
|
|
|
- " attn_weights = torch.softmax(attn_scores / self.d_out**0.5, dim=-1)\n",
|
|
|
|
|
|
|
+ " attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=-1)\n",
|
|
|
"\n",
|
|
"\n",
|
|
|
" context_vec = attn_weights @ values\n",
|
|
" context_vec = attn_weights @ values\n",
|
|
|
" return context_vec\n",
|
|
" return context_vec\n",
|
|
@@ -983,7 +984,7 @@
|
|
|
" values = self.W_value(x)\n",
|
|
" values = self.W_value(x)\n",
|
|
|
" \n",
|
|
" \n",
|
|
|
" attn_scores = queries @ keys.T\n",
|
|
" attn_scores = queries @ keys.T\n",
|
|
|
- " attn_weights = torch.softmax(attn_scores / self.d_out**0.5, dim=1)\n",
|
|
|
|
|
|
|
+ " attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=1)\n",
|
|
|
"\n",
|
|
"\n",
|
|
|
" context_vec = attn_weights @ values\n",
|
|
" context_vec = attn_weights @ values\n",
|
|
|
" return context_vec\n",
|
|
" return context_vec\n",
|
|
@@ -1064,7 +1065,7 @@
|
|
|
}
|
|
}
|
|
|
],
|
|
],
|
|
|
"source": [
|
|
"source": [
|
|
|
- "attn_weights = torch.softmax(attn_scores / d_out**0.5, dim=1)\n",
|
|
|
|
|
|
|
+ "attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=1)\n",
|
|
|
"print(attn_weights)"
|
|
"print(attn_weights)"
|
|
|
]
|
|
]
|
|
|
},
|
|
},
|
|
@@ -1236,7 +1237,7 @@
|
|
|
}
|
|
}
|
|
|
],
|
|
],
|
|
|
"source": [
|
|
"source": [
|
|
|
- "attn_weights = torch.softmax(masked / d_out**0.5, dim=1)\n",
|
|
|
|
|
|
|
+ "attn_weights = torch.softmax(masked / keys.shape[-1]**0.5, dim=1)\n",
|
|
|
"print(attn_weights)"
|
|
"print(attn_weights)"
|
|
|
]
|
|
]
|
|
|
},
|
|
},
|
|
@@ -1406,15 +1407,15 @@
|
|
|
" self.register_buffer('mask', torch.triu(torch.ones(block_size, block_size), diagonal=1)) # New\n",
|
|
" self.register_buffer('mask', torch.triu(torch.ones(block_size, block_size), diagonal=1)) # New\n",
|
|
|
"\n",
|
|
"\n",
|
|
|
" def forward(self, x):\n",
|
|
" def forward(self, x):\n",
|
|
|
- " b, n_tokens, d_in = x.shape # New batch dimension b\n",
|
|
|
|
|
|
|
+ " b, num_tokens, d_in = x.shape # New batch dimension b\n",
|
|
|
" keys = self.W_key(x)\n",
|
|
" keys = self.W_key(x)\n",
|
|
|
" queries = self.W_query(x)\n",
|
|
" queries = self.W_query(x)\n",
|
|
|
" values = self.W_value(x)\n",
|
|
" values = self.W_value(x)\n",
|
|
|
"\n",
|
|
"\n",
|
|
|
" attn_scores = queries @ keys.transpose(1, 2) # Changed transpose\n",
|
|
" attn_scores = queries @ keys.transpose(1, 2) # Changed transpose\n",
|
|
|
" attn_scores.masked_fill_( # New, _ ops are in-place\n",
|
|
" attn_scores.masked_fill_( # New, _ ops are in-place\n",
|
|
|
- " self.mask.bool()[:n_tokens, :n_tokens], -torch.inf) \n",
|
|
|
|
|
- " attn_weights = torch.softmax(attn_scores / self.d_out**0.5, dim=1)\n",
|
|
|
|
|
|
|
+ " self.mask.bool()[:num_tokens, :num_tokens], -torch.inf) \n",
|
|
|
|
|
+ " attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=1)\n",
|
|
|
" attn_weights = self.dropout(attn_weights) # New\n",
|
|
" attn_weights = self.dropout(attn_weights) # New\n",
|
|
|
"\n",
|
|
"\n",
|
|
|
" context_vec = attn_weights @ values\n",
|
|
" context_vec = attn_weights @ values\n",
|
|
@@ -1475,7 +1476,7 @@
|
|
|
},
|
|
},
|
|
|
{
|
|
{
|
|
|
"cell_type": "code",
|
|
"cell_type": "code",
|
|
|
- "execution_count": 42,
|
|
|
|
|
|
|
+ "execution_count": 35,
|
|
|
"id": "b9a66e11-7105-4bb4-be84-041f1a1f3bd2",
|
|
"id": "b9a66e11-7105-4bb4-be84-041f1a1f3bd2",
|
|
|
"metadata": {},
|
|
"metadata": {},
|
|
|
"outputs": [
|
|
"outputs": [
|
|
@@ -1658,7 +1659,7 @@
|
|
|
" # Compute scaled dot-product attention (aka self-attention) with a causal mask\n",
|
|
" # Compute scaled dot-product attention (aka self-attention) with a causal mask\n",
|
|
|
" attn_scores = queries @ keys.transpose(2, 3) # Dot product for each head\n",
|
|
" attn_scores = queries @ keys.transpose(2, 3) # Dot product for each head\n",
|
|
|
" attn_scores.masked_fill_(self.mask.bool()[:num_tokens, :num_tokens].unsqueeze(0).unsqueeze(0), -torch.inf)\n",
|
|
" attn_scores.masked_fill_(self.mask.bool()[:num_tokens, :num_tokens].unsqueeze(0).unsqueeze(0), -torch.inf)\n",
|
|
|
- " attn_weights = torch.softmax(attn_scores / self.head_dim**0.5, dim=-1)\n",
|
|
|
|
|
|
|
+ " attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=-1)\n",
|
|
|
" attn_weights = self.dropout(attn_weights)\n",
|
|
" attn_weights = self.dropout(attn_weights)\n",
|
|
|
"\n",
|
|
"\n",
|
|
|
" context_vec = (attn_weights @ values).transpose(1, 2) # Shape: (b, num_tokens, n_heads, head_dim)\n",
|
|
" context_vec = (attn_weights @ values).transpose(1, 2) # Shape: (b, num_tokens, n_heads, head_dim)\n",
|
|
@@ -1784,7 +1785,7 @@
|
|
|
},
|
|
},
|
|
|
{
|
|
{
|
|
|
"cell_type": "code",
|
|
"cell_type": "code",
|
|
|
- "execution_count": 45,
|
|
|
|
|
|
|
+ "execution_count": 40,
|
|
|
"id": "08c2a3fd-e674-4d69-9ef4-ea94b788e937",
|
|
"id": "08c2a3fd-e674-4d69-9ef4-ea94b788e937",
|
|
|
"metadata": {},
|
|
"metadata": {},
|
|
|
"outputs": [
|
|
"outputs": [
|
|
@@ -1794,7 +1795,7 @@
|
|
|
"2360064"
|
|
"2360064"
|
|
|
]
|
|
]
|
|
|
},
|
|
},
|
|
|
- "execution_count": 45,
|
|
|
|
|
|
|
+ "execution_count": 40,
|
|
|
"metadata": {},
|
|
"metadata": {},
|
|
|
"output_type": "execute_result"
|
|
"output_type": "execute_result"
|
|
|
}
|
|
}
|