gpt_instruction_finetuning.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351
  1. # Copyright (c) Sebastian Raschka under Apache License 2.0 (see LICENSE.txt).
  2. # Source for "Build a Large Language Model From Scratch"
  3. # - https://www.manning.com/books/build-a-large-language-model-from-scratch
  4. # Code: https://github.com/rasbt/LLMs-from-scratch
  5. #
  6. # A minimal instruction finetuning file based on the code in chapter 7
  7. from functools import partial
  8. from importlib.metadata import version
  9. import json
  10. import os
  11. import re
  12. import time
  13. import urllib
  14. import matplotlib.pyplot as plt
  15. import tiktoken
  16. import torch
  17. from torch.utils.data import Dataset, DataLoader
  18. from tqdm import tqdm
  19. # Import from local files in this folder
  20. from gpt_download import download_and_load_gpt2
  21. from previous_chapters import (
  22. calc_loss_loader,
  23. generate,
  24. GPTModel,
  25. load_weights_into_gpt,
  26. text_to_token_ids,
  27. train_model_simple,
  28. token_ids_to_text
  29. )
  30. class InstructionDataset(Dataset):
  31. def __init__(self, data, tokenizer):
  32. self.data = data
  33. # Pre-tokenize texts
  34. self.encoded_texts = []
  35. for entry in data:
  36. instruction_plus_input = format_input(entry)
  37. response_text = f"\n\n### Response:\n{entry['output']}"
  38. full_text = instruction_plus_input + response_text
  39. self.encoded_texts.append(
  40. tokenizer.encode(full_text)
  41. )
  42. def __getitem__(self, index):
  43. return self.encoded_texts[index]
  44. def __len__(self):
  45. return len(self.data)
  46. def custom_collate_fn(
  47. batch,
  48. pad_token_id=50256,
  49. ignore_index=-100,
  50. allowed_max_length=None,
  51. device="cpu"
  52. ):
  53. # Find the longest sequence in the batch
  54. batch_max_length = max(len(item)+1 for item in batch)
  55. # Pad and prepare inputs and targets
  56. inputs_lst, targets_lst = [], []
  57. for item in batch:
  58. new_item = item.copy()
  59. # Add an <|endoftext|> token
  60. new_item += [pad_token_id]
  61. # Pad sequences to max_length
  62. padded = new_item + [pad_token_id] * (batch_max_length - len(new_item))
  63. inputs = torch.tensor(padded[:-1]) # Truncate the last token for inputs
  64. targets = torch.tensor(padded[1:]) # Shift +1 to the right for targets
  65. # New: Replace all but the first padding tokens in targets by ignore_index
  66. mask = targets == pad_token_id
  67. indices = torch.nonzero(mask).squeeze()
  68. if indices.numel() > 1:
  69. targets[indices[1:]] = ignore_index
  70. # New: Optionally truncate to maximum sequence length
  71. if allowed_max_length is not None:
  72. inputs = inputs[:allowed_max_length]
  73. targets = targets[:allowed_max_length]
  74. inputs_lst.append(inputs)
  75. targets_lst.append(targets)
  76. # Convert list of inputs and targets to tensors and transfer to target device
  77. inputs_tensor = torch.stack(inputs_lst).to(device)
  78. targets_tensor = torch.stack(targets_lst).to(device)
  79. return inputs_tensor, targets_tensor
  80. def download_and_load_file(file_path, url):
  81. if not os.path.exists(file_path):
  82. with urllib.request.urlopen(url) as response:
  83. text_data = response.read().decode("utf-8")
  84. with open(file_path, "w", encoding="utf-8") as file:
  85. file.write(text_data)
  86. else:
  87. with open(file_path, "r", encoding="utf-8") as file:
  88. text_data = file.read()
  89. with open(file_path, "r") as file:
  90. data = json.load(file)
  91. return data
  92. def format_input(entry):
  93. instruction_text = (
  94. f"Below is an instruction that describes a task. "
  95. f"Write a response that appropriately completes the request."
  96. f"\n\n### Instruction:\n{entry['instruction']}"
  97. )
  98. input_text = f"\n\n### Input:\n{entry['input']}" if entry["input"] else ""
  99. return instruction_text + input_text
  100. def plot_losses(epochs_seen, tokens_seen, train_losses, val_losses):
  101. fig, ax1 = plt.subplots(figsize=(12, 6))
  102. # Plot training and validation loss against epochs
  103. ax1.plot(epochs_seen, train_losses, label="Training loss")
  104. ax1.plot(epochs_seen, val_losses, linestyle="-.", label="Validation loss")
  105. ax1.set_xlabel("Epochs")
  106. ax1.set_ylabel("Loss")
  107. ax1.legend(loc="upper right")
  108. # Create a second x-axis for tokens seen
  109. ax2 = ax1.twiny() # Create a second x-axis that shares the same y-axis
  110. ax2.plot(tokens_seen, train_losses, alpha=0) # Invisible plot for aligning ticks
  111. ax2.set_xlabel("Tokens seen")
  112. fig.tight_layout() # Adjust layout to make room
  113. plot_name = "loss-plot-standalone.pdf"
  114. print(f"Plot saved as {plot_name}")
  115. plt.savefig(plot_name)
  116. # plt.show()
  117. def main(test_mode=False):
  118. #######################################
  119. # Print package versions
  120. #######################################
  121. print()
  122. pkgs = [
  123. "matplotlib", # Plotting library
  124. "tiktoken", # Tokenizer
  125. "torch", # Deep learning library
  126. "tqdm", # Progress bar
  127. "tensorflow", # For OpenAI's pretrained weights
  128. ]
  129. for p in pkgs:
  130. print(f"{p} version: {version(p)}")
  131. print(50*"-")
  132. #######################################
  133. # Download and prepare dataset
  134. #######################################
  135. file_path = "instruction-data.json"
  136. url = "https://raw.githubusercontent.com/rasbt/LLMs-from-scratch/main/ch07/01_main-chapter-code/instruction-data.json"
  137. data = download_and_load_file(file_path, url)
  138. train_portion = int(len(data) * 0.85) # 85% for training
  139. test_portion = int(len(data) * 0.1) # 10% for testing
  140. train_data = data[:train_portion]
  141. test_data = data[train_portion:train_portion + test_portion]
  142. val_data = data[train_portion + test_portion:]
  143. # Use very small subset for testing purposes
  144. if args.test_mode:
  145. train_data = train_data[:10]
  146. val_data = val_data[:10]
  147. test_data = test_data[:10]
  148. print("Training set length:", len(train_data))
  149. print("Validation set length:", len(val_data))
  150. print("Test set length:", len(test_data))
  151. print(50*"-")
  152. tokenizer = tiktoken.get_encoding("gpt2")
  153. device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
  154. print("Device:", device)
  155. print(50*"-")
  156. customized_collate_fn = partial(custom_collate_fn, device=device, allowed_max_length=1024)
  157. num_workers = 0
  158. batch_size = 8
  159. torch.manual_seed(123)
  160. train_dataset = InstructionDataset(train_data, tokenizer)
  161. train_loader = DataLoader(
  162. train_dataset,
  163. batch_size=batch_size,
  164. collate_fn=customized_collate_fn,
  165. shuffle=True,
  166. drop_last=True,
  167. num_workers=num_workers
  168. )
  169. val_dataset = InstructionDataset(val_data, tokenizer)
  170. val_loader = DataLoader(
  171. val_dataset,
  172. batch_size=batch_size,
  173. collate_fn=customized_collate_fn,
  174. shuffle=False,
  175. drop_last=False,
  176. num_workers=num_workers
  177. )
  178. #######################################
  179. # Load pretrained model
  180. #######################################
  181. # Small GPT model for testing purposes
  182. if args.test_mode:
  183. BASE_CONFIG = {
  184. "vocab_size": 50257,
  185. "context_length": 120,
  186. "drop_rate": 0.0,
  187. "qkv_bias": False,
  188. "emb_dim": 12,
  189. "n_layers": 1,
  190. "n_heads": 2
  191. }
  192. model = GPTModel(BASE_CONFIG)
  193. model.eval()
  194. device = "cpu"
  195. CHOOSE_MODEL = "Small test model"
  196. # Code as it is used in the main chapter
  197. else:
  198. BASE_CONFIG = {
  199. "vocab_size": 50257, # Vocabulary size
  200. "context_length": 1024, # Context length
  201. "drop_rate": 0.0, # Dropout rate
  202. "qkv_bias": True # Query-key-value bias
  203. }
  204. model_configs = {
  205. "gpt2-small (124M)": {"emb_dim": 768, "n_layers": 12, "n_heads": 12},
  206. "gpt2-medium (355M)": {"emb_dim": 1024, "n_layers": 24, "n_heads": 16},
  207. "gpt2-large (774M)": {"emb_dim": 1280, "n_layers": 36, "n_heads": 20},
  208. "gpt2-xl (1558M)": {"emb_dim": 1600, "n_layers": 48, "n_heads": 25},
  209. }
  210. CHOOSE_MODEL = "gpt2-medium (355M)"
  211. BASE_CONFIG.update(model_configs[CHOOSE_MODEL])
  212. model_size = CHOOSE_MODEL.split(" ")[-1].lstrip("(").rstrip(")")
  213. settings, params = download_and_load_gpt2(model_size=model_size, models_dir="gpt2")
  214. model = GPTModel(BASE_CONFIG)
  215. load_weights_into_gpt(model, params)
  216. model.eval()
  217. model.to(device)
  218. print("Loaded model:", CHOOSE_MODEL)
  219. print(50*"-")
  220. #######################################
  221. # Finetuning the model
  222. #######################################
  223. print("Initial losses")
  224. with torch.no_grad():
  225. train_loss = calc_loss_loader(train_loader, model, device, num_batches=5)
  226. val_loss = calc_loss_loader(val_loader, model, device, num_batches=5)
  227. print(" Training loss:", train_loss)
  228. print(" Validation loss:", val_loss)
  229. start_time = time.time()
  230. optimizer = torch.optim.AdamW(model.parameters(), lr=0.00005, weight_decay=0.1)
  231. num_epochs = 2
  232. torch.manual_seed(123)
  233. train_losses, val_losses, tokens_seen = train_model_simple(
  234. model, train_loader, val_loader, optimizer, device,
  235. num_epochs=num_epochs, eval_freq=5, eval_iter=5,
  236. start_context=format_input(val_data[0]), tokenizer=tokenizer
  237. )
  238. end_time = time.time()
  239. execution_time_minutes = (end_time - start_time) / 60
  240. print(f"Training completed in {execution_time_minutes:.2f} minutes.")
  241. epochs_tensor = torch.linspace(0, num_epochs, len(train_losses))
  242. plot_losses(epochs_tensor, tokens_seen, train_losses, val_losses)
  243. print(50*"-")
  244. #######################################
  245. # Saving results
  246. #######################################
  247. print("Generating responses")
  248. for i, entry in tqdm(enumerate(test_data), total=len(test_data)):
  249. input_text = format_input(entry)
  250. token_ids = generate(
  251. model=model,
  252. idx=text_to_token_ids(input_text, tokenizer).to(device),
  253. max_new_tokens=256,
  254. context_size=BASE_CONFIG["context_length"],
  255. eos_id=50256
  256. )
  257. generated_text = token_ids_to_text(token_ids, tokenizer)
  258. response_text = generated_text[len(input_text):].replace("### Response:", "").strip()
  259. test_data[i]["model_response"] = response_text
  260. test_data_path = "instruction-data-with-response-standalone.json"
  261. with open(test_data_path, "w") as file:
  262. json.dump(test_data, file, indent=4) # "indent" for pretty-printing
  263. print(f"Responses saved as {test_data_path}")
  264. file_name = f"{re.sub(r'[ ()]', '', CHOOSE_MODEL) }-sft-standalone.pth"
  265. torch.save(model.state_dict(), file_name)
  266. print(f"Model saved as {file_name}")
  267. if __name__ == "__main__":
  268. import argparse
  269. parser = argparse.ArgumentParser(
  270. description="Finetune a GPT model for classification"
  271. )
  272. parser.add_argument(
  273. "--test_mode",
  274. default=False,
  275. action="store_true",
  276. help=("This flag runs the model in test mode for internal testing purposes. "
  277. "Otherwise, it runs the model as it is used in the chapter (recommended).")
  278. )
  279. args = parser.parse_args()
  280. main(args.test_mode)