gpt_instruction_finetuning.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348
  1. # Copyright (c) Sebastian Raschka under Apache License 2.0 (see LICENSE.txt).
  2. # Source for "Build a Large Language Model From Scratch"
  3. # - https://www.manning.com/books/build-a-large-language-model-from-scratch
  4. # Code: https://github.com/rasbt/LLMs-from-scratch
  5. #
  6. # A minimal instruction finetuning file based on the code in chapter 7
  7. from functools import partial
  8. from importlib.metadata import version
  9. import json
  10. import os
  11. import re
  12. import time
  13. import urllib
  14. import matplotlib.pyplot as plt
  15. import tiktoken
  16. import torch
  17. from torch.utils.data import Dataset, DataLoader
  18. from tqdm import tqdm
  19. # Import from local files in this folder
  20. from gpt_download import download_and_load_gpt2
  21. from previous_chapters import (
  22. calc_loss_loader,
  23. generate,
  24. GPTModel,
  25. load_weights_into_gpt,
  26. text_to_token_ids,
  27. train_model_simple,
  28. token_ids_to_text
  29. )
  30. class InstructionDataset(Dataset):
  31. def __init__(self, data, tokenizer):
  32. self.data = data
  33. # Pre-tokenize texts
  34. self.encoded_texts = []
  35. for entry in data:
  36. instruction_plus_input = format_input(entry)
  37. response_text = f"\n\n### Response:\n{entry['output']}"
  38. full_text = instruction_plus_input + response_text
  39. self.encoded_texts.append(
  40. tokenizer.encode(full_text)
  41. )
  42. def __getitem__(self, index):
  43. return self.encoded_texts[index]
  44. def __len__(self):
  45. return len(self.data)
  46. def custom_collate_fn(
  47. batch,
  48. pad_token_id=50256,
  49. ignore_index=-100,
  50. allowed_max_length=None,
  51. device="cpu"
  52. ):
  53. # Find the longest sequence in the batch
  54. batch_max_length = max(len(item)+1 for item in batch)
  55. # Pad and prepare inputs and targets
  56. inputs_lst, targets_lst = [], []
  57. for item in batch:
  58. new_item = item.copy()
  59. # Add an <|endoftext|> token
  60. new_item += [pad_token_id]
  61. # Pad sequences to max_length
  62. padded = new_item + [pad_token_id] * (batch_max_length - len(new_item))
  63. inputs = torch.tensor(padded[:-1]) # Truncate the last token for inputs
  64. targets = torch.tensor(padded[1:]) # Shift +1 to the right for targets
  65. # New: Replace all but the first padding tokens in targets by ignore_index
  66. mask = targets == pad_token_id
  67. indices = torch.nonzero(mask).squeeze()
  68. if indices.numel() > 1:
  69. targets[indices[1:]] = ignore_index
  70. # New: Optionally truncate to maximum sequence length
  71. if allowed_max_length is not None:
  72. inputs = inputs[:allowed_max_length]
  73. targets = targets[:allowed_max_length]
  74. inputs_lst.append(inputs)
  75. targets_lst.append(targets)
  76. # Convert list of inputs and targets to tensors and transfer to target device
  77. inputs_tensor = torch.stack(inputs_lst).to(device)
  78. targets_tensor = torch.stack(targets_lst).to(device)
  79. return inputs_tensor, targets_tensor
  80. def download_and_load_file(file_path, url):
  81. if not os.path.exists(file_path):
  82. with urllib.request.urlopen(url) as response:
  83. text_data = response.read().decode("utf-8")
  84. with open(file_path, "w", encoding="utf-8") as file:
  85. file.write(text_data)
  86. with open(file_path, "r") as file:
  87. data = json.load(file)
  88. return data
  89. def format_input(entry):
  90. instruction_text = (
  91. f"Below is an instruction that describes a task. "
  92. f"Write a response that appropriately completes the request."
  93. f"\n\n### Instruction:\n{entry['instruction']}"
  94. )
  95. input_text = f"\n\n### Input:\n{entry['input']}" if entry["input"] else ""
  96. return instruction_text + input_text
  97. def plot_losses(epochs_seen, tokens_seen, train_losses, val_losses):
  98. fig, ax1 = plt.subplots(figsize=(12, 6))
  99. # Plot training and validation loss against epochs
  100. ax1.plot(epochs_seen, train_losses, label="Training loss")
  101. ax1.plot(epochs_seen, val_losses, linestyle="-.", label="Validation loss")
  102. ax1.set_xlabel("Epochs")
  103. ax1.set_ylabel("Loss")
  104. ax1.legend(loc="upper right")
  105. # Create a second x-axis for tokens seen
  106. ax2 = ax1.twiny() # Create a second x-axis that shares the same y-axis
  107. ax2.plot(tokens_seen, train_losses, alpha=0) # Invisible plot for aligning ticks
  108. ax2.set_xlabel("Tokens seen")
  109. fig.tight_layout() # Adjust layout to make room
  110. plot_name = "loss-plot-standalone.pdf"
  111. print(f"Plot saved as {plot_name}")
  112. plt.savefig(plot_name)
  113. # plt.show()
  114. def main(test_mode=False):
  115. #######################################
  116. # Print package versions
  117. #######################################
  118. print()
  119. pkgs = [
  120. "matplotlib", # Plotting library
  121. "tiktoken", # Tokenizer
  122. "torch", # Deep learning library
  123. "tqdm", # Progress bar
  124. "tensorflow", # For OpenAI's pretrained weights
  125. ]
  126. for p in pkgs:
  127. print(f"{p} version: {version(p)}")
  128. print(50*"-")
  129. #######################################
  130. # Download and prepare dataset
  131. #######################################
  132. file_path = "instruction-data.json"
  133. url = "https://raw.githubusercontent.com/rasbt/LLMs-from-scratch/main/ch07/01_main-chapter-code/instruction-data.json"
  134. data = download_and_load_file(file_path, url)
  135. train_portion = int(len(data) * 0.85) # 85% for training
  136. test_portion = int(len(data) * 0.1) # 10% for testing
  137. train_data = data[:train_portion]
  138. test_data = data[train_portion:train_portion + test_portion]
  139. val_data = data[train_portion + test_portion:]
  140. # Use very small subset for testing purposes
  141. if test_mode:
  142. train_data = train_data[:10]
  143. val_data = val_data[:10]
  144. test_data = test_data[:10]
  145. print("Training set length:", len(train_data))
  146. print("Validation set length:", len(val_data))
  147. print("Test set length:", len(test_data))
  148. print(50*"-")
  149. tokenizer = tiktoken.get_encoding("gpt2")
  150. device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
  151. print("Device:", device)
  152. print(50*"-")
  153. customized_collate_fn = partial(custom_collate_fn, device=device, allowed_max_length=1024)
  154. num_workers = 0
  155. batch_size = 8
  156. torch.manual_seed(123)
  157. train_dataset = InstructionDataset(train_data, tokenizer)
  158. train_loader = DataLoader(
  159. train_dataset,
  160. batch_size=batch_size,
  161. collate_fn=customized_collate_fn,
  162. shuffle=True,
  163. drop_last=True,
  164. num_workers=num_workers
  165. )
  166. val_dataset = InstructionDataset(val_data, tokenizer)
  167. val_loader = DataLoader(
  168. val_dataset,
  169. batch_size=batch_size,
  170. collate_fn=customized_collate_fn,
  171. shuffle=False,
  172. drop_last=False,
  173. num_workers=num_workers
  174. )
  175. #######################################
  176. # Load pretrained model
  177. #######################################
  178. # Small GPT model for testing purposes
  179. if args.test_mode:
  180. BASE_CONFIG = {
  181. "vocab_size": 50257,
  182. "context_length": 120,
  183. "drop_rate": 0.0,
  184. "qkv_bias": False,
  185. "emb_dim": 12,
  186. "n_layers": 1,
  187. "n_heads": 2
  188. }
  189. model = GPTModel(BASE_CONFIG)
  190. model.eval()
  191. device = "cpu"
  192. CHOOSE_MODEL = "Small test model"
  193. # Code as it is used in the main chapter
  194. else:
  195. BASE_CONFIG = {
  196. "vocab_size": 50257, # Vocabulary size
  197. "context_length": 1024, # Context length
  198. "drop_rate": 0.0, # Dropout rate
  199. "qkv_bias": True # Query-key-value bias
  200. }
  201. model_configs = {
  202. "gpt2-small (124M)": {"emb_dim": 768, "n_layers": 12, "n_heads": 12},
  203. "gpt2-medium (355M)": {"emb_dim": 1024, "n_layers": 24, "n_heads": 16},
  204. "gpt2-large (774M)": {"emb_dim": 1280, "n_layers": 36, "n_heads": 20},
  205. "gpt2-xl (1558M)": {"emb_dim": 1600, "n_layers": 48, "n_heads": 25},
  206. }
  207. CHOOSE_MODEL = "gpt2-medium (355M)"
  208. BASE_CONFIG.update(model_configs[CHOOSE_MODEL])
  209. model_size = CHOOSE_MODEL.split(" ")[-1].lstrip("(").rstrip(")")
  210. settings, params = download_and_load_gpt2(model_size=model_size, models_dir="gpt2")
  211. model = GPTModel(BASE_CONFIG)
  212. load_weights_into_gpt(model, params)
  213. model.eval()
  214. model.to(device)
  215. print("Loaded model:", CHOOSE_MODEL)
  216. print(50*"-")
  217. #######################################
  218. # Finetuning the model
  219. #######################################
  220. print("Initial losses")
  221. with torch.no_grad():
  222. train_loss = calc_loss_loader(train_loader, model, device, num_batches=5)
  223. val_loss = calc_loss_loader(val_loader, model, device, num_batches=5)
  224. print(" Training loss:", train_loss)
  225. print(" Validation loss:", val_loss)
  226. start_time = time.time()
  227. optimizer = torch.optim.AdamW(model.parameters(), lr=0.00005, weight_decay=0.1)
  228. num_epochs = 2
  229. torch.manual_seed(123)
  230. train_losses, val_losses, tokens_seen = train_model_simple(
  231. model, train_loader, val_loader, optimizer, device,
  232. num_epochs=num_epochs, eval_freq=5, eval_iter=5,
  233. start_context=format_input(val_data[0]), tokenizer=tokenizer
  234. )
  235. end_time = time.time()
  236. execution_time_minutes = (end_time - start_time) / 60
  237. print(f"Training completed in {execution_time_minutes:.2f} minutes.")
  238. epochs_tensor = torch.linspace(0, num_epochs, len(train_losses))
  239. plot_losses(epochs_tensor, tokens_seen, train_losses, val_losses)
  240. print(50*"-")
  241. #######################################
  242. # Saving results
  243. #######################################
  244. print("Generating responses")
  245. for i, entry in tqdm(enumerate(test_data), total=len(test_data)):
  246. input_text = format_input(entry)
  247. token_ids = generate(
  248. model=model,
  249. idx=text_to_token_ids(input_text, tokenizer).to(device),
  250. max_new_tokens=256,
  251. context_size=BASE_CONFIG["context_length"],
  252. eos_id=50256
  253. )
  254. generated_text = token_ids_to_text(token_ids, tokenizer)
  255. response_text = generated_text[len(input_text):].replace("### Response:", "").strip()
  256. test_data[i]["model_response"] = response_text
  257. test_data_path = "instruction-data-with-response-standalone.json"
  258. with open(test_data_path, "w") as file:
  259. json.dump(test_data, file, indent=4) # "indent" for pretty-printing
  260. print(f"Responses saved as {test_data_path}")
  261. file_name = f"{re.sub(r'[ ()]', '', CHOOSE_MODEL) }-sft-standalone.pth"
  262. torch.save(model.state_dict(), file_name)
  263. print(f"Model saved as {file_name}")
  264. if __name__ == "__main__":
  265. import argparse
  266. parser = argparse.ArgumentParser(
  267. description="Finetune a GPT model for classification"
  268. )
  269. parser.add_argument(
  270. "--test_mode",
  271. default=False,
  272. action="store_true",
  273. help=("This flag runs the model in test mode for internal testing purposes. "
  274. "Otherwise, it runs the model as it is used in the chapter (recommended).")
  275. )
  276. args = parser.parse_args()
  277. main(args.test_mode)