|
|
@@ -764,7 +764,7 @@
|
|
|
},
|
|
|
{
|
|
|
"cell_type": "code",
|
|
|
- "execution_count": 16,
|
|
|
+ "execution_count": 23,
|
|
|
"id": "654fde37-b2a9-4a20-a8d3-0206c056e2ff",
|
|
|
"metadata": {},
|
|
|
"outputs": [],
|
|
|
@@ -785,43 +785,6 @@
|
|
|
" text_data = file.read()"
|
|
|
]
|
|
|
},
|
|
|
- {
|
|
|
- "cell_type": "code",
|
|
|
- "execution_count": 18,
|
|
|
- "id": "0959c855-f860-4358-8b98-bc654f047578",
|
|
|
- "metadata": {},
|
|
|
- "outputs": [],
|
|
|
- "source": [
|
|
|
- "from previous_chapters import create_dataloader_v1\n",
|
|
|
- "\n",
|
|
|
- "# Train/validation ratio\n",
|
|
|
- "train_ratio = 0.90\n",
|
|
|
- "split_idx = int(train_ratio * len(text_data))\n",
|
|
|
- "train_data = text_data[:split_idx]\n",
|
|
|
- "val_data = text_data[split_idx:]\n",
|
|
|
- "\n",
|
|
|
- "\n",
|
|
|
- "torch.manual_seed(123)\n",
|
|
|
- "\n",
|
|
|
- "train_loader = create_dataloader_v1(\n",
|
|
|
- " train_data,\n",
|
|
|
- " batch_size=2,\n",
|
|
|
- " max_length=GPT_CONFIG_124M[\"ctx_len\"],\n",
|
|
|
- " stride=GPT_CONFIG_124M[\"ctx_len\"],\n",
|
|
|
- " drop_last=True,\n",
|
|
|
- " shuffle=True\n",
|
|
|
- ")\n",
|
|
|
- "\n",
|
|
|
- "val_loader = create_dataloader_v1(\n",
|
|
|
- " val_data,\n",
|
|
|
- " batch_size=2,\n",
|
|
|
- " max_length=GPT_CONFIG_124M[\"ctx_len\"],\n",
|
|
|
- " stride=GPT_CONFIG_124M[\"ctx_len\"],\n",
|
|
|
- " drop_last=False,\n",
|
|
|
- " shuffle=False\n",
|
|
|
- ")"
|
|
|
- ]
|
|
|
- },
|
|
|
{
|
|
|
"cell_type": "markdown",
|
|
|
"id": "379330f1-80f4-4e34-8724-41d892b04cee",
|
|
|
@@ -832,7 +795,7 @@
|
|
|
},
|
|
|
{
|
|
|
"cell_type": "code",
|
|
|
- "execution_count": 19,
|
|
|
+ "execution_count": 24,
|
|
|
"id": "6kgJbe4ehI4q",
|
|
|
"metadata": {
|
|
|
"colab": {
|
|
|
@@ -858,7 +821,7 @@
|
|
|
},
|
|
|
{
|
|
|
"cell_type": "code",
|
|
|
- "execution_count": 20,
|
|
|
+ "execution_count": 25,
|
|
|
"id": "j2XPde_ThM_e",
|
|
|
"metadata": {
|
|
|
"colab": {
|
|
|
@@ -884,7 +847,7 @@
|
|
|
},
|
|
|
{
|
|
|
"cell_type": "code",
|
|
|
- "execution_count": 21,
|
|
|
+ "execution_count": 27,
|
|
|
"id": "6b46a952-d50a-4837-af09-4095698f7fd1",
|
|
|
"metadata": {
|
|
|
"colab": {
|
|
|
@@ -940,22 +903,24 @@
|
|
|
},
|
|
|
{
|
|
|
"cell_type": "code",
|
|
|
- "execution_count": 22,
|
|
|
- "id": "fd0e963b-d282-4c97-b004-8772f4b1bd8f",
|
|
|
+ "execution_count": 35,
|
|
|
+ "id": "0959c855-f860-4358-8b98-bc654f047578",
|
|
|
"metadata": {},
|
|
|
"outputs": [],
|
|
|
"source": [
|
|
|
"from previous_chapters import create_dataloader_v1\n",
|
|
|
"\n",
|
|
|
- "\n",
|
|
|
"# Train/validation ratio\n",
|
|
|
"train_ratio = 0.90\n",
|
|
|
"split_idx = int(train_ratio * len(text_data))\n",
|
|
|
+ "train_data = text_data[:split_idx]\n",
|
|
|
+ "val_data = text_data[split_idx:]\n",
|
|
|
+ "\n",
|
|
|
"\n",
|
|
|
"torch.manual_seed(123)\n",
|
|
|
"\n",
|
|
|
"train_loader = create_dataloader_v1(\n",
|
|
|
- " text_data[:split_idx],\n",
|
|
|
+ " train_data,\n",
|
|
|
" batch_size=2,\n",
|
|
|
" max_length=GPT_CONFIG_124M[\"ctx_len\"],\n",
|
|
|
" stride=GPT_CONFIG_124M[\"ctx_len\"],\n",
|
|
|
@@ -964,7 +929,7 @@
|
|
|
")\n",
|
|
|
"\n",
|
|
|
"val_loader = create_dataloader_v1(\n",
|
|
|
- " text_data[split_idx:],\n",
|
|
|
+ " val_data,\n",
|
|
|
" batch_size=2,\n",
|
|
|
" max_length=GPT_CONFIG_124M[\"ctx_len\"],\n",
|
|
|
" stride=GPT_CONFIG_124M[\"ctx_len\"],\n",
|
|
|
@@ -975,7 +940,7 @@
|
|
|
},
|
|
|
{
|
|
|
"cell_type": "code",
|
|
|
- "execution_count": 23,
|
|
|
+ "execution_count": 36,
|
|
|
"id": "f37b3eb0-854e-4895-9898-fa7d1e67566e",
|
|
|
"metadata": {},
|
|
|
"outputs": [],
|
|
|
@@ -1012,7 +977,7 @@
|
|
|
},
|
|
|
{
|
|
|
"cell_type": "code",
|
|
|
- "execution_count": 24,
|
|
|
+ "execution_count": 37,
|
|
|
"id": "ca0116d0-d229-472c-9fbf-ebc229331c3e",
|
|
|
"metadata": {},
|
|
|
"outputs": [
|
|
|
@@ -1056,7 +1021,7 @@
|
|
|
},
|
|
|
{
|
|
|
"cell_type": "code",
|
|
|
- "execution_count": 25,
|
|
|
+ "execution_count": 38,
|
|
|
"id": "eb860488-5453-41d7-9870-23b723f742a0",
|
|
|
"metadata": {
|
|
|
"colab": {
|
|
|
@@ -1101,7 +1066,7 @@
|
|
|
},
|
|
|
{
|
|
|
"cell_type": "code",
|
|
|
- "execution_count": 26,
|
|
|
+ "execution_count": 49,
|
|
|
"id": "7b9de31e-4096-47b3-976d-b6d2fdce04bc",
|
|
|
"metadata": {
|
|
|
"id": "7b9de31e-4096-47b3-976d-b6d2fdce04bc"
|
|
|
@@ -1118,17 +1083,16 @@
|
|
|
"\n",
|
|
|
"\n",
|
|
|
"def calc_loss_loader(data_loader, model, device, num_batches=None):\n",
|
|
|
- " total_loss, batches_seen = 0., 0.\n",
|
|
|
+ " total_loss = 0.\n",
|
|
|
" if num_batches is None:\n",
|
|
|
" num_batches = len(data_loader)\n",
|
|
|
" for i, (input_batch, target_batch) in enumerate(data_loader):\n",
|
|
|
" if i < num_batches:\n",
|
|
|
" loss = calc_loss_batch(input_batch, target_batch, model, device)\n",
|
|
|
" total_loss += loss.item()\n",
|
|
|
- " batches_seen += 1\n",
|
|
|
" else:\n",
|
|
|
" break\n",
|
|
|
- " return total_loss / batches_seen"
|
|
|
+ " return total_loss / num_batches"
|
|
|
]
|
|
|
},
|
|
|
{
|
|
|
@@ -1142,7 +1106,7 @@
|
|
|
},
|
|
|
{
|
|
|
"cell_type": "code",
|
|
|
- "execution_count": 27,
|
|
|
+ "execution_count": 55,
|
|
|
"id": "56f5b0c9-1065-4d67-98b9-010e42fc1e2a",
|
|
|
"metadata": {},
|
|
|
"outputs": [
|
|
|
@@ -1150,7 +1114,7 @@
|
|
|
"name": "stdout",
|
|
|
"output_type": "stream",
|
|
|
"text": [
|
|
|
- "Training loss: 11.002305030822754\n",
|
|
|
+ "Training loss: 10.98758347829183\n",
|
|
|
"Validation loss: 10.98110580444336\n"
|
|
|
]
|
|
|
}
|
|
|
@@ -1159,6 +1123,8 @@
|
|
|
"device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
|
|
|
"model.to(device) # no assignment model = model.to(device) necessary for nn.Module classes\n",
|
|
|
"\n",
|
|
|
+ "\n",
|
|
|
+ "torch.manual_seed(123) # For reproducibility due to the shuffling in the data loader\n",
|
|
|
"train_loss = calc_loss_loader(train_loader, model, device, num_batches=1)\n",
|
|
|
"val_loss = calc_loss_loader(val_loader, model, device, num_batches=1)\n",
|
|
|
"\n",
|