浏览代码

Qwen3 and Llama3 equivalency teests with HF transformers (#768)

* Qwen3 and Llama3 equivalency teests with HF transformers

* update
Sebastian Raschka 3 月之前
父节点
当前提交
b14325e56d

+ 1 - 1
.github/workflows/basic-tests-pixi.yml

@@ -28,7 +28,7 @@ jobs:
     runs-on: ${{ matrix.os }}
     strategy:
       matrix:
-        os: [ubuntu-latest, macos-latest, windows-latest]
+        os: [ubuntu-latest, windows-latest]
 
     steps:
       - uses: actions/checkout@v4

+ 5 - 1
.gitignore

@@ -1,4 +1,3 @@
-
 # Configs and keys
 ch05/07_gpt_to_llama/config.json
 ch07/02_dataset-utilities/config.json
@@ -78,6 +77,11 @@ ch07/01_main-chapter-code/gpt2-medium355M-sft-standalone.pth
 ch07/01_main-chapter-code/Smalltestmodel-sft-standalone.pth
 ch07/01_main-chapter-code/gpt2/
 
+Qwen3-0.6B-Base/
+Qwen3-0.6B/
+tokenizer-base.json
+tokenizer.json
+
 # Datasets
 the-verdict.txt
 

+ 3 - 1
pkg/llms_from_scratch/README.md

@@ -132,7 +132,8 @@ For more information about KV caching, please see the [KV cache README](../../ch
 
 ```python
 from llms_from_scratch.llama3 import (
-    Llama3Model,
+		load_weights_into_llama,
+  	Llama3Model,
     Llama3ModelFast,
     Llama3Tokenizer,
     ChatFormat,
@@ -154,6 +155,7 @@ For more information about KV caching, please see the [KV cache README](../../ch
 
 ```python
 from llms_from_scratch.qwen3 import (
+    load_weights_into_qwen
     Qwen3Model,
     Qwen3Tokenizer,
 )

+ 74 - 0
pkg/llms_from_scratch/llama3.py

@@ -497,3 +497,77 @@ class Llama3ModelFast(nn.Module):
         x = self.final_norm(x)
         logits = self.out_head(x.to(self.cfg["dtype"]))
         return logits
+
+
+def assign(left, right, tensor_name="unknown"):
+    if left.shape != right.shape:
+        raise ValueError(f"Shape mismatch in tensor '{tensor_name}'. Left: {left.shape}, Right: {right.shape}")
+
+    if isinstance(right, torch.Tensor):
+        return torch.nn.Parameter(right.clone().detach())
+    else:
+        return torch.nn.Parameter(torch.tensor(right))
+
+
+def load_weights_into_llama(model, param_config, params):
+    model.tok_emb.weight = assign(model.tok_emb.weight, params["model.embed_tokens.weight"], "model.embed_tokens.weight")
+
+    for l in range(param_config["n_layers"]):
+
+        # Load attention weights
+        model.trf_blocks[l].att.W_query.weight = assign(
+            model.trf_blocks[l].att.W_query.weight,
+            params[f"model.layers.{l}.self_attn.q_proj.weight"],
+            f"model.layers.{l}.self_attn.q_proj.weight"
+        )
+        model.trf_blocks[l].att.W_key.weight = assign(
+            model.trf_blocks[l].att.W_key.weight,
+            params[f"model.layers.{l}.self_attn.k_proj.weight"],
+            f"model.layers.{l}.self_attn.k_proj.weight"
+        )
+        model.trf_blocks[l].att.W_value.weight = assign(
+            model.trf_blocks[l].att.W_value.weight,
+            params[f"model.layers.{l}.self_attn.v_proj.weight"],
+            f"model.layers.{l}.self_attn.v_proj.weight"
+        )
+        model.trf_blocks[l].att.out_proj.weight = assign(
+            model.trf_blocks[l].att.out_proj.weight,
+            params[f"model.layers.{l}.self_attn.o_proj.weight"],
+            f"model.layers.{l}.self_attn.o_proj.weight"
+        )
+        model.trf_blocks[l].norm1.weight = assign(
+            model.trf_blocks[l].norm1.weight,
+            params[f"model.layers.{l}.input_layernorm.weight"],
+            f"model.layers.{l}.input_layernorm.weight"
+        )
+
+        # Load FeedForward weights
+        model.trf_blocks[l].ff.fc1.weight = assign(
+            model.trf_blocks[l].ff.fc1.weight,
+            params[f"model.layers.{l}.mlp.gate_proj.weight"],
+            f"model.layers.{l}.mlp.gate_proj.weight"
+        )
+        model.trf_blocks[l].ff.fc2.weight = assign(
+            model.trf_blocks[l].ff.fc2.weight,
+            params[f"model.layers.{l}.mlp.up_proj.weight"],
+            f"model.layers.{l}.mlp.up_proj.weight"
+        )
+        model.trf_blocks[l].ff.fc3.weight = assign(
+            model.trf_blocks[l].ff.fc3.weight,
+            params[f"model.layers.{l}.mlp.down_proj.weight"],
+            f"model.layers.{l}.mlp.down_proj.weight"
+        )
+        model.trf_blocks[l].norm2.weight = assign(
+            model.trf_blocks[l].norm2.weight,
+            params[f"model.layers.{l}.post_attention_layernorm.weight"],
+            f"model.layers.{l}.post_attention_layernorm.weight"
+        )
+
+    # Load output layer weights
+    model.final_norm.weight = assign(model.final_norm.weight, params["model.norm.weight"], "model.norm.weight")
+
+    if "lm_head.weight" in params.keys():
+        model.out_head.weight = assign(model.out_head.weight, params["lm_head.weight"], "lm_head.weight")
+    else:
+        model.out_head.weight = assign(model.out_head.weight, params["model.embed_tokens.weight"], "model.embed_tokens.weight")
+        print("Model uses weight tying.")

+ 61 - 2
pkg/llms_from_scratch/tests/test_llama3.py

@@ -5,11 +5,12 @@
 
 from llms_from_scratch.ch04 import generate_text_simple
 from llms_from_scratch.llama3 import (
-    compute_rope_params,
     apply_rope,
-    LLAMA32_CONFIG_1B,
+    compute_rope_params,
     GroupedQueryAttention,
     GroupedQueryAttentionFast,
+    load_weights_into_llama,
+    LLAMA32_CONFIG_1B,
     Llama3Model,
 )
 from llms_from_scratch.kv_cache.llama3 import Llama3Model as Llama3ModelKV
@@ -246,3 +247,61 @@ def test_rmsnorm_equivalence():
     out2 = lit_norm(x)
 
     torch.testing.assert_close(out1, out2, atol=1e-5, rtol=1e-5)
+
+
+@torch.inference_mode()
+@pytest.mark.skipif(not transformers_installed, reason="transformers not installed")
+def test_llama3_base_equivalence_with_transformers():
+    from transformers.models.llama import LlamaConfig, LlamaForCausalLM
+    cfg = {
+        "vocab_size": 257,
+        "context_length": 8192,
+        "emb_dim": 32,
+        "n_heads": 4,
+        "n_layers": 2,
+        "hidden_dim": 64,
+        "n_kv_groups": 2,
+        "rope_base": 500_000.0,
+        "rope_freq": {
+            "factor": 32.0,
+            "low_freq_factor": 1.0,
+            "high_freq_factor": 4.0,
+            "original_context_length": 8192,
+        },
+        "dtype": torch.float32,
+    }
+
+    ours = Llama3Model(cfg)
+
+    hf_cfg = LlamaConfig(
+        vocab_size=cfg["vocab_size"],
+        hidden_size=cfg["emb_dim"],
+        num_attention_heads=cfg["n_heads"],
+        num_key_value_heads=cfg["n_kv_groups"],
+        num_hidden_layers=cfg["n_layers"],
+        intermediate_size=cfg["hidden_dim"],
+        max_position_embeddings=cfg["context_length"],
+        rms_norm_eps=1e-5,
+        attention_bias=False,
+        rope_theta=cfg["rope_base"],
+        tie_word_embeddings=False,
+        attn_implementation="eager",
+        torch_dtype=torch.float32,
+        rope_scaling={
+            "type": "llama3",
+            "factor": cfg["rope_freq"]["factor"],
+            "low_freq_factor": cfg["rope_freq"]["low_freq_factor"],
+            "high_freq_factor": cfg["rope_freq"]["high_freq_factor"],
+            "original_max_position_embeddings": cfg["rope_freq"]["original_context_length"],
+        },
+    )
+    theirs = LlamaForCausalLM(hf_cfg)
+
+    hf_state = theirs.state_dict()
+    load_weights_into_llama(ours, {"n_layers": cfg["n_layers"], "hidden_dim": cfg["hidden_dim"]}, hf_state)
+
+    x = torch.randint(0, cfg["vocab_size"], (2, 8), dtype=torch.long)
+    ours_logits = ours(x)
+    theirs_logits = theirs(x).logits.to(ours_logits.dtype)
+
+    torch.testing.assert_close(ours_logits, theirs_logits, rtol=1e-5, atol=1e-5)

+ 55 - 3
pkg/llms_from_scratch/tests/test_qwen3.py

@@ -5,12 +5,13 @@
 
 from llms_from_scratch.ch04 import generate_text_simple
 from llms_from_scratch.qwen3 import (
-    compute_rope_params,
     apply_rope,
+    compute_rope_params,
+    load_weights_into_qwen,
     QWEN_CONFIG_06_B,
-    RMSNorm,
     Qwen3Model,
-    Qwen3Tokenizer
+    Qwen3Tokenizer,
+    RMSNorm,
 )
 from llms_from_scratch.kv_cache.qwen3 import Qwen3Model as Qwen3ModelKV
 from llms_from_scratch.kv_cache.utils import KVCache
@@ -87,6 +88,7 @@ def dummy_cfg_moe(dummy_cfg_base):
     return cfg
 
 
+@torch.inference_mode()
 def test_dummy_qwen3_forward(dummy_cfg_base, dummy_input):
     torch.manual_seed(123)
     model = Qwen3Model(dummy_cfg_base)
@@ -95,6 +97,7 @@ def test_dummy_qwen3_forward(dummy_cfg_base, dummy_input):
         f"Expected shape (1, seq_len, vocab_size), got {out.shape}"
 
 
+@torch.inference_mode()
 def test_dummy_qwen3_moe_forward(dummy_cfg_moe, dummy_input):
     torch.manual_seed(123)
     model = Qwen3Model(dummy_cfg_moe)
@@ -105,6 +108,7 @@ def test_dummy_qwen3_moe_forward(dummy_cfg_moe, dummy_input):
         "Expected MoEFeedForward in at least one transformer block"
 
 
+@torch.inference_mode()
 @pytest.mark.parametrize("cfg_name", ["dummy_cfg_base", "dummy_cfg_moe"])
 def test_qwen3_kvcache_equivalence(cfg_name, request):
     cfg = request.getfixturevalue(cfg_name)
@@ -438,3 +442,51 @@ def test_tokenizer_equivalence():
         expected_pad_token = "<|endoftext|>"
         assert tokenizer.decode([tokenizer.eos_token_id]) == expected_eos_token
         assert tokenizer.decode([tokenizer.pad_token_id]) == expected_pad_token
+
+
+@torch.inference_mode()
+@pytest.mark.skipif(not transformers_installed, reason="transformers not installed")
+def test_qwen3_base_equivalence_with_transformers():
+
+    from transformers.models.qwen3 import Qwen3Config, Qwen3ForCausalLM
+
+    # Tiny config so the test is fast
+    cfg = {
+        "vocab_size": 257,
+        "context_length": 8,
+        "emb_dim": 32,
+        "n_heads": 4,
+        "n_layers": 2,
+        "hidden_dim": 64,
+        "head_dim": 8,
+        "qk_norm": True,
+        "n_kv_groups": 2,
+        "rope_base": 1_000_000.0,
+        "dtype": torch.float32,
+    }
+    model = Qwen3Model(cfg)
+
+    hf_cfg = Qwen3Config(
+        vocab_size=cfg["vocab_size"],
+        max_position_embeddings=cfg["context_length"],
+        hidden_size=cfg["emb_dim"],
+        num_attention_heads=cfg["n_heads"],
+        num_hidden_layers=cfg["n_layers"],
+        intermediate_size=cfg["hidden_dim"],
+        head_dim=cfg["head_dim"],
+        num_key_value_heads=cfg["n_kv_groups"],
+        rope_theta=cfg["rope_base"],
+        tie_word_embeddings=False,
+        attn_implementation="eager",
+        torch_dtype=torch.float32,
+    )
+    hf_model = Qwen3ForCausalLM(hf_cfg)
+
+    hf_state = hf_model.state_dict()
+    param_config = {"n_layers": cfg["n_layers"], "hidden_dim": cfg["hidden_dim"]}
+    load_weights_into_qwen(model, param_config, hf_state)
+
+    x = torch.randint(0, cfg["vocab_size"], (2, cfg["context_length"]), dtype=torch.long)
+    ours_logits = model(x)
+    theirs_logits = hf_model(x).logits
+    torch.testing.assert_close(ours_logits, theirs_logits, rtol=1e-5, atol=1e-5)