Sfoglia il codice sorgente

Handle other Qwen3 tokenizer settings (#716)

Sebastian Raschka 4 mesi fa
parent
commit
0405b0c8e7
3 ha cambiato i file con 16 aggiunte e 25 eliminazioni
  1. 15 18
      pkg/llms_from_scratch/qwen3.py
  2. 0 6
      pkg/llms_from_scratch/tests/test_qwen3.py
  3. 1 1
      pyproject.toml

+ 15 - 18
pkg/llms_from_scratch/qwen3.py

@@ -410,15 +410,11 @@ def load_weights_into_qwen(model, param_config, params):
 
 class Qwen3Tokenizer():
     def __init__(self, tokenizer_file_path="tokenizer.json",
-                 repo_id=None, add_generation_prompt=False, add_thinking=False):
+                 repo_id=None, apply_chat_template=True,
+                 add_generation_prompt=False, add_thinking=False):
         from tokenizers import Tokenizer
         self.tokenizer_file_path = tokenizer_file_path
-
-        if add_generation_prompt != add_thinking:
-            raise ValueError(
-                "Only add_generation_prompt==add_thinking settings are currently supported"
-            )
-
+        self.apply_chat_template = apply_chat_template
         self.add_generation_prompt = add_generation_prompt
         self.add_thinking = add_thinking
 
@@ -432,14 +428,15 @@ class Qwen3Tokenizer():
         self.tokenizer = Tokenizer.from_file(tokenizer_file_path)
 
     def encode(self, prompt):
-        messages = [
-            {"role": "user", "content": prompt}
-        ]
-        formatted_prompt = self.format_qwen_chat(
-            messages,
-            add_generation_prompt=self.add_generation_prompt,
-            add_thinking=self.add_thinking
-        )
+        if self.apply_chat_template:
+            messages = [{"role": "user", "content": prompt}]
+            formatted_prompt = self.format_qwen_chat(
+                messages,
+                add_generation_prompt=self.add_generation_prompt,
+                add_thinking=self.add_thinking
+            )
+        else:
+            formatted_prompt = prompt
         return self.tokenizer.encode(formatted_prompt).ids
 
     def decode(self, token_ids):
@@ -452,10 +449,10 @@ class Qwen3Tokenizer():
             prompt += f"<|im_start|>{msg['role']}\n{msg['content']}<|im_end|>\n"
         if add_generation_prompt:
             prompt += "<|im_start|>assistant"
-            if not add_thinking:
-                prompt += "<|think>\n\n<|/think>\n\n"
+            if add_thinking:
+                prompt += "\n"  # no <think> tags
             else:
-                prompt += "\n"
+                prompt += "\n<think>\n\n</think>\n\n"
         return prompt
 
 

+ 0 - 6
pkg/llms_from_scratch/tests/test_qwen3.py

@@ -117,12 +117,6 @@ def qwen3_weights_path(tmp_path_factory):
 @pytest.mark.parametrize("generate_fn", [generate_text_simple, generate_text_simple_cached])
 def test_model_variants(ModelClass, qwen3_weights_path, generate_fn):
 
-    # Skip incompatible combinations
-    if generate_fn is generate_text_simple and getattr(ModelClass, "reset_kv_cache", False):
-        return
-    if generate_fn is generate_text_simple_cached and not getattr(ModelClass, "reset_kv_cache", False):
-        return
-
     torch.manual_seed(123)
     model = ModelClass(QWEN_CONFIG_06_B)
     model.load_state_dict(torch.load(qwen3_weights_path))

+ 1 - 1
pyproject.toml

@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
 
 [project]
 name = "llms-from-scratch"
-version = "1.0.15"
+version = "1.0.16"
 description = "Implement a ChatGPT-like LLM in PyTorch from scratch, step by step"
 readme = "README.md"
 requires-python = ">=3.10"