|
|
@@ -159,7 +159,7 @@
|
|
|
},
|
|
|
{
|
|
|
"cell_type": "code",
|
|
|
- "execution_count": 1,
|
|
|
+ "execution_count": 2,
|
|
|
"id": "22b9556a-aaf8-4ab4-a5b4-973372b0b2c3",
|
|
|
"metadata": {},
|
|
|
"outputs": [],
|
|
|
@@ -187,7 +187,7 @@
|
|
|
},
|
|
|
{
|
|
|
"cell_type": "code",
|
|
|
- "execution_count": 2,
|
|
|
+ "execution_count": 3,
|
|
|
"id": "6fb5b2f8-dd2c-4a6d-94ef-a0e9ad163951",
|
|
|
"metadata": {},
|
|
|
"outputs": [
|
|
|
@@ -219,7 +219,7 @@
|
|
|
},
|
|
|
{
|
|
|
"cell_type": "code",
|
|
|
- "execution_count": 3,
|
|
|
+ "execution_count": 4,
|
|
|
"id": "9842f39b-1654-410e-88bf-d1b899bf0241",
|
|
|
"metadata": {},
|
|
|
"outputs": [
|
|
|
@@ -253,7 +253,7 @@
|
|
|
},
|
|
|
{
|
|
|
"cell_type": "code",
|
|
|
- "execution_count": 4,
|
|
|
+ "execution_count": 5,
|
|
|
"id": "e3ccc99c-33ce-4f11-b7f2-353cf1cbdaba",
|
|
|
"metadata": {},
|
|
|
"outputs": [
|
|
|
@@ -284,7 +284,7 @@
|
|
|
},
|
|
|
{
|
|
|
"cell_type": "code",
|
|
|
- "execution_count": 5,
|
|
|
+ "execution_count": 6,
|
|
|
"id": "07b2e58d-a6ed-49f0-a1cd-2463e8d53a20",
|
|
|
"metadata": {},
|
|
|
"outputs": [
|
|
|
@@ -318,7 +318,7 @@
|
|
|
},
|
|
|
{
|
|
|
"cell_type": "code",
|
|
|
- "execution_count": 6,
|
|
|
+ "execution_count": 7,
|
|
|
"id": "2d99cac4-45ea-46b3-b3c1-e000ad16e158",
|
|
|
"metadata": {},
|
|
|
"outputs": [
|
|
|
@@ -348,7 +348,7 @@
|
|
|
},
|
|
|
{
|
|
|
"cell_type": "code",
|
|
|
- "execution_count": 7,
|
|
|
+ "execution_count": 8,
|
|
|
"id": "8fcb96f0-14e5-4973-a50e-79ea7c6af99f",
|
|
|
"metadata": {},
|
|
|
"outputs": [
|
|
|
@@ -407,7 +407,7 @@
|
|
|
},
|
|
|
{
|
|
|
"cell_type": "code",
|
|
|
- "execution_count": 8,
|
|
|
+ "execution_count": 9,
|
|
|
"id": "04004be8-07a1-468b-ab33-32e16a551b45",
|
|
|
"metadata": {},
|
|
|
"outputs": [
|
|
|
@@ -444,7 +444,7 @@
|
|
|
},
|
|
|
{
|
|
|
"cell_type": "code",
|
|
|
- "execution_count": 9,
|
|
|
+ "execution_count": 10,
|
|
|
"id": "2cea69d0-9a47-45da-8d5a-47ceef2df673",
|
|
|
"metadata": {},
|
|
|
"outputs": [
|
|
|
@@ -476,7 +476,7 @@
|
|
|
},
|
|
|
{
|
|
|
"cell_type": "code",
|
|
|
- "execution_count": 10,
|
|
|
+ "execution_count": 11,
|
|
|
"id": "fa4ef062-de81-47ee-8415-bfe1708c81b8",
|
|
|
"metadata": {},
|
|
|
"outputs": [
|
|
|
@@ -508,7 +508,7 @@
|
|
|
},
|
|
|
{
|
|
|
"cell_type": "code",
|
|
|
- "execution_count": 11,
|
|
|
+ "execution_count": 12,
|
|
|
"id": "112b492c-fb6f-4e6d-8df5-518ae83363d5",
|
|
|
"metadata": {},
|
|
|
"outputs": [
|
|
|
@@ -538,7 +538,7 @@
|
|
|
},
|
|
|
{
|
|
|
"cell_type": "code",
|
|
|
- "execution_count": 12,
|
|
|
+ "execution_count": 13,
|
|
|
"id": "ba8eafcf-f7f7-4989-b8dc-61b50c4f81dc",
|
|
|
"metadata": {},
|
|
|
"outputs": [
|
|
|
@@ -570,7 +570,7 @@
|
|
|
},
|
|
|
{
|
|
|
"cell_type": "code",
|
|
|
- "execution_count": 13,
|
|
|
+ "execution_count": 14,
|
|
|
"id": "2570eb7d-aee1-457a-a61e-7544478219fa",
|
|
|
"metadata": {},
|
|
|
"outputs": [
|
|
|
@@ -649,7 +649,7 @@
|
|
|
},
|
|
|
{
|
|
|
"cell_type": "code",
|
|
|
- "execution_count": 14,
|
|
|
+ "execution_count": 15,
|
|
|
"id": "8250fdc6-6cd6-4c5b-b9c0-8c643aadb7db",
|
|
|
"metadata": {},
|
|
|
"outputs": [],
|
|
|
@@ -669,7 +669,7 @@
|
|
|
},
|
|
|
{
|
|
|
"cell_type": "code",
|
|
|
- "execution_count": 15,
|
|
|
+ "execution_count": 16,
|
|
|
"id": "bfd7259a-f26c-4cea-b8fc-282b5cae1e00",
|
|
|
"metadata": {},
|
|
|
"outputs": [],
|
|
|
@@ -691,7 +691,7 @@
|
|
|
},
|
|
|
{
|
|
|
"cell_type": "code",
|
|
|
- "execution_count": 16,
|
|
|
+ "execution_count": 17,
|
|
|
"id": "73cedd62-01e1-4196-a575-baecc6095601",
|
|
|
"metadata": {},
|
|
|
"outputs": [
|
|
|
@@ -721,7 +721,7 @@
|
|
|
},
|
|
|
{
|
|
|
"cell_type": "code",
|
|
|
- "execution_count": 17,
|
|
|
+ "execution_count": 18,
|
|
|
"id": "8c1c3949-fc08-4d19-a41e-1c235b4e631b",
|
|
|
"metadata": {},
|
|
|
"outputs": [
|
|
|
@@ -760,7 +760,7 @@
|
|
|
},
|
|
|
{
|
|
|
"cell_type": "code",
|
|
|
- "execution_count": 18,
|
|
|
+ "execution_count": 19,
|
|
|
"id": "64cbc253-a182-4490-a765-246979ea0a28",
|
|
|
"metadata": {},
|
|
|
"outputs": [
|
|
|
@@ -788,7 +788,7 @@
|
|
|
},
|
|
|
{
|
|
|
"cell_type": "code",
|
|
|
- "execution_count": 19,
|
|
|
+ "execution_count": 20,
|
|
|
"id": "b14e44b5-d170-40f9-8847-8990804af26d",
|
|
|
"metadata": {},
|
|
|
"outputs": [
|
|
|
@@ -824,7 +824,7 @@
|
|
|
},
|
|
|
{
|
|
|
"cell_type": "code",
|
|
|
- "execution_count": 20,
|
|
|
+ "execution_count": 21,
|
|
|
"id": "146f5587-c845-4e30-9894-c7ed3a248153",
|
|
|
"metadata": {},
|
|
|
"outputs": [
|
|
|
@@ -859,7 +859,7 @@
|
|
|
},
|
|
|
{
|
|
|
"cell_type": "code",
|
|
|
- "execution_count": 21,
|
|
|
+ "execution_count": 22,
|
|
|
"id": "e138f033-fa7e-4e3a-8764-b53a96b26397",
|
|
|
"metadata": {},
|
|
|
"outputs": [
|
|
|
@@ -894,7 +894,7 @@
|
|
|
},
|
|
|
{
|
|
|
"cell_type": "code",
|
|
|
- "execution_count": 22,
|
|
|
+ "execution_count": 23,
|
|
|
"id": "51590326-cdbe-4e62-93b1-17df71c11ee4",
|
|
|
"metadata": {},
|
|
|
"outputs": [
|
|
|
@@ -950,7 +950,7 @@
|
|
|
},
|
|
|
{
|
|
|
"cell_type": "code",
|
|
|
- "execution_count": 23,
|
|
|
+ "execution_count": 24,
|
|
|
"id": "73f411e3-e231-464a-89fe-0a9035e5f839",
|
|
|
"metadata": {},
|
|
|
"outputs": [
|
|
|
@@ -1046,7 +1046,7 @@
|
|
|
},
|
|
|
{
|
|
|
"cell_type": "code",
|
|
|
- "execution_count": 24,
|
|
|
+ "execution_count": 25,
|
|
|
"id": "1933940d-0fa5-4b17-a3ce-388e5314a1bb",
|
|
|
"metadata": {},
|
|
|
"outputs": [
|
|
|
@@ -1078,7 +1078,7 @@
|
|
|
},
|
|
|
{
|
|
|
"cell_type": "code",
|
|
|
- "execution_count": 25,
|
|
|
+ "execution_count": 28,
|
|
|
"id": "43f3d2e3-185b-4184-9f98-edde5e6df746",
|
|
|
"metadata": {},
|
|
|
"outputs": [
|
|
|
@@ -1097,8 +1097,8 @@
|
|
|
],
|
|
|
"source": [
|
|
|
"block_size = attn_scores.shape[0]\n",
|
|
|
- "mask_naive = torch.tril(torch.ones(block_size, block_size))\n",
|
|
|
- "print(mask_naive)"
|
|
|
+ "mask_simple = torch.tril(torch.ones(block_size, block_size))\n",
|
|
|
+ "print(mask_simple)"
|
|
|
]
|
|
|
},
|
|
|
{
|
|
|
@@ -1111,7 +1111,7 @@
|
|
|
},
|
|
|
{
|
|
|
"cell_type": "code",
|
|
|
- "execution_count": 26,
|
|
|
+ "execution_count": 29,
|
|
|
"id": "9f531e2e-f4d2-4fea-a87f-4c132e48b9e7",
|
|
|
"metadata": {},
|
|
|
"outputs": [
|
|
|
@@ -1129,8 +1129,8 @@
|
|
|
}
|
|
|
],
|
|
|
"source": [
|
|
|
- "masked_naive = attn_weights*mask_naive\n",
|
|
|
- "print(masked_naive)"
|
|
|
+ "masked_simple = attn_weights*mask_simple\n",
|
|
|
+ "print(masked_simple)"
|
|
|
]
|
|
|
},
|
|
|
{
|
|
|
@@ -1141,12 +1141,46 @@
|
|
|
"- However, if the mask were applied after softmax, like above, it would disrupt the probability distribution created by softmax. Softmax ensures that all output values sum to 1. Masking after softmax would require re-normalizing the outputs to sum to 1 again, which complicates the process and might lead to unintended effects."
|
|
|
]
|
|
|
},
|
|
|
+ {
|
|
|
+ "cell_type": "markdown",
|
|
|
+ "id": "94db92d7-c397-4e42-bd8a-6a2b3e237e0f",
|
|
|
+ "metadata": {},
|
|
|
+ "source": [
|
|
|
+ "- To make sure that the rows sum to 1, we can normalize the attention weights as follows:"
|
|
|
+ ]
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "cell_type": "code",
|
|
|
+ "execution_count": 30,
|
|
|
+ "id": "6d392083-fd81-4f70-9bdf-8db985e673d6",
|
|
|
+ "metadata": {},
|
|
|
+ "outputs": [
|
|
|
+ {
|
|
|
+ "name": "stdout",
|
|
|
+ "output_type": "stream",
|
|
|
+ "text": [
|
|
|
+ "tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],\n",
|
|
|
+ " [0.4056, 0.5944, 0.0000, 0.0000, 0.0000, 0.0000],\n",
|
|
|
+ " [0.2566, 0.3741, 0.3693, 0.0000, 0.0000, 0.0000],\n",
|
|
|
+ " [0.2176, 0.2823, 0.2796, 0.2205, 0.0000, 0.0000],\n",
|
|
|
+ " [0.1826, 0.2178, 0.2191, 0.1689, 0.2115, 0.0000],\n",
|
|
|
+ " [0.1473, 0.2033, 0.1996, 0.1500, 0.1160, 0.1839]])\n"
|
|
|
+ ]
|
|
|
+ }
|
|
|
+ ],
|
|
|
+ "source": [
|
|
|
+ "row_sums = masked_simple.sum(dim=1, keepdim=True)\n",
|
|
|
+ "masked_simple_norm = masked_simple / row_sums\n",
|
|
|
+ "print(masked_simple_norm)"
|
|
|
+ ]
|
|
|
+ },
|
|
|
{
|
|
|
"cell_type": "markdown",
|
|
|
"id": "512e7cf4-dc0e-4cec-948e-c7a3c4eb6877",
|
|
|
"metadata": {},
|
|
|
"source": [
|
|
|
- "- So, instead, we take a different approach, masking elements with negative infinity before they enter the softmax function:"
|
|
|
+ "- While we are technically done with coding the causal attention mechanism now, let's briefly look at a more efficient approach to achieve the same as above.\n",
|
|
|
+ "- So, instead of zeroing out attention weights above the diagonal and renormalizing the results, we can mask the unnormalized attention scores above the diagonal with negative infinity before they enter the softmax function:"
|
|
|
]
|
|
|
},
|
|
|
{
|