|
|
@@ -1633,7 +1633,7 @@
|
|
|
"\n",
|
|
|
" attn_scores = queries @ keys.transpose(1, 2) # Changed transpose\n",
|
|
|
" attn_scores.masked_fill_( # New, _ ops are in-place\n",
|
|
|
- " self.mask.bool()[:num_tokens, :num_tokens], -torch.inf) \n",
|
|
|
+ " self.mask.bool()[:num_tokens, :num_tokens], -torch.inf) # `:num_tokens` to account for cases where the number of tokens in the batch is smaller than the supported context_size\n",
|
|
|
" attn_weights = torch.softmax(\n",
|
|
|
" attn_scores / keys.shape[-1]**0.5, dim=-1\n",
|
|
|
" )\n",
|
|
|
@@ -2027,7 +2027,7 @@
|
|
|
"name": "python",
|
|
|
"nbconvert_exporter": "python",
|
|
|
"pygments_lexer": "ipython3",
|
|
|
- "version": "3.10.6"
|
|
|
+ "version": "3.11.4"
|
|
|
}
|
|
|
},
|
|
|
"nbformat": 4,
|