| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393 |
- # Copyright (c) Sebastian Raschka under Apache License 2.0 (see LICENSE.txt).
- # Source for "Build a Large Language Model From Scratch"
- # - https://www.manning.com/books/build-a-large-language-model-from-scratch
- # Code: https://github.com/rasbt/LLMs-from-scratch
- import os
- import urllib.request
- from pathlib import Path
- import torch
- import torch.nn as nn
- # 0.6B model
- QWEN_CONFIG_06_B = {
- "vocab_size": 151_936, # Vocabulary size
- "context_length": 40_960, # Context length that was used to train the model
- "emb_dim": 1024, # Embedding dimension
- "n_heads": 16, # Number of attention heads
- "n_layers": 28, # Number of layers
- "hidden_dim": 3072, # Size of the intermediate dimension in FeedForward
- "head_dim": 128, # Size of the heads in GQA
- "qk_norm": True, # Whether to normalize queries and values in GQA
- "n_kv_groups": 8, # Key-Value groups for grouped-query attention
- "rope_base": 1_000_000.0, # The base in RoPE's "theta"
- "dtype": torch.bfloat16, # Lower-precision dtype to reduce memory usage
- }
- class Qwen3Model(nn.Module):
- def __init__(self, cfg):
- super().__init__()
- # Main model parameters
- self.tok_emb = nn.Embedding(cfg["vocab_size"], cfg["emb_dim"], dtype=cfg["dtype"])
- self.trf_blocks = nn.ModuleList( # ModuleList since Sequential can only accept one input, and we need `x, mask, cos, sin`
- [TransformerBlock(cfg) for _ in range(cfg["n_layers"])]
- )
- self.final_norm = RMSNorm(cfg["emb_dim"])
- self.out_head = nn.Linear(cfg["emb_dim"], cfg["vocab_size"], bias=False, dtype=cfg["dtype"])
- # Reusuable utilities
- if cfg["head_dim"] is None:
- head_dim = cfg["emb_dim"] // cfg["n_heads"]
- else:
- head_dim = cfg["head_dim"]
- cos, sin = compute_rope_params(
- head_dim=head_dim,
- theta_base=cfg["rope_base"],
- context_length=cfg["context_length"]
- )
- self.register_buffer("cos", cos, persistent=False)
- self.register_buffer("sin", sin, persistent=False)
- self.cfg = cfg
- def forward(self, in_idx):
- # Forward pass
- tok_embeds = self.tok_emb(in_idx)
- x = tok_embeds
- num_tokens = x.shape[1]
- mask = torch.triu(torch.ones(num_tokens, num_tokens, device=x.device, dtype=torch.bool), diagonal=1)
- for block in self.trf_blocks:
- x = block(x, mask, self.cos, self.sin)
- x = self.final_norm(x)
- logits = self.out_head(x.to(self.cfg["dtype"]))
- return logits
- class TransformerBlock(nn.Module):
- def __init__(self, cfg):
- super().__init__()
- self.att = GroupedQueryAttention(
- d_in=cfg["emb_dim"],
- num_heads=cfg["n_heads"],
- head_dim=cfg["head_dim"],
- num_kv_groups=cfg["n_kv_groups"],
- qk_norm=cfg["qk_norm"],
- dtype=cfg["dtype"]
- )
- self.ff = FeedForward(cfg)
- self.norm1 = RMSNorm(cfg["emb_dim"], eps=1e-6)
- self.norm2 = RMSNorm(cfg["emb_dim"], eps=1e-6)
- def forward(self, x, mask, cos, sin):
- # Shortcut connection for attention block
- shortcut = x
- x = self.norm1(x)
- x = self.att(x, mask, cos, sin,) # Shape [batch_size, num_tokens, emb_size]
- x = x + shortcut # Add the original input back
- # Shortcut connection for feed-forward block
- shortcut = x
- x = self.norm2(x)
- x = self.ff(x)
- x = x + shortcut # Add the original input back
- return x
- class FeedForward(nn.Module):
- def __init__(self, cfg):
- super().__init__()
- self.fc1 = nn.Linear(cfg["emb_dim"], cfg["hidden_dim"], dtype=cfg["dtype"], bias=False)
- self.fc2 = nn.Linear(cfg["emb_dim"], cfg["hidden_dim"], dtype=cfg["dtype"], bias=False)
- self.fc3 = nn.Linear(cfg["hidden_dim"], cfg["emb_dim"], dtype=cfg["dtype"], bias=False)
- def forward(self, x):
- x_fc1 = self.fc1(x)
- x_fc2 = self.fc2(x)
- x = nn.functional.silu(x_fc1) * x_fc2
- return self.fc3(x)
- class GroupedQueryAttention(nn.Module):
- def __init__(
- self, d_in, num_heads, num_kv_groups, head_dim=None, qk_norm=False, dtype=None
- ):
- super().__init__()
- assert num_heads % num_kv_groups == 0, "num_heads must be divisible by num_kv_groups"
- self.num_heads = num_heads
- self.num_kv_groups = num_kv_groups
- self.group_size = num_heads // num_kv_groups
- if head_dim is None:
- assert d_in % num_heads == 0, "`d_in` must be divisible by `num_heads` if `head_dim` is not set"
- head_dim = d_in // num_heads
- self.head_dim = head_dim
- self.d_out = num_heads * head_dim
- self.W_query = nn.Linear(d_in, self.d_out, bias=False, dtype=dtype)
- self.W_key = nn.Linear(d_in, num_kv_groups * head_dim, bias=False, dtype=dtype)
- self.W_value = nn.Linear(d_in, num_kv_groups * head_dim, bias=False, dtype=dtype)
- self.out_proj = nn.Linear(self.d_out, d_in, bias=False, dtype=dtype)
- if qk_norm:
- self.q_norm = RMSNorm(head_dim, eps=1e-6)
- self.k_norm = RMSNorm(head_dim, eps=1e-6)
- else:
- self.q_norm = self.k_norm = None
- def forward(self, x, mask, cos, sin):
- b, num_tokens, _ = x.shape
- # Apply projections
- queries = self.W_query(x) # (b, num_tokens, num_heads * head_dim)
- keys = self.W_key(x) # (b, num_tokens, num_kv_groups * head_dim)
- values = self.W_value(x) # (b, num_tokens, num_kv_groups * head_dim)
- # Reshape
- queries = queries.view(b, num_tokens, self.num_heads, self.head_dim).transpose(1, 2)
- keys = keys.view(b, num_tokens, self.num_kv_groups, self.head_dim).transpose(1, 2)
- values = values.view(b, num_tokens, self.num_kv_groups, self.head_dim).transpose(1, 2)
- # Optional normalization
- if self.q_norm:
- queries = self.q_norm(queries)
- if self.k_norm:
- keys = self.k_norm(keys)
- # Apply RoPE
- queries = apply_rope(queries, cos, sin)
- keys = apply_rope(keys, cos, sin)
- # Expand K and V to match number of heads
- keys = keys.repeat_interleave(self.group_size, dim=1)
- values = values.repeat_interleave(self.group_size, dim=1)
- # Attention
- attn_scores = queries @ keys.transpose(2, 3)
- attn_scores = attn_scores.masked_fill(mask, -torch.inf)
- attn_weights = torch.softmax(attn_scores / self.head_dim**0.5, dim=-1)
- context = (attn_weights @ values).transpose(1, 2).reshape(b, num_tokens, self.d_out)
- return self.out_proj(context)
- def compute_rope_params(head_dim, theta_base=10_000, context_length=4096, dtype=torch.float32):
- assert head_dim % 2 == 0, "Embedding dimension must be even"
- # Compute the inverse frequencies
- inv_freq = 1.0 / (theta_base ** (torch.arange(0, head_dim, 2, dtype=dtype)[: (head_dim // 2)].float() / head_dim))
- # Generate position indices
- positions = torch.arange(context_length, dtype=dtype)
- # Compute the angles
- angles = positions[:, None] * inv_freq[None, :] # Shape: (context_length, head_dim // 2)
- # Expand angles to match the head_dim
- angles = torch.cat([angles, angles], dim=1) # Shape: (context_length, head_dim)
- # Precompute sine and cosine
- cos = torch.cos(angles)
- sin = torch.sin(angles)
- return cos, sin
- def apply_rope(x, cos, sin):
- # x: (batch_size, num_heads, seq_len, head_dim)
- batch_size, num_heads, seq_len, head_dim = x.shape
- assert head_dim % 2 == 0, "Head dimension must be even"
- # Split x into first half and second half
- x1 = x[..., : head_dim // 2] # First half
- x2 = x[..., head_dim // 2:] # Second half
- # Adjust sin and cos shapes
- cos = cos[:seq_len, :].unsqueeze(0).unsqueeze(0) # Shape: (1, 1, seq_len, head_dim)
- sin = sin[:seq_len, :].unsqueeze(0).unsqueeze(0)
- # Apply the rotary transformation
- rotated = torch.cat((-x2, x1), dim=-1)
- x_rotated = (x * cos) + (rotated * sin)
- # It's ok to use lower-precision after applying cos and sin rotation
- return x_rotated.to(dtype=x.dtype)
- class RMSNorm(nn.Module):
- def __init__(self, emb_dim, eps=1e-6, bias=False, qwen3_compatible=True):
- super().__init__()
- self.eps = eps
- self.qwen3_compatible = qwen3_compatible
- self.scale = nn.Parameter(torch.ones(emb_dim))
- self.shift = nn.Parameter(torch.zeros(emb_dim)) if bias else None
- def forward(self, x):
- input_dtype = x.dtype
- if self.qwen3_compatible:
- x = x.to(torch.float32)
- variance = x.pow(2).mean(dim=-1, keepdim=True)
- norm_x = x * torch.rsqrt(variance + self.eps)
- norm_x = norm_x * self.scale
- if self.shift is not None:
- norm_x = norm_x + self.shift
- return norm_x.to(input_dtype)
- def load_weights_into_qwen(model, param_config, params):
- def assign(left, right, tensor_name="unknown"):
- if left.shape != right.shape:
- raise ValueError(f"Shape mismatch in tensor '{tensor_name}'. Left: {left.shape}, Right: {right.shape}")
- return torch.nn.Parameter(right.clone().detach() if isinstance(right, torch.Tensor) else torch.tensor(right))
- model.tok_emb.weight = assign(model.tok_emb.weight, params["model.embed_tokens.weight"], "model.embed_tokens.weight")
- for l in range(param_config["n_layers"]):
- block = model.trf_blocks[l]
- att = block.att
- # Q, K, V projections
- att.W_query.weight = assign(
- att.W_query.weight,
- params[f"model.layers.{l}.self_attn.q_proj.weight"],
- f"model.layers.{l}.self_attn.q_proj.weight"
- )
- att.W_key.weight = assign(
- att.W_key.weight,
- params[f"model.layers.{l}.self_attn.k_proj.weight"],
- f"model.layers.{l}.self_attn.k_proj.weight"
- )
- att.W_value.weight = assign(
- att.W_value.weight,
- params[f"model.layers.{l}.self_attn.v_proj.weight"],
- f"model.layers.{l}.self_attn.v_proj.weight"
- )
- # Output projection
- att.out_proj.weight = assign(
- att.out_proj.weight,
- params[f"model.layers.{l}.self_attn.o_proj.weight"],
- f"model.layers.{l}.self_attn.o_proj.weight"
- )
- # QK norms
- if hasattr(att, "q_norm") and att.q_norm is not None:
- att.q_norm.scale = assign(
- att.q_norm.scale,
- params[f"model.layers.{l}.self_attn.q_norm.weight"],
- f"model.layers.{l}.self_attn.q_norm.weight"
- )
- if hasattr(att, "k_norm") and att.k_norm is not None:
- att.k_norm.scale = assign(
- att.k_norm.scale,
- params[f"model.layers.{l}.self_attn.k_norm.weight"],
- f"model.layers.{l}.self_attn.k_norm.weight"
- )
- # Attention layernorm
- block.norm1.scale = assign(
- block.norm1.scale,
- params[f"model.layers.{l}.input_layernorm.weight"],
- f"model.layers.{l}.input_layernorm.weight"
- )
- # Feedforward weights
- block.ff.fc1.weight = assign(
- block.ff.fc1.weight,
- params[f"model.layers.{l}.mlp.gate_proj.weight"],
- f"model.layers.{l}.mlp.gate_proj.weight"
- )
- block.ff.fc2.weight = assign(
- block.ff.fc2.weight,
- params[f"model.layers.{l}.mlp.up_proj.weight"],
- f"model.layers.{l}.mlp.up_proj.weight"
- )
- block.ff.fc3.weight = assign(
- block.ff.fc3.weight,
- params[f"model.layers.{l}.mlp.down_proj.weight"],
- f"model.layers.{l}.mlp.down_proj.weight"
- )
- block.norm2.scale = assign(
- block.norm2.scale,
- params[f"model.layers.{l}.post_attention_layernorm.weight"],
- f"model.layers.{l}.post_attention_layernorm.weight"
- )
- # Final normalization and output head
- model.final_norm.scale = assign(model.final_norm.scale, params["model.norm.weight"], "model.norm.weight")
- # Model uses weight tying, hence we reuse the embedding layer weights here
- model.out_head.weight = assign(model.out_head.weight, params["model.embed_tokens.weight"], "model.embed_tokens.weight")
- class Qwen3Tokenizer():
- def __init__(self, tokenizer_file_path="tokenizer.json",
- repo_id=None, add_generation_prompt=False, add_thinking=False):
- from tokenizers import Tokenizer
- self.tokenizer_file_path = tokenizer_file_path
- if add_generation_prompt != add_thinking:
- raise ValueError(
- "Only add_generation_prompt==add_thinking settings are currently supported"
- )
- self.add_generation_prompt = add_generation_prompt
- self.add_thinking = add_thinking
- tokenizer_file_path_obj = Path(tokenizer_file_path)
- if not tokenizer_file_path_obj.is_file() and repo_id is not None:
- _ = download_from_huggingface(
- repo_id=repo_id,
- filename=str(tokenizer_file_path_obj.name),
- local_dir=str(tokenizer_file_path_obj.parent.name)
- )
- self.tokenizer = Tokenizer.from_file(tokenizer_file_path)
- def encode(self, prompt):
- messages = [
- {"role": "user", "content": prompt}
- ]
- formatted_prompt = self.format_qwen_chat(
- messages,
- add_generation_prompt=self.add_generation_prompt,
- add_thinking=self.add_thinking
- )
- return self.tokenizer.encode(formatted_prompt).ids
- def decode(self, token_ids):
- return self.tokenizer.decode(token_ids, skip_special_tokens=False)
- @staticmethod
- def format_qwen_chat(messages, add_generation_prompt=False, add_thinking=False):
- prompt = ""
- for msg in messages:
- prompt += f"<|im_start|>{msg['role']}\n{msg['content']}<|im_end|>\n"
- if add_generation_prompt:
- prompt += "<|im_start|>assistant"
- if not add_thinking:
- prompt += "<|think>\n\n<|/think>\n\n"
- else:
- prompt += "\n"
- return prompt
- def download_from_huggingface(repo_id, filename, local_dir, revision="main"):
- base_url = "https://huggingface.co"
- url = f"{base_url}/{repo_id}/resolve/{revision}/{filename}"
- Path(local_dir).mkdir(parents=True, exist_ok=True)
- dest_path = os.path.join(local_dir, filename)
- print(f"Downloading {url} to {dest_path}...")
- urllib.request.urlretrieve(url, dest_path)
- return dest_path
|