test_qwen3.py 9.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281
  1. # Copyright (c) Sebastian Raschka under Apache License 2.0 (see LICENSE.txt).
  2. # Source for "Build a Large Language Model From Scratch"
  3. # - https://www.manning.com/books/build-a-large-language-model-from-scratch
  4. # Code: https://github.com/rasbt/LLMs-from-scratch
  5. from llms_from_scratch.ch04 import generate_text_simple
  6. from llms_from_scratch.qwen3 import (
  7. compute_rope_params,
  8. apply_rope,
  9. QWEN_CONFIG_06_B,
  10. RMSNorm,
  11. Qwen3Model,
  12. Qwen3Tokenizer
  13. )
  14. from llms_from_scratch.kv_cache.qwen3 import Qwen3Model as Qwen3ModelKV
  15. from llms_from_scratch.kv_cache.generate import generate_text_simple as generate_text_simple_cached
  16. # from llms_from_scratch.kv_cache_batched.qwen3 import Qwen3Model as Qwen3ModelKVBatched
  17. # from llms_from_scratch.kv_cache_batched.generate import generate_text_simple as generate_text_simple_batched
  18. import importlib
  19. import pytest
  20. import torch
  21. import torch.nn as nn
  22. class Qwen3RMSNorm(nn.Module):
  23. # Source: https://github.com/huggingface/transformers/blob/main/src/transformers/models/qwen3/modeling_qwen3.py
  24. # License: Apache License, Version 2.0 (see file above)
  25. def __init__(self, hidden_size, eps=1e-6):
  26. """
  27. Qwen3RMSNorm is equivalent to T5LayerNorm
  28. """
  29. super().__init__()
  30. self.weight = nn.Parameter(torch.ones(hidden_size))
  31. self.variance_epsilon = eps
  32. def forward(self, hidden_states):
  33. input_dtype = hidden_states.dtype
  34. hidden_states = hidden_states.to(torch.float32)
  35. variance = hidden_states.pow(2).mean(-1, keepdim=True)
  36. hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
  37. print(input_dtype)
  38. return self.weight * hidden_states.to(input_dtype)
  39. def extra_repr(self):
  40. return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
  41. transformers_installed = importlib.util.find_spec("transformers") is not None
  42. @pytest.mark.skipif(not transformers_installed, reason="transformers not installed")
  43. def test_rope():
  44. from transformers.models.qwen3.modeling_qwen3 import Qwen3RotaryEmbedding, apply_rotary_pos_emb
  45. # Settings
  46. batch_size = 1
  47. context_len = 8192
  48. num_heads = 4
  49. head_dim = 16
  50. rope_theta = 1_000_000
  51. # Instantiate RoPE parameters
  52. cos, sin = compute_rope_params(
  53. head_dim=head_dim,
  54. theta_base=rope_theta,
  55. context_length=context_len,
  56. )
  57. # Dummy query and key tensors
  58. torch.manual_seed(123)
  59. queries = torch.randn(batch_size, num_heads, context_len, head_dim)
  60. keys = torch.randn(batch_size, num_heads, context_len, head_dim)
  61. # Apply rotary position embeddings
  62. queries_rot = apply_rope(queries, cos, sin)
  63. keys_rot = apply_rope(keys, cos, sin)
  64. # Generate reference RoPE via HF
  65. class RoPEConfig:
  66. rope_type = "qwen3"
  67. factor = 1.0
  68. dim: int = head_dim
  69. rope_theta = 1_000_000
  70. max_position_embeddings: int = 8192
  71. hidden_size = head_dim * num_heads
  72. num_attention_heads = num_heads
  73. config = RoPEConfig()
  74. rot_emb = Qwen3RotaryEmbedding(config=config)
  75. position_ids = torch.arange(context_len, dtype=torch.long).unsqueeze(0)
  76. ref_cos, ref_sin = rot_emb(queries, position_ids)
  77. ref_queries_rot, ref_keys_rot = apply_rotary_pos_emb(queries, keys, ref_cos, ref_sin)
  78. torch.testing.assert_close(sin, ref_sin.squeeze(0))
  79. torch.testing.assert_close(cos, ref_cos.squeeze(0))
  80. torch.testing.assert_close(keys_rot, ref_keys_rot)
  81. torch.testing.assert_close(queries_rot, ref_queries_rot)
  82. @pytest.fixture(scope="session")
  83. def qwen3_weights_path(tmp_path_factory):
  84. """Creates and saves a deterministic model for testing."""
  85. path = tmp_path_factory.mktemp("models") / "qwen3_test_weights.pt"
  86. if not path.exists():
  87. torch.manual_seed(123)
  88. model = Qwen3Model(QWEN_CONFIG_06_B)
  89. torch.save(model.state_dict(), path)
  90. return path
  91. @pytest.mark.parametrize("ModelClass", [Qwen3Model, Qwen3ModelKV])
  92. @pytest.mark.parametrize("generate_fn", [generate_text_simple])
  93. def test_model_variants(ModelClass, qwen3_weights_path, generate_fn):
  94. torch.manual_seed(123)
  95. model = ModelClass(QWEN_CONFIG_06_B)
  96. model.load_state_dict(torch.load(qwen3_weights_path))
  97. model.eval()
  98. tokenizer = Qwen3Tokenizer(
  99. tokenizer_file_path="tokenizer-base.json",
  100. repo_id="rasbt/qwen3-from-scratch",
  101. add_generation_prompt=False,
  102. add_thinking=False
  103. )
  104. prompt = "Give me a short introduction to large language models."
  105. input_token_ids = tokenizer.encode(prompt)
  106. input_token_ids = torch.tensor([input_token_ids])
  107. print(f"\n{50*'='}\n{22*' '}IN\n{50*'='}")
  108. print("\nInput text:", prompt)
  109. print("Encoded input text:", input_token_ids)
  110. print("encoded_tensor.shape:", input_token_ids.shape)
  111. out = generate_fn(
  112. model=model,
  113. idx=input_token_ids,
  114. max_new_tokens=5,
  115. context_size=QWEN_CONFIG_06_B["context_length"]
  116. )
  117. print("Encoded output text:", out)
  118. expect = torch.tensor([
  119. [151644, 872, 198, 35127, 752, 264, 2805, 16800, 311,
  120. 3460, 4128, 4119, 13, 151645, 198, 112120, 83942, 60483,
  121. 102652, 7414]
  122. ])
  123. assert torch.equal(expect, out)
  124. def test_model_KV_noKV(qwen3_weights_path):
  125. torch.manual_seed(123)
  126. model_KV = Qwen3ModelKV(QWEN_CONFIG_06_B)
  127. model_KV.load_state_dict(torch.load(qwen3_weights_path))
  128. model_KV.eval()
  129. tokenizer = Qwen3Tokenizer(
  130. tokenizer_file_path="tokenizer-base.json",
  131. repo_id="rasbt/qwen3-from-scratch",
  132. add_generation_prompt=False,
  133. add_thinking=False
  134. )
  135. prompt = "Give me a short introduction to large language models."
  136. input_token_ids = tokenizer.encode(prompt)
  137. input_token_ids = torch.tensor([input_token_ids])
  138. out_noKV = generate_text_simple_cached(
  139. model=model_KV,
  140. idx=input_token_ids,
  141. max_new_tokens=5,
  142. context_size=QWEN_CONFIG_06_B["context_length"]
  143. )
  144. del model_KV
  145. torch.manual_seed(123)
  146. model_noKV = Qwen3Model(QWEN_CONFIG_06_B)
  147. model_noKV.load_state_dict(torch.load(qwen3_weights_path))
  148. model_noKV.eval()
  149. out_KV = generate_text_simple(
  150. model=model_noKV,
  151. idx=input_token_ids,
  152. max_new_tokens=5,
  153. context_size=QWEN_CONFIG_06_B["context_length"]
  154. )
  155. assert torch.equal(out_noKV, out_KV)
  156. def test_rmsnorm_equivalence():
  157. torch.manual_seed(42)
  158. hidden_size = 64
  159. batch_size = 8
  160. seq_len = 16
  161. rms_norm = RMSNorm(hidden_size)
  162. ref_norm = Qwen3RMSNorm(hidden_size)
  163. # Sync weights
  164. with torch.no_grad():
  165. ref_norm.weight.copy_(ref_norm.weight)
  166. x = torch.randn(batch_size, seq_len, hidden_size)
  167. out1 = rms_norm(x)
  168. out2 = ref_norm(x)
  169. torch.testing.assert_close(out1, out2, atol=1e-5, rtol=1e-5)
  170. @pytest.mark.skipif(not transformers_installed, reason="transformers not installed")
  171. def test_tokenizer_equivalence():
  172. from transformers import AutoTokenizer
  173. prompt = "Give me a short introduction to large language models."
  174. messages = [
  175. {"role": "user", "content": prompt},
  176. ]
  177. # Reasoning model tokenizer
  178. repo_id = "Qwen/Qwen3-0.6B"
  179. tokenizer_ref = AutoTokenizer.from_pretrained(repo_id)
  180. for states in ((True, True), (False, False)):
  181. tokenizer = Qwen3Tokenizer(
  182. tokenizer_file_path="Qwen3-0.6B/tokenizer.json",
  183. repo_id=repo_id,
  184. add_generation_prompt=states[0],
  185. add_thinking=states[1]
  186. )
  187. input_token_ids = tokenizer.encode(prompt)
  188. input_token_ids_ref = tokenizer_ref.apply_chat_template(
  189. messages,
  190. tokenize=True,
  191. add_generation_prompt=states[0],
  192. enable_thinking=states[1],
  193. )
  194. assert input_token_ids == input_token_ids_ref, states
  195. output_text = tokenizer.decode(input_token_ids)
  196. out_text_ref = tokenizer_ref.decode(input_token_ids_ref)
  197. assert output_text == out_text_ref, states
  198. assert tokenizer_ref.eos_token_id == tokenizer.eos_token_id
  199. assert tokenizer_ref.pad_token_id == tokenizer.pad_token_id
  200. # Base model tokenizer
  201. repo_id = "Qwen/Qwen3-0.6B-Base"
  202. tokenizer_ref = AutoTokenizer.from_pretrained(repo_id)
  203. for states in ((True, True), (False, False)):
  204. tokenizer = Qwen3Tokenizer(
  205. tokenizer_file_path="Qwen3-0.6B-Base/tokenizer.json",
  206. repo_id=repo_id,
  207. add_generation_prompt=states[0],
  208. add_thinking=states[1]
  209. )
  210. input_token_ids = tokenizer.encode(prompt)
  211. input_token_ids_ref = tokenizer_ref.apply_chat_template(
  212. messages,
  213. tokenize=True,
  214. add_generation_prompt=states[0],
  215. enable_thinking=states[1],
  216. )
  217. assert input_token_ids == input_token_ids_ref, states
  218. output_text = tokenizer.decode(input_token_ids)
  219. out_text_ref = tokenizer_ref.decode(input_token_ids_ref)
  220. assert output_text == out_text_ref, states
  221. assert tokenizer_ref.eos_token_id == tokenizer.eos_token_id
  222. assert tokenizer_ref.pad_token_id == tokenizer.pad_token_id