tests.py 3.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101
  1. # Copyright (c) Sebastian Raschka under Apache License 2.0 (see LICENSE.txt).
  2. # Source for "Build a Large Language Model From Scratch"
  3. # - https://www.manning.com/books/build-a-large-language-model-from-scratch
  4. # Code: https://github.com/rasbt/LLMs-from-scratch
  5. # File for internal use (unit tests)
  6. import pytest
  7. from gpt_train import main
  8. import http.client
  9. from urllib.parse import urlparse
  10. @pytest.fixture
  11. def gpt_config():
  12. return {
  13. "vocab_size": 50257,
  14. "context_length": 12, # small for testing efficiency
  15. "emb_dim": 32, # small for testing efficiency
  16. "n_heads": 4, # small for testing efficiency
  17. "n_layers": 2, # small for testing efficiency
  18. "drop_rate": 0.1,
  19. "qkv_bias": False
  20. }
  21. @pytest.fixture
  22. def other_settings():
  23. return {
  24. "learning_rate": 5e-4,
  25. "num_epochs": 1, # small for testing efficiency
  26. "batch_size": 2,
  27. "weight_decay": 0.1
  28. }
  29. def test_main(gpt_config, other_settings):
  30. train_losses, val_losses, tokens_seen, model = main(gpt_config, other_settings)
  31. assert len(train_losses) == 39, "Unexpected number of training losses"
  32. assert len(val_losses) == 39, "Unexpected number of validation losses"
  33. assert len(tokens_seen) == 39, "Unexpected number of tokens seen"
  34. def check_file_size(url, expected_size):
  35. parsed_url = urlparse(url)
  36. if parsed_url.scheme == "https":
  37. conn = http.client.HTTPSConnection(parsed_url.netloc)
  38. else:
  39. conn = http.client.HTTPConnection(parsed_url.netloc)
  40. conn.request("HEAD", parsed_url.path)
  41. response = conn.getresponse()
  42. if response.status != 200:
  43. return False, f"{url} not accessible"
  44. size = response.getheader("Content-Length")
  45. if size is None:
  46. return False, "Content-Length header is missing"
  47. size = int(size)
  48. if size != expected_size:
  49. return False, f"{url} file has expected size {expected_size}, but got {size}"
  50. return True, f"{url} file size is correct"
  51. def test_model_files():
  52. def check_model_files(base_url):
  53. model_size = "124M"
  54. files = {
  55. "checkpoint": 77,
  56. "encoder.json": 1042301,
  57. "hparams.json": 90,
  58. "model.ckpt.data-00000-of-00001": 497759232,
  59. "model.ckpt.index": 5215,
  60. "model.ckpt.meta": 471155,
  61. "vocab.bpe": 456318
  62. }
  63. for file_name, expected_size in files.items():
  64. url = f"{base_url}/{model_size}/{file_name}"
  65. valid, message = check_file_size(url, expected_size)
  66. assert valid, message
  67. model_size = "355M"
  68. files = {
  69. "checkpoint": 77,
  70. "encoder.json": 1042301,
  71. "hparams.json": 91,
  72. "model.ckpt.data-00000-of-00001": 1419292672,
  73. "model.ckpt.index": 10399,
  74. "model.ckpt.meta": 926519,
  75. "vocab.bpe": 456318
  76. }
  77. for file_name, expected_size in files.items():
  78. url = f"{base_url}/{model_size}/{file_name}"
  79. valid, message = check_file_size(url, expected_size)
  80. assert valid, message
  81. check_model_files(base_url="https://openaipublic.blob.core.windows.net/gpt-2/models")
  82. check_model_files(base_url="https://f001.backblazeb2.com/file/LLMs-from-scratch/gpt2")