tests.py 3.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101
  1. # Code to test the GPT model implementation against the KV cache variants
  2. import pytest
  3. import torch
  4. import tiktoken
  5. from gpt_ch04 import GPTModel as GPTModelBase
  6. from gpt_ch04 import generate_text_simple
  7. from gpt_with_kv_cache import GPTModel as GPTModelKV1
  8. from gpt_with_kv_cache_optimized import GPTModel as GPTModelKV2
  9. from gpt_with_kv_cache import generate_text_simple_cached
  10. GPT_CONFIG_124M = {
  11. "vocab_size": 50257,
  12. "context_length": 1024,
  13. "emb_dim": 768,
  14. "n_heads": 12,
  15. "n_layers": 12,
  16. "drop_rate": 0.1,
  17. "qkv_bias": False,
  18. }
  19. device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
  20. @pytest.mark.parametrize("ModelClass", [GPTModelBase, GPTModelKV1, GPTModelKV2])
  21. def test_gpt_model_equivalence_not_cached(ModelClass):
  22. torch.manual_seed(123)
  23. model = ModelClass(GPT_CONFIG_124M).to(device)
  24. model.eval()
  25. tokenizer = tiktoken.get_encoding("gpt2")
  26. prompt = "Hello, I am"
  27. encoded = tokenizer.encode(prompt)
  28. encoded_tensor = torch.tensor(encoded, device=device).unsqueeze(0)
  29. model_name = ModelClass.__module__ + "." + ModelClass.__name__
  30. token_ids = generate_text_simple(
  31. model=model,
  32. idx=encoded_tensor,
  33. max_new_tokens=30,
  34. context_size=GPT_CONFIG_124M["context_length"]
  35. )
  36. if not hasattr(test_gpt_model_equivalence_not_cached, "results"):
  37. test_gpt_model_equivalence_not_cached.results = []
  38. test_gpt_model_equivalence_not_cached.results.append((model_name, token_ids))
  39. if len(test_gpt_model_equivalence_not_cached.results) == 3:
  40. base_name, base_output = test_gpt_model_equivalence_not_cached.results[0]
  41. for other_name, other_output in test_gpt_model_equivalence_not_cached.results[1:]:
  42. assert torch.equal(base_output, other_output), (
  43. f"Mismatch between {base_name} and {other_name}"
  44. )
  45. @pytest.mark.parametrize("ModelClass", [GPTModelBase, GPTModelKV1, GPTModelKV2])
  46. def test_gpt_model_equivalence_cached(ModelClass):
  47. torch.manual_seed(123)
  48. model = ModelClass(GPT_CONFIG_124M).to(device)
  49. model.eval()
  50. tokenizer = tiktoken.get_encoding("gpt2")
  51. prompt = "Hello, I am"
  52. encoded_tensor = torch.tensor(tokenizer.encode(prompt), device=device).unsqueeze(0)
  53. model_name = ModelClass.__module__ + "." + ModelClass.__name__
  54. if ModelClass is GPTModelBase:
  55. token_ids = generate_text_simple(
  56. model=model,
  57. idx=encoded_tensor,
  58. max_new_tokens=30,
  59. context_size=GPT_CONFIG_124M["context_length"]
  60. )
  61. else:
  62. token_ids = generate_text_simple_cached(
  63. model=model,
  64. idx=encoded_tensor,
  65. max_new_tokens=30,
  66. context_size=GPT_CONFIG_124M["context_length"]
  67. )
  68. if not hasattr(test_gpt_model_equivalence_cached, "results"):
  69. test_gpt_model_equivalence_cached.results = []
  70. test_gpt_model_equivalence_cached.results.append((model_name, token_ids))
  71. if len(test_gpt_model_equivalence_cached.results) == 3:
  72. base_name, base_output = test_gpt_model_equivalence_cached.results[0]
  73. for other_name, other_output in test_gpt_model_equivalence_cached.results[1:]:
  74. assert torch.equal(base_output, other_output), (
  75. f"Mismatch between {base_name} and {other_name}"
  76. )