|
|
@@ -5,6 +5,7 @@
|
|
|
|
|
|
import os
|
|
|
import json
|
|
|
+import re
|
|
|
import urllib.request
|
|
|
from pathlib import Path
|
|
|
|
|
|
@@ -115,7 +116,7 @@ class Qwen3Model(nn.Module):
|
|
|
self.final_norm = RMSNorm(cfg["emb_dim"])
|
|
|
self.out_head = nn.Linear(cfg["emb_dim"], cfg["vocab_size"], bias=False, dtype=cfg["dtype"])
|
|
|
|
|
|
- # Reusuable utilities
|
|
|
+ # Reusable utilities
|
|
|
if cfg["head_dim"] is None:
|
|
|
head_dim = cfg["emb_dim"] // cfg["n_heads"]
|
|
|
else:
|
|
|
@@ -408,52 +409,77 @@ def load_weights_into_qwen(model, param_config, params):
|
|
|
model.out_head.weight = assign(model.out_head.weight, params["model.embed_tokens.weight"], "model.embed_tokens.weight")
|
|
|
|
|
|
|
|
|
-class Qwen3Tokenizer():
|
|
|
- def __init__(self, tokenizer_file_path="tokenizer.json",
|
|
|
- repo_id=None, apply_chat_template=True,
|
|
|
- add_generation_prompt=False, add_thinking=False):
|
|
|
+class Qwen3Tokenizer:
|
|
|
+ _SPECIALS = [
|
|
|
+ "<|endoftext|>",
|
|
|
+ "<|im_start|>", "<|im_end|>",
|
|
|
+ "<|object_ref_start|>", "<|object_ref_end|>",
|
|
|
+ "<|box_start|>", "<|box_end|>",
|
|
|
+ "<|quad_start|>", "<|quad_end|>",
|
|
|
+ "<|vision_start|>", "<|vision_end|>",
|
|
|
+ "<|vision_pad|>", "<|image_pad|>", "<|video_pad|>",
|
|
|
+ ]
|
|
|
+ _SPLIT_RE = re.compile(r"(<\|[^>]+?\|>)")
|
|
|
+
|
|
|
+ def __init__(self, tokenizer_file_path="tokenizer.json", repo_id=None,
|
|
|
+ apply_chat_template=True, add_generation_prompt=False, add_thinking=False):
|
|
|
from tokenizers import Tokenizer
|
|
|
- self.tokenizer_file_path = tokenizer_file_path
|
|
|
+
|
|
|
self.apply_chat_template = apply_chat_template
|
|
|
self.add_generation_prompt = add_generation_prompt
|
|
|
self.add_thinking = add_thinking
|
|
|
|
|
|
- tokenizer_file_path_obj = Path(tokenizer_file_path)
|
|
|
- if not tokenizer_file_path_obj.is_file() and repo_id is not None:
|
|
|
- _ = download_from_huggingface(
|
|
|
+ tok_file = Path(tokenizer_file_path)
|
|
|
+ if not tok_file.is_file() and repo_id:
|
|
|
+ download_from_huggingface(
|
|
|
repo_id=repo_id,
|
|
|
- filename=str(tokenizer_file_path_obj.name),
|
|
|
- local_dir=str(tokenizer_file_path_obj.parent.name)
|
|
|
- )
|
|
|
- self.tokenizer = Tokenizer.from_file(tokenizer_file_path)
|
|
|
-
|
|
|
- def encode(self, prompt):
|
|
|
- if self.apply_chat_template:
|
|
|
- messages = [{"role": "user", "content": prompt}]
|
|
|
- formatted_prompt = self.format_qwen_chat(
|
|
|
- messages,
|
|
|
- add_generation_prompt=self.add_generation_prompt,
|
|
|
- add_thinking=self.add_thinking
|
|
|
+ filename=tok_file.name,
|
|
|
+ local_dir=str(tok_file.parent),
|
|
|
)
|
|
|
+ self._tok = Tokenizer.from_file(str(tok_file))
|
|
|
+ self._special_to_id = {t: self._tok.token_to_id(t) for t in self._SPECIALS}
|
|
|
+
|
|
|
+ self.pad_token_id = self._special_to_id.get("<|endoftext|>")
|
|
|
+ self.eos_token_id = self.pad_token_id
|
|
|
+
|
|
|
+ if repo_id and "Base" not in repo_id:
|
|
|
+ eos_token = "<|im_end|>"
|
|
|
else:
|
|
|
- formatted_prompt = prompt
|
|
|
- return self.tokenizer.encode(formatted_prompt).ids
|
|
|
-
|
|
|
- def decode(self, token_ids):
|
|
|
- return self.tokenizer.decode(token_ids, skip_special_tokens=False)
|
|
|
-
|
|
|
- @staticmethod
|
|
|
- def format_qwen_chat(messages, add_generation_prompt=False, add_thinking=False):
|
|
|
- prompt = ""
|
|
|
- for msg in messages:
|
|
|
- prompt += f"<|im_start|>{msg['role']}\n{msg['content']}<|im_end|>\n"
|
|
|
- if add_generation_prompt:
|
|
|
- prompt += "<|im_start|>assistant"
|
|
|
- if add_thinking:
|
|
|
- prompt += "\n" # no <think> tags
|
|
|
+ eos_token = "<|endoftext|>"
|
|
|
+ if eos_token in self._special_to_id:
|
|
|
+ self.eos_token_id = self._special_to_id[eos_token]
|
|
|
+
|
|
|
+ def encode(self, text, chat_wrapped=None):
|
|
|
+ if chat_wrapped is None:
|
|
|
+ chat_wrapped = self.apply_chat_template
|
|
|
+
|
|
|
+ stripped = text.strip()
|
|
|
+ if stripped in self._special_to_id and "\n" not in stripped:
|
|
|
+ return [self._special_to_id[stripped]]
|
|
|
+
|
|
|
+ if chat_wrapped:
|
|
|
+ text = self._wrap_chat(text)
|
|
|
+
|
|
|
+ ids = []
|
|
|
+ for part in filter(None, self._SPLIT_RE.split(text)):
|
|
|
+ if part in self._special_to_id:
|
|
|
+ ids.append(self._special_to_id[part])
|
|
|
+ else:
|
|
|
+ ids.extend(self._tok.encode(part).ids)
|
|
|
+ return ids
|
|
|
+
|
|
|
+ def decode(self, ids):
|
|
|
+ return self._tok.decode(ids, skip_special_tokens=False)
|
|
|
+
|
|
|
+ def _wrap_chat(self, user_msg):
|
|
|
+ s = f"<|im_start|>user\n{user_msg}<|im_end|>\n"
|
|
|
+ if self.add_generation_prompt:
|
|
|
+ s += "<|im_start|>assistant"
|
|
|
+ if self.add_thinking:
|
|
|
+ s += "\n"
|
|
|
else:
|
|
|
- prompt += "\n<think>\n\n</think>\n\n"
|
|
|
- return prompt
|
|
|
+ s += "\n<think>\n\n</think>\n\n"
|
|
|
+ return s
|
|
|
|
|
|
|
|
|
def download_from_huggingface(repo_id, filename, local_dir, revision="main"):
|