Просмотр исходного кода

Update Qwen3 tokenizer test (#727)

* Update Qwen3 tokenizer test

* add tokenizers to dev dependencies

* add tokenizers to dev dependencies
Sebastian Raschka 4 месяцев назад
Родитель
Сommit
9cf64170ed
3 измененных файлов с 20 добавлено и 12 удалено
  1. 1 0
      .github/workflows/check-links.yml
  2. 18 12
      pkg/llms_from_scratch/tests/test_qwen3.py
  3. 1 0
      pyproject.toml

+ 1 - 0
.github/workflows/check-links.yml

@@ -23,6 +23,7 @@ jobs:
     - name: Install dependencies
       run: |
         curl -LsSf https://astral.sh/uv/install.sh | sh
+        uv sync --dev
         uv add pytest-ruff pytest-check-links
 
     - name: Check links

+ 18 - 12
pkg/llms_from_scratch/tests/test_qwen3.py

@@ -18,7 +18,6 @@ from llms_from_scratch.kv_cache.generate import generate_text_simple as generate
 
 import importlib
 import pytest
-import tiktoken
 import torch
 import torch.nn as nn
 
@@ -102,8 +101,8 @@ def test_rope():
 
 @pytest.fixture(scope="session")
 def qwen3_weights_path(tmp_path_factory):
-    """Creates and saves a deterministic Llama3 model for testing."""
-    path = tmp_path_factory.mktemp("models") / "llama3_test_weights.pt"
+    """Creates and saves a deterministic model for testing."""
+    path = tmp_path_factory.mktemp("models") / "qwen3_test_weights.pt"
 
     if not path.exists():
         torch.manual_seed(123)
@@ -122,26 +121,33 @@ def test_model_variants(ModelClass, qwen3_weights_path, generate_fn):
     model.load_state_dict(torch.load(qwen3_weights_path))
     model.eval()
 
-    start_context = "Llamas eat"
+    tokenizer = Qwen3Tokenizer(
+        tokenizer_file_path="tokenizer-base.json",
+        repo_id="rasbt/qwen3-from-scratch",
+        add_generation_prompt=False,
+        add_thinking=False
+    )
 
-    tokenizer = tiktoken.get_encoding("gpt2")
-    encoded = tokenizer.encode(start_context)
-    encoded_tensor = torch.tensor(encoded).unsqueeze(0)
+    prompt = "Give me a short introduction to large language models."
+    input_token_ids = tokenizer.encode(prompt)
+    input_token_ids = torch.tensor([input_token_ids])
 
     print(f"\n{50*'='}\n{22*' '}IN\n{50*'='}")
-    print("\nInput text:", start_context)
-    print("Encoded input text:", encoded)
-    print("encoded_tensor.shape:", encoded_tensor.shape)
+    print("\nInput text:", prompt)
+    print("Encoded input text:", input_token_ids)
+    print("encoded_tensor.shape:", input_token_ids.shape)
 
     out = generate_text_simple(
         model=model,
-        idx=encoded_tensor,
+        idx=input_token_ids,
         max_new_tokens=5,
         context_size=QWEN_CONFIG_06_B["context_length"]
     )
     print("Encoded output text:", out)
     expect = torch.tensor([
-        [43, 2543, 292, 4483, 115206, 459, 43010, 104223, 55553]
+        [151644, 872, 198, 35127, 752, 264, 2805, 16800, 311,
+         3460, 4128,  4119, 13, 151645, 198, 112120, 83942, 60483,
+         102652, 7414]
     ])
     assert torch.equal(expect, out)
 

+ 1 - 0
pyproject.toml

@@ -29,6 +29,7 @@ dev = [
     "build>=1.2.2.post1",
     "llms-from-scratch",
     "twine>=6.1.0",
+    "tokenizers>=0.21.1",
 ]
 
 [tool.ruff]