gpt_class_finetune.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429
  1. # Copyright (c) Sebastian Raschka under Apache License 2.0 (see LICENSE.txt).
  2. # Source for "Build a Large Language Model From Scratch"
  3. # - https://www.manning.com/books/build-a-large-language-model-from-scratch
  4. # Code: https://github.com/rasbt/LLMs-from-scratch
  5. # This is a summary file containing the main takeaways from chapter 6.
  6. import urllib.request
  7. import zipfile
  8. import os
  9. from pathlib import Path
  10. import time
  11. import matplotlib.pyplot as plt
  12. import pandas as pd
  13. import tiktoken
  14. import torch
  15. from torch.utils.data import Dataset, DataLoader
  16. from gpt_download import download_and_load_gpt2
  17. from previous_chapters import GPTModel, load_weights_into_gpt
  18. def download_and_unzip_spam_data(url, zip_path, extracted_path, data_file_path):
  19. if data_file_path.exists():
  20. print(f"{data_file_path} already exists. Skipping download and extraction.")
  21. return
  22. # Downloading the file
  23. with urllib.request.urlopen(url) as response:
  24. with open(zip_path, "wb") as out_file:
  25. out_file.write(response.read())
  26. # Unzipping the file
  27. with zipfile.ZipFile(zip_path, "r") as zip_ref:
  28. zip_ref.extractall(extracted_path)
  29. # Add .tsv file extension
  30. original_file_path = Path(extracted_path) / "SMSSpamCollection"
  31. os.rename(original_file_path, data_file_path)
  32. print(f"File downloaded and saved as {data_file_path}")
  33. def create_balanced_dataset(df):
  34. # Count the instances of "spam"
  35. num_spam = df[df["Label"] == "spam"].shape[0]
  36. # Randomly sample "ham" instances to match the number of "spam" instances
  37. ham_subset = df[df["Label"] == "ham"].sample(num_spam, random_state=123)
  38. # Combine ham "subset" with "spam"
  39. balanced_df = pd.concat([ham_subset, df[df["Label"] == "spam"]])
  40. return balanced_df
  41. def random_split(df, train_frac, validation_frac):
  42. # Shuffle the entire DataFrame
  43. df = df.sample(frac=1, random_state=123).reset_index(drop=True)
  44. # Calculate split indices
  45. train_end = int(len(df) * train_frac)
  46. validation_end = train_end + int(len(df) * validation_frac)
  47. # Split the DataFrame
  48. train_df = df[:train_end]
  49. validation_df = df[train_end:validation_end]
  50. test_df = df[validation_end:]
  51. return train_df, validation_df, test_df
  52. class SpamDataset(Dataset):
  53. def __init__(self, csv_file, tokenizer, max_length=None, pad_token_id=50256):
  54. self.data = pd.read_csv(csv_file)
  55. # Pre-tokenize texts
  56. self.encoded_texts = [
  57. tokenizer.encode(text) for text in self.data["Text"]
  58. ]
  59. if max_length is None:
  60. self.max_length = self._longest_encoded_length()
  61. else:
  62. self.max_length = max_length
  63. # Truncate sequences if they are longer than max_length
  64. self.encoded_texts = [
  65. encoded_text[:self.max_length]
  66. for encoded_text in self.encoded_texts
  67. ]
  68. # Pad sequences to the longest sequence
  69. self.encoded_texts = [
  70. encoded_text + [pad_token_id] * (self.max_length - len(encoded_text))
  71. for encoded_text in self.encoded_texts
  72. ]
  73. def __getitem__(self, index):
  74. encoded = self.encoded_texts[index]
  75. label = self.data.iloc[index]["Label"]
  76. return (
  77. torch.tensor(encoded, dtype=torch.long),
  78. torch.tensor(label, dtype=torch.long)
  79. )
  80. def __len__(self):
  81. return len(self.data)
  82. def _longest_encoded_length(self):
  83. max_length = 0
  84. for encoded_text in self.encoded_texts:
  85. encoded_length = len(encoded_text)
  86. if encoded_length > max_length:
  87. max_length = encoded_length
  88. return max_length
  89. # Note: A more pythonic version to implement this method
  90. # is the following, which is also used in the next chapter:
  91. # return max(len(encoded_text) for encoded_text in self.encoded_texts)
  92. def calc_accuracy_loader(data_loader, model, device, num_batches=None):
  93. model.eval()
  94. correct_predictions, num_examples = 0, 0
  95. if num_batches is None:
  96. num_batches = len(data_loader)
  97. else:
  98. num_batches = min(num_batches, len(data_loader))
  99. for i, (input_batch, target_batch) in enumerate(data_loader):
  100. if i < num_batches:
  101. input_batch, target_batch = input_batch.to(device), target_batch.to(device)
  102. with torch.no_grad():
  103. logits = model(input_batch)[:, -1, :] # Logits of last output token
  104. predicted_labels = torch.argmax(logits, dim=-1)
  105. num_examples += predicted_labels.shape[0]
  106. correct_predictions += (predicted_labels == target_batch).sum().item()
  107. else:
  108. break
  109. return correct_predictions / num_examples
  110. def calc_loss_batch(input_batch, target_batch, model, device):
  111. input_batch, target_batch = input_batch.to(device), target_batch.to(device)
  112. logits = model(input_batch)[:, -1, :] # Logits of last output token
  113. loss = torch.nn.functional.cross_entropy(logits, target_batch)
  114. return loss
  115. def calc_loss_loader(data_loader, model, device, num_batches=None):
  116. total_loss = 0.
  117. if len(data_loader) == 0:
  118. return float("nan")
  119. elif num_batches is None:
  120. num_batches = len(data_loader)
  121. else:
  122. num_batches = min(num_batches, len(data_loader))
  123. for i, (input_batch, target_batch) in enumerate(data_loader):
  124. if i < num_batches:
  125. loss = calc_loss_batch(input_batch, target_batch, model, device)
  126. total_loss += loss.item()
  127. else:
  128. break
  129. return total_loss / num_batches
  130. def evaluate_model(model, train_loader, val_loader, device, eval_iter):
  131. model.eval()
  132. with torch.no_grad():
  133. train_loss = calc_loss_loader(train_loader, model, device, num_batches=eval_iter)
  134. val_loss = calc_loss_loader(val_loader, model, device, num_batches=eval_iter)
  135. model.train()
  136. return train_loss, val_loss
  137. def train_classifier_simple(model, train_loader, val_loader, optimizer, device, num_epochs,
  138. eval_freq, eval_iter):
  139. # Initialize lists to track losses and tokens seen
  140. train_losses, val_losses, train_accs, val_accs = [], [], [], []
  141. examples_seen, global_step = 0, -1
  142. # Main training loop
  143. for epoch in range(num_epochs):
  144. model.train() # Set model to training mode
  145. for input_batch, target_batch in train_loader:
  146. optimizer.zero_grad() # Reset loss gradients from previous batch iteration
  147. loss = calc_loss_batch(input_batch, target_batch, model, device)
  148. loss.backward() # Calculate loss gradients
  149. optimizer.step() # Update model weights using loss gradients
  150. examples_seen += input_batch.shape[0] # New: track examples instead of tokens
  151. global_step += 1
  152. # Optional evaluation step
  153. if global_step % eval_freq == 0:
  154. train_loss, val_loss = evaluate_model(
  155. model, train_loader, val_loader, device, eval_iter)
  156. train_losses.append(train_loss)
  157. val_losses.append(val_loss)
  158. print(f"Ep {epoch+1} (Step {global_step:06d}): "
  159. f"Train loss {train_loss:.3f}, Val loss {val_loss:.3f}")
  160. # Calculate accuracy after each epoch
  161. train_accuracy = calc_accuracy_loader(train_loader, model, device, num_batches=eval_iter)
  162. val_accuracy = calc_accuracy_loader(val_loader, model, device, num_batches=eval_iter)
  163. print(f"Training accuracy: {train_accuracy*100:.2f}% | ", end="")
  164. print(f"Validation accuracy: {val_accuracy*100:.2f}%")
  165. train_accs.append(train_accuracy)
  166. val_accs.append(val_accuracy)
  167. return train_losses, val_losses, train_accs, val_accs, examples_seen
  168. def plot_values(epochs_seen, examples_seen, train_values, val_values, label="loss"):
  169. fig, ax1 = plt.subplots(figsize=(5, 3))
  170. # Plot training and validation loss against epochs
  171. ax1.plot(epochs_seen, train_values, label=f"Training {label}")
  172. ax1.plot(epochs_seen, val_values, linestyle="-.", label=f"Validation {label}")
  173. ax1.set_xlabel("Epochs")
  174. ax1.set_ylabel(label.capitalize())
  175. ax1.legend()
  176. # Create a second x-axis for tokens seen
  177. ax2 = ax1.twiny() # Create a second x-axis that shares the same y-axis
  178. ax2.plot(examples_seen, train_values, alpha=0) # Invisible plot for aligning ticks
  179. ax2.set_xlabel("Examples seen")
  180. fig.tight_layout() # Adjust layout to make room
  181. plt.savefig(f"{label}-plot.pdf")
  182. # plt.show()
  183. if __name__ == "__main__":
  184. import argparse
  185. parser = argparse.ArgumentParser(
  186. description="Finetune a GPT model for classification"
  187. )
  188. parser.add_argument(
  189. "--test_mode",
  190. default=False,
  191. action="store_true",
  192. help=("This flag runs the model in test mode for internal testing purposes. "
  193. "Otherwise, it runs the model as it is used in the chapter (recommended).")
  194. )
  195. args = parser.parse_args()
  196. ########################################
  197. # Download and prepare dataset
  198. ########################################
  199. url = "https://archive.ics.uci.edu/static/public/228/sms+spam+collection.zip"
  200. zip_path = "sms_spam_collection.zip"
  201. extracted_path = "sms_spam_collection"
  202. data_file_path = Path(extracted_path) / "SMSSpamCollection.tsv"
  203. try:
  204. download_and_unzip_spam_data(url, zip_path, extracted_path, data_file_path)
  205. except (urllib.error.HTTPError, urllib.error.URLError, TimeoutError) as e:
  206. print(f"Primary URL failed: {e}. Trying backup URL...")
  207. url = "https://f001.backblazeb2.com/file/LLMs-from-scratch/sms%2Bspam%2Bcollection.zip"
  208. download_and_unzip_spam_data(url, zip_path, extracted_path, data_file_path)
  209. df = pd.read_csv(data_file_path, sep="\t", header=None, names=["Label", "Text"])
  210. balanced_df = create_balanced_dataset(df)
  211. balanced_df["Label"] = balanced_df["Label"].map({"ham": 0, "spam": 1})
  212. train_df, validation_df, test_df = random_split(balanced_df, 0.7, 0.1)
  213. train_df.to_csv("train.csv", index=None)
  214. validation_df.to_csv("validation.csv", index=None)
  215. test_df.to_csv("test.csv", index=None)
  216. ########################################
  217. # Create data loaders
  218. ########################################
  219. tokenizer = tiktoken.get_encoding("gpt2")
  220. train_dataset = SpamDataset(
  221. csv_file="train.csv",
  222. max_length=None,
  223. tokenizer=tokenizer
  224. )
  225. val_dataset = SpamDataset(
  226. csv_file="validation.csv",
  227. max_length=train_dataset.max_length,
  228. tokenizer=tokenizer
  229. )
  230. test_dataset = SpamDataset(
  231. csv_file="test.csv",
  232. max_length=train_dataset.max_length,
  233. tokenizer=tokenizer
  234. )
  235. num_workers = 0
  236. batch_size = 8
  237. torch.manual_seed(123)
  238. train_loader = DataLoader(
  239. dataset=train_dataset,
  240. batch_size=batch_size,
  241. shuffle=True,
  242. num_workers=num_workers,
  243. drop_last=True,
  244. )
  245. val_loader = DataLoader(
  246. dataset=val_dataset,
  247. batch_size=batch_size,
  248. num_workers=num_workers,
  249. drop_last=False,
  250. )
  251. test_loader = DataLoader(
  252. dataset=test_dataset,
  253. batch_size=batch_size,
  254. num_workers=num_workers,
  255. drop_last=False,
  256. )
  257. ########################################
  258. # Load pretrained model
  259. ########################################
  260. # Small GPT model for testing purposes
  261. if args.test_mode:
  262. BASE_CONFIG = {
  263. "vocab_size": 50257,
  264. "context_length": 120,
  265. "drop_rate": 0.0,
  266. "qkv_bias": False,
  267. "emb_dim": 12,
  268. "n_layers": 1,
  269. "n_heads": 2
  270. }
  271. model = GPTModel(BASE_CONFIG)
  272. model.eval()
  273. device = "cpu"
  274. # Code as it is used in the main chapter
  275. else:
  276. CHOOSE_MODEL = "gpt2-small (124M)"
  277. INPUT_PROMPT = "Every effort moves"
  278. BASE_CONFIG = {
  279. "vocab_size": 50257, # Vocabulary size
  280. "context_length": 1024, # Context length
  281. "drop_rate": 0.0, # Dropout rate
  282. "qkv_bias": True # Query-key-value bias
  283. }
  284. model_configs = {
  285. "gpt2-small (124M)": {"emb_dim": 768, "n_layers": 12, "n_heads": 12},
  286. "gpt2-medium (355M)": {"emb_dim": 1024, "n_layers": 24, "n_heads": 16},
  287. "gpt2-large (774M)": {"emb_dim": 1280, "n_layers": 36, "n_heads": 20},
  288. "gpt2-xl (1558M)": {"emb_dim": 1600, "n_layers": 48, "n_heads": 25},
  289. }
  290. BASE_CONFIG.update(model_configs[CHOOSE_MODEL])
  291. assert train_dataset.max_length <= BASE_CONFIG["context_length"], (
  292. f"Dataset length {train_dataset.max_length} exceeds model's context "
  293. f"length {BASE_CONFIG['context_length']}. Reinitialize data sets with "
  294. f"`max_length={BASE_CONFIG['context_length']}`"
  295. )
  296. model_size = CHOOSE_MODEL.split(" ")[-1].lstrip("(").rstrip(")")
  297. settings, params = download_and_load_gpt2(model_size=model_size, models_dir="gpt2")
  298. model = GPTModel(BASE_CONFIG)
  299. load_weights_into_gpt(model, params)
  300. device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
  301. ########################################
  302. # Modify and pretrained model
  303. ########################################
  304. for param in model.parameters():
  305. param.requires_grad = False
  306. torch.manual_seed(123)
  307. num_classes = 2
  308. model.out_head = torch.nn.Linear(in_features=BASE_CONFIG["emb_dim"], out_features=num_classes)
  309. model.to(device)
  310. for param in model.trf_blocks[-1].parameters():
  311. param.requires_grad = True
  312. for param in model.final_norm.parameters():
  313. param.requires_grad = True
  314. ########################################
  315. # Finetune modified model
  316. ########################################
  317. start_time = time.time()
  318. torch.manual_seed(123)
  319. optimizer = torch.optim.AdamW(model.parameters(), lr=5e-5, weight_decay=0.1)
  320. num_epochs = 5
  321. train_losses, val_losses, train_accs, val_accs, examples_seen = train_classifier_simple(
  322. model, train_loader, val_loader, optimizer, device,
  323. num_epochs=num_epochs, eval_freq=50, eval_iter=5,
  324. )
  325. end_time = time.time()
  326. execution_time_minutes = (end_time - start_time) / 60
  327. print(f"Training completed in {execution_time_minutes:.2f} minutes.")
  328. ########################################
  329. # Plot results
  330. ########################################
  331. # loss plot
  332. epochs_tensor = torch.linspace(0, num_epochs, len(train_losses))
  333. examples_seen_tensor = torch.linspace(0, examples_seen, len(train_losses))
  334. plot_values(epochs_tensor, examples_seen_tensor, train_losses, val_losses)
  335. # accuracy plot
  336. epochs_tensor = torch.linspace(0, num_epochs, len(train_accs))
  337. examples_seen_tensor = torch.linspace(0, examples_seen, len(train_accs))
  338. plot_values(epochs_tensor, examples_seen_tensor, train_accs, val_accs, label="accuracy")