test_llama3.py 7.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233
  1. # Copyright (c) Sebastian Raschka under Apache License 2.0 (see LICENSE.txt).
  2. # Source for "Build a Large Language Model From Scratch"
  3. # - https://www.manning.com/books/build-a-large-language-model-from-scratch
  4. # Code: https://github.com/rasbt/LLMs-from-scratch
  5. from llms_from_scratch.ch04 import generate_text_simple
  6. from llms_from_scratch.llama3 import (
  7. compute_rope_params,
  8. apply_rope,
  9. LLAMA32_CONFIG_1B,
  10. GroupedQueryAttention,
  11. GroupedQueryAttentionFast,
  12. Llama3Model,
  13. )
  14. import importlib
  15. import pytest
  16. import tiktoken
  17. import torch
  18. class LitGPTRMSNorm(torch.nn.Module):
  19. """Root Mean Square Layer Normalization.
  20. From https://github.com/Lightning-AI/litgpt/blob/main/litgpt/model.py
  21. Apache License 2.0-Clause License: https://github.com/Lightning-AI/litgpt/blob/main/LICENSE
  22. Derived from https://github.com/bzhangGo/rmsnorm/blob/master/rmsnorm_torch.py. BSD 3-Clause License:
  23. https://github.com/bzhangGo/rmsnorm/blob/master/LICENSE.
  24. """
  25. def __init__(self, size: int, dim: int = -1, eps: float = 1e-6, add_unit_offset: bool = False) -> None:
  26. super().__init__()
  27. self.weight = torch.nn.Parameter(torch.ones(size))
  28. self.eps = eps
  29. self.dim = dim
  30. self.add_unit_offset = add_unit_offset
  31. def forward(self, x: torch.Tensor) -> torch.Tensor:
  32. dtype = x.dtype
  33. x = x.float()
  34. # NOTE: the original RMSNorm paper implementation is not equivalent
  35. norm_x = torch.mean(x * x, dim=self.dim, keepdim=True)
  36. x_normed = x * torch.rsqrt(norm_x + self.eps)
  37. weight = (1 + self.weight) if self.add_unit_offset else self.weight
  38. return (x_normed * weight.float()).to(dtype=dtype)
  39. def reset_parameters(self) -> None:
  40. torch.nn.init.ones_(self.weight)
  41. transformers_installed = importlib.util.find_spec("transformers") is not None
  42. @pytest.mark.skipif(not transformers_installed, reason="transformers not installed")
  43. def test_rope():
  44. from transformers.models.llama.modeling_llama import LlamaRotaryEmbedding, apply_rotary_pos_emb
  45. # Settings
  46. batch_size = 1
  47. context_len = 8192
  48. num_heads = 4
  49. head_dim = 16
  50. rope_theta = 500_000
  51. rope_config = {
  52. "factor": 8.0,
  53. "low_freq_factor": 1.0,
  54. "high_freq_factor": 4.0,
  55. "original_context_length": 8192,
  56. }
  57. # Instantiate RoPE parameters
  58. cos, sin = compute_rope_params(
  59. head_dim=head_dim,
  60. theta_base=rope_theta,
  61. context_length=context_len,
  62. freq_config=rope_config,
  63. )
  64. # Dummy query and key tensors
  65. torch.manual_seed(123)
  66. queries = torch.randn(batch_size, num_heads, context_len, head_dim)
  67. keys = torch.randn(batch_size, num_heads, context_len, head_dim)
  68. # Apply rotary position embeddings
  69. queries_rot = apply_rope(queries, cos, sin)
  70. keys_rot = apply_rope(keys, cos, sin)
  71. # Generate reference RoPE via HF
  72. hf_rope_params = {
  73. "factor": 8.0,
  74. "low_freq_factor": 1.0,
  75. "high_freq_factor": 4.0,
  76. "original_max_position_embeddings": 8192,
  77. "rope_type": "llama3"
  78. }
  79. class RoPEConfig:
  80. rope_type = "llama3"
  81. rope_scaling = hf_rope_params
  82. factor = 1.0
  83. dim: int = head_dim
  84. rope_theta = 500_000
  85. max_position_embeddings: int = 8192
  86. hidden_size = head_dim * num_heads
  87. num_attention_heads = num_heads
  88. config = RoPEConfig()
  89. rot_emb = LlamaRotaryEmbedding(config=config)
  90. position_ids = torch.arange(context_len, dtype=torch.long).unsqueeze(0)
  91. ref_cos, ref_sin = rot_emb(queries, position_ids)
  92. ref_queries_rot, ref_keys_rot = apply_rotary_pos_emb(queries, keys, ref_cos, ref_sin)
  93. torch.testing.assert_close(sin, ref_sin.squeeze(0))
  94. torch.testing.assert_close(cos, ref_cos.squeeze(0))
  95. torch.testing.assert_close(keys_rot, ref_keys_rot)
  96. torch.testing.assert_close(queries_rot, ref_queries_rot)
  97. GPT_CONFIG_124M = {
  98. "vocab_size": 50257, # Vocabulary size
  99. "context_length": 1024, # Context length
  100. "emb_dim": 768, # Embedding dimension
  101. "n_heads": 12, # Number of attention heads
  102. "n_layers": 12, # Number of layers
  103. "drop_rate": 0.1, # Dropout rate
  104. "qkv_bias": False # Query-Key-Value bias
  105. }
  106. def test_grouped_query_attention_equivalence():
  107. torch.manual_seed(42)
  108. b, t, d_in, d_out, num_heads, num_kv_groups = 2, 8, 32, 64, 4, 2
  109. x = torch.randn(b, t, d_in)
  110. cos, sin = compute_rope_params(
  111. head_dim=d_out // num_heads,
  112. theta_base=50_000,
  113. context_length=t,
  114. freq_config={
  115. "factor": 32.0,
  116. "low_freq_factor": 1.0,
  117. "high_freq_factor": 4.0,
  118. "original_context_length": t,
  119. }
  120. )
  121. # Causal mask for the slow version
  122. mask = torch.triu(torch.ones(t, t, dtype=torch.bool), diagonal=1)
  123. attn1 = GroupedQueryAttention(d_in, d_out, num_heads, num_kv_groups)
  124. attn2 = GroupedQueryAttentionFast(d_in, d_out, num_heads, num_kv_groups)
  125. # Copy weights to make both models identical
  126. attn2.load_state_dict(attn1.state_dict())
  127. # Run both
  128. y1 = attn1(x, mask, cos, sin)
  129. y2 = attn2(x, cos, sin)
  130. # Compare outputs
  131. max_diff = (y1 - y2).abs().max().item()
  132. print(f"Max difference between slow and fast outputs: {max_diff:.4e}")
  133. assert torch.allclose(y1, y2, atol=1e-4)
  134. @pytest.fixture(scope="session")
  135. def llama3_weights_path(tmp_path_factory):
  136. """Creates and saves a deterministic Llama3 model for testing."""
  137. path = tmp_path_factory.mktemp("models") / "llama3_test_weights.pt"
  138. if not path.exists():
  139. torch.manual_seed(123)
  140. model = Llama3Model(LLAMA32_CONFIG_1B)
  141. torch.save(model.state_dict(), path)
  142. return path
  143. @pytest.mark.parametrize("ModelClass", [Llama3Model])
  144. def test_gpt_model_variants(ModelClass, llama3_weights_path):
  145. torch.manual_seed(123)
  146. model = ModelClass(LLAMA32_CONFIG_1B)
  147. model.load_state_dict(torch.load(llama3_weights_path))
  148. model.eval()
  149. start_context = "Llamas eat"
  150. tokenizer = tiktoken.get_encoding("gpt2")
  151. encoded = tokenizer.encode(start_context)
  152. encoded_tensor = torch.tensor(encoded).unsqueeze(0)
  153. print(f"\n{50*'='}\n{22*' '}IN\n{50*'='}")
  154. print("\nInput text:", start_context)
  155. print("Encoded input text:", encoded)
  156. print("encoded_tensor.shape:", encoded_tensor.shape)
  157. out = generate_text_simple(
  158. model=model,
  159. idx=encoded_tensor,
  160. max_new_tokens=5,
  161. context_size=LLAMA32_CONFIG_1B["context_length"]
  162. )
  163. print("Encoded output text:", out)
  164. expect = torch.tensor([
  165. [43, 2543, 292, 4483, 100383, 8113, 76873, 42175, 72641]
  166. ])
  167. assert torch.equal(expect, out)
  168. def test_rmsnorm_equivalence():
  169. torch.manual_seed(42)
  170. hidden_size = 64
  171. batch_size = 8
  172. seq_len = 16
  173. rms_norm = torch.nn.RMSNorm(hidden_size, eps=1e-6)
  174. lit_norm = LitGPTRMSNorm(hidden_size)
  175. # Sync weights
  176. with torch.no_grad():
  177. lit_norm.weight.copy_(lit_norm.weight)
  178. x = torch.randn(batch_size, seq_len, hidden_size)
  179. out1 = rms_norm(x)
  180. out2 = lit_norm(x)
  181. torch.testing.assert_close(out1, out2, atol=1e-5, rtol=1e-5)