|
@@ -15,6 +15,8 @@ from llms_from_scratch.qwen3 import (
|
|
|
from llms_from_scratch.kv_cache.qwen3 import Qwen3Model as Qwen3ModelKV
|
|
from llms_from_scratch.kv_cache.qwen3 import Qwen3Model as Qwen3ModelKV
|
|
|
from llms_from_scratch.kv_cache.generate import generate_text_simple as generate_text_simple_cached
|
|
from llms_from_scratch.kv_cache.generate import generate_text_simple as generate_text_simple_cached
|
|
|
|
|
|
|
|
|
|
+# from llms_from_scratch.kv_cache_batched.qwen3 import Qwen3Model as Qwen3ModelKVBatched
|
|
|
|
|
+# from llms_from_scratch.kv_cache_batched.generate import generate_text_simple as generate_text_simple_batched
|
|
|
|
|
|
|
|
import importlib
|
|
import importlib
|
|
|
import pytest
|
|
import pytest
|
|
@@ -113,7 +115,7 @@ def qwen3_weights_path(tmp_path_factory):
|
|
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.parametrize("ModelClass", [Qwen3Model, Qwen3ModelKV])
|
|
@pytest.mark.parametrize("ModelClass", [Qwen3Model, Qwen3ModelKV])
|
|
|
-@pytest.mark.parametrize("generate_fn", [generate_text_simple, generate_text_simple_cached])
|
|
|
|
|
|
|
+@pytest.mark.parametrize("generate_fn", [generate_text_simple])
|
|
|
def test_model_variants(ModelClass, qwen3_weights_path, generate_fn):
|
|
def test_model_variants(ModelClass, qwen3_weights_path, generate_fn):
|
|
|
|
|
|
|
|
torch.manual_seed(123)
|
|
torch.manual_seed(123)
|
|
@@ -137,7 +139,7 @@ def test_model_variants(ModelClass, qwen3_weights_path, generate_fn):
|
|
|
print("Encoded input text:", input_token_ids)
|
|
print("Encoded input text:", input_token_ids)
|
|
|
print("encoded_tensor.shape:", input_token_ids.shape)
|
|
print("encoded_tensor.shape:", input_token_ids.shape)
|
|
|
|
|
|
|
|
- out = generate_text_simple(
|
|
|
|
|
|
|
+ out = generate_fn(
|
|
|
model=model,
|
|
model=model,
|
|
|
idx=input_token_ids,
|
|
idx=input_token_ids,
|
|
|
max_new_tokens=5,
|
|
max_new_tokens=5,
|
|
@@ -152,6 +154,47 @@ def test_model_variants(ModelClass, qwen3_weights_path, generate_fn):
|
|
|
assert torch.equal(expect, out)
|
|
assert torch.equal(expect, out)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
+def test_model_KV_noKV(qwen3_weights_path):
|
|
|
|
|
+
|
|
|
|
|
+ torch.manual_seed(123)
|
|
|
|
|
+ model_KV = Qwen3ModelKV(QWEN_CONFIG_06_B)
|
|
|
|
|
+ model_KV.load_state_dict(torch.load(qwen3_weights_path))
|
|
|
|
|
+ model_KV.eval()
|
|
|
|
|
+
|
|
|
|
|
+ tokenizer = Qwen3Tokenizer(
|
|
|
|
|
+ tokenizer_file_path="tokenizer-base.json",
|
|
|
|
|
+ repo_id="rasbt/qwen3-from-scratch",
|
|
|
|
|
+ add_generation_prompt=False,
|
|
|
|
|
+ add_thinking=False
|
|
|
|
|
+ )
|
|
|
|
|
+
|
|
|
|
|
+ prompt = "Give me a short introduction to large language models."
|
|
|
|
|
+ input_token_ids = tokenizer.encode(prompt)
|
|
|
|
|
+ input_token_ids = torch.tensor([input_token_ids])
|
|
|
|
|
+
|
|
|
|
|
+ out_noKV = generate_text_simple_cached(
|
|
|
|
|
+ model=model_KV,
|
|
|
|
|
+ idx=input_token_ids,
|
|
|
|
|
+ max_new_tokens=5,
|
|
|
|
|
+ context_size=QWEN_CONFIG_06_B["context_length"]
|
|
|
|
|
+ )
|
|
|
|
|
+ del model_KV
|
|
|
|
|
+
|
|
|
|
|
+ torch.manual_seed(123)
|
|
|
|
|
+ model_noKV = Qwen3Model(QWEN_CONFIG_06_B)
|
|
|
|
|
+ model_noKV.load_state_dict(torch.load(qwen3_weights_path))
|
|
|
|
|
+ model_noKV.eval()
|
|
|
|
|
+
|
|
|
|
|
+ out_KV = generate_text_simple(
|
|
|
|
|
+ model=model_noKV,
|
|
|
|
|
+ idx=input_token_ids,
|
|
|
|
|
+ max_new_tokens=5,
|
|
|
|
|
+ context_size=QWEN_CONFIG_06_B["context_length"]
|
|
|
|
|
+ )
|
|
|
|
|
+
|
|
|
|
|
+ assert torch.equal(out_noKV, out_KV)
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
def test_rmsnorm_equivalence():
|
|
def test_rmsnorm_equivalence():
|
|
|
torch.manual_seed(42)
|
|
torch.manual_seed(42)
|
|
|
|
|
|
|
@@ -177,13 +220,16 @@ def test_rmsnorm_equivalence():
|
|
|
@pytest.mark.skipif(not transformers_installed, reason="transformers not installed")
|
|
@pytest.mark.skipif(not transformers_installed, reason="transformers not installed")
|
|
|
def test_tokenizer_equivalence():
|
|
def test_tokenizer_equivalence():
|
|
|
from transformers import AutoTokenizer
|
|
from transformers import AutoTokenizer
|
|
|
- repo_id = "Qwen/Qwen3-0.6B"
|
|
|
|
|
- tokenizer_ref = AutoTokenizer.from_pretrained(repo_id)
|
|
|
|
|
|
|
+
|
|
|
prompt = "Give me a short introduction to large language models."
|
|
prompt = "Give me a short introduction to large language models."
|
|
|
messages = [
|
|
messages = [
|
|
|
{"role": "user", "content": prompt},
|
|
{"role": "user", "content": prompt},
|
|
|
]
|
|
]
|
|
|
|
|
|
|
|
|
|
+ # Reasoning model tokenizer
|
|
|
|
|
+ repo_id = "Qwen/Qwen3-0.6B"
|
|
|
|
|
+ tokenizer_ref = AutoTokenizer.from_pretrained(repo_id)
|
|
|
|
|
+
|
|
|
for states in ((True, True), (False, False)):
|
|
for states in ((True, True), (False, False)):
|
|
|
tokenizer = Qwen3Tokenizer(
|
|
tokenizer = Qwen3Tokenizer(
|
|
|
tokenizer_file_path="Qwen3-0.6B/tokenizer.json",
|
|
tokenizer_file_path="Qwen3-0.6B/tokenizer.json",
|
|
@@ -203,3 +249,33 @@ def test_tokenizer_equivalence():
|
|
|
output_text = tokenizer.decode(input_token_ids)
|
|
output_text = tokenizer.decode(input_token_ids)
|
|
|
out_text_ref = tokenizer_ref.decode(input_token_ids_ref)
|
|
out_text_ref = tokenizer_ref.decode(input_token_ids_ref)
|
|
|
assert output_text == out_text_ref, states
|
|
assert output_text == out_text_ref, states
|
|
|
|
|
+
|
|
|
|
|
+ assert tokenizer_ref.eos_token_id == tokenizer.eos_token_id
|
|
|
|
|
+ assert tokenizer_ref.pad_token_id == tokenizer.pad_token_id
|
|
|
|
|
+
|
|
|
|
|
+ # Base model tokenizer
|
|
|
|
|
+ repo_id = "Qwen/Qwen3-0.6B-Base"
|
|
|
|
|
+ tokenizer_ref = AutoTokenizer.from_pretrained(repo_id)
|
|
|
|
|
+
|
|
|
|
|
+ for states in ((True, True), (False, False)):
|
|
|
|
|
+ tokenizer = Qwen3Tokenizer(
|
|
|
|
|
+ tokenizer_file_path="Qwen3-0.6B-Base/tokenizer.json",
|
|
|
|
|
+ repo_id=repo_id,
|
|
|
|
|
+ add_generation_prompt=states[0],
|
|
|
|
|
+ add_thinking=states[1]
|
|
|
|
|
+ )
|
|
|
|
|
+ input_token_ids = tokenizer.encode(prompt)
|
|
|
|
|
+ input_token_ids_ref = tokenizer_ref.apply_chat_template(
|
|
|
|
|
+ messages,
|
|
|
|
|
+ tokenize=True,
|
|
|
|
|
+ add_generation_prompt=states[0],
|
|
|
|
|
+ enable_thinking=states[1],
|
|
|
|
|
+ )
|
|
|
|
|
+ assert input_token_ids == input_token_ids_ref, states
|
|
|
|
|
+
|
|
|
|
|
+ output_text = tokenizer.decode(input_token_ids)
|
|
|
|
|
+ out_text_ref = tokenizer_ref.decode(input_token_ids_ref)
|
|
|
|
|
+ assert output_text == out_text_ref, states
|
|
|
|
|
+
|
|
|
|
|
+ assert tokenizer_ref.eos_token_id == tokenizer.eos_token_id
|
|
|
|
|
+ assert tokenizer_ref.pad_token_id == tokenizer.pad_token_id
|