|
@@ -5,12 +5,13 @@
|
|
|
|
|
|
|
|
from llms_from_scratch.ch04 import generate_text_simple
|
|
from llms_from_scratch.ch04 import generate_text_simple
|
|
|
from llms_from_scratch.qwen3 import (
|
|
from llms_from_scratch.qwen3 import (
|
|
|
- compute_rope_params,
|
|
|
|
|
apply_rope,
|
|
apply_rope,
|
|
|
|
|
+ compute_rope_params,
|
|
|
|
|
+ load_weights_into_qwen,
|
|
|
QWEN_CONFIG_06_B,
|
|
QWEN_CONFIG_06_B,
|
|
|
- RMSNorm,
|
|
|
|
|
Qwen3Model,
|
|
Qwen3Model,
|
|
|
- Qwen3Tokenizer
|
|
|
|
|
|
|
+ Qwen3Tokenizer,
|
|
|
|
|
+ RMSNorm,
|
|
|
)
|
|
)
|
|
|
from llms_from_scratch.kv_cache.qwen3 import Qwen3Model as Qwen3ModelKV
|
|
from llms_from_scratch.kv_cache.qwen3 import Qwen3Model as Qwen3ModelKV
|
|
|
from llms_from_scratch.kv_cache.utils import KVCache
|
|
from llms_from_scratch.kv_cache.utils import KVCache
|
|
@@ -87,6 +88,7 @@ def dummy_cfg_moe(dummy_cfg_base):
|
|
|
return cfg
|
|
return cfg
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
+@torch.inference_mode()
|
|
|
def test_dummy_qwen3_forward(dummy_cfg_base, dummy_input):
|
|
def test_dummy_qwen3_forward(dummy_cfg_base, dummy_input):
|
|
|
torch.manual_seed(123)
|
|
torch.manual_seed(123)
|
|
|
model = Qwen3Model(dummy_cfg_base)
|
|
model = Qwen3Model(dummy_cfg_base)
|
|
@@ -95,6 +97,7 @@ def test_dummy_qwen3_forward(dummy_cfg_base, dummy_input):
|
|
|
f"Expected shape (1, seq_len, vocab_size), got {out.shape}"
|
|
f"Expected shape (1, seq_len, vocab_size), got {out.shape}"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
+@torch.inference_mode()
|
|
|
def test_dummy_qwen3_moe_forward(dummy_cfg_moe, dummy_input):
|
|
def test_dummy_qwen3_moe_forward(dummy_cfg_moe, dummy_input):
|
|
|
torch.manual_seed(123)
|
|
torch.manual_seed(123)
|
|
|
model = Qwen3Model(dummy_cfg_moe)
|
|
model = Qwen3Model(dummy_cfg_moe)
|
|
@@ -105,6 +108,7 @@ def test_dummy_qwen3_moe_forward(dummy_cfg_moe, dummy_input):
|
|
|
"Expected MoEFeedForward in at least one transformer block"
|
|
"Expected MoEFeedForward in at least one transformer block"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
+@torch.inference_mode()
|
|
|
@pytest.mark.parametrize("cfg_name", ["dummy_cfg_base", "dummy_cfg_moe"])
|
|
@pytest.mark.parametrize("cfg_name", ["dummy_cfg_base", "dummy_cfg_moe"])
|
|
|
def test_qwen3_kvcache_equivalence(cfg_name, request):
|
|
def test_qwen3_kvcache_equivalence(cfg_name, request):
|
|
|
cfg = request.getfixturevalue(cfg_name)
|
|
cfg = request.getfixturevalue(cfg_name)
|
|
@@ -438,3 +442,51 @@ def test_tokenizer_equivalence():
|
|
|
expected_pad_token = "<|endoftext|>"
|
|
expected_pad_token = "<|endoftext|>"
|
|
|
assert tokenizer.decode([tokenizer.eos_token_id]) == expected_eos_token
|
|
assert tokenizer.decode([tokenizer.eos_token_id]) == expected_eos_token
|
|
|
assert tokenizer.decode([tokenizer.pad_token_id]) == expected_pad_token
|
|
assert tokenizer.decode([tokenizer.pad_token_id]) == expected_pad_token
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+@torch.inference_mode()
|
|
|
|
|
+@pytest.mark.skipif(not transformers_installed, reason="transformers not installed")
|
|
|
|
|
+def test_qwen3_base_equivalence_with_transformers():
|
|
|
|
|
+
|
|
|
|
|
+ from transformers.models.qwen3 import Qwen3Config, Qwen3ForCausalLM
|
|
|
|
|
+
|
|
|
|
|
+ # Tiny config so the test is fast
|
|
|
|
|
+ cfg = {
|
|
|
|
|
+ "vocab_size": 257,
|
|
|
|
|
+ "context_length": 8,
|
|
|
|
|
+ "emb_dim": 32,
|
|
|
|
|
+ "n_heads": 4,
|
|
|
|
|
+ "n_layers": 2,
|
|
|
|
|
+ "hidden_dim": 64,
|
|
|
|
|
+ "head_dim": 8,
|
|
|
|
|
+ "qk_norm": True,
|
|
|
|
|
+ "n_kv_groups": 2,
|
|
|
|
|
+ "rope_base": 1_000_000.0,
|
|
|
|
|
+ "dtype": torch.float32,
|
|
|
|
|
+ }
|
|
|
|
|
+ model = Qwen3Model(cfg)
|
|
|
|
|
+
|
|
|
|
|
+ hf_cfg = Qwen3Config(
|
|
|
|
|
+ vocab_size=cfg["vocab_size"],
|
|
|
|
|
+ max_position_embeddings=cfg["context_length"],
|
|
|
|
|
+ hidden_size=cfg["emb_dim"],
|
|
|
|
|
+ num_attention_heads=cfg["n_heads"],
|
|
|
|
|
+ num_hidden_layers=cfg["n_layers"],
|
|
|
|
|
+ intermediate_size=cfg["hidden_dim"],
|
|
|
|
|
+ head_dim=cfg["head_dim"],
|
|
|
|
|
+ num_key_value_heads=cfg["n_kv_groups"],
|
|
|
|
|
+ rope_theta=cfg["rope_base"],
|
|
|
|
|
+ tie_word_embeddings=False,
|
|
|
|
|
+ attn_implementation="eager",
|
|
|
|
|
+ torch_dtype=torch.float32,
|
|
|
|
|
+ )
|
|
|
|
|
+ hf_model = Qwen3ForCausalLM(hf_cfg)
|
|
|
|
|
+
|
|
|
|
|
+ hf_state = hf_model.state_dict()
|
|
|
|
|
+ param_config = {"n_layers": cfg["n_layers"], "hidden_dim": cfg["hidden_dim"]}
|
|
|
|
|
+ load_weights_into_qwen(model, param_config, hf_state)
|
|
|
|
|
+
|
|
|
|
|
+ x = torch.randint(0, cfg["vocab_size"], (2, cfg["context_length"]), dtype=torch.long)
|
|
|
|
|
+ ours_logits = model(x)
|
|
|
|
|
+ theirs_logits = hf_model(x).logits
|
|
|
|
|
+ torch.testing.assert_close(ours_logits, theirs_logits, rtol=1e-5, atol=1e-5)
|