|
|
@@ -4,7 +4,7 @@
|
|
|
# Code: https://github.com/rasbt/LLMs-from-scratch
|
|
|
|
|
|
from llms_from_scratch.ch02 import create_dataloader_v1
|
|
|
-from llms_from_scratch.ch04 import GPTModel
|
|
|
+from llms_from_scratch.ch04 import GPTModel, GPTModelFast
|
|
|
from llms_from_scratch.ch05 import train_model_simple
|
|
|
|
|
|
import os
|
|
|
@@ -16,60 +16,47 @@ import torch
|
|
|
from torch.utils.data import Subset, DataLoader
|
|
|
|
|
|
|
|
|
-@pytest.mark.parametrize("file_name", ["the-verdict.txt"])
|
|
|
-def test_train_simple(tmp_path, file_name):
|
|
|
+GPT_CONFIG_124M = {
|
|
|
+ "vocab_size": 50257,
|
|
|
+ "context_length": 256, # Shortened for test speed
|
|
|
+ "emb_dim": 768,
|
|
|
+ "n_heads": 12,
|
|
|
+ "n_layers": 12,
|
|
|
+ "drop_rate": 0.1,
|
|
|
+ "qkv_bias": False
|
|
|
+}
|
|
|
|
|
|
- GPT_CONFIG_124M = {
|
|
|
- "vocab_size": 50257, # Vocabulary size
|
|
|
- "context_length": 256, # Shortened context length (orig: 1024)
|
|
|
- "emb_dim": 768, # Embedding dimension
|
|
|
- "n_heads": 12, # Number of attention heads
|
|
|
- "n_layers": 12, # Number of layers
|
|
|
- "drop_rate": 0.1, # Dropout rate
|
|
|
- "qkv_bias": False # Query-key-value bias
|
|
|
- }
|
|
|
+OTHER_SETTINGS = {
|
|
|
+ "learning_rate": 5e-4,
|
|
|
+ "num_epochs": 2,
|
|
|
+ "batch_size": 1,
|
|
|
+ "weight_decay": 0.1
|
|
|
+}
|
|
|
|
|
|
- OTHER_SETTINGS = {
|
|
|
- "learning_rate": 5e-4,
|
|
|
- "num_epochs": 2,
|
|
|
- "batch_size": 1,
|
|
|
- "weight_decay": 0.1
|
|
|
- }
|
|
|
|
|
|
+@pytest.mark.parametrize("ModelClass", [GPTModel, GPTModelFast])
|
|
|
+def test_train_simple(tmp_path, ModelClass):
|
|
|
torch.manual_seed(123)
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
|
|
|
|
##############################
|
|
|
# Download data if necessary
|
|
|
##############################
|
|
|
-
|
|
|
file_path = tmp_path / "the-verdict.txt"
|
|
|
url = "https://raw.githubusercontent.com/rasbt/LLMs-from-scratch/main/ch02/01_main-chapter-code/the-verdict.txt"
|
|
|
|
|
|
if not os.path.exists(file_path):
|
|
|
with urllib.request.urlopen(url) as response:
|
|
|
- text_data = response.read().decode('utf-8')
|
|
|
- with open(file_path, "w", encoding="utf-8") as file:
|
|
|
- file.write(text_data)
|
|
|
+ text_data = response.read().decode("utf-8")
|
|
|
+ with open(file_path, "w", encoding="utf-8") as f:
|
|
|
+ f.write(text_data)
|
|
|
else:
|
|
|
- with open(file_path, "r", encoding="utf-8") as file:
|
|
|
- text_data = file.read()
|
|
|
-
|
|
|
- ##############################
|
|
|
- # Initialize model
|
|
|
- ##############################
|
|
|
-
|
|
|
- model = GPTModel(GPT_CONFIG_124M)
|
|
|
- model.to(device) # no assignment model = model.to(device) necessary for nn.Module classes
|
|
|
- optimizer = torch.optim.AdamW(
|
|
|
- model.parameters(), lr=OTHER_SETTINGS["learning_rate"], weight_decay=OTHER_SETTINGS["weight_decay"]
|
|
|
- )
|
|
|
+ with open(file_path, "r", encoding="utf-8") as f:
|
|
|
+ text_data = f.read()
|
|
|
|
|
|
##############################
|
|
|
# Set up dataloaders
|
|
|
##############################
|
|
|
-
|
|
|
- # Train/validation ratio
|
|
|
train_ratio = 0.90
|
|
|
split_idx = int(train_ratio * len(text_data))
|
|
|
|
|
|
@@ -93,16 +80,25 @@ def test_train_simple(tmp_path, file_name):
|
|
|
num_workers=0
|
|
|
)
|
|
|
|
|
|
+ # Limit to 1 batch for speed
|
|
|
+ train_subset = Subset(train_loader.dataset, range(1))
|
|
|
+ one_batch_train_loader = DataLoader(train_subset, batch_size=1)
|
|
|
+ val_subset = Subset(val_loader.dataset, range(1))
|
|
|
+ one_batch_val_loader = DataLoader(val_subset, batch_size=1)
|
|
|
+
|
|
|
##############################
|
|
|
# Train model
|
|
|
##############################
|
|
|
+ model = ModelClass(GPT_CONFIG_124M)
|
|
|
+ model.to(device)
|
|
|
|
|
|
- tokenizer = tiktoken.get_encoding("gpt2")
|
|
|
+ optimizer = torch.optim.AdamW(
|
|
|
+ model.parameters(),
|
|
|
+ lr=OTHER_SETTINGS["learning_rate"],
|
|
|
+ weight_decay=OTHER_SETTINGS["weight_decay"]
|
|
|
+ )
|
|
|
|
|
|
- train_subset = Subset(train_loader.dataset, range(1))
|
|
|
- one_batch_train_loader = DataLoader(train_subset, batch_size=1)
|
|
|
- val_subset = Subset(val_loader.dataset, range(1))
|
|
|
- one_batch_val_loader = DataLoader(val_subset, batch_size=1)
|
|
|
+ tokenizer = tiktoken.get_encoding("gpt2")
|
|
|
|
|
|
train_losses, val_losses, tokens_seen = train_model_simple(
|
|
|
model, one_batch_train_loader, one_batch_val_loader, optimizer, device,
|