test_ch05.py 3.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115
  1. # Copyright (c) Sebastian Raschka under Apache License 2.0 (see LICENSE.txt).
  2. # Source for "Build a Large Language Model From Scratch"
  3. # - https://www.manning.com/books/build-a-large-language-model-from-scratch
  4. # Code: https://github.com/rasbt/LLMs-from-scratch
  5. from llms_from_scratch.ch02 import create_dataloader_v1
  6. from llms_from_scratch.ch04 import GPTModel
  7. from llms_from_scratch.ch05 import train_model_simple
  8. import os
  9. import urllib
  10. import pytest
  11. import tiktoken
  12. import torch
  13. from torch.utils.data import Subset, DataLoader
  14. @pytest.mark.parametrize("file_name", ["the-verdict.txt"])
  15. def test_train_simple(tmp_path, file_name):
  16. GPT_CONFIG_124M = {
  17. "vocab_size": 50257, # Vocabulary size
  18. "context_length": 256, # Shortened context length (orig: 1024)
  19. "emb_dim": 768, # Embedding dimension
  20. "n_heads": 12, # Number of attention heads
  21. "n_layers": 12, # Number of layers
  22. "drop_rate": 0.1, # Dropout rate
  23. "qkv_bias": False # Query-key-value bias
  24. }
  25. OTHER_SETTINGS = {
  26. "learning_rate": 5e-4,
  27. "num_epochs": 2,
  28. "batch_size": 1,
  29. "weight_decay": 0.1
  30. }
  31. torch.manual_seed(123)
  32. device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
  33. ##############################
  34. # Download data if necessary
  35. ##############################
  36. file_path = tmp_path / "the-verdict.txt"
  37. url = "https://raw.githubusercontent.com/rasbt/LLMs-from-scratch/main/ch02/01_main-chapter-code/the-verdict.txt"
  38. if not os.path.exists(file_path):
  39. with urllib.request.urlopen(url) as response:
  40. text_data = response.read().decode('utf-8')
  41. with open(file_path, "w", encoding="utf-8") as file:
  42. file.write(text_data)
  43. else:
  44. with open(file_path, "r", encoding="utf-8") as file:
  45. text_data = file.read()
  46. ##############################
  47. # Initialize model
  48. ##############################
  49. model = GPTModel(GPT_CONFIG_124M)
  50. model.to(device) # no assignment model = model.to(device) necessary for nn.Module classes
  51. optimizer = torch.optim.AdamW(
  52. model.parameters(), lr=OTHER_SETTINGS["learning_rate"], weight_decay=OTHER_SETTINGS["weight_decay"]
  53. )
  54. ##############################
  55. # Set up dataloaders
  56. ##############################
  57. # Train/validation ratio
  58. train_ratio = 0.90
  59. split_idx = int(train_ratio * len(text_data))
  60. train_loader = create_dataloader_v1(
  61. text_data[:split_idx],
  62. batch_size=OTHER_SETTINGS["batch_size"],
  63. max_length=GPT_CONFIG_124M["context_length"],
  64. stride=GPT_CONFIG_124M["context_length"],
  65. drop_last=True,
  66. shuffle=True,
  67. num_workers=0
  68. )
  69. val_loader = create_dataloader_v1(
  70. text_data[split_idx:],
  71. batch_size=OTHER_SETTINGS["batch_size"],
  72. max_length=GPT_CONFIG_124M["context_length"],
  73. stride=GPT_CONFIG_124M["context_length"],
  74. drop_last=False,
  75. shuffle=False,
  76. num_workers=0
  77. )
  78. ##############################
  79. # Train model
  80. ##############################
  81. tokenizer = tiktoken.get_encoding("gpt2")
  82. train_subset = Subset(train_loader.dataset, range(1))
  83. one_batch_train_loader = DataLoader(train_subset, batch_size=1)
  84. val_subset = Subset(val_loader.dataset, range(1))
  85. one_batch_val_loader = DataLoader(val_subset, batch_size=1)
  86. train_losses, val_losses, tokens_seen = train_model_simple(
  87. model, one_batch_train_loader, one_batch_val_loader, optimizer, device,
  88. num_epochs=OTHER_SETTINGS["num_epochs"], eval_freq=1, eval_iter=1,
  89. start_context="Every effort moves you", tokenizer=tokenizer
  90. )
  91. assert round(train_losses[0], 1) == 7.6
  92. assert round(val_losses[0], 1) == 10.1
  93. assert train_losses[-1] < train_losses[0]