| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205 |
- # Copyright (c) Sebastian Raschka under Apache License 2.0 (see LICENSE.txt).
- # Source for "Build a Large Language Model From Scratch"
- # - https://www.manning.com/books/build-a-large-language-model-from-scratch
- # Code: https://github.com/rasbt/LLMs-from-scratch
- from llms_from_scratch.ch04 import generate_text_simple
- from llms_from_scratch.qwen3 import (
- compute_rope_params,
- apply_rope,
- QWEN_CONFIG_06_B,
- RMSNorm,
- Qwen3Model,
- Qwen3Tokenizer
- )
- from llms_from_scratch.kv_cache.qwen3 import Qwen3Model as Qwen3ModelKV
- from llms_from_scratch.kv_cache.generate import generate_text_simple as generate_text_simple_cached
- import importlib
- import pytest
- import torch
- import torch.nn as nn
- class Qwen3RMSNorm(nn.Module):
- # Source: https://github.com/huggingface/transformers/blob/main/src/transformers/models/qwen3/modeling_qwen3.py
- # License: Apache License, Version 2.0 (see file above)
- def __init__(self, hidden_size, eps=1e-6):
- """
- Qwen3RMSNorm is equivalent to T5LayerNorm
- """
- super().__init__()
- self.weight = nn.Parameter(torch.ones(hidden_size))
- self.variance_epsilon = eps
- def forward(self, hidden_states):
- input_dtype = hidden_states.dtype
- hidden_states = hidden_states.to(torch.float32)
- variance = hidden_states.pow(2).mean(-1, keepdim=True)
- hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
- print(input_dtype)
- return self.weight * hidden_states.to(input_dtype)
- def extra_repr(self):
- return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
- transformers_installed = importlib.util.find_spec("transformers") is not None
- @pytest.mark.skipif(not transformers_installed, reason="transformers not installed")
- def test_rope():
- from transformers.models.qwen3.modeling_qwen3 import Qwen3RotaryEmbedding, apply_rotary_pos_emb
- # Settings
- batch_size = 1
- context_len = 8192
- num_heads = 4
- head_dim = 16
- rope_theta = 1_000_000
- # Instantiate RoPE parameters
- cos, sin = compute_rope_params(
- head_dim=head_dim,
- theta_base=rope_theta,
- context_length=context_len,
- )
- # Dummy query and key tensors
- torch.manual_seed(123)
- queries = torch.randn(batch_size, num_heads, context_len, head_dim)
- keys = torch.randn(batch_size, num_heads, context_len, head_dim)
- # Apply rotary position embeddings
- queries_rot = apply_rope(queries, cos, sin)
- keys_rot = apply_rope(keys, cos, sin)
- # Generate reference RoPE via HF
- class RoPEConfig:
- rope_type = "qwen3"
- factor = 1.0
- dim: int = head_dim
- rope_theta = 1_000_000
- max_position_embeddings: int = 8192
- hidden_size = head_dim * num_heads
- num_attention_heads = num_heads
- config = RoPEConfig()
- rot_emb = Qwen3RotaryEmbedding(config=config)
- position_ids = torch.arange(context_len, dtype=torch.long).unsqueeze(0)
- ref_cos, ref_sin = rot_emb(queries, position_ids)
- ref_queries_rot, ref_keys_rot = apply_rotary_pos_emb(queries, keys, ref_cos, ref_sin)
- torch.testing.assert_close(sin, ref_sin.squeeze(0))
- torch.testing.assert_close(cos, ref_cos.squeeze(0))
- torch.testing.assert_close(keys_rot, ref_keys_rot)
- torch.testing.assert_close(queries_rot, ref_queries_rot)
- @pytest.fixture(scope="session")
- def qwen3_weights_path(tmp_path_factory):
- """Creates and saves a deterministic model for testing."""
- path = tmp_path_factory.mktemp("models") / "qwen3_test_weights.pt"
- if not path.exists():
- torch.manual_seed(123)
- model = Qwen3Model(QWEN_CONFIG_06_B)
- torch.save(model.state_dict(), path)
- return path
- @pytest.mark.parametrize("ModelClass", [Qwen3Model, Qwen3ModelKV])
- @pytest.mark.parametrize("generate_fn", [generate_text_simple, generate_text_simple_cached])
- def test_model_variants(ModelClass, qwen3_weights_path, generate_fn):
- torch.manual_seed(123)
- model = ModelClass(QWEN_CONFIG_06_B)
- model.load_state_dict(torch.load(qwen3_weights_path))
- model.eval()
- tokenizer = Qwen3Tokenizer(
- tokenizer_file_path="tokenizer-base.json",
- repo_id="rasbt/qwen3-from-scratch",
- add_generation_prompt=False,
- add_thinking=False
- )
- prompt = "Give me a short introduction to large language models."
- input_token_ids = tokenizer.encode(prompt)
- input_token_ids = torch.tensor([input_token_ids])
- print(f"\n{50*'='}\n{22*' '}IN\n{50*'='}")
- print("\nInput text:", prompt)
- print("Encoded input text:", input_token_ids)
- print("encoded_tensor.shape:", input_token_ids.shape)
- out = generate_text_simple(
- model=model,
- idx=input_token_ids,
- max_new_tokens=5,
- context_size=QWEN_CONFIG_06_B["context_length"]
- )
- print("Encoded output text:", out)
- expect = torch.tensor([
- [151644, 872, 198, 35127, 752, 264, 2805, 16800, 311,
- 3460, 4128, 4119, 13, 151645, 198, 112120, 83942, 60483,
- 102652, 7414]
- ])
- assert torch.equal(expect, out)
- def test_rmsnorm_equivalence():
- torch.manual_seed(42)
- hidden_size = 64
- batch_size = 8
- seq_len = 16
- rms_norm = RMSNorm(hidden_size)
- ref_norm = Qwen3RMSNorm(hidden_size)
- # Sync weights
- with torch.no_grad():
- ref_norm.weight.copy_(ref_norm.weight)
- x = torch.randn(batch_size, seq_len, hidden_size)
- out1 = rms_norm(x)
- out2 = ref_norm(x)
- torch.testing.assert_close(out1, out2, atol=1e-5, rtol=1e-5)
- @pytest.mark.skipif(not transformers_installed, reason="transformers not installed")
- def test_tokenizer_equivalence():
- from transformers import AutoTokenizer
- repo_id = "Qwen/Qwen3-0.6B"
- tokenizer_ref = AutoTokenizer.from_pretrained(repo_id)
- prompt = "Give me a short introduction to large language models."
- messages = [
- {"role": "user", "content": prompt},
- ]
- for states in ((True, True), (False, False)):
- tokenizer = Qwen3Tokenizer(
- tokenizer_file_path="Qwen3-0.6B/tokenizer.json",
- repo_id=repo_id,
- add_generation_prompt=states[0],
- add_thinking=states[1]
- )
- input_token_ids = tokenizer.encode(prompt)
- input_token_ids_ref = tokenizer_ref.apply_chat_template(
- messages,
- tokenize=True,
- add_generation_prompt=states[0],
- enable_thinking=states[1],
- )
- assert input_token_ids == input_token_ids_ref, states
- output_text = tokenizer.decode(input_token_ids)
- out_text_ref = tokenizer_ref.decode(input_token_ids_ref)
- assert output_text == out_text_ref, states
|