| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562 |
- # Copyright (c) Sebastian Raschka under Apache License 2.0 (see LICENSE.txt).
- # Source for "Build a Large Language Model From Scratch"
- # - https://www.manning.com/books/build-a-large-language-model-from-scratch
- # Code: https://github.com/rasbt/LLMs-from-scratch
- #
- # Code to run the exercises; see exercise-solutions.ipynb for more information
- from functools import partial
- from importlib.metadata import version
- import json
- import math
- import os
- import re
- import time
- import urllib
- import matplotlib.pyplot as plt
- from matplotlib.ticker import MaxNLocator
- import tiktoken
- import torch
- from torch.utils.data import Dataset, DataLoader
- from tqdm import tqdm
- # Import from local files in this folder
- from gpt_download import download_and_load_gpt2
- from previous_chapters import (
- calc_loss_loader,
- generate,
- GPTModel,
- load_weights_into_gpt,
- text_to_token_ids,
- train_model_simple,
- token_ids_to_text
- )
- class InstructionDataset(Dataset):
- def __init__(self, data, tokenizer):
- self.data = data
- # Pre-tokenize texts
- self.encoded_texts = []
- for entry in data:
- instruction_plus_input = format_input(entry)
- response_text = f"\n\n### Response:\n{entry['output']}"
- full_text = instruction_plus_input + response_text
- self.encoded_texts.append(
- tokenizer.encode(full_text)
- )
- def __getitem__(self, index):
- return self.encoded_texts[index]
- def __len__(self):
- return len(self.data)
- class InstructionDatasetWithMasking(Dataset):
- def __init__(self, data, tokenizer):
- self.data = data
- # New: Separate list for instruction lengths
- self.instruction_lengths = []
- self.encoded_texts = []
- for entry in data:
- instruction_plus_input = format_input(entry)
- response_text = f"\n\n### Response:\n{entry['output']}"
- full_text = instruction_plus_input + response_text
- self.encoded_texts.append(
- tokenizer.encode(full_text)
- )
- # New: collect instruction lengths
- instruction_length = len(tokenizer.encode(instruction_plus_input))
- self.instruction_lengths.append(instruction_length)
- def __getitem__(self, index):
- # New: return both instruction lengths and texts separately
- return self.instruction_lengths[index], self.encoded_texts[index]
- def __len__(self):
- return len(self.data)
- class InstructionDatasetPhi(Dataset):
- def __init__(self, data, tokenizer):
- self.data = data
- # Pre-tokenize texts
- self.encoded_texts = []
- for entry in data:
- ###################################################################
- # NEW: Use `format_input_phi` and adjust the response text template
- instruction_plus_input = format_input_phi(entry)
- response_text = f"\n<|assistant|>:\n{entry['output']}"
- ###################################################################
- full_text = instruction_plus_input + response_text
- self.encoded_texts.append(
- tokenizer.encode(full_text)
- )
- def __getitem__(self, index):
- return self.encoded_texts[index]
- def __len__(self):
- return len(self.data)
- class LinearWithLoRA(torch.nn.Module):
- def __init__(self, linear, rank, alpha):
- super().__init__()
- self.linear = linear
- self.lora = LoRALayer(
- linear.in_features, linear.out_features, rank, alpha
- )
- def forward(self, x):
- return self.linear(x) + self.lora(x)
- class LoRALayer(torch.nn.Module):
- def __init__(self, in_dim, out_dim, rank, alpha):
- super().__init__()
- self.A = torch.nn.Parameter(torch.empty(in_dim, rank))
- torch.nn.init.kaiming_uniform_(self.A, a=math.sqrt(5)) # similar to standard weight initialization
- self.B = torch.nn.Parameter(torch.zeros(rank, out_dim))
- self.alpha = alpha
- def forward(self, x):
- x = self.alpha * (x @ self.A @ self.B)
- return x
- def replace_linear_with_lora(model, rank, alpha):
- for name, module in model.named_children():
- if isinstance(module, torch.nn.Linear):
- # Replace the Linear layer with LinearWithLoRA
- setattr(model, name, LinearWithLoRA(module, rank, alpha))
- else:
- # Recursively apply the same function to child modules
- replace_linear_with_lora(module, rank, alpha)
- def custom_collate_fn(
- batch,
- pad_token_id=50256,
- ignore_index=-100,
- allowed_max_length=None,
- device="cpu"
- ):
- # Find the longest sequence in the batch
- batch_max_length = max(len(item)+1 for item in batch)
- # Pad and prepare inputs and targets
- inputs_lst, targets_lst = [], []
- for item in batch:
- new_item = item.copy()
- # Add an <|endoftext|> token
- new_item += [pad_token_id]
- # Pad sequences to max_length
- padded = new_item + [pad_token_id] * (batch_max_length - len(new_item))
- inputs = torch.tensor(padded[:-1]) # Truncate the last token for inputs
- targets = torch.tensor(padded[1:]) # Shift +1 to the right for targets
- # New: Replace all but the first padding tokens in targets by ignore_index
- mask = targets == pad_token_id
- indices = torch.nonzero(mask).squeeze()
- if indices.numel() > 1:
- targets[indices[1:]] = ignore_index
- # New: Optionally truncate to maximum sequence length
- if allowed_max_length is not None:
- inputs = inputs[:allowed_max_length]
- targets = targets[:allowed_max_length]
- inputs_lst.append(inputs)
- targets_lst.append(targets)
- # Convert list of inputs and targets to tensors and transfer to target device
- inputs_tensor = torch.stack(inputs_lst).to(device)
- targets_tensor = torch.stack(targets_lst).to(device)
- return inputs_tensor, targets_tensor
- def custom_collate_with_masking_fn(
- batch,
- pad_token_id=50256,
- ignore_index=-100,
- allowed_max_length=None,
- device="cpu"
- ):
- # Find the longest sequence in the batch
- batch_max_length = max(len(item)+1 for instruction_length, item in batch) # New: batch is now a tuple
- # Pad and prepare inputs and targets
- inputs_lst, targets_lst = [], []
- for instruction_length, item in batch: # New: batch is now a tuple
- new_item = item.copy()
- # Add an <|endoftext|> token
- new_item += [pad_token_id]
- # Pad sequences to max_length
- padded = new_item + [pad_token_id] * (batch_max_length - len(new_item))
- inputs = torch.tensor(padded[:-1]) # Truncate the last token for inputs
- targets = torch.tensor(padded[1:]) # Shift +1 to the right for targets
- # Replace all but the first padding tokens in targets by ignore_index
- mask = targets == pad_token_id
- indices = torch.nonzero(mask).squeeze()
- if indices.numel() > 1:
- targets[indices[1:]] = ignore_index
- # New: Mask all input and instruction tokens in the targets
- targets[:instruction_length-1] = -100
- # Optionally truncate to maximum sequence length
- if allowed_max_length is not None:
- inputs = inputs[:allowed_max_length]
- targets = targets[:allowed_max_length]
- inputs_lst.append(inputs)
- targets_lst.append(targets)
- # Convert list of inputs and targets to tensors and transfer to target device
- inputs_tensor = torch.stack(inputs_lst).to(device)
- targets_tensor = torch.stack(targets_lst).to(device)
- return inputs_tensor, targets_tensor
- def download_and_load_file(file_path, url):
- if not os.path.exists(file_path):
- with urllib.request.urlopen(url) as response:
- text_data = response.read().decode("utf-8")
- with open(file_path, "w", encoding="utf-8") as file:
- file.write(text_data)
- else:
- with open(file_path, "r", encoding="utf-8") as file:
- text_data = file.read()
- with open(file_path, "r") as file:
- data = json.load(file)
- return data
- def format_input_phi(entry):
- instruction_text = (
- f"<|user|>\n{entry['instruction']}"
- )
- input_text = f"\n{entry['input']}" if entry["input"] else ""
- return instruction_text + input_text
- def format_input(entry):
- instruction_text = (
- f"Below is an instruction that describes a task. "
- f"Write a response that appropriately completes the request."
- f"\n\n### Instruction:\n{entry['instruction']}"
- )
- input_text = f"\n\n### Input:\n{entry['input']}" if entry["input"] else ""
- return instruction_text + input_text
- def plot_losses(epochs_seen, tokens_seen, train_losses, val_losses, plot_name):
- fig, ax1 = plt.subplots(figsize=(12, 6))
- # Plot training and validation loss against epochs
- ax1.plot(epochs_seen, train_losses, label="Training loss")
- ax1.plot(epochs_seen, val_losses, linestyle="-.", label="Validation loss")
- ax1.set_xlabel("Epochs")
- ax1.set_ylabel("Loss")
- ax1.legend(loc="upper right")
- ax1.xaxis.set_major_locator(MaxNLocator(integer=True)) # only show integer labels on x-axis
- # Create a second x-axis for tokens seen
- ax2 = ax1.twiny() # Create a second x-axis that shares the same y-axis
- ax2.plot(tokens_seen, train_losses, alpha=0) # Invisible plot for aligning ticks
- ax2.set_xlabel("Tokens seen")
- fig.tight_layout() # Adjust layout to make room
- print(f"Plot saved as {plot_name}")
- plt.savefig(plot_name)
- # plt.show()
- def main(mask_instructions=False, alpaca52k=False, phi3_prompt=False, lora=False):
- #######################################
- # Print package versions
- #######################################
- print()
- pkgs = [
- "matplotlib", # Plotting library
- "tiktoken", # Tokenizer
- "torch", # Deep learning library
- "tqdm", # Progress bar
- "tensorflow", # For OpenAI's pretrained weights
- ]
- for p in pkgs:
- print(f"{p} version: {version(p)}")
- print(50*"-")
- #######################################
- # Download and prepare dataset
- #######################################
- file_path = "instruction-data.json"
- if alpaca52k:
- url = "https://raw.githubusercontent.com/tatsu-lab/stanford_alpaca/main/alpaca_data.json"
- else:
- url = "https://raw.githubusercontent.com/rasbt/LLMs-from-scratch/main/ch07/01_main-chapter-code/instruction-data.json"
- data = download_and_load_file(file_path, url)
- train_portion = int(len(data) * 0.85) # 85% for training
- test_portion = int(len(data) * 0.1) # 10% for testing
- train_data = data[:train_portion]
- test_data = data[train_portion:train_portion + test_portion]
- val_data = data[train_portion + test_portion:]
- print("Training set length:", len(train_data))
- print("Validation set length:", len(val_data))
- print("Test set length:", len(test_data))
- print(50*"-")
- tokenizer = tiktoken.get_encoding("gpt2")
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
- print("Device:", device)
- print(50*"-")
- if alpaca52k:
- allowed_max_length = 512
- else:
- allowed_max_length = 1024
- if mask_instructions and phi3_prompt:
- raise ValueError("Simultaneous support for instruction masking and the Phi-3 prompt template has not been implemented, yet.")
- if mask_instructions:
- customized_collate_fn = partial(custom_collate_with_masking_fn, device=device, allowed_max_length=allowed_max_length)
- CustomDataset = InstructionDatasetWithMasking
- elif phi3_prompt:
- customized_collate_fn = partial(custom_collate_fn, device=device, allowed_max_length=allowed_max_length)
- CustomDataset = InstructionDatasetPhi
- else:
- customized_collate_fn = partial(custom_collate_fn, device=device, allowed_max_length=allowed_max_length)
- CustomDataset = InstructionDataset
- num_workers = 0
- if alpaca52k:
- batch_size = 4
- else:
- batch_size = 8
- torch.manual_seed(123)
- train_dataset = CustomDataset(train_data, tokenizer)
- train_loader = DataLoader(
- train_dataset,
- batch_size=batch_size,
- collate_fn=customized_collate_fn,
- shuffle=True,
- drop_last=True,
- num_workers=num_workers
- )
- val_dataset = CustomDataset(val_data, tokenizer)
- val_loader = DataLoader(
- val_dataset,
- batch_size=batch_size,
- collate_fn=customized_collate_fn,
- shuffle=False,
- drop_last=False,
- num_workers=num_workers
- )
- #######################################
- # Load pretrained model
- #######################################
- BASE_CONFIG = {
- "vocab_size": 50257, # Vocabulary size
- "context_length": 1024, # Context length
- "drop_rate": 0.0, # Dropout rate
- "qkv_bias": True # Query-key-value bias
- }
- model_configs = {
- "gpt2-small (124M)": {"emb_dim": 768, "n_layers": 12, "n_heads": 12},
- "gpt2-medium (355M)": {"emb_dim": 1024, "n_layers": 24, "n_heads": 16},
- "gpt2-large (774M)": {"emb_dim": 1280, "n_layers": 36, "n_heads": 20},
- "gpt2-xl (1558M)": {"emb_dim": 1600, "n_layers": 48, "n_heads": 25},
- }
- CHOOSE_MODEL = "gpt2-medium (355M)"
- BASE_CONFIG.update(model_configs[CHOOSE_MODEL])
- model_size = CHOOSE_MODEL.split(" ")[-1].lstrip("(").rstrip(")")
- settings, params = download_and_load_gpt2(model_size=model_size, models_dir="gpt2")
- model = GPTModel(BASE_CONFIG)
- load_weights_into_gpt(model, params)
- model.eval()
- model.to(device)
- print("Loaded model:", CHOOSE_MODEL)
- print(50*"-")
- if lora:
- total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
- print(f"Total trainable parameters before: {total_params:,}")
- for param in model.parameters():
- param.requires_grad = False
- total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
- print(f"Total trainable parameters after: {total_params:,}")
- replace_linear_with_lora(model, rank=16, alpha=16)
- total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
- print(f"Total trainable LoRA parameters: {total_params:,}")
- model.to(device)
- #######################################
- # Finetuning the model
- #######################################
- print("Initial losses")
- with torch.no_grad():
- train_loss = calc_loss_loader(train_loader, model, device, num_batches=5)
- val_loss = calc_loss_loader(val_loader, model, device, num_batches=5)
- print(" Training loss:", train_loss)
- print(" Validation loss:", val_loss)
- start_time = time.time()
- num_epochs = 2
- optimizer = torch.optim.AdamW(model.parameters(), lr=0.00005, weight_decay=0.1)
- torch.manual_seed(123)
- start_context = format_input_phi(val_data[0]) if phi3_prompt else format_input(val_data[0])
- train_losses, val_losses, tokens_seen = train_model_simple(
- model, train_loader, val_loader, optimizer, device,
- num_epochs=num_epochs, eval_freq=5, eval_iter=5,
- start_context=start_context, tokenizer=tokenizer
- )
- end_time = time.time()
- execution_time_minutes = (end_time - start_time) / 60
- print(f"Training completed in {execution_time_minutes:.2f} minutes.")
- epochs_tensor = torch.linspace(0, num_epochs, len(train_losses))
- plot_name = "loss-plot.pdf"
- if mask_instructions:
- plot_name = plot_name.replace(".pdf", "-mask-instructions.pdf")
- if alpaca52k:
- plot_name = plot_name.replace(".pdf", "-alpaca52k.pdf")
- if phi3_prompt:
- plot_name = plot_name.replace(".pdf", "-phi3-prompt.pdf")
- if lora:
- plot_name = plot_name.replace(".pdf", "-lora.pdf")
- if not any([mask_instructions, alpaca52k, phi3_prompt, lora]):
- plot_name = plot_name.replace(".pdf", "-baseline.pdf")
- plot_losses(epochs_tensor, tokens_seen, train_losses, val_losses, plot_name)
- print(50*"-")
- #######################################
- # Saving results
- #######################################
- print("Generating responses")
- for i, entry in tqdm(enumerate(test_data), total=len(test_data)):
- input_text = format_input_phi(entry) if phi3_prompt else format_input(entry)
- token_ids = generate(
- model=model,
- idx=text_to_token_ids(input_text, tokenizer).to(device),
- max_new_tokens=256,
- context_size=BASE_CONFIG["context_length"],
- eos_id=50256
- )
- generated_text = token_ids_to_text(token_ids, tokenizer)
- if phi3_prompt:
- response_text = generated_text[len(input_text):].replace("<|assistant|>:", "").strip()
- else:
- response_text = generated_text[len(input_text):].replace("### Response:", "").strip()
- test_data[i]["model_response"] = response_text
- test_data_path = "instruction-data-with-response.json"
- file_name = f"{re.sub(r'[ ()]', '', CHOOSE_MODEL) }-sft.pth"
- if mask_instructions:
- test_data_path = test_data_path.replace(".json", "-mask-instructions.json")
- file_name = file_name.replace(".pth", "-mask-instructions.pth")
- if alpaca52k:
- test_data_path = test_data_path.replace(".json", "-alpaca52k.json")
- file_name = file_name.replace(".pth", "-alpaca52k.pth")
- if phi3_prompt:
- test_data_path = test_data_path.replace(".json", "-phi3-prompt.json")
- file_name = file_name.replace(".pth", "-phi3-prompt.pth")
- if lora:
- test_data_path = test_data_path.replace(".json", "-lora.json")
- file_name = file_name.replace(".pth", "-lora.pth")
- if not any([mask_instructions, alpaca52k, phi3_prompt, lora]):
- test_data_path = test_data_path.replace(".json", "-baseline.json")
- file_name = file_name.replace(".pth", "-baseline.pth")
- with open(test_data_path, "w") as file:
- json.dump(test_data, file, indent=4) # "indent" for pretty-printing
- print(f"Responses saved as {test_data_path}")
- torch.save(model.state_dict(), file_name)
- print(f"Model saved as {file_name}")
- if __name__ == "__main__":
- import argparse
- parser = argparse.ArgumentParser(
- description="Instruction finetune a GPT model"
- )
- options = {"baseline", "mask_instructions", "alpaca_52k", "phi3_prompt", "lora"}
- parser.add_argument(
- "--exercise_solution",
- type=str,
- default="last_block",
- help=(
- f"Which experiment to run. Options: {options}."
- )
- )
- args = parser.parse_args()
- if args.exercise_solution == "baseline":
- main()
- elif args.exercise_solution == "mask_instructions":
- main(mask_instructions=True)
- elif args.exercise_solution == "alpaca_52k":
- main(alpaca52k=True)
- elif args.exercise_solution == "phi3_prompt":
- main(phi3_prompt=True)
- elif args.exercise_solution == "lora":
- main(lora=True)
- else:
- raise ValueError(f"{args.exercise_solution} is not a valid --args.exercise_solution option. Options: {options}")
|