ch06.py 9.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254
  1. # Copyright (c) Sebastian Raschka under Apache License 2.0 (see LICENSE.txt).
  2. # Source for "Build a Large Language Model From Scratch"
  3. # - https://www.manning.com/books/build-a-large-language-model-from-scratch
  4. # Code: https://github.com/rasbt/LLMs-from-scratch
  5. import urllib.request
  6. import zipfile
  7. import os
  8. from pathlib import Path
  9. import matplotlib.pyplot as plt
  10. from torch.utils.data import Dataset
  11. import torch
  12. import pandas as pd
  13. def download_and_unzip_spam_data(url, zip_path, extracted_path, data_file_path):
  14. if data_file_path.exists():
  15. print(f"{data_file_path} already exists. Skipping download and extraction.")
  16. return
  17. # Downloading the file
  18. with urllib.request.urlopen(url) as response:
  19. with open(zip_path, "wb") as out_file:
  20. out_file.write(response.read())
  21. # Unzipping the file
  22. with zipfile.ZipFile(zip_path, "r") as zip_ref:
  23. zip_ref.extractall(extracted_path)
  24. # Add .tsv file extension
  25. original_file_path = Path(extracted_path) / "SMSSpamCollection"
  26. os.rename(original_file_path, data_file_path)
  27. print(f"File downloaded and saved as {data_file_path}")
  28. def create_balanced_dataset(df):
  29. # Count the instances of "spam"
  30. num_spam = df[df["Label"] == "spam"].shape[0]
  31. # Randomly sample "ham" instances to match the number of "spam" instances
  32. ham_subset = df[df["Label"] == "ham"].sample(num_spam, random_state=123)
  33. # Combine ham "subset" with "spam"
  34. balanced_df = pd.concat([ham_subset, df[df["Label"] == "spam"]])
  35. return balanced_df
  36. def random_split(df, train_frac, validation_frac):
  37. # Shuffle the entire DataFrame
  38. df = df.sample(frac=1, random_state=123).reset_index(drop=True)
  39. # Calculate split indices
  40. train_end = int(len(df) * train_frac)
  41. validation_end = train_end + int(len(df) * validation_frac)
  42. # Split the DataFrame
  43. train_df = df[:train_end]
  44. validation_df = df[train_end:validation_end]
  45. test_df = df[validation_end:]
  46. return train_df, validation_df, test_df
  47. class SpamDataset(Dataset):
  48. def __init__(self, csv_file, tokenizer, max_length=None, pad_token_id=50256):
  49. self.data = pd.read_csv(csv_file)
  50. # Pre-tokenize texts
  51. self.encoded_texts = [
  52. tokenizer.encode(text) for text in self.data["Text"]
  53. ]
  54. if max_length is None:
  55. self.max_length = self._longest_encoded_length()
  56. else:
  57. self.max_length = max_length
  58. # Truncate sequences if they are longer than max_length
  59. self.encoded_texts = [
  60. encoded_text[:self.max_length]
  61. for encoded_text in self.encoded_texts
  62. ]
  63. # Pad sequences to the longest sequence
  64. self.encoded_texts = [
  65. encoded_text + [pad_token_id] * (self.max_length - len(encoded_text))
  66. for encoded_text in self.encoded_texts
  67. ]
  68. def __getitem__(self, index):
  69. encoded = self.encoded_texts[index]
  70. label = self.data.iloc[index]["Label"]
  71. return (
  72. torch.tensor(encoded, dtype=torch.long),
  73. torch.tensor(label, dtype=torch.long)
  74. )
  75. def __len__(self):
  76. return len(self.data)
  77. def _longest_encoded_length(self):
  78. max_length = 0
  79. for encoded_text in self.encoded_texts:
  80. encoded_length = len(encoded_text)
  81. if encoded_length > max_length:
  82. max_length = encoded_length
  83. return max_length
  84. # Note: A more pythonic version to implement this method
  85. # is the following, which is also used in the next chapter:
  86. # return max(len(encoded_text) for encoded_text in self.encoded_texts)
  87. def calc_accuracy_loader(data_loader, model, device, num_batches=None):
  88. model.eval()
  89. correct_predictions, num_examples = 0, 0
  90. if num_batches is None:
  91. num_batches = len(data_loader)
  92. else:
  93. num_batches = min(num_batches, len(data_loader))
  94. for i, (input_batch, target_batch) in enumerate(data_loader):
  95. if i < num_batches:
  96. input_batch, target_batch = input_batch.to(device), target_batch.to(device)
  97. with torch.no_grad():
  98. logits = model(input_batch)[:, -1, :] # Logits of last output token
  99. predicted_labels = torch.argmax(logits, dim=-1)
  100. num_examples += predicted_labels.shape[0]
  101. correct_predictions += (predicted_labels == target_batch).sum().item()
  102. else:
  103. break
  104. return correct_predictions / num_examples
  105. def calc_loss_batch(input_batch, target_batch, model, device):
  106. input_batch, target_batch = input_batch.to(device), target_batch.to(device)
  107. logits = model(input_batch)[:, -1, :] # Logits of last output token
  108. loss = torch.nn.functional.cross_entropy(logits, target_batch)
  109. return loss
  110. def calc_loss_loader(data_loader, model, device, num_batches=None):
  111. total_loss = 0.
  112. if len(data_loader) == 0:
  113. return float("nan")
  114. elif num_batches is None:
  115. num_batches = len(data_loader)
  116. else:
  117. # Reduce the number of batches to match the total number of batches in the data loader
  118. # if num_batches exceeds the number of batches in the data loader
  119. num_batches = min(num_batches, len(data_loader))
  120. for i, (input_batch, target_batch) in enumerate(data_loader):
  121. if i < num_batches:
  122. loss = calc_loss_batch(input_batch, target_batch, model, device)
  123. total_loss += loss.item()
  124. else:
  125. break
  126. return total_loss / num_batches
  127. def evaluate_model(model, train_loader, val_loader, device, eval_iter):
  128. model.eval()
  129. with torch.no_grad():
  130. train_loss = calc_loss_loader(train_loader, model, device, num_batches=eval_iter)
  131. val_loss = calc_loss_loader(val_loader, model, device, num_batches=eval_iter)
  132. model.train()
  133. return train_loss, val_loss
  134. def train_classifier_simple(model, train_loader, val_loader, optimizer, device, num_epochs,
  135. eval_freq, eval_iter):
  136. # Initialize lists to track losses and examples seen
  137. train_losses, val_losses, train_accs, val_accs = [], [], [], []
  138. examples_seen, global_step = 0, -1
  139. # Main training loop
  140. for epoch in range(num_epochs):
  141. model.train() # Set model to training mode
  142. for input_batch, target_batch in train_loader:
  143. optimizer.zero_grad() # Reset loss gradients from previous batch iteration
  144. loss = calc_loss_batch(input_batch, target_batch, model, device)
  145. loss.backward() # Calculate loss gradients
  146. optimizer.step() # Update model weights using loss gradients
  147. examples_seen += input_batch.shape[0] # New: track examples instead of tokens
  148. global_step += 1
  149. # Optional evaluation step
  150. if global_step % eval_freq == 0:
  151. train_loss, val_loss = evaluate_model(
  152. model, train_loader, val_loader, device, eval_iter)
  153. train_losses.append(train_loss)
  154. val_losses.append(val_loss)
  155. print(f"Ep {epoch+1} (Step {global_step:06d}): "
  156. f"Train loss {train_loss:.3f}, Val loss {val_loss:.3f}")
  157. # Calculate accuracy after each epoch
  158. train_accuracy = calc_accuracy_loader(train_loader, model, device, num_batches=eval_iter)
  159. val_accuracy = calc_accuracy_loader(val_loader, model, device, num_batches=eval_iter)
  160. print(f"Training accuracy: {train_accuracy*100:.2f}% | ", end="")
  161. print(f"Validation accuracy: {val_accuracy*100:.2f}%")
  162. train_accs.append(train_accuracy)
  163. val_accs.append(val_accuracy)
  164. return train_losses, val_losses, train_accs, val_accs, examples_seen
  165. def plot_values(epochs_seen, examples_seen, train_values, val_values, label="loss"):
  166. fig, ax1 = plt.subplots(figsize=(5, 3))
  167. # Plot training and validation loss against epochs
  168. ax1.plot(epochs_seen, train_values, label=f"Training {label}")
  169. ax1.plot(epochs_seen, val_values, linestyle="-.", label=f"Validation {label}")
  170. ax1.set_xlabel("Epochs")
  171. ax1.set_ylabel(label.capitalize())
  172. ax1.legend()
  173. # Create a second x-axis for examples seen
  174. ax2 = ax1.twiny() # Create a second x-axis that shares the same y-axis
  175. ax2.plot(examples_seen, train_values, alpha=0) # Invisible plot for aligning ticks
  176. ax2.set_xlabel("Examples seen")
  177. fig.tight_layout() # Adjust layout to make room
  178. plt.savefig(f"{label}-plot.pdf")
  179. plt.show()
  180. def classify_review(text, model, tokenizer, device, max_length=None, pad_token_id=50256):
  181. model.eval()
  182. # Prepare inputs to the model
  183. input_ids = tokenizer.encode(text)
  184. supported_context_length = model.pos_emb.weight.shape[0]
  185. # Note: In the book, this was originally written as pos_emb.weight.shape[1] by mistake
  186. # It didn't break the code but would have caused unnecessary truncation (to 768 instead of 1024)
  187. # Truncate sequences if they too long
  188. input_ids = input_ids[:min(max_length, supported_context_length)]
  189. # Pad sequences to the longest sequence
  190. input_ids += [pad_token_id] * (max_length - len(input_ids))
  191. input_tensor = torch.tensor(input_ids, device=device).unsqueeze(0) # add batch dimension
  192. # Model inference
  193. with torch.no_grad():
  194. logits = model(input_tensor)[:, -1, :] # Logits of the last output token
  195. predicted_label = torch.argmax(logits, dim=-1).item()
  196. # Return the classified result
  197. return "spam" if predicted_label == 1 else "not spam"