gpt_class_finetune.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440
  1. # Copyright (c) Sebastian Raschka under Apache License 2.0 (see LICENSE.txt).
  2. # Source for "Build a Large Language Model From Scratch"
  3. # - https://www.manning.com/books/build-a-large-language-model-from-scratch
  4. # Code: https://github.com/rasbt/LLMs-from-scratch
  5. # This is a summary file containing the main takeaways from chapter 6.
  6. import urllib.request
  7. import zipfile
  8. import os
  9. from pathlib import Path
  10. import time
  11. import matplotlib.pyplot as plt
  12. import pandas as pd
  13. import tiktoken
  14. import torch
  15. from torch.utils.data import Dataset, DataLoader
  16. from gpt_download import download_and_load_gpt2
  17. from previous_chapters import GPTModel, load_weights_into_gpt
  18. def download_and_unzip_spam_data(url, zip_path, extracted_path, data_file_path, test_mode=False):
  19. if data_file_path.exists():
  20. print(f"{data_file_path} already exists. Skipping download and extraction.")
  21. return
  22. if test_mode: # Try multiple times since CI sometimes has connectivity issues
  23. max_retries = 5
  24. delay = 5 # delay between retries in seconds
  25. for attempt in range(max_retries):
  26. try:
  27. # Downloading the file
  28. with urllib.request.urlopen(url, timeout=10) as response:
  29. with open(zip_path, "wb") as out_file:
  30. out_file.write(response.read())
  31. break # if download is successful, break out of the loop
  32. except urllib.error.URLError as e:
  33. print(f"Attempt {attempt + 1} failed: {e}")
  34. if attempt < max_retries - 1:
  35. time.sleep(delay) # wait before retrying
  36. else:
  37. print("Failed to download file after several attempts.")
  38. return # exit if all retries fail
  39. else: # Code as it appears in the chapter
  40. # Downloading the file
  41. with urllib.request.urlopen(url) as response:
  42. with open(zip_path, "wb") as out_file:
  43. out_file.write(response.read())
  44. # Unzipping the file
  45. with zipfile.ZipFile(zip_path, "r") as zip_ref:
  46. zip_ref.extractall(extracted_path)
  47. # Add .tsv file extension
  48. original_file_path = Path(extracted_path) / "SMSSpamCollection"
  49. os.rename(original_file_path, data_file_path)
  50. print(f"File downloaded and saved as {data_file_path}")
  51. def create_balanced_dataset(df):
  52. # Count the instances of "spam"
  53. num_spam = df[df["Label"] == "spam"].shape[0]
  54. # Randomly sample "ham" instances to match the number of "spam" instances
  55. ham_subset = df[df["Label"] == "ham"].sample(num_spam, random_state=123)
  56. # Combine ham "subset" with "spam"
  57. balanced_df = pd.concat([ham_subset, df[df["Label"] == "spam"]])
  58. return balanced_df
  59. def random_split(df, train_frac, validation_frac):
  60. # Shuffle the entire DataFrame
  61. df = df.sample(frac=1, random_state=123).reset_index(drop=True)
  62. # Calculate split indices
  63. train_end = int(len(df) * train_frac)
  64. validation_end = train_end + int(len(df) * validation_frac)
  65. # Split the DataFrame
  66. train_df = df[:train_end]
  67. validation_df = df[train_end:validation_end]
  68. test_df = df[validation_end:]
  69. return train_df, validation_df, test_df
  70. class SpamDataset(Dataset):
  71. def __init__(self, csv_file, tokenizer, max_length=None, pad_token_id=50256):
  72. self.data = pd.read_csv(csv_file)
  73. # Pre-tokenize texts
  74. self.encoded_texts = [
  75. tokenizer.encode(text) for text in self.data["Text"]
  76. ]
  77. if max_length is None:
  78. self.max_length = self._longest_encoded_length()
  79. else:
  80. self.max_length = max_length
  81. # Truncate sequences if they are longer than max_length
  82. self.encoded_texts = [
  83. encoded_text[:self.max_length]
  84. for encoded_text in self.encoded_texts
  85. ]
  86. # Pad sequences to the longest sequence
  87. self.encoded_texts = [
  88. encoded_text + [pad_token_id] * (self.max_length - len(encoded_text))
  89. for encoded_text in self.encoded_texts
  90. ]
  91. def __getitem__(self, index):
  92. encoded = self.encoded_texts[index]
  93. label = self.data.iloc[index]["Label"]
  94. return (
  95. torch.tensor(encoded, dtype=torch.long),
  96. torch.tensor(label, dtype=torch.long)
  97. )
  98. def __len__(self):
  99. return len(self.data)
  100. def _longest_encoded_length(self):
  101. max_length = 0
  102. for encoded_text in self.encoded_texts:
  103. encoded_length = len(encoded_text)
  104. if encoded_length > max_length:
  105. max_length = encoded_length
  106. return max_length
  107. def calc_accuracy_loader(data_loader, model, device, num_batches=None):
  108. model.eval()
  109. correct_predictions, num_examples = 0, 0
  110. if num_batches is None:
  111. num_batches = len(data_loader)
  112. else:
  113. num_batches = min(num_batches, len(data_loader))
  114. for i, (input_batch, target_batch) in enumerate(data_loader):
  115. if i < num_batches:
  116. input_batch, target_batch = input_batch.to(device), target_batch.to(device)
  117. with torch.no_grad():
  118. logits = model(input_batch)[:, -1, :] # Logits of last output token
  119. predicted_labels = torch.argmax(logits, dim=-1)
  120. num_examples += predicted_labels.shape[0]
  121. correct_predictions += (predicted_labels == target_batch).sum().item()
  122. else:
  123. break
  124. return correct_predictions / num_examples
  125. def calc_loss_batch(input_batch, target_batch, model, device):
  126. input_batch, target_batch = input_batch.to(device), target_batch.to(device)
  127. logits = model(input_batch)[:, -1, :] # Logits of last output token
  128. loss = torch.nn.functional.cross_entropy(logits, target_batch)
  129. return loss
  130. def calc_loss_loader(data_loader, model, device, num_batches=None):
  131. total_loss = 0.
  132. if len(data_loader) == 0:
  133. return float("nan")
  134. elif num_batches is None:
  135. num_batches = len(data_loader)
  136. else:
  137. num_batches = min(num_batches, len(data_loader))
  138. for i, (input_batch, target_batch) in enumerate(data_loader):
  139. if i < num_batches:
  140. loss = calc_loss_batch(input_batch, target_batch, model, device)
  141. total_loss += loss.item()
  142. else:
  143. break
  144. return total_loss / num_batches
  145. def evaluate_model(model, train_loader, val_loader, device, eval_iter):
  146. model.eval()
  147. with torch.no_grad():
  148. train_loss = calc_loss_loader(train_loader, model, device, num_batches=eval_iter)
  149. val_loss = calc_loss_loader(val_loader, model, device, num_batches=eval_iter)
  150. model.train()
  151. return train_loss, val_loss
  152. def train_classifier_simple(model, train_loader, val_loader, optimizer, device, num_epochs,
  153. eval_freq, eval_iter, tokenizer):
  154. # Initialize lists to track losses and tokens seen
  155. train_losses, val_losses, train_accs, val_accs = [], [], [], []
  156. examples_seen, global_step = 0, -1
  157. # Main training loop
  158. for epoch in range(num_epochs):
  159. model.train() # Set model to training mode
  160. for input_batch, target_batch in train_loader:
  161. optimizer.zero_grad() # Reset loss gradients from previous batch iteration
  162. loss = calc_loss_batch(input_batch, target_batch, model, device)
  163. loss.backward() # Calculate loss gradients
  164. optimizer.step() # Update model weights using loss gradients
  165. examples_seen += input_batch.shape[0] # New: track examples instead of tokens
  166. global_step += 1
  167. # Optional evaluation step
  168. if global_step % eval_freq == 0:
  169. train_loss, val_loss = evaluate_model(
  170. model, train_loader, val_loader, device, eval_iter)
  171. train_losses.append(train_loss)
  172. val_losses.append(val_loss)
  173. print(f"Ep {epoch+1} (Step {global_step:06d}): "
  174. f"Train loss {train_loss:.3f}, Val loss {val_loss:.3f}")
  175. # Calculate accuracy after each epoch
  176. train_accuracy = calc_accuracy_loader(train_loader, model, device, num_batches=eval_iter)
  177. val_accuracy = calc_accuracy_loader(val_loader, model, device, num_batches=eval_iter)
  178. print(f"Training accuracy: {train_accuracy*100:.2f}% | ", end="")
  179. print(f"Validation accuracy: {val_accuracy*100:.2f}%")
  180. train_accs.append(train_accuracy)
  181. val_accs.append(val_accuracy)
  182. return train_losses, val_losses, train_accs, val_accs, examples_seen
  183. def plot_values(epochs_seen, examples_seen, train_values, val_values, label="loss"):
  184. fig, ax1 = plt.subplots(figsize=(5, 3))
  185. # Plot training and validation loss against epochs
  186. ax1.plot(epochs_seen, train_values, label=f"Training {label}")
  187. ax1.plot(epochs_seen, val_values, linestyle="-.", label=f"Validation {label}")
  188. ax1.set_xlabel("Epochs")
  189. ax1.set_ylabel(label.capitalize())
  190. ax1.legend()
  191. # Create a second x-axis for tokens seen
  192. ax2 = ax1.twiny() # Create a second x-axis that shares the same y-axis
  193. ax2.plot(examples_seen, train_values, alpha=0) # Invisible plot for aligning ticks
  194. ax2.set_xlabel("Examples seen")
  195. fig.tight_layout() # Adjust layout to make room
  196. plt.savefig(f"{label}-plot.pdf")
  197. # plt.show()
  198. if __name__ == "__main__":
  199. import argparse
  200. parser = argparse.ArgumentParser(
  201. description="Finetune a GPT model for classification"
  202. )
  203. parser.add_argument(
  204. "--test_mode",
  205. default=False,
  206. action="store_true",
  207. help=("This flag runs the model in test mode for internal testing purposes. "
  208. "Otherwise, it runs the model as it is used in the chapter (recommended).")
  209. )
  210. args = parser.parse_args()
  211. ########################################
  212. # Download and prepare dataset
  213. ########################################
  214. url = "https://archive.ics.uci.edu/static/public/228/sms+spam+collection.zip"
  215. zip_path = "sms_spam_collection.zip"
  216. extracted_path = "sms_spam_collection"
  217. data_file_path = Path(extracted_path) / "SMSSpamCollection.tsv"
  218. download_and_unzip_spam_data(url, zip_path, extracted_path, data_file_path, test_mode=args.test_mode)
  219. df = pd.read_csv(data_file_path, sep="\t", header=None, names=["Label", "Text"])
  220. balanced_df = create_balanced_dataset(df)
  221. balanced_df["Label"] = balanced_df["Label"].map({"ham": 0, "spam": 1})
  222. train_df, validation_df, test_df = random_split(balanced_df, 0.7, 0.1)
  223. train_df.to_csv("train.csv", index=None)
  224. validation_df.to_csv("validation.csv", index=None)
  225. test_df.to_csv("test.csv", index=None)
  226. ########################################
  227. # Create data loaders
  228. ########################################
  229. tokenizer = tiktoken.get_encoding("gpt2")
  230. train_dataset = SpamDataset(
  231. csv_file="train.csv",
  232. max_length=None,
  233. tokenizer=tokenizer
  234. )
  235. val_dataset = SpamDataset(
  236. csv_file="validation.csv",
  237. max_length=train_dataset.max_length,
  238. tokenizer=tokenizer
  239. )
  240. test_dataset = SpamDataset(
  241. csv_file="test.csv",
  242. max_length=train_dataset.max_length,
  243. tokenizer=tokenizer
  244. )
  245. num_workers = 0
  246. batch_size = 8
  247. torch.manual_seed(123)
  248. train_loader = DataLoader(
  249. dataset=train_dataset,
  250. batch_size=batch_size,
  251. shuffle=True,
  252. num_workers=num_workers,
  253. drop_last=True,
  254. )
  255. val_loader = DataLoader(
  256. dataset=val_dataset,
  257. batch_size=batch_size,
  258. num_workers=num_workers,
  259. drop_last=False,
  260. )
  261. test_loader = DataLoader(
  262. dataset=test_dataset,
  263. batch_size=batch_size,
  264. num_workers=num_workers,
  265. drop_last=False,
  266. )
  267. ########################################
  268. # Load pretrained model
  269. ########################################
  270. # Small GPT model for testing purposes
  271. if args.test_mode:
  272. BASE_CONFIG = {
  273. "vocab_size": 50257,
  274. "context_length": 120,
  275. "drop_rate": 0.0,
  276. "qkv_bias": False,
  277. "emb_dim": 12,
  278. "n_layers": 1,
  279. "n_heads": 2
  280. }
  281. model = GPTModel(BASE_CONFIG)
  282. model.eval()
  283. device = "cpu"
  284. # Code as it is used in the main chapter
  285. else:
  286. CHOOSE_MODEL = "gpt2-small (124M)"
  287. INPUT_PROMPT = "Every effort moves"
  288. BASE_CONFIG = {
  289. "vocab_size": 50257, # Vocabulary size
  290. "context_length": 1024, # Context length
  291. "drop_rate": 0.0, # Dropout rate
  292. "qkv_bias": True # Query-key-value bias
  293. }
  294. model_configs = {
  295. "gpt2-small (124M)": {"emb_dim": 768, "n_layers": 12, "n_heads": 12},
  296. "gpt2-medium (355M)": {"emb_dim": 1024, "n_layers": 24, "n_heads": 16},
  297. "gpt2-large (774M)": {"emb_dim": 1280, "n_layers": 36, "n_heads": 20},
  298. "gpt2-xl (1558M)": {"emb_dim": 1600, "n_layers": 48, "n_heads": 25},
  299. }
  300. BASE_CONFIG.update(model_configs[CHOOSE_MODEL])
  301. assert train_dataset.max_length <= BASE_CONFIG["context_length"], (
  302. f"Dataset length {train_dataset.max_length} exceeds model's context "
  303. f"length {BASE_CONFIG['context_length']}. Reinitialize data sets with "
  304. f"`max_length={BASE_CONFIG['context_length']}`"
  305. )
  306. model_size = CHOOSE_MODEL.split(" ")[-1].lstrip("(").rstrip(")")
  307. settings, params = download_and_load_gpt2(model_size=model_size, models_dir="gpt2")
  308. model = GPTModel(BASE_CONFIG)
  309. load_weights_into_gpt(model, params)
  310. device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
  311. ########################################
  312. # Modify and pretrained model
  313. ########################################
  314. for param in model.parameters():
  315. param.requires_grad = False
  316. torch.manual_seed(123)
  317. num_classes = 2
  318. model.out_head = torch.nn.Linear(in_features=BASE_CONFIG["emb_dim"], out_features=num_classes)
  319. model.to(device)
  320. for param in model.trf_blocks[-1].parameters():
  321. param.requires_grad = True
  322. for param in model.final_norm.parameters():
  323. param.requires_grad = True
  324. ########################################
  325. # Finetune modified model
  326. ########################################
  327. start_time = time.time()
  328. torch.manual_seed(123)
  329. optimizer = torch.optim.AdamW(model.parameters(), lr=5e-5, weight_decay=0.1)
  330. num_epochs = 5
  331. train_losses, val_losses, train_accs, val_accs, examples_seen = train_classifier_simple(
  332. model, train_loader, val_loader, optimizer, device,
  333. num_epochs=num_epochs, eval_freq=50, eval_iter=5,
  334. tokenizer=tokenizer
  335. )
  336. end_time = time.time()
  337. execution_time_minutes = (end_time - start_time) / 60
  338. print(f"Training completed in {execution_time_minutes:.2f} minutes.")
  339. ########################################
  340. # Plot results
  341. ########################################
  342. # loss plot
  343. epochs_tensor = torch.linspace(0, num_epochs, len(train_losses))
  344. examples_seen_tensor = torch.linspace(0, examples_seen, len(train_losses))
  345. plot_values(epochs_tensor, examples_seen_tensor, train_losses, val_losses)
  346. # accuracy plot
  347. epochs_tensor = torch.linspace(0, num_epochs, len(train_accs))
  348. examples_seen_tensor = torch.linspace(0, examples_seen, len(train_accs))
  349. plot_values(epochs_tensor, examples_seen_tensor, train_accs, val_accs, label="accuracy")