瀏覽代碼

Add more sophisticated Qwen3 tokenizer (#729)

Sebastian Raschka 4 月之前
父節點
當前提交
21c41721cc

+ 0 - 15
ch05/11_qwen3/standalone-qwen3.ipynb

@@ -487,21 +487,6 @@
     "        \"dtype\": torch.bfloat16,\n",
     "    } \n",
     "\n",
-    "elif CHOOSE_MODEL == \"8B\":\n",
-    "    QWEN3_CONFIG = {\n",
-    "        \"vocab_size\": 151_936,\n",
-    "        \"context_length\": 40_960,\n",
-    "        \"emb_dim\": 4096,                 # 60% larger than above\n",
-    "        \"n_heads\": 32,\n",
-    "        \"n_layers\": 36,                  # 26% larger than above\n",
-    "        \"hidden_dim\": 12288,\n",
-    "        \"head_dim\": 128,\n",
-    "        \"qk_norm\": True,\n",
-    "        \"n_kv_groups\": 8,\n",
-    "        \"rope_base\": 1_000_000.0,\n",
-    "        \"dtype\": torch.bfloat16,\n",
-    "    } \n",
-    "\n",
     "elif CHOOSE_MODEL == \"14B\":\n",
     "    QWEN3_CONFIG = {\n",
     "        \"vocab_size\": 151_936,\n",

+ 1 - 1
pkg/llms_from_scratch/llama3.py

@@ -64,7 +64,7 @@ class Llama3Model(nn.Module):
         self.final_norm = nn.RMSNorm(cfg["emb_dim"], eps=1e-5, dtype=cfg["dtype"])
         self.out_head = nn.Linear(cfg["emb_dim"], cfg["vocab_size"], bias=False, dtype=cfg["dtype"])
 
-        # Reusuable utilities
+        # Reusable utilities
         cos, sin = compute_rope_params(
             head_dim=cfg["emb_dim"] // cfg["n_heads"],
             theta_base=cfg["rope_base"],

+ 64 - 38
pkg/llms_from_scratch/qwen3.py

@@ -5,6 +5,7 @@
 
 import os
 import json
+import re
 import urllib.request
 from pathlib import Path
 
@@ -115,7 +116,7 @@ class Qwen3Model(nn.Module):
         self.final_norm = RMSNorm(cfg["emb_dim"])
         self.out_head = nn.Linear(cfg["emb_dim"], cfg["vocab_size"], bias=False, dtype=cfg["dtype"])
 
-        # Reusuable utilities
+        # Reusable utilities
         if cfg["head_dim"] is None:
             head_dim = cfg["emb_dim"] // cfg["n_heads"]
         else:
@@ -408,52 +409,77 @@ def load_weights_into_qwen(model, param_config, params):
     model.out_head.weight = assign(model.out_head.weight, params["model.embed_tokens.weight"], "model.embed_tokens.weight")
 
 
-class Qwen3Tokenizer():
-    def __init__(self, tokenizer_file_path="tokenizer.json",
-                 repo_id=None, apply_chat_template=True,
-                 add_generation_prompt=False, add_thinking=False):
+class Qwen3Tokenizer:
+    _SPECIALS = [
+        "<|endoftext|>",
+        "<|im_start|>", "<|im_end|>",
+        "<|object_ref_start|>", "<|object_ref_end|>",
+        "<|box_start|>", "<|box_end|>",
+        "<|quad_start|>", "<|quad_end|>",
+        "<|vision_start|>", "<|vision_end|>",
+        "<|vision_pad|>", "<|image_pad|>", "<|video_pad|>",
+    ]
+    _SPLIT_RE = re.compile(r"(<\|[^>]+?\|>)")
+
+    def __init__(self, tokenizer_file_path="tokenizer.json", repo_id=None,
+                 apply_chat_template=True, add_generation_prompt=False, add_thinking=False):
         from tokenizers import Tokenizer
-        self.tokenizer_file_path = tokenizer_file_path
+
         self.apply_chat_template = apply_chat_template
         self.add_generation_prompt = add_generation_prompt
         self.add_thinking = add_thinking
 
-        tokenizer_file_path_obj = Path(tokenizer_file_path)
-        if not tokenizer_file_path_obj.is_file() and repo_id is not None:
-            _ = download_from_huggingface(
+        tok_file = Path(tokenizer_file_path)
+        if not tok_file.is_file() and repo_id:
+            download_from_huggingface(
                 repo_id=repo_id,
-                filename=str(tokenizer_file_path_obj.name),
-                local_dir=str(tokenizer_file_path_obj.parent.name)
-            )
-        self.tokenizer = Tokenizer.from_file(tokenizer_file_path)
-
-    def encode(self, prompt):
-        if self.apply_chat_template:
-            messages = [{"role": "user", "content": prompt}]
-            formatted_prompt = self.format_qwen_chat(
-                messages,
-                add_generation_prompt=self.add_generation_prompt,
-                add_thinking=self.add_thinking
+                filename=tok_file.name,
+                local_dir=str(tok_file.parent),
             )
+        self._tok = Tokenizer.from_file(str(tok_file))
+        self._special_to_id = {t: self._tok.token_to_id(t) for t in self._SPECIALS}
+
+        self.pad_token_id = self._special_to_id.get("<|endoftext|>")
+        self.eos_token_id = self.pad_token_id
+
+        if repo_id and "Base" not in repo_id:
+            eos_token = "<|im_end|>"
         else:
-            formatted_prompt = prompt
-        return self.tokenizer.encode(formatted_prompt).ids
-
-    def decode(self, token_ids):
-        return self.tokenizer.decode(token_ids, skip_special_tokens=False)
-
-    @staticmethod
-    def format_qwen_chat(messages, add_generation_prompt=False, add_thinking=False):
-        prompt = ""
-        for msg in messages:
-            prompt += f"<|im_start|>{msg['role']}\n{msg['content']}<|im_end|>\n"
-        if add_generation_prompt:
-            prompt += "<|im_start|>assistant"
-            if add_thinking:
-                prompt += "\n"  # no <think> tags
+            eos_token = "<|endoftext|>"
+        if eos_token in self._special_to_id:
+            self.eos_token_id = self._special_to_id[eos_token]
+
+    def encode(self, text, chat_wrapped=None):
+        if chat_wrapped is None:
+            chat_wrapped = self.apply_chat_template
+
+        stripped = text.strip()
+        if stripped in self._special_to_id and "\n" not in stripped:
+            return [self._special_to_id[stripped]]
+
+        if chat_wrapped:
+            text = self._wrap_chat(text)
+
+        ids = []
+        for part in filter(None, self._SPLIT_RE.split(text)):
+            if part in self._special_to_id:
+                ids.append(self._special_to_id[part])
+            else:
+                ids.extend(self._tok.encode(part).ids)
+        return ids
+
+    def decode(self, ids):
+        return self._tok.decode(ids, skip_special_tokens=False)
+
+    def _wrap_chat(self, user_msg):
+        s = f"<|im_start|>user\n{user_msg}<|im_end|>\n"
+        if self.add_generation_prompt:
+            s += "<|im_start|>assistant"
+            if self.add_thinking:
+                s += "\n"
             else:
-                prompt += "\n<think>\n\n</think>\n\n"
-        return prompt
+                s += "\n<think>\n\n</think>\n\n"
+        return s
 
 
 def download_from_huggingface(repo_id, filename, local_dir, revision="main"):

+ 80 - 4
pkg/llms_from_scratch/tests/test_qwen3.py

@@ -15,6 +15,8 @@ from llms_from_scratch.qwen3 import (
 from llms_from_scratch.kv_cache.qwen3 import Qwen3Model as Qwen3ModelKV
 from llms_from_scratch.kv_cache.generate import generate_text_simple as generate_text_simple_cached
 
+# from llms_from_scratch.kv_cache_batched.qwen3 import Qwen3Model as Qwen3ModelKVBatched
+# from llms_from_scratch.kv_cache_batched.generate import generate_text_simple as generate_text_simple_batched
 
 import importlib
 import pytest
@@ -113,7 +115,7 @@ def qwen3_weights_path(tmp_path_factory):
 
 
 @pytest.mark.parametrize("ModelClass", [Qwen3Model, Qwen3ModelKV])
-@pytest.mark.parametrize("generate_fn", [generate_text_simple, generate_text_simple_cached])
+@pytest.mark.parametrize("generate_fn", [generate_text_simple])
 def test_model_variants(ModelClass, qwen3_weights_path, generate_fn):
 
     torch.manual_seed(123)
@@ -137,7 +139,7 @@ def test_model_variants(ModelClass, qwen3_weights_path, generate_fn):
     print("Encoded input text:", input_token_ids)
     print("encoded_tensor.shape:", input_token_ids.shape)
 
-    out = generate_text_simple(
+    out = generate_fn(
         model=model,
         idx=input_token_ids,
         max_new_tokens=5,
@@ -152,6 +154,47 @@ def test_model_variants(ModelClass, qwen3_weights_path, generate_fn):
     assert torch.equal(expect, out)
 
 
+def test_model_KV_noKV(qwen3_weights_path):
+
+    torch.manual_seed(123)
+    model_KV = Qwen3ModelKV(QWEN_CONFIG_06_B)
+    model_KV.load_state_dict(torch.load(qwen3_weights_path))
+    model_KV.eval()
+
+    tokenizer = Qwen3Tokenizer(
+        tokenizer_file_path="tokenizer-base.json",
+        repo_id="rasbt/qwen3-from-scratch",
+        add_generation_prompt=False,
+        add_thinking=False
+    )
+
+    prompt = "Give me a short introduction to large language models."
+    input_token_ids = tokenizer.encode(prompt)
+    input_token_ids = torch.tensor([input_token_ids])
+
+    out_noKV = generate_text_simple_cached(
+        model=model_KV,
+        idx=input_token_ids,
+        max_new_tokens=5,
+        context_size=QWEN_CONFIG_06_B["context_length"]
+    )
+    del model_KV
+
+    torch.manual_seed(123)
+    model_noKV = Qwen3Model(QWEN_CONFIG_06_B)
+    model_noKV.load_state_dict(torch.load(qwen3_weights_path))
+    model_noKV.eval()
+
+    out_KV = generate_text_simple(
+        model=model_noKV,
+        idx=input_token_ids,
+        max_new_tokens=5,
+        context_size=QWEN_CONFIG_06_B["context_length"]
+    )
+
+    assert torch.equal(out_noKV, out_KV)
+
+
 def test_rmsnorm_equivalence():
     torch.manual_seed(42)
 
@@ -177,13 +220,16 @@ def test_rmsnorm_equivalence():
 @pytest.mark.skipif(not transformers_installed, reason="transformers not installed")
 def test_tokenizer_equivalence():
     from transformers import AutoTokenizer
-    repo_id = "Qwen/Qwen3-0.6B"
-    tokenizer_ref = AutoTokenizer.from_pretrained(repo_id)
+
     prompt = "Give me a short introduction to large language models."
     messages = [
         {"role": "user", "content": prompt},
     ]
 
+    # Reasoning model tokenizer
+    repo_id = "Qwen/Qwen3-0.6B"
+    tokenizer_ref = AutoTokenizer.from_pretrained(repo_id)
+
     for states in ((True, True), (False, False)):
         tokenizer = Qwen3Tokenizer(
             tokenizer_file_path="Qwen3-0.6B/tokenizer.json",
@@ -203,3 +249,33 @@ def test_tokenizer_equivalence():
         output_text = tokenizer.decode(input_token_ids)
         out_text_ref = tokenizer_ref.decode(input_token_ids_ref)
         assert output_text == out_text_ref, states
+
+        assert tokenizer_ref.eos_token_id == tokenizer.eos_token_id
+        assert tokenizer_ref.pad_token_id == tokenizer.pad_token_id
+
+    # Base model tokenizer
+    repo_id = "Qwen/Qwen3-0.6B-Base"
+    tokenizer_ref = AutoTokenizer.from_pretrained(repo_id)
+
+    for states in ((True, True), (False, False)):
+        tokenizer = Qwen3Tokenizer(
+            tokenizer_file_path="Qwen3-0.6B-Base/tokenizer.json",
+            repo_id=repo_id,
+            add_generation_prompt=states[0],
+            add_thinking=states[1]
+        )
+        input_token_ids = tokenizer.encode(prompt)
+        input_token_ids_ref = tokenizer_ref.apply_chat_template(
+            messages,
+            tokenize=True,
+            add_generation_prompt=states[0],
+            enable_thinking=states[1],
+        )
+        assert input_token_ids == input_token_ids_ref, states
+
+        output_text = tokenizer.decode(input_token_ids)
+        out_text_ref = tokenizer_ref.decode(input_token_ids_ref)
+        assert output_text == out_text_ref, states
+
+        assert tokenizer_ref.eos_token_id == tokenizer.eos_token_id
+        assert tokenizer_ref.pad_token_id == tokenizer.pad_token_id