gpt_train.py 8.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242
  1. # Copyright (c) Sebastian Raschka under Apache License 2.0 (see LICENSE.txt).
  2. # Source for "Build a Large Language Model From Scratch"
  3. # - https://www.manning.com/books/build-a-large-language-model-from-scratch
  4. # Code: https://github.com/rasbt/LLMs-from-scratch
  5. import matplotlib.pyplot as plt
  6. import os
  7. import torch
  8. import urllib.request
  9. import tiktoken
  10. # Import from local files
  11. from previous_chapters import GPTModel, create_dataloader_v1, generate_text_simple
  12. def text_to_token_ids(text, tokenizer):
  13. encoded = tokenizer.encode(text)
  14. encoded_tensor = torch.tensor(encoded).unsqueeze(0) # add batch dimension
  15. return encoded_tensor
  16. def token_ids_to_text(token_ids, tokenizer):
  17. flat = token_ids.squeeze(0) # remove batch dimension
  18. return tokenizer.decode(flat.tolist())
  19. def calc_loss_batch(input_batch, target_batch, model, device):
  20. input_batch, target_batch = input_batch.to(device), target_batch.to(device)
  21. logits = model(input_batch)
  22. loss = torch.nn.functional.cross_entropy(logits.flatten(0, 1), target_batch.flatten())
  23. return loss
  24. def calc_loss_loader(data_loader, model, device, num_batches=None):
  25. total_loss = 0.
  26. if len(data_loader) == 0:
  27. return float("nan")
  28. elif num_batches is None:
  29. num_batches = len(data_loader)
  30. else:
  31. num_batches = min(num_batches, len(data_loader))
  32. for i, (input_batch, target_batch) in enumerate(data_loader):
  33. if i < num_batches:
  34. loss = calc_loss_batch(input_batch, target_batch, model, device)
  35. total_loss += loss.item()
  36. else:
  37. break
  38. return total_loss / num_batches
  39. def evaluate_model(model, train_loader, val_loader, device, eval_iter):
  40. model.eval()
  41. with torch.no_grad():
  42. train_loss = calc_loss_loader(train_loader, model, device, num_batches=eval_iter)
  43. val_loss = calc_loss_loader(val_loader, model, device, num_batches=eval_iter)
  44. model.train()
  45. return train_loss, val_loss
  46. def generate_and_print_sample(model, tokenizer, device, start_context):
  47. model.eval()
  48. context_size = model.pos_emb.weight.shape[0]
  49. encoded = text_to_token_ids(start_context, tokenizer).to(device)
  50. with torch.no_grad():
  51. token_ids = generate_text_simple(
  52. model=model, idx=encoded,
  53. max_new_tokens=50, context_size=context_size
  54. )
  55. decoded_text = token_ids_to_text(token_ids, tokenizer)
  56. print(decoded_text.replace("\n", " ")) # Compact print format
  57. model.train()
  58. def train_model_simple(model, train_loader, val_loader, optimizer, device, num_epochs,
  59. eval_freq, eval_iter, start_context, tokenizer):
  60. # Initialize lists to track losses and tokens seen
  61. train_losses, val_losses, track_tokens_seen = [], [], []
  62. tokens_seen = 0
  63. global_step = -1
  64. # Main training loop
  65. for epoch in range(num_epochs):
  66. model.train() # Set model to training mode
  67. for input_batch, target_batch in train_loader:
  68. optimizer.zero_grad() # Reset loss gradients from previous batch iteration
  69. loss = calc_loss_batch(input_batch, target_batch, model, device)
  70. loss.backward() # Calculate loss gradients
  71. optimizer.step() # Update model weights using loss gradients
  72. tokens_seen += input_batch.numel()
  73. global_step += 1
  74. # Optional evaluation step
  75. if global_step % eval_freq == 0:
  76. train_loss, val_loss = evaluate_model(
  77. model, train_loader, val_loader, device, eval_iter)
  78. train_losses.append(train_loss)
  79. val_losses.append(val_loss)
  80. track_tokens_seen.append(tokens_seen)
  81. print(f"Ep {epoch+1} (Step {global_step:06d}): "
  82. f"Train loss {train_loss:.3f}, Val loss {val_loss:.3f}")
  83. # Print a sample text after each epoch
  84. generate_and_print_sample(
  85. model, tokenizer, device, start_context
  86. )
  87. return train_losses, val_losses, track_tokens_seen
  88. def plot_losses(epochs_seen, tokens_seen, train_losses, val_losses):
  89. fig, ax1 = plt.subplots()
  90. # Plot training and validation loss against epochs
  91. ax1.plot(epochs_seen, train_losses, label="Training loss")
  92. ax1.plot(epochs_seen, val_losses, linestyle="-.", label="Validation loss")
  93. ax1.set_xlabel("Epochs")
  94. ax1.set_ylabel("Loss")
  95. ax1.legend(loc="upper right")
  96. # Create a second x-axis for tokens seen
  97. ax2 = ax1.twiny() # Create a second x-axis that shares the same y-axis
  98. ax2.plot(tokens_seen, train_losses, alpha=0) # Invisible plot for aligning ticks
  99. ax2.set_xlabel("Tokens seen")
  100. fig.tight_layout() # Adjust layout to make room
  101. # plt.show()
  102. def main(gpt_config, settings):
  103. torch.manual_seed(123)
  104. device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
  105. ##############################
  106. # Download data if necessary
  107. ##############################
  108. file_path = "the-verdict.txt"
  109. url = "https://raw.githubusercontent.com/rasbt/LLMs-from-scratch/main/ch02/01_main-chapter-code/the-verdict.txt"
  110. if not os.path.exists(file_path):
  111. with urllib.request.urlopen(url) as response:
  112. text_data = response.read().decode('utf-8')
  113. with open(file_path, "w", encoding="utf-8") as file:
  114. file.write(text_data)
  115. else:
  116. with open(file_path, "r", encoding="utf-8") as file:
  117. text_data = file.read()
  118. ##############################
  119. # Initialize model
  120. ##############################
  121. model = GPTModel(gpt_config)
  122. model.to(device) # no assignment model = model.to(device) necessary for nn.Module classes
  123. optimizer = torch.optim.AdamW(
  124. model.parameters(), lr=settings["learning_rate"], weight_decay=settings["weight_decay"]
  125. )
  126. ##############################
  127. # Set up dataloaders
  128. ##############################
  129. # Train/validation ratio
  130. train_ratio = 0.90
  131. split_idx = int(train_ratio * len(text_data))
  132. train_loader = create_dataloader_v1(
  133. text_data[:split_idx],
  134. batch_size=settings["batch_size"],
  135. max_length=gpt_config["context_length"],
  136. stride=gpt_config["context_length"],
  137. drop_last=True,
  138. shuffle=True,
  139. num_workers=0
  140. )
  141. val_loader = create_dataloader_v1(
  142. text_data[split_idx:],
  143. batch_size=settings["batch_size"],
  144. max_length=gpt_config["context_length"],
  145. stride=gpt_config["context_length"],
  146. drop_last=False,
  147. shuffle=False,
  148. num_workers=0
  149. )
  150. ##############################
  151. # Train model
  152. ##############################
  153. tokenizer = tiktoken.get_encoding("gpt2")
  154. train_losses, val_losses, tokens_seen = train_model_simple(
  155. model, train_loader, val_loader, optimizer, device,
  156. num_epochs=settings["num_epochs"], eval_freq=5, eval_iter=1,
  157. start_context="Every effort moves you", tokenizer=tokenizer
  158. )
  159. return train_losses, val_losses, tokens_seen, model
  160. if __name__ == "__main__":
  161. GPT_CONFIG_124M = {
  162. "vocab_size": 50257, # Vocabulary size
  163. "context_length": 256, # Shortened context length (orig: 1024)
  164. "emb_dim": 768, # Embedding dimension
  165. "n_heads": 12, # Number of attention heads
  166. "n_layers": 12, # Number of layers
  167. "drop_rate": 0.1, # Dropout rate
  168. "qkv_bias": False # Query-key-value bias
  169. }
  170. OTHER_SETTINGS = {
  171. "learning_rate": 5e-4,
  172. "num_epochs": 10,
  173. "batch_size": 2,
  174. "weight_decay": 0.1
  175. }
  176. ###########################
  177. # Initiate training
  178. ###########################
  179. train_losses, val_losses, tokens_seen, model = main(GPT_CONFIG_124M, OTHER_SETTINGS)
  180. ###########################
  181. # After training
  182. ###########################
  183. # Plot results
  184. epochs_tensor = torch.linspace(0, OTHER_SETTINGS["num_epochs"], len(train_losses))
  185. plot_losses(epochs_tensor, tokens_seen, train_losses, val_losses)
  186. plt.savefig("loss.pdf")
  187. # Save and load model
  188. torch.save(model.state_dict(), "model.pth")
  189. model = GPTModel(GPT_CONFIG_124M)
  190. model.load_state_dict(torch.load("model.pth", weights_only=True))