|
@@ -548,7 +548,7 @@
|
|
|
},
|
|
},
|
|
|
{
|
|
{
|
|
|
"cell_type": "code",
|
|
"cell_type": "code",
|
|
|
- "execution_count": 17,
|
|
|
|
|
|
|
+ "execution_count": 16,
|
|
|
"id": "57c3143b-e860-4d3b-a22a-de22b547a6a9",
|
|
"id": "57c3143b-e860-4d3b-a22a-de22b547a6a9",
|
|
|
"metadata": {},
|
|
"metadata": {},
|
|
|
"outputs": [
|
|
"outputs": [
|
|
@@ -558,7 +558,7 @@
|
|
|
"1161"
|
|
"1161"
|
|
|
]
|
|
]
|
|
|
},
|
|
},
|
|
|
- "execution_count": 17,
|
|
|
|
|
|
|
+ "execution_count": 16,
|
|
|
"metadata": {},
|
|
"metadata": {},
|
|
|
"output_type": "execute_result"
|
|
"output_type": "execute_result"
|
|
|
}
|
|
}
|
|
@@ -569,7 +569,7 @@
|
|
|
},
|
|
},
|
|
|
{
|
|
{
|
|
|
"cell_type": "code",
|
|
"cell_type": "code",
|
|
|
- "execution_count": 19,
|
|
|
|
|
|
|
+ "execution_count": 17,
|
|
|
"id": "50e51bb1-ae05-4aa8-a9ff-455b65ed1959",
|
|
"id": "50e51bb1-ae05-4aa8-a9ff-455b65ed1959",
|
|
|
"metadata": {},
|
|
"metadata": {},
|
|
|
"outputs": [
|
|
"outputs": [
|
|
@@ -600,7 +600,7 @@
|
|
|
},
|
|
},
|
|
|
{
|
|
{
|
|
|
"cell_type": "code",
|
|
"cell_type": "code",
|
|
|
- "execution_count": 17,
|
|
|
|
|
|
|
+ "execution_count": 18,
|
|
|
"id": "948861c5-3f30-4712-a234-725f20d26f68",
|
|
"id": "948861c5-3f30-4712-a234-725f20d26f68",
|
|
|
"metadata": {},
|
|
"metadata": {},
|
|
|
"outputs": [],
|
|
"outputs": [],
|
|
@@ -636,32 +636,68 @@
|
|
|
},
|
|
},
|
|
|
{
|
|
{
|
|
|
"cell_type": "code",
|
|
"cell_type": "code",
|
|
|
- "execution_count": 18,
|
|
|
|
|
|
|
+ "execution_count": 19,
|
|
|
"id": "effcef79-e0a5-4f4a-a43a-31dd94b9250a",
|
|
"id": "effcef79-e0a5-4f4a-a43a-31dd94b9250a",
|
|
|
"metadata": {},
|
|
"metadata": {},
|
|
|
|
|
+ "outputs": [
|
|
|
|
|
+ {
|
|
|
|
|
+ "name": "stdout",
|
|
|
|
|
+ "output_type": "stream",
|
|
|
|
|
+ "text": [
|
|
|
|
|
+ "Hello, do you like tea? <|endoftext|> In the sunlit terraces of the palace.\n"
|
|
|
|
|
+ ]
|
|
|
|
|
+ }
|
|
|
|
|
+ ],
|
|
|
|
|
+ "source": [
|
|
|
|
|
+ "tokenizer = SimpleTokenizerV2(vocab)\n",
|
|
|
|
|
+ "\n",
|
|
|
|
|
+ "text1 = \"Hello, do you like tea?\"\n",
|
|
|
|
|
+ "text2 = \"In the sunlit terraces of the palace.\"\n",
|
|
|
|
|
+ "\n",
|
|
|
|
|
+ "text = \" <|endoftext|> \".join((text1, text2))\n",
|
|
|
|
|
+ "\n",
|
|
|
|
|
+ "print(text)"
|
|
|
|
|
+ ]
|
|
|
|
|
+ },
|
|
|
|
|
+ {
|
|
|
|
|
+ "cell_type": "code",
|
|
|
|
|
+ "execution_count": 20,
|
|
|
|
|
+ "id": "ddfe7346-398d-4bf8-99f1-5b071244ce95",
|
|
|
|
|
+ "metadata": {},
|
|
|
"outputs": [
|
|
"outputs": [
|
|
|
{
|
|
{
|
|
|
"data": {
|
|
"data": {
|
|
|
"text/plain": [
|
|
"text/plain": [
|
|
|
- "[1, 7, 364, 1157, 644, 1002, 12, 0, 59, 1015, 983, 1011, 740, 1015, 1, 9]"
|
|
|
|
|
|
|
+ "[1160,\n",
|
|
|
|
|
+ " 5,\n",
|
|
|
|
|
+ " 362,\n",
|
|
|
|
|
+ " 1155,\n",
|
|
|
|
|
+ " 642,\n",
|
|
|
|
|
+ " 1000,\n",
|
|
|
|
|
+ " 10,\n",
|
|
|
|
|
+ " 1159,\n",
|
|
|
|
|
+ " 57,\n",
|
|
|
|
|
+ " 1013,\n",
|
|
|
|
|
+ " 981,\n",
|
|
|
|
|
+ " 1009,\n",
|
|
|
|
|
+ " 738,\n",
|
|
|
|
|
+ " 1013,\n",
|
|
|
|
|
+ " 1160,\n",
|
|
|
|
|
+ " 7]"
|
|
|
]
|
|
]
|
|
|
},
|
|
},
|
|
|
- "execution_count": 18,
|
|
|
|
|
|
|
+ "execution_count": 20,
|
|
|
"metadata": {},
|
|
"metadata": {},
|
|
|
"output_type": "execute_result"
|
|
"output_type": "execute_result"
|
|
|
}
|
|
}
|
|
|
],
|
|
],
|
|
|
"source": [
|
|
"source": [
|
|
|
- "tokenizer = SimpleTokenizerV2(vocab)\n",
|
|
|
|
|
- "\n",
|
|
|
|
|
- "text = \"Hello, do you like tea? <|endoftext|> In the sunlit terraces of the palace.\"\n",
|
|
|
|
|
- "\n",
|
|
|
|
|
"tokenizer.encode(text)"
|
|
"tokenizer.encode(text)"
|
|
|
]
|
|
]
|
|
|
},
|
|
},
|
|
|
{
|
|
{
|
|
|
"cell_type": "code",
|
|
"cell_type": "code",
|
|
|
- "execution_count": 19,
|
|
|
|
|
|
|
+ "execution_count": 21,
|
|
|
"id": "0c350ff6-2734-4e84-9ec7-d578baa4ae1b",
|
|
"id": "0c350ff6-2734-4e84-9ec7-d578baa4ae1b",
|
|
|
"metadata": {},
|
|
"metadata": {},
|
|
|
"outputs": [
|
|
"outputs": [
|
|
@@ -671,7 +707,7 @@
|
|
|
"'<|unk|>, do you like tea? <|endoftext|> In the sunlit terraces of the <|unk|>.'"
|
|
"'<|unk|>, do you like tea? <|endoftext|> In the sunlit terraces of the <|unk|>.'"
|
|
|
]
|
|
]
|
|
|
},
|
|
},
|
|
|
- "execution_count": 19,
|
|
|
|
|
|
|
+ "execution_count": 21,
|
|
|
"metadata": {},
|
|
"metadata": {},
|
|
|
"output_type": "execute_result"
|
|
"output_type": "execute_result"
|
|
|
}
|
|
}
|
|
@@ -703,7 +739,7 @@
|
|
|
},
|
|
},
|
|
|
{
|
|
{
|
|
|
"cell_type": "code",
|
|
"cell_type": "code",
|
|
|
- "execution_count": 20,
|
|
|
|
|
|
|
+ "execution_count": 22,
|
|
|
"id": "ede1d41f-934b-4bf4-8184-54394a257a94",
|
|
"id": "ede1d41f-934b-4bf4-8184-54394a257a94",
|
|
|
"metadata": {},
|
|
"metadata": {},
|
|
|
"outputs": [],
|
|
"outputs": [],
|
|
@@ -713,7 +749,7 @@
|
|
|
},
|
|
},
|
|
|
{
|
|
{
|
|
|
"cell_type": "code",
|
|
"cell_type": "code",
|
|
|
- "execution_count": 21,
|
|
|
|
|
|
|
+ "execution_count": 23,
|
|
|
"id": "48967a77-7d17-42bf-9e92-fc619d63a59e",
|
|
"id": "48967a77-7d17-42bf-9e92-fc619d63a59e",
|
|
|
"metadata": {},
|
|
"metadata": {},
|
|
|
"outputs": [
|
|
"outputs": [
|
|
@@ -734,7 +770,7 @@
|
|
|
},
|
|
},
|
|
|
{
|
|
{
|
|
|
"cell_type": "code",
|
|
"cell_type": "code",
|
|
|
- "execution_count": 22,
|
|
|
|
|
|
|
+ "execution_count": 24,
|
|
|
"id": "6ad3312f-a5f7-4efc-9d7d-8ea09d7b5128",
|
|
"id": "6ad3312f-a5f7-4efc-9d7d-8ea09d7b5128",
|
|
|
"metadata": {},
|
|
"metadata": {},
|
|
|
"outputs": [],
|
|
"outputs": [],
|
|
@@ -744,7 +780,7 @@
|
|
|
},
|
|
},
|
|
|
{
|
|
{
|
|
|
"cell_type": "code",
|
|
"cell_type": "code",
|
|
|
- "execution_count": 23,
|
|
|
|
|
|
|
+ "execution_count": 25,
|
|
|
"id": "5ff2cd85-7cfb-4325-b390-219938589428",
|
|
"id": "5ff2cd85-7cfb-4325-b390-219938589428",
|
|
|
"metadata": {},
|
|
"metadata": {},
|
|
|
"outputs": [
|
|
"outputs": [
|
|
@@ -766,7 +802,7 @@
|
|
|
},
|
|
},
|
|
|
{
|
|
{
|
|
|
"cell_type": "code",
|
|
"cell_type": "code",
|
|
|
- "execution_count": 24,
|
|
|
|
|
|
|
+ "execution_count": 26,
|
|
|
"id": "d26a48bb-f82e-41a8-a955-a1c9cf9d50ab",
|
|
"id": "d26a48bb-f82e-41a8-a955-a1c9cf9d50ab",
|
|
|
"metadata": {},
|
|
"metadata": {},
|
|
|
"outputs": [
|
|
"outputs": [
|
|
@@ -784,6 +820,76 @@
|
|
|
"print(strings)"
|
|
"print(strings)"
|
|
|
]
|
|
]
|
|
|
},
|
|
},
|
|
|
|
|
+ {
|
|
|
|
|
+ "cell_type": "markdown",
|
|
|
|
|
+ "id": "f63d62ab-4b80-489c-8041-e4052fe29969",
|
|
|
|
|
+ "metadata": {},
|
|
|
|
|
+ "source": [
|
|
|
|
|
+ "- Experiments with unknown words:"
|
|
|
|
|
+ ]
|
|
|
|
|
+ },
|
|
|
|
|
+ {
|
|
|
|
|
+ "cell_type": "code",
|
|
|
|
|
+ "execution_count": 27,
|
|
|
|
|
+ "id": "ce25cf25-a2bb-44d2-bac1-cb566f433f98",
|
|
|
|
|
+ "metadata": {},
|
|
|
|
|
+ "outputs": [
|
|
|
|
|
+ {
|
|
|
|
|
+ "name": "stdout",
|
|
|
|
|
+ "output_type": "stream",
|
|
|
|
|
+ "text": [
|
|
|
|
|
+ "[33901, 86, 343, 86, 220, 959]\n"
|
|
|
|
|
+ ]
|
|
|
|
|
+ }
|
|
|
|
|
+ ],
|
|
|
|
|
+ "source": [
|
|
|
|
|
+ "integers = tokenizer.encode(\"Akwirw ier\")\n",
|
|
|
|
|
+ "print(integers)"
|
|
|
|
|
+ ]
|
|
|
|
|
+ },
|
|
|
|
|
+ {
|
|
|
|
|
+ "cell_type": "code",
|
|
|
|
|
+ "execution_count": 28,
|
|
|
|
|
+ "id": "3e224f96-41d0-4074-ac6e-f7db2490f806",
|
|
|
|
|
+ "metadata": {},
|
|
|
|
|
+ "outputs": [
|
|
|
|
|
+ {
|
|
|
|
|
+ "name": "stdout",
|
|
|
|
|
+ "output_type": "stream",
|
|
|
|
|
+ "text": [
|
|
|
|
|
+ "33901 -> Ak\n",
|
|
|
|
|
+ "86 -> w\n",
|
|
|
|
|
+ "343 -> ir\n",
|
|
|
|
|
+ "86 -> w\n",
|
|
|
|
|
+ "220 -> \n",
|
|
|
|
|
+ "959 -> ier\n"
|
|
|
|
|
+ ]
|
|
|
|
|
+ }
|
|
|
|
|
+ ],
|
|
|
|
|
+ "source": [
|
|
|
|
|
+ "for i in integers:\n",
|
|
|
|
|
+ " print(f\"{i} -> {tokenizer.decode([i])}\")"
|
|
|
|
|
+ ]
|
|
|
|
|
+ },
|
|
|
|
|
+ {
|
|
|
|
|
+ "cell_type": "code",
|
|
|
|
|
+ "execution_count": 29,
|
|
|
|
|
+ "id": "766bcf29-64bf-47ca-9b65-4ae8e607d580",
|
|
|
|
|
+ "metadata": {},
|
|
|
|
|
+ "outputs": [
|
|
|
|
|
+ {
|
|
|
|
|
+ "name": "stdout",
|
|
|
|
|
+ "output_type": "stream",
|
|
|
|
|
+ "text": [
|
|
|
|
|
+ "Akwirw ier\n"
|
|
|
|
|
+ ]
|
|
|
|
|
+ }
|
|
|
|
|
+ ],
|
|
|
|
|
+ "source": [
|
|
|
|
|
+ "strings = tokenizer.decode(integers)\n",
|
|
|
|
|
+ "print(strings)"
|
|
|
|
|
+ ]
|
|
|
|
|
+ },
|
|
|
{
|
|
{
|
|
|
"cell_type": "markdown",
|
|
"cell_type": "markdown",
|
|
|
"id": "abbd7c0d-70f8-4386-a114-907e96c950b0",
|
|
"id": "abbd7c0d-70f8-4386-a114-907e96c950b0",
|
|
@@ -794,7 +900,7 @@
|
|
|
},
|
|
},
|
|
|
{
|
|
{
|
|
|
"cell_type": "code",
|
|
"cell_type": "code",
|
|
|
- "execution_count": 25,
|
|
|
|
|
|
|
+ "execution_count": 30,
|
|
|
"id": "848d5ade-fd1f-46c3-9e31-1426e315c71b",
|
|
"id": "848d5ade-fd1f-46c3-9e31-1426e315c71b",
|
|
|
"metadata": {},
|
|
"metadata": {},
|
|
|
"outputs": [
|
|
"outputs": [
|
|
@@ -807,10 +913,10 @@
|
|
|
}
|
|
}
|
|
|
],
|
|
],
|
|
|
"source": [
|
|
"source": [
|
|
|
- "with open('the-verdict.txt', 'r', encoding='utf-8') as f:\n",
|
|
|
|
|
|
|
+ "with open(\"the-verdict.txt\", \"r\", encoding=\"utf-8\") as f:\n",
|
|
|
" raw_text = f.read()\n",
|
|
" raw_text = f.read()\n",
|
|
|
"\n",
|
|
"\n",
|
|
|
- "enc_text = tokenizer.encode(raw_text, allowed_special={\"<|endoftext|>\"})\n",
|
|
|
|
|
|
|
+ "enc_text = tokenizer.encode(raw_text)\n",
|
|
|
"print(len(enc_text))"
|
|
"print(len(enc_text))"
|
|
|
]
|
|
]
|
|
|
},
|
|
},
|
|
@@ -825,7 +931,7 @@
|
|
|
},
|
|
},
|
|
|
{
|
|
{
|
|
|
"cell_type": "code",
|
|
"cell_type": "code",
|
|
|
- "execution_count": 26,
|
|
|
|
|
|
|
+ "execution_count": 31,
|
|
|
"id": "e84424a7-646d-45b6-99e3-80d15fb761f2",
|
|
"id": "e84424a7-646d-45b6-99e3-80d15fb761f2",
|
|
|
"metadata": {},
|
|
"metadata": {},
|
|
|
"outputs": [],
|
|
"outputs": [],
|
|
@@ -835,7 +941,7 @@
|
|
|
},
|
|
},
|
|
|
{
|
|
{
|
|
|
"cell_type": "code",
|
|
"cell_type": "code",
|
|
|
- "execution_count": 27,
|
|
|
|
|
|
|
+ "execution_count": 32,
|
|
|
"id": "dfbff852-a92f-48c8-a46d-143a0f109f40",
|
|
"id": "dfbff852-a92f-48c8-a46d-143a0f109f40",
|
|
|
"metadata": {},
|
|
"metadata": {},
|
|
|
"outputs": [
|
|
"outputs": [
|
|
@@ -868,7 +974,7 @@
|
|
|
},
|
|
},
|
|
|
{
|
|
{
|
|
|
"cell_type": "code",
|
|
"cell_type": "code",
|
|
|
- "execution_count": 28,
|
|
|
|
|
|
|
+ "execution_count": 33,
|
|
|
"id": "d97b031e-ed55-409d-95f2-aeb38c6fe366",
|
|
"id": "d97b031e-ed55-409d-95f2-aeb38c6fe366",
|
|
|
"metadata": {},
|
|
"metadata": {},
|
|
|
"outputs": [
|
|
"outputs": [
|
|
@@ -878,12 +984,13 @@
|
|
|
"text": [
|
|
"text": [
|
|
|
"[290] ----> 4920\n",
|
|
"[290] ----> 4920\n",
|
|
|
"[290, 4920] ----> 2241\n",
|
|
"[290, 4920] ----> 2241\n",
|
|
|
- "[290, 4920, 2241] ----> 287\n"
|
|
|
|
|
|
|
+ "[290, 4920, 2241] ----> 287\n",
|
|
|
|
|
+ "[290, 4920, 2241, 287] ----> 257\n"
|
|
|
]
|
|
]
|
|
|
}
|
|
}
|
|
|
],
|
|
],
|
|
|
"source": [
|
|
"source": [
|
|
|
- "for i in range(1, context_size):\n",
|
|
|
|
|
|
|
+ "for i in range(1, context_size+1):\n",
|
|
|
" context = enc_sample[:i]\n",
|
|
" context = enc_sample[:i]\n",
|
|
|
" desired = enc_sample[i]\n",
|
|
" desired = enc_sample[i]\n",
|
|
|
"\n",
|
|
"\n",
|
|
@@ -892,7 +999,7 @@
|
|
|
},
|
|
},
|
|
|
{
|
|
{
|
|
|
"cell_type": "code",
|
|
"cell_type": "code",
|
|
|
- "execution_count": 29,
|
|
|
|
|
|
|
+ "execution_count": 34,
|
|
|
"id": "f57bd746-dcbf-4433-8e24-ee213a8c34a1",
|
|
"id": "f57bd746-dcbf-4433-8e24-ee213a8c34a1",
|
|
|
"metadata": {},
|
|
"metadata": {},
|
|
|
"outputs": [
|
|
"outputs": [
|
|
@@ -902,12 +1009,13 @@
|
|
|
"text": [
|
|
"text": [
|
|
|
" and ----> established\n",
|
|
" and ----> established\n",
|
|
|
" and established ----> himself\n",
|
|
" and established ----> himself\n",
|
|
|
- " and established himself ----> in\n"
|
|
|
|
|
|
|
+ " and established himself ----> in\n",
|
|
|
|
|
+ " and established himself in ----> a\n"
|
|
|
]
|
|
]
|
|
|
}
|
|
}
|
|
|
],
|
|
],
|
|
|
"source": [
|
|
"source": [
|
|
|
- "for i in range(1, context_size):\n",
|
|
|
|
|
|
|
+ "for i in range(1, context_size+1):\n",
|
|
|
" context = enc_sample[:i]\n",
|
|
" context = enc_sample[:i]\n",
|
|
|
" desired = enc_sample[i]\n",
|
|
" desired = enc_sample[i]\n",
|
|
|
"\n",
|
|
"\n",
|
|
@@ -933,7 +1041,7 @@
|
|
|
},
|
|
},
|
|
|
{
|
|
{
|
|
|
"cell_type": "code",
|
|
"cell_type": "code",
|
|
|
- "execution_count": 30,
|
|
|
|
|
|
|
+ "execution_count": 35,
|
|
|
"id": "e1770134-e7f3-4725-a679-e04c3be48cac",
|
|
"id": "e1770134-e7f3-4725-a679-e04c3be48cac",
|
|
|
"metadata": {},
|
|
"metadata": {},
|
|
|
"outputs": [
|
|
"outputs": [
|
|
@@ -941,7 +1049,7 @@
|
|
|
"name": "stdout",
|
|
"name": "stdout",
|
|
|
"output_type": "stream",
|
|
"output_type": "stream",
|
|
|
"text": [
|
|
"text": [
|
|
|
- "PyTorch version: 2.0.1\n"
|
|
|
|
|
|
|
+ "PyTorch version: 2.1.0\n"
|
|
|
]
|
|
]
|
|
|
}
|
|
}
|
|
|
],
|
|
],
|
|
@@ -960,7 +1068,7 @@
|
|
|
},
|
|
},
|
|
|
{
|
|
{
|
|
|
"cell_type": "code",
|
|
"cell_type": "code",
|
|
|
- "execution_count": 31,
|
|
|
|
|
|
|
+ "execution_count": 36,
|
|
|
"id": "74b41073-4c9f-46e2-a1bd-d38e4122b375",
|
|
"id": "74b41073-4c9f-46e2-a1bd-d38e4122b375",
|
|
|
"metadata": {},
|
|
"metadata": {},
|
|
|
"outputs": [],
|
|
"outputs": [],
|
|
@@ -993,7 +1101,7 @@
|
|
|
},
|
|
},
|
|
|
{
|
|
{
|
|
|
"cell_type": "code",
|
|
"cell_type": "code",
|
|
|
- "execution_count": 32,
|
|
|
|
|
|
|
+ "execution_count": 37,
|
|
|
"id": "5eb30ebe-97b3-43c5-9ff1-a97d621b3c4e",
|
|
"id": "5eb30ebe-97b3-43c5-9ff1-a97d621b3c4e",
|
|
|
"metadata": {},
|
|
"metadata": {},
|
|
|
"outputs": [],
|
|
"outputs": [],
|
|
@@ -1021,18 +1129,18 @@
|
|
|
},
|
|
},
|
|
|
{
|
|
{
|
|
|
"cell_type": "code",
|
|
"cell_type": "code",
|
|
|
- "execution_count": 33,
|
|
|
|
|
|
|
+ "execution_count": 38,
|
|
|
"id": "df31d96c-6bfd-4564-a956-6192242d7579",
|
|
"id": "df31d96c-6bfd-4564-a956-6192242d7579",
|
|
|
"metadata": {},
|
|
"metadata": {},
|
|
|
"outputs": [],
|
|
"outputs": [],
|
|
|
"source": [
|
|
"source": [
|
|
|
- "with open('the-verdict.txt', 'r', encoding='utf-8') as f:\n",
|
|
|
|
|
|
|
+ "with open(\"the-verdict.txt\", \"r\", encoding=\"utf-8\") as f:\n",
|
|
|
" raw_text = f.read()"
|
|
" raw_text = f.read()"
|
|
|
]
|
|
]
|
|
|
},
|
|
},
|
|
|
{
|
|
{
|
|
|
"cell_type": "code",
|
|
"cell_type": "code",
|
|
|
- "execution_count": 34,
|
|
|
|
|
|
|
+ "execution_count": 39,
|
|
|
"id": "9226d00c-ad9a-4949-a6e4-9afccfc7214f",
|
|
"id": "9226d00c-ad9a-4949-a6e4-9afccfc7214f",
|
|
|
"metadata": {},
|
|
"metadata": {},
|
|
|
"outputs": [
|
|
"outputs": [
|
|
@@ -1048,13 +1156,13 @@
|
|
|
"dataloader = create_dataloader(raw_text, batch_size=1, max_length=4, stride=1)\n",
|
|
"dataloader = create_dataloader(raw_text, batch_size=1, max_length=4, stride=1)\n",
|
|
|
"\n",
|
|
"\n",
|
|
|
"data_iter = iter(dataloader)\n",
|
|
"data_iter = iter(dataloader)\n",
|
|
|
- "next_batch = next(data_iter)\n",
|
|
|
|
|
- "print(next_batch)"
|
|
|
|
|
|
|
+ "first_batch = next(data_iter)\n",
|
|
|
|
|
+ "print(first_batch)"
|
|
|
]
|
|
]
|
|
|
},
|
|
},
|
|
|
{
|
|
{
|
|
|
"cell_type": "code",
|
|
"cell_type": "code",
|
|
|
- "execution_count": 35,
|
|
|
|
|
|
|
+ "execution_count": 40,
|
|
|
"id": "10deb4bc-4de1-4d20-921e-4b1c7a0e1a6d",
|
|
"id": "10deb4bc-4de1-4d20-921e-4b1c7a0e1a6d",
|
|
|
"metadata": {},
|
|
"metadata": {},
|
|
|
"outputs": [
|
|
"outputs": [
|
|
@@ -1067,8 +1175,8 @@
|
|
|
}
|
|
}
|
|
|
],
|
|
],
|
|
|
"source": [
|
|
"source": [
|
|
|
- "next_batch = next(data_iter)\n",
|
|
|
|
|
- "print(next_batch)"
|
|
|
|
|
|
|
+ "second_batch = next(data_iter)\n",
|
|
|
|
|
+ "print(second_batch)"
|
|
|
]
|
|
]
|
|
|
},
|
|
},
|
|
|
{
|
|
{
|
|
@@ -1077,12 +1185,12 @@
|
|
|
"metadata": {},
|
|
"metadata": {},
|
|
|
"source": [
|
|
"source": [
|
|
|
"- We can also create batched outputs\n",
|
|
"- We can also create batched outputs\n",
|
|
|
- "- Note that we increase the stride here so that we don't have overlaps between the batches, which could lead to increased overfitting"
|
|
|
|
|
|
|
+ "- Note that we increase the stride here so that we don't have overlaps between the batches, since more overlap could lead to increased overfitting"
|
|
|
]
|
|
]
|
|
|
},
|
|
},
|
|
|
{
|
|
{
|
|
|
"cell_type": "code",
|
|
"cell_type": "code",
|
|
|
- "execution_count": 36,
|
|
|
|
|
|
|
+ "execution_count": 41,
|
|
|
"id": "1916e7a6-f03d-4f09-91a6-d0bdbac5a58c",
|
|
"id": "1916e7a6-f03d-4f09-91a6-d0bdbac5a58c",
|
|
|
"metadata": {},
|
|
"metadata": {},
|
|
|
"outputs": [
|
|
"outputs": [
|
|
@@ -1149,7 +1257,7 @@
|
|
|
},
|
|
},
|
|
|
{
|
|
{
|
|
|
"cell_type": "code",
|
|
"cell_type": "code",
|
|
|
- "execution_count": 37,
|
|
|
|
|
|
|
+ "execution_count": 42,
|
|
|
"id": "15a6304c-9474-4470-b85d-3991a49fa653",
|
|
"id": "15a6304c-9474-4470-b85d-3991a49fa653",
|
|
|
"metadata": {},
|
|
"metadata": {},
|
|
|
"outputs": [],
|
|
"outputs": [],
|
|
@@ -1167,7 +1275,7 @@
|
|
|
},
|
|
},
|
|
|
{
|
|
{
|
|
|
"cell_type": "code",
|
|
"cell_type": "code",
|
|
|
- "execution_count": 38,
|
|
|
|
|
|
|
+ "execution_count": 43,
|
|
|
"id": "93cb2cee-9aa6-4bb8-8977-c65661d16eda",
|
|
"id": "93cb2cee-9aa6-4bb8-8977-c65661d16eda",
|
|
|
"metadata": {},
|
|
"metadata": {},
|
|
|
"outputs": [],
|
|
"outputs": [],
|
|
@@ -1189,29 +1297,26 @@
|
|
|
},
|
|
},
|
|
|
{
|
|
{
|
|
|
"cell_type": "code",
|
|
"cell_type": "code",
|
|
|
- "execution_count": 39,
|
|
|
|
|
|
|
+ "execution_count": 44,
|
|
|
"id": "a686eb61-e737-4351-8f1c-222913d47468",
|
|
"id": "a686eb61-e737-4351-8f1c-222913d47468",
|
|
|
"metadata": {},
|
|
"metadata": {},
|
|
|
"outputs": [
|
|
"outputs": [
|
|
|
{
|
|
{
|
|
|
- "data": {
|
|
|
|
|
- "text/plain": [
|
|
|
|
|
- "Parameter containing:\n",
|
|
|
|
|
- "tensor([[ 0.3374, -0.1778, -0.1690],\n",
|
|
|
|
|
- " [ 0.9178, 1.5810, 1.3010],\n",
|
|
|
|
|
- " [ 1.2753, -0.2010, -0.1606],\n",
|
|
|
|
|
- " [-0.4015, 0.9666, -1.1481],\n",
|
|
|
|
|
- " [-1.1589, 0.3255, -0.6315],\n",
|
|
|
|
|
- " [-2.8400, -0.7849, -1.4096]], requires_grad=True)"
|
|
|
|
|
- ]
|
|
|
|
|
- },
|
|
|
|
|
- "execution_count": 39,
|
|
|
|
|
- "metadata": {},
|
|
|
|
|
- "output_type": "execute_result"
|
|
|
|
|
|
|
+ "name": "stdout",
|
|
|
|
|
+ "output_type": "stream",
|
|
|
|
|
+ "text": [
|
|
|
|
|
+ "Parameter containing:\n",
|
|
|
|
|
+ "tensor([[ 0.3374, -0.1778, -0.1690],\n",
|
|
|
|
|
+ " [ 0.9178, 1.5810, 1.3010],\n",
|
|
|
|
|
+ " [ 1.2753, -0.2010, -0.1606],\n",
|
|
|
|
|
+ " [-0.4015, 0.9666, -1.1481],\n",
|
|
|
|
|
+ " [-1.1589, 0.3255, -0.6315],\n",
|
|
|
|
|
+ " [-2.8400, -0.7849, -1.4096]], requires_grad=True)\n"
|
|
|
|
|
+ ]
|
|
|
}
|
|
}
|
|
|
],
|
|
],
|
|
|
"source": [
|
|
"source": [
|
|
|
- "embedding_layer.weight"
|
|
|
|
|
|
|
+ "print(embedding_layer.weight)"
|
|
|
]
|
|
]
|
|
|
},
|
|
},
|
|
|
{
|
|
{
|
|
@@ -1233,23 +1338,20 @@
|
|
|
},
|
|
},
|
|
|
{
|
|
{
|
|
|
"cell_type": "code",
|
|
"cell_type": "code",
|
|
|
- "execution_count": 40,
|
|
|
|
|
|
|
+ "execution_count": 45,
|
|
|
"id": "e43600ba-f287-4746-8ddf-d0f71a9023ca",
|
|
"id": "e43600ba-f287-4746-8ddf-d0f71a9023ca",
|
|
|
"metadata": {},
|
|
"metadata": {},
|
|
|
"outputs": [
|
|
"outputs": [
|
|
|
{
|
|
{
|
|
|
- "data": {
|
|
|
|
|
- "text/plain": [
|
|
|
|
|
- "tensor([[-0.4015, 0.9666, -1.1481]], grad_fn=<EmbeddingBackward0>)"
|
|
|
|
|
- ]
|
|
|
|
|
- },
|
|
|
|
|
- "execution_count": 40,
|
|
|
|
|
- "metadata": {},
|
|
|
|
|
- "output_type": "execute_result"
|
|
|
|
|
|
|
+ "name": "stdout",
|
|
|
|
|
+ "output_type": "stream",
|
|
|
|
|
+ "text": [
|
|
|
|
|
+ "tensor([[-0.4015, 0.9666, -1.1481]], grad_fn=<EmbeddingBackward0>)\n"
|
|
|
|
|
+ ]
|
|
|
}
|
|
}
|
|
|
],
|
|
],
|
|
|
"source": [
|
|
"source": [
|
|
|
- "embedding_layer(torch.tensor([3]))"
|
|
|
|
|
|
|
+ "print(embedding_layer(torch.tensor([3])))"
|
|
|
]
|
|
]
|
|
|
},
|
|
},
|
|
|
{
|
|
{
|
|
@@ -1263,47 +1365,23 @@
|
|
|
},
|
|
},
|
|
|
{
|
|
{
|
|
|
"cell_type": "code",
|
|
"cell_type": "code",
|
|
|
- "execution_count": 41,
|
|
|
|
|
|
|
+ "execution_count": 46,
|
|
|
"id": "50280ead-0363-44c8-8c35-bb885d92c8b7",
|
|
"id": "50280ead-0363-44c8-8c35-bb885d92c8b7",
|
|
|
"metadata": {},
|
|
"metadata": {},
|
|
|
"outputs": [
|
|
"outputs": [
|
|
|
{
|
|
{
|
|
|
- "data": {
|
|
|
|
|
- "text/plain": [
|
|
|
|
|
- "tensor([[-2.8400, -0.7849, -1.4096],\n",
|
|
|
|
|
- " [ 0.9178, 1.5810, 1.3010],\n",
|
|
|
|
|
- " [-0.4015, 0.9666, -1.1481],\n",
|
|
|
|
|
- " [ 1.2753, -0.2010, -0.1606]], grad_fn=<EmbeddingBackward0>)"
|
|
|
|
|
- ]
|
|
|
|
|
- },
|
|
|
|
|
- "execution_count": 41,
|
|
|
|
|
- "metadata": {},
|
|
|
|
|
- "output_type": "execute_result"
|
|
|
|
|
|
|
+ "name": "stdout",
|
|
|
|
|
+ "output_type": "stream",
|
|
|
|
|
+ "text": [
|
|
|
|
|
+ "tensor([[-2.8400, -0.7849, -1.4096],\n",
|
|
|
|
|
+ " [ 0.9178, 1.5810, 1.3010],\n",
|
|
|
|
|
+ " [-0.4015, 0.9666, -1.1481],\n",
|
|
|
|
|
+ " [ 1.2753, -0.2010, -0.1606]], grad_fn=<EmbeddingBackward0>)\n"
|
|
|
|
|
+ ]
|
|
|
}
|
|
}
|
|
|
],
|
|
],
|
|
|
"source": [
|
|
"source": [
|
|
|
- "embedding_layer(input_ids)"
|
|
|
|
|
- ]
|
|
|
|
|
- },
|
|
|
|
|
- {
|
|
|
|
|
- "cell_type": "markdown",
|
|
|
|
|
- "id": "53f452c4-5fcb-4528-8fda-fd1a16f26bc7",
|
|
|
|
|
- "metadata": {},
|
|
|
|
|
- "source": [
|
|
|
|
|
- "- The BytePair encoder has a vocabulary size of 50,257:"
|
|
|
|
|
- ]
|
|
|
|
|
- },
|
|
|
|
|
- {
|
|
|
|
|
- "cell_type": "code",
|
|
|
|
|
- "execution_count": 42,
|
|
|
|
|
- "id": "91c1f77f-cb0c-4f72-a258-ec9bab2bc755",
|
|
|
|
|
- "metadata": {},
|
|
|
|
|
- "outputs": [],
|
|
|
|
|
- "source": [
|
|
|
|
|
- "vocab_size = 50257\n",
|
|
|
|
|
- "output_dim = 256\n",
|
|
|
|
|
- "\n",
|
|
|
|
|
- "token_embedding_layer = torch.nn.Embedding(vocab_size, output_dim)"
|
|
|
|
|
|
|
+ "print(embedding_layer(input_ids))"
|
|
|
]
|
|
]
|
|
|
},
|
|
},
|
|
|
{
|
|
{
|
|
@@ -1319,12 +1397,13 @@
|
|
|
"id": "7f187f87-c1f8-4c2e-8050-350bbb972f55",
|
|
"id": "7f187f87-c1f8-4c2e-8050-350bbb972f55",
|
|
|
"metadata": {},
|
|
"metadata": {},
|
|
|
"source": [
|
|
"source": [
|
|
|
|
|
+ "- The BytePair encoder has a vocabulary size of 50,257:\n",
|
|
|
"- Suppose we want to encode the input tokens into a 256-dimensional vector representation:"
|
|
"- Suppose we want to encode the input tokens into a 256-dimensional vector representation:"
|
|
|
]
|
|
]
|
|
|
},
|
|
},
|
|
|
{
|
|
{
|
|
|
"cell_type": "code",
|
|
"cell_type": "code",
|
|
|
- "execution_count": 43,
|
|
|
|
|
|
|
+ "execution_count": 48,
|
|
|
"id": "0b9e344d-03a6-4f2c-b723-67b6a20c5041",
|
|
"id": "0b9e344d-03a6-4f2c-b723-67b6a20c5041",
|
|
|
"metadata": {},
|
|
"metadata": {},
|
|
|
"outputs": [],
|
|
"outputs": [],
|
|
@@ -1340,42 +1419,70 @@
|
|
|
"id": "a2654722-24e4-4b0d-a43c-436a461eb70b",
|
|
"id": "a2654722-24e4-4b0d-a43c-436a461eb70b",
|
|
|
"metadata": {},
|
|
"metadata": {},
|
|
|
"source": [
|
|
"source": [
|
|
|
- "- If we sample data from teh dataloader, we embed the tokens in each batch into a 256-dimensional vector\n",
|
|
|
|
|
|
|
+ "- If we sample data from the dataloader, we embed the tokens in each batch into a 256-dimensional vector\n",
|
|
|
"- If we have a batch size of 8 with 4 tokens each, this results in a 8 x 4 x 256 tensor:"
|
|
"- If we have a batch size of 8 with 4 tokens each, this results in a 8 x 4 x 256 tensor:"
|
|
|
]
|
|
]
|
|
|
},
|
|
},
|
|
|
{
|
|
{
|
|
|
"cell_type": "code",
|
|
"cell_type": "code",
|
|
|
- "execution_count": 44,
|
|
|
|
|
|
|
+ "execution_count": 49,
|
|
|
"id": "ad56a263-3d2e-4d91-98bf-d0b68d3c7fc3",
|
|
"id": "ad56a263-3d2e-4d91-98bf-d0b68d3c7fc3",
|
|
|
"metadata": {},
|
|
"metadata": {},
|
|
|
"outputs": [],
|
|
"outputs": [],
|
|
|
"source": [
|
|
"source": [
|
|
|
- "dataloader = create_dataloader(raw_text, batch_size=8, max_length=4, stride=5)\n",
|
|
|
|
|
|
|
+ "max_length = 4\n",
|
|
|
|
|
+ "dataloader = create_dataloader(raw_text, batch_size=8, max_length=max_length, stride=5)\n",
|
|
|
"data_iter = iter(dataloader)\n",
|
|
"data_iter = iter(dataloader)\n",
|
|
|
"inputs, targets = next(data_iter)"
|
|
"inputs, targets = next(data_iter)"
|
|
|
]
|
|
]
|
|
|
},
|
|
},
|
|
|
{
|
|
{
|
|
|
"cell_type": "code",
|
|
"cell_type": "code",
|
|
|
- "execution_count": 45,
|
|
|
|
|
|
|
+ "execution_count": 50,
|
|
|
|
|
+ "id": "84416b60-3707-4370-bcbc-da0b62f2b64d",
|
|
|
|
|
+ "metadata": {},
|
|
|
|
|
+ "outputs": [
|
|
|
|
|
+ {
|
|
|
|
|
+ "name": "stdout",
|
|
|
|
|
+ "output_type": "stream",
|
|
|
|
|
+ "text": [
|
|
|
|
|
+ "Token IDs:\n",
|
|
|
|
|
+ " tensor([[ 40, 367, 2885, 1464],\n",
|
|
|
|
|
+ " [ 3619, 402, 271, 10899],\n",
|
|
|
|
|
+ " [ 257, 7026, 15632, 438],\n",
|
|
|
|
|
+ " [ 257, 922, 5891, 1576],\n",
|
|
|
|
|
+ " [ 568, 340, 373, 645],\n",
|
|
|
|
|
+ " [ 5975, 284, 502, 284],\n",
|
|
|
|
|
+ " [ 326, 11, 287, 262],\n",
|
|
|
|
|
+ " [ 286, 465, 13476, 11]])\n",
|
|
|
|
|
+ "\n",
|
|
|
|
|
+ "Inputs shape:\n",
|
|
|
|
|
+ " torch.Size([8, 4])\n"
|
|
|
|
|
+ ]
|
|
|
|
|
+ }
|
|
|
|
|
+ ],
|
|
|
|
|
+ "source": [
|
|
|
|
|
+ "print(\"Token IDs:\\n\", inputs)\n",
|
|
|
|
|
+ "print(\"\\nInputs shape:\\n\", inputs.shape)"
|
|
|
|
|
+ ]
|
|
|
|
|
+ },
|
|
|
|
|
+ {
|
|
|
|
|
+ "cell_type": "code",
|
|
|
|
|
+ "execution_count": 51,
|
|
|
"id": "7766ec38-30d0-4128-8c31-f49f063c43d1",
|
|
"id": "7766ec38-30d0-4128-8c31-f49f063c43d1",
|
|
|
"metadata": {},
|
|
"metadata": {},
|
|
|
"outputs": [
|
|
"outputs": [
|
|
|
{
|
|
{
|
|
|
- "data": {
|
|
|
|
|
- "text/plain": [
|
|
|
|
|
- "torch.Size([8, 4, 256])"
|
|
|
|
|
- ]
|
|
|
|
|
- },
|
|
|
|
|
- "execution_count": 45,
|
|
|
|
|
- "metadata": {},
|
|
|
|
|
- "output_type": "execute_result"
|
|
|
|
|
|
|
+ "name": "stdout",
|
|
|
|
|
+ "output_type": "stream",
|
|
|
|
|
+ "text": [
|
|
|
|
|
+ "torch.Size([8, 4, 256])\n"
|
|
|
|
|
+ ]
|
|
|
}
|
|
}
|
|
|
],
|
|
],
|
|
|
"source": [
|
|
"source": [
|
|
|
"token_embeddings = token_embedding_layer(inputs)\n",
|
|
"token_embeddings = token_embedding_layer(inputs)\n",
|
|
|
- "token_embeddings.shape"
|
|
|
|
|
|
|
+ "print(token_embeddings.shape)"
|
|
|
]
|
|
]
|
|
|
},
|
|
},
|
|
|
{
|
|
{
|
|
@@ -1383,12 +1490,12 @@
|
|
|
"id": "fe2ae164-6f19-4e32-b9e5-76950fcf1c9f",
|
|
"id": "fe2ae164-6f19-4e32-b9e5-76950fcf1c9f",
|
|
|
"metadata": {},
|
|
"metadata": {},
|
|
|
"source": [
|
|
"source": [
|
|
|
- "- GPT2 uses absolute position embeddings, so we just create another embedding layer:"
|
|
|
|
|
|
|
+ "- GPT-2 uses absolute position embeddings, so we just create another embedding layer:"
|
|
|
]
|
|
]
|
|
|
},
|
|
},
|
|
|
{
|
|
{
|
|
|
"cell_type": "code",
|
|
"cell_type": "code",
|
|
|
- "execution_count": 46,
|
|
|
|
|
|
|
+ "execution_count": 52,
|
|
|
"id": "cc048e20-7ac8-417e-81f5-8fe6f9a4fe07",
|
|
"id": "cc048e20-7ac8-417e-81f5-8fe6f9a4fe07",
|
|
|
"metadata": {},
|
|
"metadata": {},
|
|
|
"outputs": [],
|
|
"outputs": [],
|
|
@@ -1398,24 +1505,21 @@
|
|
|
},
|
|
},
|
|
|
{
|
|
{
|
|
|
"cell_type": "code",
|
|
"cell_type": "code",
|
|
|
- "execution_count": 47,
|
|
|
|
|
|
|
+ "execution_count": 53,
|
|
|
"id": "c369a1e7-d566-4b53-b398-d6adafb44105",
|
|
"id": "c369a1e7-d566-4b53-b398-d6adafb44105",
|
|
|
"metadata": {},
|
|
"metadata": {},
|
|
|
"outputs": [
|
|
"outputs": [
|
|
|
{
|
|
{
|
|
|
- "data": {
|
|
|
|
|
- "text/plain": [
|
|
|
|
|
- "torch.Size([8, 4, 256])"
|
|
|
|
|
- ]
|
|
|
|
|
- },
|
|
|
|
|
- "execution_count": 47,
|
|
|
|
|
- "metadata": {},
|
|
|
|
|
- "output_type": "execute_result"
|
|
|
|
|
|
|
+ "name": "stdout",
|
|
|
|
|
+ "output_type": "stream",
|
|
|
|
|
+ "text": [
|
|
|
|
|
+ "torch.Size([4, 256])\n"
|
|
|
|
|
+ ]
|
|
|
}
|
|
}
|
|
|
],
|
|
],
|
|
|
"source": [
|
|
"source": [
|
|
|
- "pos_embeddings = pos_embedding_layer(inputs)\n",
|
|
|
|
|
- "pos_embeddings.shape"
|
|
|
|
|
|
|
+ "pos_embeddings = pos_embedding_layer(torch.arange(max_length))\n",
|
|
|
|
|
+ "print(pos_embeddings.shape)"
|
|
|
]
|
|
]
|
|
|
},
|
|
},
|
|
|
{
|
|
{
|
|
@@ -1428,25 +1532,38 @@
|
|
|
},
|
|
},
|
|
|
{
|
|
{
|
|
|
"cell_type": "code",
|
|
"cell_type": "code",
|
|
|
- "execution_count": 48,
|
|
|
|
|
|
|
+ "execution_count": 54,
|
|
|
"id": "b22fab89-526e-43c8-9035-5b7018e34288",
|
|
"id": "b22fab89-526e-43c8-9035-5b7018e34288",
|
|
|
"metadata": {},
|
|
"metadata": {},
|
|
|
"outputs": [
|
|
"outputs": [
|
|
|
{
|
|
{
|
|
|
- "data": {
|
|
|
|
|
- "text/plain": [
|
|
|
|
|
- "torch.Size([8, 4, 256])"
|
|
|
|
|
- ]
|
|
|
|
|
- },
|
|
|
|
|
- "execution_count": 48,
|
|
|
|
|
- "metadata": {},
|
|
|
|
|
- "output_type": "execute_result"
|
|
|
|
|
|
|
+ "name": "stdout",
|
|
|
|
|
+ "output_type": "stream",
|
|
|
|
|
+ "text": [
|
|
|
|
|
+ "torch.Size([8, 4, 256])\n"
|
|
|
|
|
+ ]
|
|
|
}
|
|
}
|
|
|
],
|
|
],
|
|
|
"source": [
|
|
"source": [
|
|
|
"input_embeddings = token_embeddings + pos_embeddings\n",
|
|
"input_embeddings = token_embeddings + pos_embeddings\n",
|
|
|
- "input_embeddings.shape"
|
|
|
|
|
|
|
+ "print(input_embeddings.shape)"
|
|
|
]
|
|
]
|
|
|
|
|
+ },
|
|
|
|
|
+ {
|
|
|
|
|
+ "cell_type": "code",
|
|
|
|
|
+ "execution_count": null,
|
|
|
|
|
+ "id": "a6b71f61-57f4-496b-bf48-9097c591f54c",
|
|
|
|
|
+ "metadata": {},
|
|
|
|
|
+ "outputs": [],
|
|
|
|
|
+ "source": []
|
|
|
|
|
+ },
|
|
|
|
|
+ {
|
|
|
|
|
+ "cell_type": "code",
|
|
|
|
|
+ "execution_count": null,
|
|
|
|
|
+ "id": "c2894bbd-6cf5-4bfa-80ad-a23b5d1a45f4",
|
|
|
|
|
+ "metadata": {},
|
|
|
|
|
+ "outputs": [],
|
|
|
|
|
+ "source": []
|
|
|
}
|
|
}
|
|
|
],
|
|
],
|
|
|
"metadata": {
|
|
"metadata": {
|
|
@@ -1465,7 +1582,7 @@
|
|
|
"name": "python",
|
|
"name": "python",
|
|
|
"nbconvert_exporter": "python",
|
|
"nbconvert_exporter": "python",
|
|
|
"pygments_lexer": "ipython3",
|
|
"pygments_lexer": "ipython3",
|
|
|
- "version": "3.10.6"
|
|
|
|
|
|
|
+ "version": "3.10.12"
|
|
|
}
|
|
}
|
|
|
},
|
|
},
|
|
|
"nbformat": 4,
|
|
"nbformat": 4,
|