test_appendix_d.py 3.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118
  1. # Copyright (c) Sebastian Raschka under Apache License 2.0 (see LICENSE.txt).
  2. # Source for "Build a Large Language Model From Scratch"
  3. # - https://www.manning.com/books/build-a-large-language-model-from-scratch
  4. # Code: https://github.com/rasbt/LLMs-from-scratch
  5. from llms_from_scratch.ch02 import create_dataloader_v1
  6. from llms_from_scratch.ch04 import GPTModel
  7. from llms_from_scratch.appendix_d import train_model
  8. import os
  9. import urllib
  10. import tiktoken
  11. import torch
  12. from torch.utils.data import Subset, DataLoader
  13. def test_train(tmp_path):
  14. GPT_CONFIG_124M = {
  15. "vocab_size": 50257, # Vocabulary size
  16. "context_length": 256, # Shortened context length (orig: 1024)
  17. "emb_dim": 768, # Embedding dimension
  18. "n_heads": 12, # Number of attention heads
  19. "n_layers": 12, # Number of layers
  20. "drop_rate": 0.1, # Dropout rate
  21. "qkv_bias": False # Query-key-value bias
  22. }
  23. OTHER_SETTINGS = {
  24. "learning_rate": 5e-4,
  25. "num_epochs": 2,
  26. "batch_size": 1,
  27. "weight_decay": 0.1
  28. }
  29. torch.manual_seed(123)
  30. device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
  31. ##############################
  32. # Download data if necessary
  33. ##############################
  34. file_path = tmp_path / "the-verdict.txt"
  35. url = "https://raw.githubusercontent.com/rasbt/LLMs-from-scratch/main/ch02/01_main-chapter-code/the-verdict.txt"
  36. if not os.path.exists(file_path):
  37. with urllib.request.urlopen(url) as response:
  38. text_data = response.read().decode("utf-8")
  39. with open(file_path, "w", encoding="utf-8") as file:
  40. file.write(text_data)
  41. else:
  42. with open(file_path, "r", encoding="utf-8") as file:
  43. text_data = file.read()
  44. ##############################
  45. # Initialize model
  46. ##############################
  47. model = GPTModel(GPT_CONFIG_124M)
  48. model.to(device) # no assignment model = model.to(device) necessary for nn.Module classes
  49. ##############################
  50. # Set up dataloaders
  51. ##############################
  52. # Train/validation ratio
  53. train_ratio = 0.90
  54. split_idx = int(train_ratio * len(text_data))
  55. train_loader = create_dataloader_v1(
  56. text_data[:split_idx],
  57. batch_size=OTHER_SETTINGS["batch_size"],
  58. max_length=GPT_CONFIG_124M["context_length"],
  59. stride=GPT_CONFIG_124M["context_length"],
  60. drop_last=True,
  61. shuffle=True,
  62. num_workers=0
  63. )
  64. val_loader = create_dataloader_v1(
  65. text_data[split_idx:],
  66. batch_size=OTHER_SETTINGS["batch_size"],
  67. max_length=GPT_CONFIG_124M["context_length"],
  68. stride=GPT_CONFIG_124M["context_length"],
  69. drop_last=False,
  70. shuffle=False,
  71. num_workers=0
  72. )
  73. ##############################
  74. # Train model
  75. ##############################
  76. tokenizer = tiktoken.get_encoding("gpt2")
  77. train_subset = Subset(train_loader.dataset, range(1))
  78. one_batch_train_loader = DataLoader(train_subset, batch_size=1)
  79. val_subset = Subset(val_loader.dataset, range(1))
  80. one_batch_val_loader = DataLoader(val_subset, batch_size=1)
  81. peak_lr = 0.001 # this was originally set to 5e-4 in the book by mistake
  82. optimizer = torch.optim.AdamW(model.parameters(), lr=peak_lr, weight_decay=0.1) # the book accidentally omitted the lr assignment
  83. tokenizer = tiktoken.get_encoding("gpt2")
  84. n_epochs = 6
  85. warmup_steps = 1
  86. train_losses, val_losses, tokens_seen, lrs = train_model(
  87. model, one_batch_train_loader, one_batch_val_loader, optimizer, device, n_epochs=n_epochs,
  88. eval_freq=5, eval_iter=1, start_context="Every effort moves you",
  89. tokenizer=tokenizer, warmup_steps=warmup_steps,
  90. initial_lr=1e-5, min_lr=1e-5
  91. )
  92. assert round(train_losses[0], 1) == 10.9
  93. assert round(val_losses[0], 1) == 11.0
  94. assert train_losses[-1] < train_losses[0]