| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242 |
- # Copyright (c) Sebastian Raschka under Apache License 2.0 (see LICENSE.txt).
- # Source for "Build a Large Language Model From Scratch"
- # - https://www.manning.com/books/build-a-large-language-model-from-scratch
- # Code: https://github.com/rasbt/LLMs-from-scratch
- import matplotlib.pyplot as plt
- import os
- import torch
- import urllib.request
- import tiktoken
- # Import from local files
- from previous_chapters import GPTModel, create_dataloader_v1, generate_text_simple
- def text_to_token_ids(text, tokenizer):
- encoded = tokenizer.encode(text)
- encoded_tensor = torch.tensor(encoded).unsqueeze(0) # add batch dimension
- return encoded_tensor
- def token_ids_to_text(token_ids, tokenizer):
- flat = token_ids.squeeze(0) # remove batch dimension
- return tokenizer.decode(flat.tolist())
- def calc_loss_batch(input_batch, target_batch, model, device):
- input_batch, target_batch = input_batch.to(device), target_batch.to(device)
- logits = model(input_batch)
- loss = torch.nn.functional.cross_entropy(logits.flatten(0, 1), target_batch.flatten())
- return loss
- def calc_loss_loader(data_loader, model, device, num_batches=None):
- total_loss = 0.
- if len(data_loader) == 0:
- return float("nan")
- elif num_batches is None:
- num_batches = len(data_loader)
- else:
- num_batches = min(num_batches, len(data_loader))
- for i, (input_batch, target_batch) in enumerate(data_loader):
- if i < num_batches:
- loss = calc_loss_batch(input_batch, target_batch, model, device)
- total_loss += loss.item()
- else:
- break
- return total_loss / num_batches
- def evaluate_model(model, train_loader, val_loader, device, eval_iter):
- model.eval()
- with torch.no_grad():
- train_loss = calc_loss_loader(train_loader, model, device, num_batches=eval_iter)
- val_loss = calc_loss_loader(val_loader, model, device, num_batches=eval_iter)
- model.train()
- return train_loss, val_loss
- def generate_and_print_sample(model, tokenizer, device, start_context):
- model.eval()
- context_size = model.pos_emb.weight.shape[0]
- encoded = text_to_token_ids(start_context, tokenizer).to(device)
- with torch.no_grad():
- token_ids = generate_text_simple(
- model=model, idx=encoded,
- max_new_tokens=50, context_size=context_size
- )
- decoded_text = token_ids_to_text(token_ids, tokenizer)
- print(decoded_text.replace("\n", " ")) # Compact print format
- model.train()
- def train_model_simple(model, train_loader, val_loader, optimizer, device, num_epochs,
- eval_freq, eval_iter, start_context, tokenizer):
- # Initialize lists to track losses and tokens seen
- train_losses, val_losses, track_tokens_seen = [], [], []
- tokens_seen = 0
- global_step = -1
- # Main training loop
- for epoch in range(num_epochs):
- model.train() # Set model to training mode
- for input_batch, target_batch in train_loader:
- optimizer.zero_grad() # Reset loss gradients from previous batch iteration
- loss = calc_loss_batch(input_batch, target_batch, model, device)
- loss.backward() # Calculate loss gradients
- optimizer.step() # Update model weights using loss gradients
- tokens_seen += input_batch.numel()
- global_step += 1
- # Optional evaluation step
- if global_step % eval_freq == 0:
- train_loss, val_loss = evaluate_model(
- model, train_loader, val_loader, device, eval_iter)
- train_losses.append(train_loss)
- val_losses.append(val_loss)
- track_tokens_seen.append(tokens_seen)
- print(f"Ep {epoch+1} (Step {global_step:06d}): "
- f"Train loss {train_loss:.3f}, Val loss {val_loss:.3f}")
- # Print a sample text after each epoch
- generate_and_print_sample(
- model, tokenizer, device, start_context
- )
- return train_losses, val_losses, track_tokens_seen
- def plot_losses(epochs_seen, tokens_seen, train_losses, val_losses):
- fig, ax1 = plt.subplots()
- # Plot training and validation loss against epochs
- ax1.plot(epochs_seen, train_losses, label="Training loss")
- ax1.plot(epochs_seen, val_losses, linestyle="-.", label="Validation loss")
- ax1.set_xlabel("Epochs")
- ax1.set_ylabel("Loss")
- ax1.legend(loc="upper right")
- # Create a second x-axis for tokens seen
- ax2 = ax1.twiny() # Create a second x-axis that shares the same y-axis
- ax2.plot(tokens_seen, train_losses, alpha=0) # Invisible plot for aligning ticks
- ax2.set_xlabel("Tokens seen")
- fig.tight_layout() # Adjust layout to make room
- # plt.show()
- def main(gpt_config, settings):
- torch.manual_seed(123)
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
- ##############################
- # Download data if necessary
- ##############################
- file_path = "the-verdict.txt"
- url = "https://raw.githubusercontent.com/rasbt/LLMs-from-scratch/main/ch02/01_main-chapter-code/the-verdict.txt"
- if not os.path.exists(file_path):
- with urllib.request.urlopen(url) as response:
- text_data = response.read().decode('utf-8')
- with open(file_path, "w", encoding="utf-8") as file:
- file.write(text_data)
- else:
- with open(file_path, "r", encoding="utf-8") as file:
- text_data = file.read()
- ##############################
- # Initialize model
- ##############################
- model = GPTModel(gpt_config)
- model.to(device) # no assignment model = model.to(device) necessary for nn.Module classes
- optimizer = torch.optim.AdamW(
- model.parameters(), lr=settings["learning_rate"], weight_decay=settings["weight_decay"]
- )
- ##############################
- # Set up dataloaders
- ##############################
- # Train/validation ratio
- train_ratio = 0.90
- split_idx = int(train_ratio * len(text_data))
- train_loader = create_dataloader_v1(
- text_data[:split_idx],
- batch_size=settings["batch_size"],
- max_length=gpt_config["context_length"],
- stride=gpt_config["context_length"],
- drop_last=True,
- shuffle=True,
- num_workers=0
- )
- val_loader = create_dataloader_v1(
- text_data[split_idx:],
- batch_size=settings["batch_size"],
- max_length=gpt_config["context_length"],
- stride=gpt_config["context_length"],
- drop_last=False,
- shuffle=False,
- num_workers=0
- )
- ##############################
- # Train model
- ##############################
- tokenizer = tiktoken.get_encoding("gpt2")
- train_losses, val_losses, tokens_seen = train_model_simple(
- model, train_loader, val_loader, optimizer, device,
- num_epochs=settings["num_epochs"], eval_freq=5, eval_iter=1,
- start_context="Every effort moves you", tokenizer=tokenizer
- )
- return train_losses, val_losses, tokens_seen, model
- if __name__ == "__main__":
- GPT_CONFIG_124M = {
- "vocab_size": 50257, # Vocabulary size
- "context_length": 256, # Shortened context length (orig: 1024)
- "emb_dim": 768, # Embedding dimension
- "n_heads": 12, # Number of attention heads
- "n_layers": 12, # Number of layers
- "drop_rate": 0.1, # Dropout rate
- "qkv_bias": False # Query-key-value bias
- }
- OTHER_SETTINGS = {
- "learning_rate": 5e-4,
- "num_epochs": 10,
- "batch_size": 2,
- "weight_decay": 0.1
- }
- ###########################
- # Initiate training
- ###########################
- train_losses, val_losses, tokens_seen, model = main(GPT_CONFIG_124M, OTHER_SETTINGS)
- ###########################
- # After training
- ###########################
- # Plot results
- epochs_tensor = torch.linspace(0, OTHER_SETTINGS["num_epochs"], len(train_losses))
- plot_losses(epochs_tensor, tokens_seen, train_losses, val_losses)
- plt.savefig("loss.pdf")
- # Save and load model
- torch.save(model.state_dict(), "model.pth")
- model = GPTModel(GPT_CONFIG_124M)
- model.load_state_dict(torch.load("model.pth", weights_only=True))
|