hparam_search.py 7.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213
  1. # Copyright (c) Sebastian Raschka under Apache License 2.0 (see LICENSE.txt).
  2. # Source for "Build a Large Language Model From Scratch"
  3. # - https://www.manning.com/books/build-a-large-language-model-from-scratch
  4. # Code: https://github.com/rasbt/LLMs-from-scratch
  5. import itertools
  6. import math
  7. import os
  8. import tiktoken
  9. import torch
  10. from previous_chapters import GPTModel, create_dataloader_v1
  11. # Define a grid of hyperparameters to search over
  12. HPARAM_GRID = {
  13. "batch_size": [2, 4, 8, 16],
  14. "drop_rate": [0.0, 0.1, 0.2],
  15. "warmup_iters": [10, 20, 30],
  16. "weight_decay": [0.1, 0.01, 0.0],
  17. "peak_lr": [0.0001, 0.0005, 0.001, 0.005],
  18. "initial_lr": [0.00005, 0.0001],
  19. "min_lr": [0.00005, 0.00001, 0.0001],
  20. "n_epochs": [5, 10, 15, 20, 25],
  21. }
  22. def calc_loss_loader(data_loader, model, device, num_batches=None):
  23. total_loss = 0.
  24. if num_batches is None:
  25. num_batches = len(data_loader)
  26. else:
  27. num_batches = min(num_batches, len(data_loader))
  28. for i, (input_batch, target_batch) in enumerate(data_loader):
  29. if i < num_batches:
  30. loss = calc_loss_batch(input_batch, target_batch, model, device)
  31. total_loss += loss.item()
  32. else:
  33. break
  34. return total_loss / num_batches
  35. def calc_loss_batch(input_batch, target_batch, model, device):
  36. input_batch, target_batch = input_batch.to(device), target_batch.to(device)
  37. logits = model(input_batch)
  38. logits = logits.view(-1, logits.size(-1))
  39. loss = torch.nn.functional.cross_entropy(logits, target_batch.view(-1))
  40. return loss
  41. def evaluate_model(model, train_loader, val_loader, device, eval_iter):
  42. model.eval()
  43. with torch.no_grad():
  44. train_loss = calc_loss_loader(train_loader, model, device, num_iters=eval_iter)
  45. val_loss = calc_loss_loader(val_loader, model, device, num_iters=eval_iter)
  46. model.train()
  47. return train_loss, val_loss
  48. def train_model(model, train_loader, val_loader, optimizer, device,
  49. n_epochs, eval_freq, eval_iter,
  50. encoded_start_context, tokenizer, warmup_iters=10,
  51. initial_lr=3e-05, min_lr=1e-6):
  52. global_step = 0
  53. max_lr = optimizer.param_groups[0]["lr"]
  54. # Calculate total number of iterations
  55. total_training_iters = len(train_loader) * n_epochs
  56. # Calculate the learning rate increment at each step during warmup
  57. lr_increment = (optimizer.param_groups[0]["lr"] - initial_lr) / warmup_iters
  58. for epoch in range(n_epochs):
  59. model.train()
  60. for input_batch, target_batch in train_loader:
  61. optimizer.zero_grad()
  62. # Increment the global step at the beginning of the iteration
  63. global_step += 1
  64. # Warmup: adjust learning rate linearly
  65. if global_step < warmup_iters:
  66. lr = initial_lr + global_step * lr_increment
  67. # Cosine annealing phase
  68. else:
  69. progress = (global_step - warmup_iters) / (total_training_iters - warmup_iters)
  70. lr = min_lr + (max_lr - min_lr) * 0.5 * (1 + math.cos(math.pi * progress))
  71. # Apply the calculated learning rate
  72. for param_group in optimizer.param_groups:
  73. param_group["lr"] = lr
  74. loss = calc_loss_batch(input_batch, target_batch, model, device)
  75. loss.backward()
  76. # Apply gradient clipping
  77. if global_step >= warmup_iters:
  78. torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
  79. optimizer.step()
  80. train_loss, val_loss = evaluate_model(model, train_loader, val_loader, device, eval_iter)
  81. return train_loss, val_loss
  82. if __name__ == "__main__":
  83. # Generate all combinations of hyperparameters
  84. hyperparameter_combinations = list(itertools.product(*HPARAM_GRID.values()))
  85. total_combinations = len(hyperparameter_combinations)
  86. print(f"Total hyperparameter configurations: {total_combinations}")
  87. # Placeholder for the best loss and best hyperparameters
  88. best_val_loss = float('inf')
  89. best_hparams = {}
  90. script_path = os.path.abspath(__file__)
  91. script_dir = os.path.dirname(script_path)
  92. with open(os.path.join(script_dir, "the-verdict.txt"), "r", encoding="utf-8") as file:
  93. text_data = file.read()
  94. tokenizer = tiktoken.get_encoding("gpt2")
  95. device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
  96. train_ratio = 0.95
  97. split_idx = int(train_ratio * len(text_data))
  98. torch.manual_seed(123)
  99. interrupted = False
  100. current_config = 0
  101. for combination in hyperparameter_combinations:
  102. try:
  103. current_config += 1
  104. print(f"Evaluating configuration {current_config} of {total_combinations}")
  105. # Unpack the current combination of hyperparameters
  106. HPARAM_CONFIG = dict(zip(HPARAM_GRID.keys(), combination))
  107. GPT_CONFIG_124M = {
  108. "vocab_size": 50257, # Vocabulary size
  109. "context_length": 256, # Context length -- shortened from original 1024 tokens
  110. "emb_dim": 768, # Embedding dimension
  111. "n_heads": 12, # Number of attention heads
  112. "n_layers": 12, # Number of layers
  113. "drop_rate": HPARAM_CONFIG["drop_rate"],
  114. "qkv_bias": False, # Query-Key-Value bias
  115. }
  116. torch.manual_seed(123)
  117. train_loader = create_dataloader_v1(
  118. text_data[:split_idx],
  119. batch_size=HPARAM_CONFIG["batch_size"],
  120. max_length=GPT_CONFIG_124M["context_length"],
  121. stride=GPT_CONFIG_124M["context_length"],
  122. drop_last=True,
  123. shuffle=True,
  124. num_workers=0
  125. )
  126. val_loader = create_dataloader_v1(
  127. text_data[split_idx:],
  128. batch_size=HPARAM_CONFIG["batch_size"],
  129. max_length=GPT_CONFIG_124M["context_length"],
  130. stride=GPT_CONFIG_124M["context_length"],
  131. drop_last=False,
  132. shuffle=False,
  133. num_workers=0
  134. )
  135. model = GPTModel(GPT_CONFIG_124M)
  136. model.to(device)
  137. optimizer = torch.optim.AdamW(
  138. model.parameters(),
  139. lr=HPARAM_CONFIG["peak_lr"],
  140. weight_decay=HPARAM_CONFIG["weight_decay"]
  141. )
  142. encoded_start_context = tokenizer.encode("Nevertheless")
  143. encoded_tensor = torch.tensor(encoded_start_context).unsqueeze(0)
  144. train_loss, val_loss = train_model(
  145. model, train_loader, val_loader, optimizer, device,
  146. n_epochs=HPARAM_CONFIG["n_epochs"],
  147. eval_freq=5, eval_iter=1,
  148. encoded_start_context=encoded_tensor,
  149. tokenizer=tokenizer,
  150. warmup_iters=HPARAM_CONFIG["warmup_iters"],
  151. initial_lr=HPARAM_CONFIG["initial_lr"],
  152. min_lr=HPARAM_CONFIG["min_lr"]
  153. )
  154. # Log the best hyperparameters based on validation loss
  155. if val_loss < best_val_loss:
  156. best_val_loss = val_loss
  157. best_train_loss = train_loss
  158. best_hparams = HPARAM_CONFIG
  159. except KeyboardInterrupt:
  160. print("Hyperparameter search completed.")
  161. print(f"Best hyperparameters: {best_hparams}")
  162. print(f"Best Val loss: {best_val_loss} | Training loss {train_loss}")
  163. interrupted = True
  164. break
  165. if not interrupted:
  166. print("Hyperparameter search completed.")
  167. print(f"Best hyperparameters: {best_hparams}")
  168. print(f"Best Val loss: {best_val_loss} | Training loss {train_loss}")