train.py 7.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234
  1. # Copyright (c) Sebastian Raschka under Apache License 2.0 (see LICENSE.txt).
  2. # Source for "Build a Large Language Model From Scratch"
  3. # - https://www.manning.com/books/build-a-large-language-model-from-scratch
  4. # Code: https://github.com/rasbt/LLMs-from-scratch
  5. import matplotlib.pyplot as plt
  6. import os
  7. import torch
  8. import urllib.request
  9. # Import from local files
  10. from previous_chapters import GPTModel, create_dataloader_v1, generate_text_simple
  11. def text_to_token_ids(text, tokenizer):
  12. encoded = tokenizer.encode(text)
  13. encoded_tensor = torch.tensor(encoded).unsqueeze(0) # add batch dimension
  14. return encoded_tensor
  15. def token_ids_to_text(token_ids, tokenizer):
  16. flat = token_ids.squeeze(0) # remove batch dimension
  17. return tokenizer.decode(flat.tolist())
  18. def calc_loss_batch(input_batch, target_batch, model, device):
  19. input_batch, target_batch = input_batch.to(device), target_batch.to(device)
  20. logits = model(input_batch)
  21. loss = torch.nn.functional.cross_entropy(logits.flatten(0, 1), target_batch.flatten())
  22. return loss
  23. def calc_loss_loader(data_loader, model, device, num_batches=None):
  24. total_loss = 0.
  25. if num_batches is None:
  26. num_batches = len(data_loader)
  27. else:
  28. num_batches = min(num_batches, len(data_loader))
  29. for i, (input_batch, target_batch) in enumerate(data_loader):
  30. if i < num_batches:
  31. loss = calc_loss_batch(input_batch, target_batch, model, device)
  32. total_loss += loss.item()
  33. else:
  34. break
  35. return total_loss / num_batches
  36. def evaluate_model(model, train_loader, val_loader, device, eval_iter):
  37. model.eval()
  38. with torch.no_grad():
  39. train_loss = calc_loss_loader(train_loader, model, device, num_batches=eval_iter)
  40. val_loss = calc_loss_loader(val_loader, model, device, num_batches=eval_iter)
  41. model.train()
  42. return train_loss, val_loss
  43. def generate_and_print_sample(model, tokenizer, device, start_context):
  44. model.eval()
  45. context_size = model.pos_emb.weight.shape[0]
  46. encoded = text_to_token_ids(start_context, tokenizer).to(device)
  47. with torch.no_grad():
  48. token_ids = generate_text_simple(
  49. model=model, idx=encoded,
  50. max_new_tokens=50, context_size=context_size
  51. )
  52. decoded_text = token_ids_to_text(token_ids, tokenizer)
  53. print(decoded_text.replace("\n", " ")) # Compact print format
  54. model.train()
  55. def train_model_simple(model, train_loader, val_loader, optimizer, device, num_epochs,
  56. eval_freq, eval_iter, start_context):
  57. # Initialize lists to track losses and tokens seen
  58. train_losses, val_losses, track_tokens_seen = [], [], []
  59. tokens_seen = 0
  60. global_step = -1
  61. # Main training loop
  62. for epoch in range(num_epochs):
  63. model.train() # Set model to training mode
  64. for input_batch, target_batch in train_loader:
  65. optimizer.zero_grad() # Reset loss gradients from previous epoch
  66. loss = calc_loss_batch(input_batch, target_batch, model, device)
  67. loss.backward() # Calculate loss gradients
  68. optimizer.step() # Update model weights using loss gradients
  69. tokens_seen += input_batch.numel()
  70. global_step += 1
  71. # Optional evaluation step
  72. if global_step % eval_freq == 0:
  73. train_loss, val_loss = evaluate_model(
  74. model, train_loader, val_loader, device, eval_iter)
  75. train_losses.append(train_loss)
  76. val_losses.append(val_loss)
  77. track_tokens_seen.append(tokens_seen)
  78. print(f"Ep {epoch+1} (Step {global_step:06d}): "
  79. f"Train loss {train_loss:.3f}, Val loss {val_loss:.3f}")
  80. # Print a sample text after each epoch
  81. generate_and_print_sample(
  82. model, train_loader.dataset.tokenizer, device, start_context
  83. )
  84. return train_losses, val_losses, track_tokens_seen
  85. def plot_losses(epochs_seen, tokens_seen, train_losses, val_losses):
  86. fig, ax1 = plt.subplots()
  87. # Plot training and validation loss against epochs
  88. ax1.plot(epochs_seen, train_losses, label="Training loss")
  89. ax1.plot(epochs_seen, val_losses, linestyle="-.", label="Validation loss")
  90. ax1.set_xlabel("Epochs")
  91. ax1.set_ylabel("Loss")
  92. ax1.legend(loc="upper right")
  93. # Create a second x-axis for tokens seen
  94. ax2 = ax1.twiny() # Create a second x-axis that shares the same y-axis
  95. ax2.plot(tokens_seen, train_losses, alpha=0) # Invisible plot for aligning ticks
  96. ax2.set_xlabel("Tokens seen")
  97. fig.tight_layout() # Adjust layout to make room
  98. # plt.show()
  99. def main(gpt_config, hparams):
  100. torch.manual_seed(123)
  101. device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
  102. ##############################
  103. # Download data if necessary
  104. ##############################
  105. file_path = "the-verdict.txt"
  106. url = "https://raw.githubusercontent.com/rasbt/LLMs-from-scratch/main/ch02/01_main-chapter-code/the-verdict.txt"
  107. if not os.path.exists(file_path):
  108. with urllib.request.urlopen(url) as response:
  109. text_data = response.read().decode('utf-8')
  110. with open(file_path, "w", encoding="utf-8") as file:
  111. file.write(text_data)
  112. else:
  113. with open(file_path, "r", encoding="utf-8") as file:
  114. text_data = file.read()
  115. ##############################
  116. # Initialize model
  117. ##############################
  118. model = GPTModel(gpt_config)
  119. model.to(device) # no assignment model = model.to(device) necessary for nn.Module classes
  120. optimizer = torch.optim.AdamW(
  121. model.parameters(), lr=hparams["learning_rate"], weight_decay=hparams["weight_decay"]
  122. )
  123. ##############################
  124. # Set up dataloaders
  125. ##############################
  126. # Train/validation ratio
  127. train_ratio = 0.90
  128. split_idx = int(train_ratio * len(text_data))
  129. train_loader = create_dataloader_v1(
  130. text_data[:split_idx],
  131. batch_size=hparams["batch_size"],
  132. max_length=gpt_config["ctx_len"],
  133. stride=gpt_config["ctx_len"],
  134. drop_last=True,
  135. shuffle=True
  136. )
  137. val_loader = create_dataloader_v1(
  138. text_data[split_idx:],
  139. batch_size=hparams["batch_size"],
  140. max_length=gpt_config["ctx_len"],
  141. stride=gpt_config["ctx_len"],
  142. drop_last=False,
  143. shuffle=False
  144. )
  145. ##############################
  146. # Train model
  147. ##############################
  148. train_losses, val_losses, tokens_seen = train_model_simple(
  149. model, train_loader, val_loader, optimizer, device,
  150. num_epochs=hparams["num_epochs"], eval_freq=5, eval_iter=1,
  151. start_context="Every effort moves you",
  152. )
  153. return train_losses, val_losses, tokens_seen, model
  154. if __name__ == "__main__":
  155. GPT_CONFIG_124M = {
  156. "vocab_size": 50257, # Vocabulary size
  157. "ctx_len": 256, # Shortened context length (orig: 1024)
  158. "emb_dim": 768, # Embedding dimension
  159. "n_heads": 12, # Number of attention heads
  160. "n_layers": 12, # Number of layers
  161. "drop_rate": 0.1, # Dropout rate
  162. "qkv_bias": False # Query-key-value bias
  163. }
  164. OTHER_HPARAMS = {
  165. "learning_rate": 5e-4,
  166. "num_epochs": 10,
  167. "batch_size": 2,
  168. "weight_decay": 0.1
  169. }
  170. ###########################
  171. # Initiate training
  172. ###########################
  173. train_losses, val_losses, tokens_seen, model = main(GPT_CONFIG_124M, OTHER_HPARAMS)
  174. ###########################
  175. # After training
  176. ###########################
  177. # Plot results
  178. epochs_tensor = torch.linspace(0, OTHER_HPARAMS["num_epochs"], len(train_losses))
  179. plot_losses(epochs_tensor, tokens_seen, train_losses, val_losses)
  180. plt.savefig("loss.pdf")
  181. # Save and load model
  182. torch.save(model.state_dict(), "model.pth")
  183. model = GPTModel(GPT_CONFIG_124M)
  184. model.load_state_dict(torch.load("model.pth"))