| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533 |
- # Copyright (c) Sebastian Raschka under Apache License 2.0 (see LICENSE.txt).
- # Source for "Build a Large Language Model From Scratch"
- # - https://www.manning.com/books/build-a-large-language-model-from-scratch
- # Code: https://github.com/rasbt/LLMs-from-scratch
- import os
- import time
- import urllib.request
- import matplotlib.pyplot as plt
- import torch
- import torch.nn as nn
- from torch.utils.data import Dataset, DataLoader
- import tiktoken
- #####################################
- # Chapter 2
- #####################################
- class GPTDatasetV1(Dataset):
- def __init__(self, txt, tokenizer, max_length, stride):
- self.input_ids = []
- self.target_ids = []
- # Tokenize the entire text
- token_ids = tokenizer.encode(txt, allowed_special={"<|endoftext|>"})
- # Use a sliding window to chunk the book into overlapping sequences of max_length
- for i in range(0, len(token_ids) - max_length, stride):
- input_chunk = token_ids[i:i + max_length]
- target_chunk = token_ids[i + 1: i + max_length + 1]
- self.input_ids.append(torch.tensor(input_chunk))
- self.target_ids.append(torch.tensor(target_chunk))
- def __len__(self):
- return len(self.input_ids)
- def __getitem__(self, idx):
- return self.input_ids[idx], self.target_ids[idx]
- def create_dataloader_v1(txt, batch_size=4, max_length=256,
- stride=128, shuffle=True, drop_last=True, num_workers=0):
- # Initialize the tokenizer
- tokenizer = tiktoken.get_encoding("gpt2")
- # Create dataset
- dataset = GPTDatasetV1(txt, tokenizer, max_length, stride)
- # Create dataloader
- dataloader = DataLoader(
- dataset, batch_size=batch_size, shuffle=shuffle, drop_last=drop_last, num_workers=num_workers)
- return dataloader
- #####################################
- # Chapter 3
- #####################################
- class MultiHeadAttention(nn.Module):
- def __init__(self, d_in, d_out, context_length, dropout, num_heads, qkv_bias=False):
- super().__init__()
- assert d_out % num_heads == 0, "d_out must be divisible by n_heads"
- self.d_out = d_out
- self.num_heads = num_heads
- self.head_dim = d_out // num_heads # Reduce the projection dim to match desired output dim
- self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)
- self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias)
- self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)
- self.out_proj = nn.Linear(d_out, d_out) # Linear layer to combine head outputs
- self.dropout = nn.Dropout(dropout)
- self.register_buffer("mask", torch.triu(torch.ones(context_length, context_length), diagonal=1))
- def forward(self, x):
- b, num_tokens, d_in = x.shape
- keys = self.W_key(x) # Shape: (b, num_tokens, d_out)
- queries = self.W_query(x)
- values = self.W_value(x)
- # We implicitly split the matrix by adding a `num_heads` dimension
- # Unroll last dim: (b, num_tokens, d_out) -> (b, num_tokens, num_heads, head_dim)
- keys = keys.view(b, num_tokens, self.num_heads, self.head_dim)
- values = values.view(b, num_tokens, self.num_heads, self.head_dim)
- queries = queries.view(b, num_tokens, self.num_heads, self.head_dim)
- # Transpose: (b, num_tokens, num_heads, head_dim) -> (b, num_heads, num_tokens, head_dim)
- keys = keys.transpose(1, 2)
- queries = queries.transpose(1, 2)
- values = values.transpose(1, 2)
- # Compute scaled dot-product attention (aka self-attention) with a causal mask
- attn_scores = queries @ keys.transpose(2, 3) # Dot product for each head
- # Original mask truncated to the number of tokens and converted to boolean
- mask_bool = self.mask.bool()[:num_tokens, :num_tokens]
- # Use the mask to fill attention scores
- attn_scores.masked_fill_(mask_bool, -torch.inf)
- attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=-1)
- attn_weights = self.dropout(attn_weights)
- # Shape: (b, num_tokens, num_heads, head_dim)
- context_vec = (attn_weights @ values).transpose(1, 2)
- # Combine heads, where self.d_out = self.num_heads * self.head_dim
- context_vec = context_vec.reshape(b, num_tokens, self.d_out)
- context_vec = self.out_proj(context_vec) # optional projection
- return context_vec
- #####################################
- # Chapter 4
- #####################################
- class LayerNorm(nn.Module):
- def __init__(self, emb_dim):
- super().__init__()
- self.eps = 1e-5
- self.scale = nn.Parameter(torch.ones(emb_dim))
- self.shift = nn.Parameter(torch.zeros(emb_dim))
- def forward(self, x):
- mean = x.mean(dim=-1, keepdim=True)
- var = x.var(dim=-1, keepdim=True, unbiased=False)
- norm_x = (x - mean) / torch.sqrt(var + self.eps)
- return self.scale * norm_x + self.shift
- class GELU(nn.Module):
- def __init__(self):
- super().__init__()
- def forward(self, x):
- return 0.5 * x * (1 + torch.tanh(
- torch.sqrt(torch.tensor(2.0 / torch.pi)) *
- (x + 0.044715 * torch.pow(x, 3))
- ))
- class FeedForward(nn.Module):
- def __init__(self, cfg):
- super().__init__()
- self.layers = nn.Sequential(
- nn.Linear(cfg["emb_dim"], 4 * cfg["emb_dim"]),
- GELU(),
- nn.Linear(4 * cfg["emb_dim"], cfg["emb_dim"]),
- )
- def forward(self, x):
- return self.layers(x)
- class TransformerBlock(nn.Module):
- def __init__(self, cfg):
- super().__init__()
- self.att = MultiHeadAttention(
- d_in=cfg["emb_dim"],
- d_out=cfg["emb_dim"],
- context_length=cfg["context_length"],
- num_heads=cfg["n_heads"],
- dropout=cfg["drop_rate"],
- qkv_bias=cfg["qkv_bias"])
- self.ff = FeedForward(cfg)
- self.norm1 = LayerNorm(cfg["emb_dim"])
- self.norm2 = LayerNorm(cfg["emb_dim"])
- self.drop_shortcut = nn.Dropout(cfg["drop_rate"])
- def forward(self, x):
- # Shortcut connection for attention block
- shortcut = x
- x = self.norm1(x)
- x = self.att(x) # Shape [batch_size, num_tokens, emb_size]
- x = self.drop_shortcut(x)
- x = x + shortcut # Add the original input back
- # Shortcut connection for feed-forward block
- shortcut = x
- x = self.norm2(x)
- x = self.ff(x)
- x = self.drop_shortcut(x)
- x = x + shortcut # Add the original input back
- return x
- class GPTModel(nn.Module):
- def __init__(self, cfg):
- super().__init__()
- self.tok_emb = nn.Embedding(cfg["vocab_size"], cfg["emb_dim"])
- self.pos_emb = nn.Embedding(cfg["context_length"], cfg["emb_dim"])
- self.drop_emb = nn.Dropout(cfg["drop_rate"])
- self.trf_blocks = nn.Sequential(
- *[TransformerBlock(cfg) for _ in range(cfg["n_layers"])])
- self.final_norm = LayerNorm(cfg["emb_dim"])
- self.out_head = nn.Linear(cfg["emb_dim"], cfg["vocab_size"], bias=False)
- def forward(self, in_idx):
- batch_size, seq_len = in_idx.shape
- tok_embeds = self.tok_emb(in_idx)
- pos_embeds = self.pos_emb(torch.arange(seq_len, device=in_idx.device))
- x = tok_embeds + pos_embeds # Shape [batch_size, num_tokens, emb_size]
- x = self.drop_emb(x)
- x = self.trf_blocks(x)
- x = self.final_norm(x)
- logits = self.out_head(x)
- return logits
- def generate_text_simple(model, idx, max_new_tokens, context_size):
- # idx is (B, T) array of indices in the current context
- for _ in range(max_new_tokens):
- # Crop current context if it exceeds the supported context size
- # E.g., if LLM supports only 5 tokens, and the context size is 10
- # then only the last 5 tokens are used as context
- idx_cond = idx[:, -context_size:]
- # Get the predictions
- with torch.no_grad():
- logits = model(idx_cond)
- # Focus only on the last time step
- # (batch, n_token, vocab_size) becomes (batch, vocab_size)
- logits = logits[:, -1, :]
- # Get the idx of the vocab entry with the highest logits value
- idx_next = torch.argmax(logits, dim=-1, keepdim=True) # (batch, 1)
- # Append sampled index to the running sequence
- idx = torch.cat((idx, idx_next), dim=1) # (batch, n_tokens+1)
- return idx
- #####################################
- # Chapter 5
- #####################################
- def text_to_token_ids(text, tokenizer):
- encoded = tokenizer.encode(text)
- encoded_tensor = torch.tensor(encoded).unsqueeze(0) # add batch dimension
- return encoded_tensor
- def token_ids_to_text(token_ids, tokenizer):
- flat = token_ids.squeeze(0) # remove batch dimension
- return tokenizer.decode(flat.tolist())
- def calc_loss_batch(input_batch, target_batch, model, device):
- input_batch, target_batch = input_batch.to(device), target_batch.to(device)
- logits = model(input_batch)
- loss = torch.nn.functional.cross_entropy(logits.flatten(0, 1), target_batch.flatten())
- return loss
- def calc_loss_loader(data_loader, model, device, num_batches=None):
- total_loss = 0.
- if len(data_loader) == 0:
- return float("nan")
- elif num_batches is None:
- num_batches = len(data_loader)
- else:
- num_batches = min(num_batches, len(data_loader))
- for i, (input_batch, target_batch) in enumerate(data_loader):
- if i < num_batches:
- loss = calc_loss_batch(input_batch, target_batch, model, device)
- total_loss += loss.item()
- else:
- break
- return total_loss / num_batches
- def evaluate_model(model, train_loader, val_loader, device, eval_iter):
- model.eval()
- with torch.no_grad():
- train_loss = calc_loss_loader(train_loader, model, device, num_batches=eval_iter)
- val_loss = calc_loss_loader(val_loader, model, device, num_batches=eval_iter)
- model.train()
- return train_loss, val_loss
- def generate_and_print_sample(model, tokenizer, device, start_context):
- model.eval()
- context_size = model.pos_emb.weight.shape[0]
- encoded = text_to_token_ids(start_context, tokenizer).to(device)
- with torch.no_grad():
- token_ids = generate_text_simple(
- model=model, idx=encoded,
- max_new_tokens=50, context_size=context_size
- )
- decoded_text = token_ids_to_text(token_ids, tokenizer)
- print(decoded_text.replace("\n", " ")) # Compact print format
- model.train()
- def train_model_simple_with_timing(model, train_loader, val_loader, optimizer, device,
- num_epochs, eval_freq, eval_iter, start_context, tokenizer):
- train_losses, val_losses, track_tokens = [], [], []
- total_tokens, global_step, last_tokens = 0, -1, 0
- # Variables for cumulative average tokens/sec
- cumulative_tokens, cumulative_time = 0.0, 0.0
- # CUDA-specific timing setup
- use_cuda = device.type == "cuda"
- if use_cuda:
- t_start = torch.cuda.Event(enable_timing=True)
- t_end = torch.cuda.Event(enable_timing=True)
- torch.cuda.synchronize() # Ensure all prior CUDA operations are done
- t_start.record() # Start the timer for the first interval
- else:
- t0 = time.time() # Start the timer for the first interval
- # Main training loop
- for epoch in range(num_epochs):
- model.train()
- for inp_batch, tgt_batch in train_loader:
- optimizer.zero_grad()
- global_step += 1
- # Forward and backward pass
- loss = calc_loss_batch(inp_batch, tgt_batch, model, device)
- loss.backward()
- optimizer.step()
- total_tokens += inp_batch.numel()
- # At evaluation intervals, measure elapsed time and tokens per second
- if global_step % eval_freq == 0:
- # End timing for the current interval
- if use_cuda:
- t_end.record()
- torch.cuda.synchronize() # Wait for all CUDA ops to complete.
- elapsed = t_start.elapsed_time(t_end) / 1000 # Convert ms to seconds
- t_start.record() # Reset timer for the next interval
- else:
- elapsed = time.time() - t0
- t0 = time.time() # Reset timer for the next interval
- # Calculate tokens processed in this interval
- tokens_interval = total_tokens - last_tokens
- last_tokens = total_tokens
- tps = tokens_interval / elapsed if elapsed > 0 else 0 # Tokens per second
- # Update cumulative counters (skip the first evaluation interval)
- if global_step: # This is False only when global_step == 0 (first evaluation)
- cumulative_tokens += tokens_interval
- cumulative_time += elapsed
- # Compute cumulative average tokens/sec (excluding the first interval)
- avg_tps = cumulative_tokens / cumulative_time if cumulative_time > 0 else 0
- # Evaluate model performance (this may add overhead)
- train_loss, val_loss = evaluate_model(model, train_loader, val_loader, device, eval_iter)
- train_losses.append(train_loss)
- val_losses.append(val_loss)
- track_tokens.append(total_tokens)
- print(f"Ep {epoch+1}, Step {global_step:06d}, "
- f"Train: {train_loss:.3f}, Val: {val_loss:.3f}, "
- f"Step tok/sec: {round(tps)}, Avg tok/sec: {round(avg_tps)}")
- generate_and_print_sample(model, tokenizer, device, start_context)
- # Memory stats
- if torch.cuda.is_available():
- device = torch.cuda.current_device()
- allocated = torch.cuda.memory_allocated(device) / 1024**3 # Convert to GB
- reserved = torch.cuda.memory_reserved(device) / 1024**3 # Convert to GB
- print(f"\nAllocated memory: {allocated:.4f} GB")
- print(f"Reserved memory: {reserved:.4f} GB\n")
- return train_losses, val_losses, track_tokens
- def plot_losses(epochs_seen, tokens_seen, train_losses, val_losses):
- fig, ax1 = plt.subplots()
- # Plot training and validation loss against epochs
- ax1.plot(epochs_seen, train_losses, label="Training loss")
- ax1.plot(epochs_seen, val_losses, linestyle="-.", label="Validation loss")
- ax1.set_xlabel("Epochs")
- ax1.set_ylabel("Loss")
- ax1.legend(loc="upper right")
- # Create a second x-axis for tokens seen
- ax2 = ax1.twiny() # Create a second x-axis that shares the same y-axis
- ax2.plot(tokens_seen, train_losses, alpha=0) # Invisible plot for aligning ticks
- ax2.set_xlabel("Tokens seen")
- fig.tight_layout() # Adjust layout to make room
- # plt.show()
- #####################################
- # Main function calls
- #####################################
- def main(gpt_config, settings):
- torch.manual_seed(123)
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
- print(f"PyTorch version: {torch.__version__}")
- print(f"Using {device}")
- if torch.cuda.is_available():
- print(f"CUDA version: {torch.version.cuda}")
- print()
- ##############################
- # Download data if necessary
- ##############################
- file_path = "middlemarch.txt"
- url = "https://www.gutenberg.org/cache/epub/145/pg145.txt"
- if not os.path.exists(file_path):
- with urllib.request.urlopen(url) as response:
- text_data = response.read().decode('utf-8')
- with open(file_path, "w", encoding="utf-8") as file:
- file.write(text_data)
- else:
- with open(file_path, "r", encoding="utf-8") as file:
- text_data = file.read()
- ##############################
- # Initialize model
- ##############################
- model = GPTModel(gpt_config)
- model.to(device) # no assignment model = model.to(device) necessary for nn.Module classes
- optimizer = torch.optim.AdamW(
- model.parameters(), lr=settings["learning_rate"], weight_decay=settings["weight_decay"]
- )
- ##############################
- # Set up dataloaders
- ##############################
- # Train/validation ratio
- train_ratio = 0.90
- split_idx = int(train_ratio * len(text_data))
- train_loader = create_dataloader_v1(
- text_data[:split_idx],
- batch_size=settings["batch_size"],
- max_length=gpt_config["context_length"],
- stride=gpt_config["context_length"],
- drop_last=True,
- shuffle=True,
- num_workers=4
- )
- val_loader = create_dataloader_v1(
- text_data[split_idx:],
- batch_size=settings["batch_size"],
- max_length=gpt_config["context_length"],
- stride=gpt_config["context_length"],
- drop_last=False,
- shuffle=False,
- num_workers=4
- )
- ##############################
- # Train model
- ##############################
- tokenizer = tiktoken.get_encoding("gpt2")
- train_losses, val_losses, tokens_seen = train_model_simple_with_timing(
- model=model,
- train_loader=train_loader,
- val_loader=val_loader,
- optimizer=optimizer,
- device=device,
- num_epochs=settings["num_epochs"],
- eval_freq=15,
- eval_iter=1,
- start_context="Every effort moves you",
- tokenizer=tokenizer
- )
- return train_losses, val_losses, tokens_seen, model
- if __name__ == "__main__":
- GPT_CONFIG_124M = {
- "vocab_size": 50257, # Vocabulary size
- "context_length": 1024, # Input tokens per training example
- "emb_dim": 768, # Embedding dimension
- "n_heads": 12, # Number of attention heads
- "n_layers": 12, # Number of layers
- "drop_rate": 0.1, # Dropout rate
- "qkv_bias": False # Query-key-value bias
- }
- OTHER_SETTINGS = {
- "learning_rate": 5e-4,
- "num_epochs": 15,
- "batch_size": 8,
- "weight_decay": 0.1
- }
- ###########################
- # Initiate training
- ###########################
- train_losses, val_losses, tokens_seen, model = main(GPT_CONFIG_124M, OTHER_SETTINGS)
- ###########################
- # After training
- ###########################
- # Plot results
- epochs_tensor = torch.linspace(0, OTHER_SETTINGS["num_epochs"], len(train_losses))
- plot_losses(epochs_tensor, tokens_seen, train_losses, val_losses)
- plt.savefig("loss.pdf")
- # Save and load model
- # torch.save(model.state_dict(), "model.pth")
- # model = GPTModel(GPT_CONFIG_124M)
- # model.load_state_dict(torch.load("model.pth", weights_only=True))
|