test_qwen3.py 6.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205
  1. # Copyright (c) Sebastian Raschka under Apache License 2.0 (see LICENSE.txt).
  2. # Source for "Build a Large Language Model From Scratch"
  3. # - https://www.manning.com/books/build-a-large-language-model-from-scratch
  4. # Code: https://github.com/rasbt/LLMs-from-scratch
  5. from llms_from_scratch.ch04 import generate_text_simple
  6. from llms_from_scratch.qwen3 import (
  7. compute_rope_params,
  8. apply_rope,
  9. QWEN_CONFIG_06_B,
  10. RMSNorm,
  11. Qwen3Model,
  12. Qwen3Tokenizer
  13. )
  14. from llms_from_scratch.kv_cache.qwen3 import Qwen3Model as Qwen3ModelKV
  15. from llms_from_scratch.kv_cache.generate import generate_text_simple as generate_text_simple_cached
  16. import importlib
  17. import pytest
  18. import torch
  19. import torch.nn as nn
  20. class Qwen3RMSNorm(nn.Module):
  21. # Source: https://github.com/huggingface/transformers/blob/main/src/transformers/models/qwen3/modeling_qwen3.py
  22. # License: Apache License, Version 2.0 (see file above)
  23. def __init__(self, hidden_size, eps=1e-6):
  24. """
  25. Qwen3RMSNorm is equivalent to T5LayerNorm
  26. """
  27. super().__init__()
  28. self.weight = nn.Parameter(torch.ones(hidden_size))
  29. self.variance_epsilon = eps
  30. def forward(self, hidden_states):
  31. input_dtype = hidden_states.dtype
  32. hidden_states = hidden_states.to(torch.float32)
  33. variance = hidden_states.pow(2).mean(-1, keepdim=True)
  34. hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
  35. print(input_dtype)
  36. return self.weight * hidden_states.to(input_dtype)
  37. def extra_repr(self):
  38. return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
  39. transformers_installed = importlib.util.find_spec("transformers") is not None
  40. @pytest.mark.skipif(not transformers_installed, reason="transformers not installed")
  41. def test_rope():
  42. from transformers.models.qwen3.modeling_qwen3 import Qwen3RotaryEmbedding, apply_rotary_pos_emb
  43. # Settings
  44. batch_size = 1
  45. context_len = 8192
  46. num_heads = 4
  47. head_dim = 16
  48. rope_theta = 1_000_000
  49. # Instantiate RoPE parameters
  50. cos, sin = compute_rope_params(
  51. head_dim=head_dim,
  52. theta_base=rope_theta,
  53. context_length=context_len,
  54. )
  55. # Dummy query and key tensors
  56. torch.manual_seed(123)
  57. queries = torch.randn(batch_size, num_heads, context_len, head_dim)
  58. keys = torch.randn(batch_size, num_heads, context_len, head_dim)
  59. # Apply rotary position embeddings
  60. queries_rot = apply_rope(queries, cos, sin)
  61. keys_rot = apply_rope(keys, cos, sin)
  62. # Generate reference RoPE via HF
  63. class RoPEConfig:
  64. rope_type = "qwen3"
  65. factor = 1.0
  66. dim: int = head_dim
  67. rope_theta = 1_000_000
  68. max_position_embeddings: int = 8192
  69. hidden_size = head_dim * num_heads
  70. num_attention_heads = num_heads
  71. config = RoPEConfig()
  72. rot_emb = Qwen3RotaryEmbedding(config=config)
  73. position_ids = torch.arange(context_len, dtype=torch.long).unsqueeze(0)
  74. ref_cos, ref_sin = rot_emb(queries, position_ids)
  75. ref_queries_rot, ref_keys_rot = apply_rotary_pos_emb(queries, keys, ref_cos, ref_sin)
  76. torch.testing.assert_close(sin, ref_sin.squeeze(0))
  77. torch.testing.assert_close(cos, ref_cos.squeeze(0))
  78. torch.testing.assert_close(keys_rot, ref_keys_rot)
  79. torch.testing.assert_close(queries_rot, ref_queries_rot)
  80. @pytest.fixture(scope="session")
  81. def qwen3_weights_path(tmp_path_factory):
  82. """Creates and saves a deterministic model for testing."""
  83. path = tmp_path_factory.mktemp("models") / "qwen3_test_weights.pt"
  84. if not path.exists():
  85. torch.manual_seed(123)
  86. model = Qwen3Model(QWEN_CONFIG_06_B)
  87. torch.save(model.state_dict(), path)
  88. return path
  89. @pytest.mark.parametrize("ModelClass", [Qwen3Model, Qwen3ModelKV])
  90. @pytest.mark.parametrize("generate_fn", [generate_text_simple, generate_text_simple_cached])
  91. def test_model_variants(ModelClass, qwen3_weights_path, generate_fn):
  92. torch.manual_seed(123)
  93. model = ModelClass(QWEN_CONFIG_06_B)
  94. model.load_state_dict(torch.load(qwen3_weights_path))
  95. model.eval()
  96. tokenizer = Qwen3Tokenizer(
  97. tokenizer_file_path="tokenizer-base.json",
  98. repo_id="rasbt/qwen3-from-scratch",
  99. add_generation_prompt=False,
  100. add_thinking=False
  101. )
  102. prompt = "Give me a short introduction to large language models."
  103. input_token_ids = tokenizer.encode(prompt)
  104. input_token_ids = torch.tensor([input_token_ids])
  105. print(f"\n{50*'='}\n{22*' '}IN\n{50*'='}")
  106. print("\nInput text:", prompt)
  107. print("Encoded input text:", input_token_ids)
  108. print("encoded_tensor.shape:", input_token_ids.shape)
  109. out = generate_text_simple(
  110. model=model,
  111. idx=input_token_ids,
  112. max_new_tokens=5,
  113. context_size=QWEN_CONFIG_06_B["context_length"]
  114. )
  115. print("Encoded output text:", out)
  116. expect = torch.tensor([
  117. [151644, 872, 198, 35127, 752, 264, 2805, 16800, 311,
  118. 3460, 4128, 4119, 13, 151645, 198, 112120, 83942, 60483,
  119. 102652, 7414]
  120. ])
  121. assert torch.equal(expect, out)
  122. def test_rmsnorm_equivalence():
  123. torch.manual_seed(42)
  124. hidden_size = 64
  125. batch_size = 8
  126. seq_len = 16
  127. rms_norm = RMSNorm(hidden_size)
  128. ref_norm = Qwen3RMSNorm(hidden_size)
  129. # Sync weights
  130. with torch.no_grad():
  131. ref_norm.weight.copy_(ref_norm.weight)
  132. x = torch.randn(batch_size, seq_len, hidden_size)
  133. out1 = rms_norm(x)
  134. out2 = ref_norm(x)
  135. torch.testing.assert_close(out1, out2, atol=1e-5, rtol=1e-5)
  136. @pytest.mark.skipif(not transformers_installed, reason="transformers not installed")
  137. def test_tokenizer_equivalence():
  138. from transformers import AutoTokenizer
  139. repo_id = "Qwen/Qwen3-0.6B"
  140. tokenizer_ref = AutoTokenizer.from_pretrained(repo_id)
  141. prompt = "Give me a short introduction to large language models."
  142. messages = [
  143. {"role": "user", "content": prompt},
  144. ]
  145. for states in ((True, True), (False, False)):
  146. tokenizer = Qwen3Tokenizer(
  147. tokenizer_file_path="Qwen3-0.6B/tokenizer.json",
  148. repo_id=repo_id,
  149. add_generation_prompt=states[0],
  150. add_thinking=states[1]
  151. )
  152. input_token_ids = tokenizer.encode(prompt)
  153. input_token_ids_ref = tokenizer_ref.apply_chat_template(
  154. messages,
  155. tokenize=True,
  156. add_generation_prompt=states[0],
  157. enable_thinking=states[1],
  158. )
  159. assert input_token_ids == input_token_ids_ref, states
  160. output_text = tokenizer.decode(input_token_ids)
  161. out_text_ref = tokenizer_ref.decode(input_token_ids_ref)
  162. assert output_text == out_text_ref, states