tests.py 5.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154
  1. import os
  2. import sys
  3. import io
  4. import nbformat
  5. import types
  6. import pytest
  7. import tiktoken
  8. def import_definitions_from_notebook(fullname, names):
  9. """Loads function definitions from a Jupyter notebook file into a module."""
  10. path = os.path.join(os.path.dirname(__file__), fullname + ".ipynb")
  11. path = os.path.normpath(path)
  12. if not os.path.exists(path):
  13. raise FileNotFoundError(f"Notebook file not found at: {path}")
  14. with io.open(path, "r", encoding="utf-8") as f:
  15. nb = nbformat.read(f, as_version=4)
  16. mod = types.ModuleType(fullname)
  17. sys.modules[fullname] = mod
  18. # Execute all code cells to capture dependencies
  19. for cell in nb.cells:
  20. if cell.cell_type == "code":
  21. exec(cell.source, mod.__dict__)
  22. # Ensure required names are in module
  23. missing_names = [name for name in names if name not in mod.__dict__]
  24. if missing_names:
  25. raise ImportError(f"Missing definitions in notebook: {missing_names}")
  26. return mod
  27. @pytest.fixture(scope="module")
  28. def imported_module():
  29. fullname = "bpe-from-scratch"
  30. names = ["BPETokenizerSimple", "download_file_if_absent"]
  31. return import_definitions_from_notebook(fullname, names)
  32. @pytest.fixture(scope="module")
  33. def verdict_file(imported_module):
  34. """Fixture to handle downloading The Verdict file."""
  35. download_file_if_absent = getattr(imported_module, "download_file_if_absent", None)
  36. verdict_path = download_file_if_absent(
  37. url=(
  38. "https://raw.githubusercontent.com/rasbt/"
  39. "LLMs-from-scratch/main/ch02/01_main-chapter-code/"
  40. "the-verdict.txt"
  41. ),
  42. filename="the-verdict.txt",
  43. search_dirs=["ch02/01_main-chapter-code/", "../01_main-chapter-code/", "."]
  44. )
  45. return verdict_path
  46. @pytest.fixture(scope="module")
  47. def gpt2_files(imported_module):
  48. """Fixture to handle downloading GPT-2 files."""
  49. download_file_if_absent = getattr(imported_module, "download_file_if_absent", None)
  50. search_directories = ["ch02/02_bonus_bytepair-encoder/gpt2_model/", "../02_bonus_bytepair-encoder/gpt2_model/", "."]
  51. files_to_download = {
  52. "https://openaipublic.blob.core.windows.net/gpt-2/models/124M/vocab.bpe": "vocab.bpe",
  53. "https://openaipublic.blob.core.windows.net/gpt-2/models/124M/encoder.json": "encoder.json"
  54. }
  55. paths = {filename: download_file_if_absent(url, filename, search_directories)
  56. for url, filename in files_to_download.items()}
  57. return paths
  58. def test_tokenizer_training(imported_module, verdict_file):
  59. BPETokenizerSimple = getattr(imported_module, "BPETokenizerSimple", None)
  60. tokenizer = BPETokenizerSimple()
  61. with open(verdict_file, "r", encoding="utf-8") as f: # added ../01_main-chapter-code/
  62. text = f.read()
  63. tokenizer.train(text, vocab_size=1000, allowed_special={"<|endoftext|>"})
  64. assert len(tokenizer.vocab) == 1000, "Tokenizer vocabulary size mismatch."
  65. assert len(tokenizer.bpe_merges) == 742, "Tokenizer BPE merges count mismatch."
  66. input_text = "Jack embraced beauty through art and life."
  67. token_ids = tokenizer.encode(input_text)
  68. assert token_ids == [424, 256, 654, 531, 302, 311, 256, 296, 97, 465, 121, 595, 841, 116, 287, 466, 256, 326, 972, 46], "Token IDs do not match expected output."
  69. assert tokenizer.decode(token_ids) == input_text, "Decoded text does not match the original input."
  70. tokenizer.save_vocab_and_merges(vocab_path="vocab.json", bpe_merges_path="bpe_merges.txt")
  71. tokenizer2 = BPETokenizerSimple()
  72. tokenizer2.load_vocab_and_merges(vocab_path="vocab.json", bpe_merges_path="bpe_merges.txt")
  73. assert tokenizer2.decode(token_ids) == input_text, "Decoded text mismatch after reloading tokenizer."
  74. def test_gpt2_tokenizer_openai_simple(imported_module, gpt2_files):
  75. BPETokenizerSimple = getattr(imported_module, "BPETokenizerSimple", None)
  76. tokenizer_gpt2 = BPETokenizerSimple()
  77. tokenizer_gpt2.load_vocab_and_merges_from_openai(
  78. vocab_path=gpt2_files["encoder.json"], bpe_merges_path=gpt2_files["vocab.bpe"]
  79. )
  80. assert len(tokenizer_gpt2.vocab) == 50257, "GPT-2 tokenizer vocabulary size mismatch."
  81. input_text = "This is some text"
  82. token_ids = tokenizer_gpt2.encode(input_text)
  83. assert token_ids == [1212, 318, 617, 2420], "Tokenized output does not match expected GPT-2 encoding."
  84. def test_gpt2_tokenizer_openai_edgecases(imported_module, gpt2_files):
  85. BPETokenizerSimple = getattr(imported_module, "BPETokenizerSimple", None)
  86. tokenizer_gpt2 = BPETokenizerSimple()
  87. tokenizer_gpt2.load_vocab_and_merges_from_openai(
  88. vocab_path=gpt2_files["encoder.json"], bpe_merges_path=gpt2_files["vocab.bpe"]
  89. )
  90. tik_tokenizer = tiktoken.get_encoding("gpt2")
  91. test_cases = [
  92. ("Hello,", [15496, 11]),
  93. ("Implementations", [3546, 26908, 602]),
  94. ("asdf asdfasdf a!!, @aba 9asdf90asdfk", [292, 7568, 355, 7568, 292, 7568, 257, 3228, 11, 2488, 15498, 860, 292, 7568, 3829, 292, 7568, 74]),
  95. ("Hello, world. Is this-- a test?", [15496, 11, 995, 13, 1148, 428, 438, 257, 1332, 30])
  96. ]
  97. errors = []
  98. for input_text, expected_tokens in test_cases:
  99. tik_tokens = tik_tokenizer.encode(input_text)
  100. gpt2_tokens = tokenizer_gpt2.encode(input_text)
  101. print(f"Text: {input_text}")
  102. print(f"Expected Tokens: {expected_tokens}")
  103. print(f"tiktoken Output: {tik_tokens}")
  104. print(f"BPETokenizerSimple Output: {gpt2_tokens}")
  105. print("-" * 40)
  106. if tik_tokens != expected_tokens:
  107. errors.append(f"Tiktokenized output does not match expected GPT-2 encoding for '{input_text}'.\n"
  108. f"Expected: {expected_tokens}, Got: {tik_tokens}")
  109. if gpt2_tokens != expected_tokens:
  110. errors.append(f"Tokenized output does not match expected GPT-2 encoding for '{input_text}'.\n"
  111. f"Expected: {expected_tokens}, Got: {gpt2_tokens}")
  112. if errors:
  113. pytest.fail("\n".join(errors))