qwen3.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393
  1. # Copyright (c) Sebastian Raschka under Apache License 2.0 (see LICENSE.txt).
  2. # Source for "Build a Large Language Model From Scratch"
  3. # - https://www.manning.com/books/build-a-large-language-model-from-scratch
  4. # Code: https://github.com/rasbt/LLMs-from-scratch
  5. import os
  6. import urllib.request
  7. from pathlib import Path
  8. import torch
  9. import torch.nn as nn
  10. # 0.6B model
  11. QWEN_CONFIG_06_B = {
  12. "vocab_size": 151_936, # Vocabulary size
  13. "context_length": 40_960, # Context length that was used to train the model
  14. "emb_dim": 1024, # Embedding dimension
  15. "n_heads": 16, # Number of attention heads
  16. "n_layers": 28, # Number of layers
  17. "hidden_dim": 3072, # Size of the intermediate dimension in FeedForward
  18. "head_dim": 128, # Size of the heads in GQA
  19. "qk_norm": True, # Whether to normalize queries and values in GQA
  20. "n_kv_groups": 8, # Key-Value groups for grouped-query attention
  21. "rope_base": 1_000_000.0, # The base in RoPE's "theta"
  22. "dtype": torch.bfloat16, # Lower-precision dtype to reduce memory usage
  23. }
  24. class Qwen3Model(nn.Module):
  25. def __init__(self, cfg):
  26. super().__init__()
  27. # Main model parameters
  28. self.tok_emb = nn.Embedding(cfg["vocab_size"], cfg["emb_dim"], dtype=cfg["dtype"])
  29. self.trf_blocks = nn.ModuleList( # ModuleList since Sequential can only accept one input, and we need `x, mask, cos, sin`
  30. [TransformerBlock(cfg) for _ in range(cfg["n_layers"])]
  31. )
  32. self.final_norm = RMSNorm(cfg["emb_dim"])
  33. self.out_head = nn.Linear(cfg["emb_dim"], cfg["vocab_size"], bias=False, dtype=cfg["dtype"])
  34. # Reusuable utilities
  35. if cfg["head_dim"] is None:
  36. head_dim = cfg["emb_dim"] // cfg["n_heads"]
  37. else:
  38. head_dim = cfg["head_dim"]
  39. cos, sin = compute_rope_params(
  40. head_dim=head_dim,
  41. theta_base=cfg["rope_base"],
  42. context_length=cfg["context_length"]
  43. )
  44. self.register_buffer("cos", cos, persistent=False)
  45. self.register_buffer("sin", sin, persistent=False)
  46. self.cfg = cfg
  47. def forward(self, in_idx):
  48. # Forward pass
  49. tok_embeds = self.tok_emb(in_idx)
  50. x = tok_embeds
  51. num_tokens = x.shape[1]
  52. mask = torch.triu(torch.ones(num_tokens, num_tokens, device=x.device, dtype=torch.bool), diagonal=1)
  53. for block in self.trf_blocks:
  54. x = block(x, mask, self.cos, self.sin)
  55. x = self.final_norm(x)
  56. logits = self.out_head(x.to(self.cfg["dtype"]))
  57. return logits
  58. class TransformerBlock(nn.Module):
  59. def __init__(self, cfg):
  60. super().__init__()
  61. self.att = GroupedQueryAttention(
  62. d_in=cfg["emb_dim"],
  63. num_heads=cfg["n_heads"],
  64. head_dim=cfg["head_dim"],
  65. num_kv_groups=cfg["n_kv_groups"],
  66. qk_norm=cfg["qk_norm"],
  67. dtype=cfg["dtype"]
  68. )
  69. self.ff = FeedForward(cfg)
  70. self.norm1 = RMSNorm(cfg["emb_dim"], eps=1e-6)
  71. self.norm2 = RMSNorm(cfg["emb_dim"], eps=1e-6)
  72. def forward(self, x, mask, cos, sin):
  73. # Shortcut connection for attention block
  74. shortcut = x
  75. x = self.norm1(x)
  76. x = self.att(x, mask, cos, sin,) # Shape [batch_size, num_tokens, emb_size]
  77. x = x + shortcut # Add the original input back
  78. # Shortcut connection for feed-forward block
  79. shortcut = x
  80. x = self.norm2(x)
  81. x = self.ff(x)
  82. x = x + shortcut # Add the original input back
  83. return x
  84. class FeedForward(nn.Module):
  85. def __init__(self, cfg):
  86. super().__init__()
  87. self.fc1 = nn.Linear(cfg["emb_dim"], cfg["hidden_dim"], dtype=cfg["dtype"], bias=False)
  88. self.fc2 = nn.Linear(cfg["emb_dim"], cfg["hidden_dim"], dtype=cfg["dtype"], bias=False)
  89. self.fc3 = nn.Linear(cfg["hidden_dim"], cfg["emb_dim"], dtype=cfg["dtype"], bias=False)
  90. def forward(self, x):
  91. x_fc1 = self.fc1(x)
  92. x_fc2 = self.fc2(x)
  93. x = nn.functional.silu(x_fc1) * x_fc2
  94. return self.fc3(x)
  95. class GroupedQueryAttention(nn.Module):
  96. def __init__(
  97. self, d_in, num_heads, num_kv_groups, head_dim=None, qk_norm=False, dtype=None
  98. ):
  99. super().__init__()
  100. assert num_heads % num_kv_groups == 0, "num_heads must be divisible by num_kv_groups"
  101. self.num_heads = num_heads
  102. self.num_kv_groups = num_kv_groups
  103. self.group_size = num_heads // num_kv_groups
  104. if head_dim is None:
  105. assert d_in % num_heads == 0, "`d_in` must be divisible by `num_heads` if `head_dim` is not set"
  106. head_dim = d_in // num_heads
  107. self.head_dim = head_dim
  108. self.d_out = num_heads * head_dim
  109. self.W_query = nn.Linear(d_in, self.d_out, bias=False, dtype=dtype)
  110. self.W_key = nn.Linear(d_in, num_kv_groups * head_dim, bias=False, dtype=dtype)
  111. self.W_value = nn.Linear(d_in, num_kv_groups * head_dim, bias=False, dtype=dtype)
  112. self.out_proj = nn.Linear(self.d_out, d_in, bias=False, dtype=dtype)
  113. if qk_norm:
  114. self.q_norm = RMSNorm(head_dim, eps=1e-6)
  115. self.k_norm = RMSNorm(head_dim, eps=1e-6)
  116. else:
  117. self.q_norm = self.k_norm = None
  118. def forward(self, x, mask, cos, sin):
  119. b, num_tokens, _ = x.shape
  120. # Apply projections
  121. queries = self.W_query(x) # (b, num_tokens, num_heads * head_dim)
  122. keys = self.W_key(x) # (b, num_tokens, num_kv_groups * head_dim)
  123. values = self.W_value(x) # (b, num_tokens, num_kv_groups * head_dim)
  124. # Reshape
  125. queries = queries.view(b, num_tokens, self.num_heads, self.head_dim).transpose(1, 2)
  126. keys = keys.view(b, num_tokens, self.num_kv_groups, self.head_dim).transpose(1, 2)
  127. values = values.view(b, num_tokens, self.num_kv_groups, self.head_dim).transpose(1, 2)
  128. # Optional normalization
  129. if self.q_norm:
  130. queries = self.q_norm(queries)
  131. if self.k_norm:
  132. keys = self.k_norm(keys)
  133. # Apply RoPE
  134. queries = apply_rope(queries, cos, sin)
  135. keys = apply_rope(keys, cos, sin)
  136. # Expand K and V to match number of heads
  137. keys = keys.repeat_interleave(self.group_size, dim=1)
  138. values = values.repeat_interleave(self.group_size, dim=1)
  139. # Attention
  140. attn_scores = queries @ keys.transpose(2, 3)
  141. attn_scores = attn_scores.masked_fill(mask, -torch.inf)
  142. attn_weights = torch.softmax(attn_scores / self.head_dim**0.5, dim=-1)
  143. context = (attn_weights @ values).transpose(1, 2).reshape(b, num_tokens, self.d_out)
  144. return self.out_proj(context)
  145. def compute_rope_params(head_dim, theta_base=10_000, context_length=4096, dtype=torch.float32):
  146. assert head_dim % 2 == 0, "Embedding dimension must be even"
  147. # Compute the inverse frequencies
  148. inv_freq = 1.0 / (theta_base ** (torch.arange(0, head_dim, 2, dtype=dtype)[: (head_dim // 2)].float() / head_dim))
  149. # Generate position indices
  150. positions = torch.arange(context_length, dtype=dtype)
  151. # Compute the angles
  152. angles = positions[:, None] * inv_freq[None, :] # Shape: (context_length, head_dim // 2)
  153. # Expand angles to match the head_dim
  154. angles = torch.cat([angles, angles], dim=1) # Shape: (context_length, head_dim)
  155. # Precompute sine and cosine
  156. cos = torch.cos(angles)
  157. sin = torch.sin(angles)
  158. return cos, sin
  159. def apply_rope(x, cos, sin):
  160. # x: (batch_size, num_heads, seq_len, head_dim)
  161. batch_size, num_heads, seq_len, head_dim = x.shape
  162. assert head_dim % 2 == 0, "Head dimension must be even"
  163. # Split x into first half and second half
  164. x1 = x[..., : head_dim // 2] # First half
  165. x2 = x[..., head_dim // 2:] # Second half
  166. # Adjust sin and cos shapes
  167. cos = cos[:seq_len, :].unsqueeze(0).unsqueeze(0) # Shape: (1, 1, seq_len, head_dim)
  168. sin = sin[:seq_len, :].unsqueeze(0).unsqueeze(0)
  169. # Apply the rotary transformation
  170. rotated = torch.cat((-x2, x1), dim=-1)
  171. x_rotated = (x * cos) + (rotated * sin)
  172. # It's ok to use lower-precision after applying cos and sin rotation
  173. return x_rotated.to(dtype=x.dtype)
  174. class RMSNorm(nn.Module):
  175. def __init__(self, emb_dim, eps=1e-6, bias=False, qwen3_compatible=True):
  176. super().__init__()
  177. self.eps = eps
  178. self.qwen3_compatible = qwen3_compatible
  179. self.scale = nn.Parameter(torch.ones(emb_dim))
  180. self.shift = nn.Parameter(torch.zeros(emb_dim)) if bias else None
  181. def forward(self, x):
  182. input_dtype = x.dtype
  183. if self.qwen3_compatible:
  184. x = x.to(torch.float32)
  185. variance = x.pow(2).mean(dim=-1, keepdim=True)
  186. norm_x = x * torch.rsqrt(variance + self.eps)
  187. norm_x = norm_x * self.scale
  188. if self.shift is not None:
  189. norm_x = norm_x + self.shift
  190. return norm_x.to(input_dtype)
  191. def load_weights_into_qwen(model, param_config, params):
  192. def assign(left, right, tensor_name="unknown"):
  193. if left.shape != right.shape:
  194. raise ValueError(f"Shape mismatch in tensor '{tensor_name}'. Left: {left.shape}, Right: {right.shape}")
  195. return torch.nn.Parameter(right.clone().detach() if isinstance(right, torch.Tensor) else torch.tensor(right))
  196. model.tok_emb.weight = assign(model.tok_emb.weight, params["model.embed_tokens.weight"], "model.embed_tokens.weight")
  197. for l in range(param_config["n_layers"]):
  198. block = model.trf_blocks[l]
  199. att = block.att
  200. # Q, K, V projections
  201. att.W_query.weight = assign(
  202. att.W_query.weight,
  203. params[f"model.layers.{l}.self_attn.q_proj.weight"],
  204. f"model.layers.{l}.self_attn.q_proj.weight"
  205. )
  206. att.W_key.weight = assign(
  207. att.W_key.weight,
  208. params[f"model.layers.{l}.self_attn.k_proj.weight"],
  209. f"model.layers.{l}.self_attn.k_proj.weight"
  210. )
  211. att.W_value.weight = assign(
  212. att.W_value.weight,
  213. params[f"model.layers.{l}.self_attn.v_proj.weight"],
  214. f"model.layers.{l}.self_attn.v_proj.weight"
  215. )
  216. # Output projection
  217. att.out_proj.weight = assign(
  218. att.out_proj.weight,
  219. params[f"model.layers.{l}.self_attn.o_proj.weight"],
  220. f"model.layers.{l}.self_attn.o_proj.weight"
  221. )
  222. # QK norms
  223. if hasattr(att, "q_norm") and att.q_norm is not None:
  224. att.q_norm.scale = assign(
  225. att.q_norm.scale,
  226. params[f"model.layers.{l}.self_attn.q_norm.weight"],
  227. f"model.layers.{l}.self_attn.q_norm.weight"
  228. )
  229. if hasattr(att, "k_norm") and att.k_norm is not None:
  230. att.k_norm.scale = assign(
  231. att.k_norm.scale,
  232. params[f"model.layers.{l}.self_attn.k_norm.weight"],
  233. f"model.layers.{l}.self_attn.k_norm.weight"
  234. )
  235. # Attention layernorm
  236. block.norm1.scale = assign(
  237. block.norm1.scale,
  238. params[f"model.layers.{l}.input_layernorm.weight"],
  239. f"model.layers.{l}.input_layernorm.weight"
  240. )
  241. # Feedforward weights
  242. block.ff.fc1.weight = assign(
  243. block.ff.fc1.weight,
  244. params[f"model.layers.{l}.mlp.gate_proj.weight"],
  245. f"model.layers.{l}.mlp.gate_proj.weight"
  246. )
  247. block.ff.fc2.weight = assign(
  248. block.ff.fc2.weight,
  249. params[f"model.layers.{l}.mlp.up_proj.weight"],
  250. f"model.layers.{l}.mlp.up_proj.weight"
  251. )
  252. block.ff.fc3.weight = assign(
  253. block.ff.fc3.weight,
  254. params[f"model.layers.{l}.mlp.down_proj.weight"],
  255. f"model.layers.{l}.mlp.down_proj.weight"
  256. )
  257. block.norm2.scale = assign(
  258. block.norm2.scale,
  259. params[f"model.layers.{l}.post_attention_layernorm.weight"],
  260. f"model.layers.{l}.post_attention_layernorm.weight"
  261. )
  262. # Final normalization and output head
  263. model.final_norm.scale = assign(model.final_norm.scale, params["model.norm.weight"], "model.norm.weight")
  264. # Model uses weight tying, hence we reuse the embedding layer weights here
  265. model.out_head.weight = assign(model.out_head.weight, params["model.embed_tokens.weight"], "model.embed_tokens.weight")
  266. class Qwen3Tokenizer():
  267. def __init__(self, tokenizer_file_path="tokenizer.json",
  268. repo_id=None, add_generation_prompt=False, add_thinking=False):
  269. from tokenizers import Tokenizer
  270. self.tokenizer_file_path = tokenizer_file_path
  271. if add_generation_prompt != add_thinking:
  272. raise ValueError(
  273. "Only add_generation_prompt==add_thinking settings are currently supported"
  274. )
  275. self.add_generation_prompt = add_generation_prompt
  276. self.add_thinking = add_thinking
  277. tokenizer_file_path_obj = Path(tokenizer_file_path)
  278. if not tokenizer_file_path_obj.is_file() and repo_id is not None:
  279. _ = download_from_huggingface(
  280. repo_id=repo_id,
  281. filename=str(tokenizer_file_path_obj.name),
  282. local_dir=str(tokenizer_file_path_obj.parent.name)
  283. )
  284. self.tokenizer = Tokenizer.from_file(tokenizer_file_path)
  285. def encode(self, prompt):
  286. messages = [
  287. {"role": "user", "content": prompt}
  288. ]
  289. formatted_prompt = self.format_qwen_chat(
  290. messages,
  291. add_generation_prompt=self.add_generation_prompt,
  292. add_thinking=self.add_thinking
  293. )
  294. return self.tokenizer.encode(formatted_prompt).ids
  295. def decode(self, token_ids):
  296. return self.tokenizer.decode(token_ids, skip_special_tokens=False)
  297. @staticmethod
  298. def format_qwen_chat(messages, add_generation_prompt=False, add_thinking=False):
  299. prompt = ""
  300. for msg in messages:
  301. prompt += f"<|im_start|>{msg['role']}\n{msg['content']}<|im_end|>\n"
  302. if add_generation_prompt:
  303. prompt += "<|im_start|>assistant"
  304. if not add_thinking:
  305. prompt += "<|think>\n\n<|/think>\n\n"
  306. else:
  307. prompt += "\n"
  308. return prompt
  309. def download_from_huggingface(repo_id, filename, local_dir, revision="main"):
  310. base_url = "https://huggingface.co"
  311. url = f"{base_url}/{repo_id}/resolve/{revision}/{filename}"
  312. Path(local_dir).mkdir(parents=True, exist_ok=True)
  313. dest_path = os.path.join(local_dir, filename)
  314. print(f"Downloading {url} to {dest_path}...")
  315. urllib.request.urlretrieve(url, dest_path)
  316. return dest_path