appendix_d.py 3.9 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394
  1. # Copyright (c) Sebastian Raschka under Apache License 2.0 (see LICENSE.txt).
  2. # Source for "Build a Large Language Model From Scratch"
  3. # - https://www.manning.com/books/build-a-large-language-model-from-scratch
  4. # Code: https://github.com/rasbt/LLMs-from-scratch
  5. from .ch05 import calc_loss_batch, evaluate_model, generate_and_print_sample
  6. import math
  7. import torch
  8. def find_highest_gradient(model):
  9. max_grad = None
  10. for param in model.parameters():
  11. if param.grad is not None:
  12. grad_values = param.grad.data.flatten()
  13. max_grad_param = grad_values.max()
  14. if max_grad is None or max_grad_param > max_grad:
  15. max_grad = max_grad_param
  16. return max_grad
  17. def train_model(model, train_loader, val_loader, optimizer, device,
  18. n_epochs, eval_freq, eval_iter, start_context, tokenizer,
  19. warmup_steps, initial_lr=3e-05, min_lr=1e-6, orig_book_version=False):
  20. train_losses, val_losses, track_tokens_seen, track_lrs = [], [], [], []
  21. tokens_seen, global_step = 0, -1
  22. # Retrieve the maximum learning rate from the optimizer
  23. peak_lr = optimizer.param_groups[0]["lr"]
  24. # Calculate the total number of iterations in the training process
  25. total_training_steps = len(train_loader) * n_epochs
  26. # Calculate the learning rate increment during the warmup phase
  27. lr_increment = (peak_lr - initial_lr) / warmup_steps
  28. for epoch in range(n_epochs):
  29. model.train()
  30. for input_batch, target_batch in train_loader:
  31. optimizer.zero_grad()
  32. global_step += 1
  33. # Adjust the learning rate based on the current phase (warmup or cosine annealing)
  34. if global_step < warmup_steps:
  35. # Linear warmup
  36. lr = initial_lr + global_step * lr_increment
  37. else:
  38. # Cosine annealing after warmup
  39. progress = ((global_step - warmup_steps) /
  40. (total_training_steps - warmup_steps))
  41. lr = min_lr + (peak_lr - min_lr) * 0.5 * (1 + math.cos(math.pi * progress))
  42. # Apply the calculated learning rate to the optimizer
  43. for param_group in optimizer.param_groups:
  44. param_group["lr"] = lr
  45. track_lrs.append(lr) # Store the current learning rate
  46. # Calculate and backpropagate the loss
  47. loss = calc_loss_batch(input_batch, target_batch, model, device)
  48. loss.backward()
  49. # Apply gradient clipping after the warmup phase to avoid exploding gradients
  50. if orig_book_version:
  51. if global_step > warmup_steps:
  52. torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
  53. else:
  54. if global_step >= warmup_steps: # the book originally used global_step > warmup_steps, which led to a skipped clipping step after warmup
  55. torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
  56. optimizer.step()
  57. tokens_seen += input_batch.numel()
  58. # Periodically evaluate the model on the training and validation sets
  59. if global_step % eval_freq == 0:
  60. train_loss, val_loss = evaluate_model(
  61. model, train_loader, val_loader,
  62. device, eval_iter
  63. )
  64. train_losses.append(train_loss)
  65. val_losses.append(val_loss)
  66. track_tokens_seen.append(tokens_seen)
  67. # Print the current losses
  68. print(f"Ep {epoch+1} (Iter {global_step:06d}): "
  69. f"Train loss {train_loss:.3f}, "
  70. f"Val loss {val_loss:.3f}")
  71. # Generate and print a sample from the model to monitor progress
  72. generate_and_print_sample(
  73. model, tokenizer, device, start_context
  74. )
  75. return train_losses, val_losses, track_tokens_seen, track_lrs