|
|
@@ -20,8 +20,9 @@
|
|
|
"name": "stdout",
|
|
|
"output_type": "stream",
|
|
|
"text": [
|
|
|
- "numpy version: 1.26.0\n",
|
|
|
- "matplotlib version: 3.8.2\n",
|
|
|
+ "numpy version: 1.25.2\n",
|
|
|
+ "matplotlib version: 3.7.2\n",
|
|
|
+ "numpy version: 1.25.2\n",
|
|
|
"tiktoken version: 0.5.1\n",
|
|
|
"torch version: 2.2.1\n"
|
|
|
]
|
|
|
@@ -916,8 +917,11 @@
|
|
|
}
|
|
|
],
|
|
|
"source": [
|
|
|
- "print(\"Characters:\", len(text_data))\n",
|
|
|
- "print(\"Tokens:\", len(tokenizer.encode(text_data)))"
|
|
|
+ "total_char = len(text_data)\n",
|
|
|
+ "total_tokens = len(tokenizer.encode(text_data))\n",
|
|
|
+ "\n",
|
|
|
+ "print(\"Characters:\", total_char)\n",
|
|
|
+ "print(\"Tokens:\", total_tokens)"
|
|
|
]
|
|
|
},
|
|
|
{
|
|
|
@@ -962,7 +966,6 @@
|
|
|
"train_ratio = 0.90\n",
|
|
|
"split_idx = int(train_ratio * len(text_data))\n",
|
|
|
"\n",
|
|
|
- "\n",
|
|
|
"torch.manual_seed(123)\n",
|
|
|
"\n",
|
|
|
"train_loader = create_dataloader_v1(\n",
|
|
|
@@ -984,6 +987,26 @@
|
|
|
")"
|
|
|
]
|
|
|
},
|
|
|
+ {
|
|
|
+ "cell_type": "code",
|
|
|
+ "execution_count": 23,
|
|
|
+ "id": "f37b3eb0-854e-4895-9898-fa7d1e67566e",
|
|
|
+ "metadata": {},
|
|
|
+ "outputs": [],
|
|
|
+ "source": [
|
|
|
+ "# Sanity check\n",
|
|
|
+ "\n",
|
|
|
+ "if total_tokens * (train_ratio) < GPT_CONFIG_124M[\"ctx_len\"]:\n",
|
|
|
+ " print(\"Not enough tokens for the training loader. \"\n",
|
|
|
+ " \"Try to lower the `GPT_CONFIG_124M['ctx_len']` or \"\n",
|
|
|
+ " \"increase the `training_ratio`\")\n",
|
|
|
+ "\n",
|
|
|
+ "if total_tokens * (1-train_ratio) < GPT_CONFIG_124M[\"ctx_len\"]:\n",
|
|
|
+ " print(\"Not enough tokens for the validation loader. \"\n",
|
|
|
+ " \"Try to lower the `GPT_CONFIG_124M['ctx_len']` or \"\n",
|
|
|
+ " \"decrease the `training_ratio`\")"
|
|
|
+ ]
|
|
|
+ },
|
|
|
{
|
|
|
"cell_type": "markdown",
|
|
|
"id": "e7ac3296-a4d1-4303-9ac5-376518960c33",
|
|
|
@@ -1003,7 +1026,7 @@
|
|
|
},
|
|
|
{
|
|
|
"cell_type": "code",
|
|
|
- "execution_count": 23,
|
|
|
+ "execution_count": 24,
|
|
|
"id": "ca0116d0-d229-472c-9fbf-ebc229331c3e",
|
|
|
"metadata": {},
|
|
|
"outputs": [
|
|
|
@@ -1047,7 +1070,7 @@
|
|
|
},
|
|
|
{
|
|
|
"cell_type": "code",
|
|
|
- "execution_count": 24,
|
|
|
+ "execution_count": 25,
|
|
|
"id": "eb860488-5453-41d7-9870-23b723f742a0",
|
|
|
"metadata": {
|
|
|
"colab": {
|
|
|
@@ -1092,7 +1115,7 @@
|
|
|
},
|
|
|
{
|
|
|
"cell_type": "code",
|
|
|
- "execution_count": 25,
|
|
|
+ "execution_count": 26,
|
|
|
"id": "7b9de31e-4096-47b3-976d-b6d2fdce04bc",
|
|
|
"metadata": {
|
|
|
"id": "7b9de31e-4096-47b3-976d-b6d2fdce04bc"
|
|
|
@@ -1133,7 +1156,7 @@
|
|
|
},
|
|
|
{
|
|
|
"cell_type": "code",
|
|
|
- "execution_count": 26,
|
|
|
+ "execution_count": 27,
|
|
|
"id": "56f5b0c9-1065-4d67-98b9-010e42fc1e2a",
|
|
|
"metadata": {},
|
|
|
"outputs": [
|
|
|
@@ -1195,7 +1218,7 @@
|
|
|
},
|
|
|
"outputs": [],
|
|
|
"source": [
|
|
|
- "def train_model_simple(model, train_loader, val_loader, optimizer, device, n_epochs,\n",
|
|
|
+ "def train_model_simple(model, train_loader, val_loader, optimizer, device, num_epochs,\n",
|
|
|
" eval_freq, eval_iter, start_context):\n",
|
|
|
" # Initialize lists to track losses and tokens seen\n",
|
|
|
" train_losses, val_losses, track_tokens_seen = [], [], []\n",
|
|
|
@@ -1203,7 +1226,7 @@
|
|
|
" global_step = -1\n",
|
|
|
"\n",
|
|
|
" # Main training loop\n",
|
|
|
- " for epoch in range(n_epochs):\n",
|
|
|
+ " for epoch in range(num_epochs):\n",
|
|
|
" model.train() # Set model to training mode\n",
|
|
|
" \n",
|
|
|
" for input_batch, target_batch in train_loader:\n",
|
|
|
@@ -1246,8 +1269,10 @@
|
|
|
" context_size = model.pos_emb.weight.shape[0]\n",
|
|
|
" encoded = text_to_token_ids(start_context, tokenizer).to(device)\n",
|
|
|
" with torch.no_grad():\n",
|
|
|
- " token_ids = generate_text_simple(model=model, idx=encoded,\n",
|
|
|
- " max_new_tokens=50, context_size=context_size)\n",
|
|
|
+ " token_ids = generate_text_simple(\n",
|
|
|
+ " model=model, idx=encoded,\n",
|
|
|
+ " max_new_tokens=50, context_size=context_size\n",
|
|
|
+ " )\n",
|
|
|
" decoded_text = token_ids_to_text(token_ids, tokenizer)\n",
|
|
|
" print(decoded_text.replace(\"\\n\", \" \")) # Compact print format\n",
|
|
|
" model.train()"
|
|
|
@@ -1314,10 +1339,10 @@
|
|
|
"model.to(device)\n",
|
|
|
"optimizer = torch.optim.AdamW(model.parameters(), lr=5e-4, weight_decay=0.1)\n",
|
|
|
"\n",
|
|
|
- "n_epochs = 10\n",
|
|
|
+ "num_epochs = 10\n",
|
|
|
"train_losses, val_losses, tokens_seen = train_model_simple(\n",
|
|
|
" model, train_loader, val_loader, optimizer, device,\n",
|
|
|
- " n_epochs=n_epochs, eval_freq=5, eval_iter=1,\n",
|
|
|
+ " num_epochs=num_epochs, eval_freq=5, eval_iter=1,\n",
|
|
|
" start_context=\"Every effort moves you\",\n",
|
|
|
")"
|
|
|
]
|
|
|
@@ -1368,7 +1393,7 @@
|
|
|
" plt.show()\n",
|
|
|
"\n",
|
|
|
"\n",
|
|
|
- "epochs_tensor = torch.linspace(0, n_epochs, len(train_losses))\n",
|
|
|
+ "epochs_tensor = torch.linspace(0, num_epochs, len(train_losses))\n",
|
|
|
"plot_losses(epochs_tensor, tokens_seen, train_losses, val_losses)"
|
|
|
]
|
|
|
},
|
|
|
@@ -1959,7 +1984,7 @@
|
|
|
"metadata": {},
|
|
|
"outputs": [],
|
|
|
"source": [
|
|
|
- "torch.save(model.state_dict(), 'model.pth')"
|
|
|
+ "torch.save(model.state_dict(), \"model.pth\")"
|
|
|
]
|
|
|
},
|
|
|
{
|
|
|
@@ -1978,7 +2003,7 @@
|
|
|
"outputs": [],
|
|
|
"source": [
|
|
|
"model = GPTModel(GPT_CONFIG_124M)\n",
|
|
|
- "model.load_state_dict(torch.load('model.pth'))\n",
|
|
|
+ "model.load_state_dict(torch.load(\"model.pth\"))\n",
|
|
|
"model.eval();"
|
|
|
]
|
|
|
},
|
|
|
@@ -1999,10 +2024,10 @@
|
|
|
"outputs": [],
|
|
|
"source": [
|
|
|
"torch.save({\n",
|
|
|
- " 'model_state_dict': model.state_dict(),\n",
|
|
|
- " 'optimizer_state_dict': optimizer.state_dict(),\n",
|
|
|
+ " \"model_state_dict\": model.state_dict(),\n",
|
|
|
+ " \"optimizer_state_dict\": optimizer.state_dict(),\n",
|
|
|
" }, \n",
|
|
|
- " 'model_and_optimizer.pth'\n",
|
|
|
+ " \"model_and_optimizer.pth\"\n",
|
|
|
")"
|
|
|
]
|
|
|
},
|
|
|
@@ -2013,9 +2038,9 @@
|
|
|
"metadata": {},
|
|
|
"outputs": [],
|
|
|
"source": [
|
|
|
- "checkpoint = torch.load('model_and_optimizer.pth')\n",
|
|
|
- "model.load_state_dict(checkpoint['model_state_dict'])\n",
|
|
|
- "optimizer.load_state_dict(checkpoint['optimizer_state_dict'])\n",
|
|
|
+ "checkpoint = torch.load(\"model_and_optimizer.pth\")\n",
|
|
|
+ "model.load_state_dict(checkpoint[\"model_state_dict\"])\n",
|
|
|
+ "optimizer.load_state_dict(checkpoint[\"optimizer_state_dict\"])\n",
|
|
|
"model.train();"
|
|
|
]
|
|
|
},
|
|
|
@@ -2474,7 +2499,7 @@
|
|
|
"name": "python",
|
|
|
"nbconvert_exporter": "python",
|
|
|
"pygments_lexer": "ipython3",
|
|
|
- "version": "3.10.12"
|
|
|
+ "version": "3.11.4"
|
|
|
}
|
|
|
},
|
|
|
"nbformat": 4,
|