train-gpt.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366
  1. # Copyright (c) Sebastian Raschka under Apache License 2.0 (see LICENSE.txt).
  2. # Source for "Build a Large Language Model From Scratch"
  3. # - https://www.manning.com/books/build-a-large-language-model-from-scratch
  4. # Code: https://github.com/rasbt/LLMs-from-scratch
  5. import argparse
  6. from pathlib import Path
  7. import time
  8. import pandas as pd
  9. import tiktoken
  10. import torch
  11. from torch.utils.data import DataLoader
  12. from torch.utils.data import Dataset
  13. from gpt_download import download_and_load_gpt2
  14. from previous_chapters import GPTModel, load_weights_into_gpt
  15. class IMDBDataset(Dataset):
  16. def __init__(self, csv_file, tokenizer, max_length=None, pad_token_id=50256):
  17. self.data = pd.read_csv(csv_file)
  18. self.max_length = max_length if max_length is not None else self._longest_encoded_length(tokenizer)
  19. # Pre-tokenize texts
  20. self.encoded_texts = [
  21. tokenizer.encode(text)[:self.max_length]
  22. for text in self.data["text"]
  23. ]
  24. # Pad sequences to the longest sequence
  25. self.encoded_texts = [
  26. et + [pad_token_id] * (self.max_length - len(et))
  27. for et in self.encoded_texts
  28. ]
  29. def __getitem__(self, index):
  30. encoded = self.encoded_texts[index]
  31. label = self.data.iloc[index]["label"]
  32. return torch.tensor(encoded, dtype=torch.long), torch.tensor(label, dtype=torch.long)
  33. def __len__(self):
  34. return len(self.data)
  35. def _longest_encoded_length(self, tokenizer):
  36. max_length = 0
  37. for text in self.data["text"]:
  38. encoded_length = len(tokenizer.encode(text))
  39. if encoded_length > max_length:
  40. max_length = encoded_length
  41. return max_length
  42. def instantiate_model(choose_model, load_weights):
  43. BASE_CONFIG = {
  44. "vocab_size": 50257, # Vocabulary size
  45. "context_length": 1024, # Context length
  46. "drop_rate": 0.0, # Dropout rate
  47. "qkv_bias": True # Query-key-value bias
  48. }
  49. model_configs = {
  50. "gpt2-small (124M)": {"emb_dim": 768, "n_layers": 12, "n_heads": 12},
  51. "gpt2-medium (355M)": {"emb_dim": 1024, "n_layers": 24, "n_heads": 16},
  52. "gpt2-large (774M)": {"emb_dim": 1280, "n_layers": 36, "n_heads": 20},
  53. "gpt2-xl (1558M)": {"emb_dim": 1600, "n_layers": 48, "n_heads": 25},
  54. }
  55. BASE_CONFIG.update(model_configs[choose_model])
  56. if not load_weights:
  57. torch.manual_seed(123)
  58. model = GPTModel(BASE_CONFIG)
  59. if load_weights:
  60. model_size = choose_model.split(" ")[-1].lstrip("(").rstrip(")")
  61. settings, params = download_and_load_gpt2(model_size=model_size, models_dir="gpt2")
  62. load_weights_into_gpt(model, params)
  63. model.eval()
  64. return model
  65. def calc_loss_batch(input_batch, target_batch, model, device, trainable_token=-1):
  66. input_batch, target_batch = input_batch.to(device), target_batch.to(device)
  67. logits = model(input_batch)[:, trainable_token, :] # Logits of last output token
  68. loss = torch.nn.functional.cross_entropy(logits, target_batch)
  69. return loss
  70. def calc_loss_loader(data_loader, model, device, num_batches=None, trainable_token=-1):
  71. total_loss = 0.
  72. if len(data_loader) == 0:
  73. return float("nan")
  74. elif num_batches is None:
  75. num_batches = len(data_loader)
  76. else:
  77. # Reduce the number of batches to match the total number of batches in the data loader
  78. # if num_batches exceeds the number of batches in the data loader
  79. num_batches = min(num_batches, len(data_loader))
  80. for i, (input_batch, target_batch) in enumerate(data_loader):
  81. if i < num_batches:
  82. loss = calc_loss_batch(input_batch, target_batch, model, device, trainable_token=trainable_token)
  83. total_loss += loss.item()
  84. else:
  85. break
  86. return total_loss / num_batches
  87. @torch.no_grad() # Disable gradient tracking for efficiency
  88. def calc_accuracy_loader(data_loader, model, device, num_batches=None, trainable_token=-1):
  89. model.eval()
  90. correct_predictions, num_examples = 0, 0
  91. if num_batches is None:
  92. num_batches = len(data_loader)
  93. else:
  94. num_batches = min(num_batches, len(data_loader))
  95. for i, (input_batch, target_batch) in enumerate(data_loader):
  96. if i < num_batches:
  97. input_batch, target_batch = input_batch.to(device), target_batch.to(device)
  98. logits = model(input_batch)[:, trainable_token, :] # Logits of last output token
  99. predicted_labels = torch.argmax(logits, dim=-1)
  100. num_examples += predicted_labels.shape[0]
  101. correct_predictions += (predicted_labels == target_batch).sum().item()
  102. else:
  103. break
  104. return correct_predictions / num_examples
  105. def evaluate_model(model, train_loader, val_loader, device, eval_iter, trainable_token=-1):
  106. model.eval()
  107. with torch.no_grad():
  108. train_loss = calc_loss_loader(train_loader, model, device, num_batches=eval_iter, trainable_token=trainable_token)
  109. val_loss = calc_loss_loader(val_loader, model, device, num_batches=eval_iter, trainable_token=trainable_token)
  110. model.train()
  111. return train_loss, val_loss
  112. def train_classifier_simple(model, train_loader, val_loader, optimizer, device, num_epochs,
  113. eval_freq, eval_iter, max_steps=None, trainable_token=-1):
  114. # Initialize lists to track losses and tokens seen
  115. train_losses, val_losses, train_accs, val_accs = [], [], [], []
  116. examples_seen, global_step = 0, -1
  117. # Main training loop
  118. for epoch in range(num_epochs):
  119. model.train() # Set model to training mode
  120. for input_batch, target_batch in train_loader:
  121. optimizer.zero_grad() # Reset loss gradients from previous batch iteration
  122. loss = calc_loss_batch(input_batch, target_batch, model, device, trainable_token=trainable_token)
  123. loss.backward() # Calculate loss gradients
  124. optimizer.step() # Update model weights using loss gradients
  125. examples_seen += input_batch.shape[0] # New: track examples instead of tokens
  126. global_step += 1
  127. # Optional evaluation step
  128. if global_step % eval_freq == 0:
  129. train_loss, val_loss = evaluate_model(
  130. model, train_loader, val_loader, device, eval_iter, trainable_token=trainable_token)
  131. train_losses.append(train_loss)
  132. val_losses.append(val_loss)
  133. print(f"Ep {epoch+1} (Step {global_step:06d}): "
  134. f"Train loss {train_loss:.3f}, Val loss {val_loss:.3f}")
  135. if max_steps is not None and global_step > max_steps:
  136. break
  137. # New: Calculate accuracy after each epoch
  138. train_accuracy = calc_accuracy_loader(train_loader, model, device, num_batches=eval_iter, trainable_token=trainable_token)
  139. val_accuracy = calc_accuracy_loader(val_loader, model, device, num_batches=eval_iter, trainable_token=trainable_token)
  140. print(f"Training accuracy: {train_accuracy*100:.2f}% | ", end="")
  141. print(f"Validation accuracy: {val_accuracy*100:.2f}%")
  142. train_accs.append(train_accuracy)
  143. val_accs.append(val_accuracy)
  144. if max_steps is not None and global_step > max_steps:
  145. break
  146. return train_losses, val_losses, train_accs, val_accs, examples_seen
  147. if __name__ == "__main__":
  148. parser = argparse.ArgumentParser()
  149. parser.add_argument(
  150. "--model_size",
  151. type=str,
  152. default="gpt2-small (124M)",
  153. help=(
  154. "Which GPT model to use. Options: 'gpt2-small (124M)', 'gpt2-medium (355M)',"
  155. " 'gpt2-large (774M)', 'gpt2-xl (1558M)'."
  156. )
  157. )
  158. parser.add_argument(
  159. "--weights",
  160. type=str,
  161. default="pretrained",
  162. help=(
  163. "Whether to use 'pretrained' or 'random' weights."
  164. )
  165. )
  166. parser.add_argument(
  167. "--trainable_layers",
  168. type=str,
  169. default="last_block",
  170. help=(
  171. "Which layers to train. Options: 'all', 'last_block', 'last_layer'."
  172. )
  173. )
  174. parser.add_argument(
  175. "--trainable_token",
  176. type=str,
  177. default="last",
  178. help=(
  179. "Which token to train. Options: 'first', 'last'."
  180. )
  181. )
  182. parser.add_argument(
  183. "--context_length",
  184. type=str,
  185. default="256",
  186. help=(
  187. "The context length of the data inputs."
  188. "Options: 'longest_training_example', 'model_context_length' or integer value."
  189. )
  190. )
  191. args = parser.parse_args()
  192. if args.trainable_token == "first":
  193. args.trainable_token = 0
  194. elif args.trainable_token == "last":
  195. args.trainable_token = -1
  196. else:
  197. raise ValueError("Invalid --trainable_token argument")
  198. ###############################
  199. # Load model
  200. ###############################
  201. if args.weights == "pretrained":
  202. load_weights = True
  203. elif args.weights == "random":
  204. load_weights = False
  205. else:
  206. raise ValueError("Invalid --weights argument.")
  207. model = instantiate_model(args.model_size, load_weights)
  208. for param in model.parameters():
  209. param.requires_grad = False
  210. if args.model_size == "gpt2-small (124M)":
  211. in_features = 768
  212. elif args.model_size == "gpt2-medium (355M)":
  213. in_features = 1024
  214. elif args.model_size == "gpt2-large (774M)":
  215. in_features = 1280
  216. elif args.model_size == "gpt2-xl (1558M)":
  217. in_features = 1600
  218. else:
  219. raise ValueError("Invalid --model_size argument")
  220. torch.manual_seed(123)
  221. model.out_head = torch.nn.Linear(in_features=in_features, out_features=2)
  222. if args.trainable_layers == "last_layer":
  223. pass
  224. elif args.trainable_layers == "last_block":
  225. for param in model.trf_blocks[-1].parameters():
  226. param.requires_grad = True
  227. for param in model.final_norm.parameters():
  228. param.requires_grad = True
  229. elif args.trainable_layers == "all":
  230. for param in model.parameters():
  231. param.requires_grad = True
  232. else:
  233. raise ValueError("Invalid --trainable_layers argument.")
  234. device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
  235. model.to(device)
  236. ###############################
  237. # Instantiate dataloaders
  238. ###############################
  239. base_path = Path(".")
  240. tokenizer = tiktoken.get_encoding("gpt2")
  241. train_dataset = None
  242. if args.context_length == "model_context_length":
  243. max_length = model.pos_emb.weight.shape[0]
  244. elif args.context_length == "longest_training_example":
  245. train_dataset = IMDBDataset(base_path / "train.csv", max_length=None, tokenizer=tokenizer)
  246. max_length = train_dataset.max_length
  247. else:
  248. try:
  249. max_length = int(args.context_length)
  250. except ValueError:
  251. raise ValueError("Invalid --context_length argument")
  252. if train_dataset is None:
  253. train_dataset = IMDBDataset(base_path / "train.csv", max_length=max_length, tokenizer=tokenizer)
  254. val_dataset = IMDBDataset(base_path / "validation.csv", max_length=max_length, tokenizer=tokenizer)
  255. test_dataset = IMDBDataset(base_path / "test.csv", max_length=max_length, tokenizer=tokenizer)
  256. num_workers = 0
  257. batch_size = 8
  258. train_loader = DataLoader(
  259. dataset=train_dataset,
  260. batch_size=batch_size,
  261. shuffle=True,
  262. num_workers=num_workers,
  263. drop_last=True,
  264. )
  265. val_loader = DataLoader(
  266. dataset=val_dataset,
  267. batch_size=batch_size,
  268. num_workers=num_workers,
  269. drop_last=False,
  270. )
  271. test_loader = DataLoader(
  272. dataset=test_dataset,
  273. batch_size=batch_size,
  274. num_workers=num_workers,
  275. drop_last=False,
  276. )
  277. ###############################
  278. # Train model
  279. ###############################
  280. start_time = time.time()
  281. torch.manual_seed(123)
  282. optimizer = torch.optim.AdamW(model.parameters(), lr=5e-5, weight_decay=0.1)
  283. num_epochs = 3
  284. train_losses, val_losses, train_accs, val_accs, examples_seen = train_classifier_simple(
  285. model, train_loader, val_loader, optimizer, device,
  286. num_epochs=num_epochs, eval_freq=50, eval_iter=20,
  287. max_steps=None, trainable_token=args.trainable_token
  288. )
  289. end_time = time.time()
  290. execution_time_minutes = (end_time - start_time) / 60
  291. print(f"Training completed in {execution_time_minutes:.2f} minutes.")
  292. ###############################
  293. # Evaluate model
  294. ###############################
  295. print("\nEvaluating on the full datasets ...\n")
  296. train_accuracy = calc_accuracy_loader(train_loader, model, device, trainable_token=args.trainable_token)
  297. val_accuracy = calc_accuracy_loader(val_loader, model, device, trainable_token=args.trainable_token)
  298. test_accuracy = calc_accuracy_loader(test_loader, model, device, trainable_token=args.trainable_token)
  299. print(f"Training accuracy: {train_accuracy*100:.2f}%")
  300. print(f"Validation accuracy: {val_accuracy*100:.2f}%")
  301. print(f"Test accuracy: {test_accuracy*100:.2f}%")