| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573 |
- # Copyright (c) Sebastian Raschka under Apache License 2.0 (see LICENSE.txt).
- # Source for "Build a Large Language Model From Scratch"
- # - https://www.manning.com/books/build-a-large-language-model-from-scratch
- # Code: https://github.com/rasbt/LLMs-from-scratch
- import os
- from pathlib import Path
- import torch
- import torch.nn as nn
- import tiktoken
- from tiktoken.load import load_tiktoken_bpe
- LLAMA32_CONFIG_1B = {
- "vocab_size": 128_256, # Vocabulary size
- "context_length": 131_072, # Context length that was used to train the model
- "emb_dim": 2048, # Embedding dimension
- "n_heads": 32, # Number of attention heads
- "n_layers": 16, # Number of layers
- "hidden_dim": 8192, # Size of the intermediate dimension in FeedForward
- "n_kv_groups": 8, # Key-Value groups for grouped-query attention
- "rope_base": 500_000.0, # The base in RoPE's "theta"
- "dtype": torch.bfloat16, # Lower-precision dtype to reduce memory usage
- "rope_freq": { # RoPE frequency scaling
- "factor": 32.0,
- "low_freq_factor": 1.0,
- "high_freq_factor": 4.0,
- "original_context_length": 8192,
- }
- }
- LLAMA32_CONFIG_3B = {
- "vocab_size": 128_256, # Vocabulary size
- "context_length": 131_072, # Context length that was used to train the model
- "emb_dim": 3072, # Embedding dimension
- "n_heads": 24, # Number of attention heads
- "n_layers": 28, # Number of layers
- "hidden_dim": 8192, # Size of the intermediate dimension in FeedForward
- "n_kv_groups": 8, # Key-Value groups for grouped-query attention
- "rope_base": 500_000.0, # The base in RoPE's "theta"
- "dtype": torch.bfloat16, # Lower-precision dtype to reduce memory usage
- "rope_freq": { # RoPE frequency scaling
- "factor": 32.0,
- "low_freq_factor": 1.0,
- "high_freq_factor": 4.0,
- "original_context_length": 8192,
- }
- }
- class Llama3Model(nn.Module):
- def __init__(self, cfg):
- super().__init__()
- # Main model parameters
- self.tok_emb = nn.Embedding(cfg["vocab_size"], cfg["emb_dim"], dtype=cfg["dtype"])
- self.trf_blocks = nn.ModuleList( # ModuleList since Sequential can only accept one input, and we need `x, mask, cos, sin`
- [TransformerBlock(cfg) for _ in range(cfg["n_layers"])]
- )
- self.final_norm = nn.RMSNorm(cfg["emb_dim"], eps=1e-5, dtype=cfg["dtype"])
- self.out_head = nn.Linear(cfg["emb_dim"], cfg["vocab_size"], bias=False, dtype=cfg["dtype"])
- # Reusable utilities
- cos, sin = compute_rope_params(
- head_dim=cfg["emb_dim"] // cfg["n_heads"],
- theta_base=cfg["rope_base"],
- context_length=cfg["context_length"],
- freq_config=cfg["rope_freq"]
- )
- self.register_buffer("cos", cos, persistent=False)
- self.register_buffer("sin", sin, persistent=False)
- self.cfg = cfg
- def forward(self, in_idx):
- tok_embeds = self.tok_emb(in_idx)
- x = tok_embeds
- num_tokens = x.shape[1]
- mask = torch.triu(torch.ones(num_tokens, num_tokens, device=x.device, dtype=torch.bool), diagonal=1)
- for block in self.trf_blocks:
- x = block(x, mask, self.cos, self.sin)
- x = self.final_norm(x)
- logits = self.out_head(x.to(self.cfg["dtype"]))
- return logits
- class TransformerBlock(nn.Module):
- def __init__(self, cfg):
- super().__init__()
- self.att = GroupedQueryAttention(
- d_in=cfg["emb_dim"],
- d_out=cfg["emb_dim"],
- num_heads=cfg["n_heads"],
- num_kv_groups=cfg["n_kv_groups"],
- dtype=cfg["dtype"]
- )
- self.ff = FeedForward(cfg)
- self.norm1 = nn.RMSNorm(cfg["emb_dim"], eps=1e-5, dtype=cfg["dtype"])
- self.norm2 = nn.RMSNorm(cfg["emb_dim"], eps=1e-5, dtype=cfg["dtype"])
- def forward(self, x, mask, cos, sin):
- # Shortcut connection for attention block
- shortcut = x
- x = self.norm1(x)
- x = self.att(x, mask, cos, sin) # Shape [batch_size, num_tokens, emb_size]
- x = x + shortcut # Add the original input back
- # Shortcut connection for feed-forward block
- shortcut = x
- x = self.norm2(x)
- x = self.ff(x)
- x = x + shortcut # Add the original input back
- return x
- class FeedForward(nn.Module):
- def __init__(self, cfg):
- super().__init__()
- self.fc1 = nn.Linear(cfg["emb_dim"], cfg["hidden_dim"], dtype=cfg["dtype"], bias=False)
- self.fc2 = nn.Linear(cfg["emb_dim"], cfg["hidden_dim"], dtype=cfg["dtype"], bias=False)
- self.fc3 = nn.Linear(cfg["hidden_dim"], cfg["emb_dim"], dtype=cfg["dtype"], bias=False)
- def forward(self, x):
- x_fc1 = self.fc1(x)
- x_fc2 = self.fc2(x)
- x = nn.functional.silu(x_fc1) * x_fc2
- return self.fc3(x)
- class GroupedQueryAttention(nn.Module):
- def __init__(
- self, d_in, d_out, num_heads, num_kv_groups, dtype=None
- ):
- super().__init__()
- assert d_out % num_heads == 0, "d_out must be divisible by num_heads"
- assert num_heads % num_kv_groups == 0, "num_heads must be divisible by num_kv_groups"
- self.d_out = d_out
- self.num_heads = num_heads
- self.head_dim = d_out // num_heads
- self.W_key = nn.Linear(d_in, num_kv_groups * self.head_dim, bias=False, dtype=dtype)
- self.W_value = nn.Linear(d_in, num_kv_groups * self.head_dim, bias=False, dtype=dtype)
- self.num_kv_groups = num_kv_groups
- self.group_size = num_heads // num_kv_groups
- self.W_query = nn.Linear(d_in, d_out, bias=False, dtype=dtype)
- self.out_proj = nn.Linear(d_out, d_out, bias=False, dtype=dtype)
- def forward(self, x, mask, cos, sin):
- b, num_tokens, d_in = x.shape
- queries = self.W_query(x) # Shape: (b, num_tokens, d_out)
- keys = self.W_key(x) # Shape: (b, num_tokens, num_kv_groups * head_dim)
- values = self.W_value(x) # Shape: (b, num_tokens, num_kv_groups * head_dim)
- # Reshape queries, keys, and values
- queries = queries.view(b, num_tokens, self.num_heads, self.head_dim)
- keys = keys.view(b, num_tokens, self.num_kv_groups, self.head_dim)
- values = values.view(b, num_tokens, self.num_kv_groups, self.head_dim)
- # Transpose keys, values, and queries
- keys = keys.transpose(1, 2) # Shape: (b, num_kv_groups, num_tokens, head_dim)
- values = values.transpose(1, 2) # Shape: (b, num_kv_groups, num_tokens, head_dim)
- queries = queries.transpose(1, 2) # Shape: (b, num_heads, num_tokens, head_dim)
- # Apply RoPE
- keys = apply_rope(keys, cos, sin)
- queries = apply_rope(queries, cos, sin)
- # Expand keys and values to match the number of heads
- # Shape: (b, num_heads, num_tokens, head_dim)
- keys = keys.repeat_interleave(self.group_size, dim=1) # Shape: (b, num_heads, num_tokens, head_dim)
- values = values.repeat_interleave(self.group_size, dim=1) # Shape: (b, num_heads, num_tokens, head_dim)
- # For example, before repeat_interleave along dim=1 (query groups):
- # [K1, K2]
- # After repeat_interleave (each query group is repeated group_size times):
- # [K1, K1, K2, K2]
- # If we used regular repeat instead of repeat_interleave, we'd get:
- # [K1, K2, K1, K2]
- # Compute scaled dot-product attention (aka self-attention) with a causal mask
- # Shape: (b, num_heads, num_tokens, num_tokens)
- attn_scores = queries @ keys.transpose(2, 3) # Dot product for each head
- # Use the mask to fill attention scores
- attn_scores = attn_scores.masked_fill(mask[:num_tokens, :num_tokens], -torch.inf)
- attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=-1)
- assert keys.shape[-1] == self.head_dim
- # Shape: (b, num_tokens, num_heads, head_dim)
- context_vec = (attn_weights @ values).transpose(1, 2)
- # Combine heads, where self.d_out = self.num_heads * self.head_dim
- context_vec = context_vec.reshape(b, num_tokens, self.d_out)
- context_vec = self.out_proj(context_vec) # optional projection
- return context_vec
- def compute_rope_params(head_dim, theta_base=10_000, context_length=4096, freq_config=None, dtype=torch.float32):
- assert head_dim % 2 == 0, "Embedding dimension must be even"
- # Compute the inverse frequencies
- inv_freq = 1.0 / (theta_base ** (torch.arange(0, head_dim, 2, dtype=dtype)[: (head_dim // 2)].float() / head_dim))
- # Frequency adjustments
- if freq_config is not None:
- low_freq_wavelen = freq_config["original_context_length"] / freq_config["low_freq_factor"]
- high_freq_wavelen = freq_config["original_context_length"] / freq_config["high_freq_factor"]
- wavelen = 2 * torch.pi / inv_freq
- inv_freq_llama = torch.where(
- wavelen > low_freq_wavelen, inv_freq / freq_config["factor"], inv_freq
- )
- smooth_factor = (freq_config["original_context_length"] / wavelen - freq_config["low_freq_factor"]) / (
- freq_config["high_freq_factor"] - freq_config["low_freq_factor"]
- )
- smoothed_inv_freq = (
- (1 - smooth_factor) * (inv_freq / freq_config["factor"]) + smooth_factor * inv_freq
- )
- is_medium_freq = (wavelen <= low_freq_wavelen) & (wavelen >= high_freq_wavelen)
- inv_freq_llama = torch.where(is_medium_freq, smoothed_inv_freq, inv_freq_llama)
- inv_freq = inv_freq_llama
- # Generate position indices
- positions = torch.arange(context_length, dtype=dtype)
- # Compute the angles
- angles = positions[:, None] * inv_freq[None, :] # Shape: (context_length, head_dim // 2)
- # Expand angles to match the head_dim
- angles = torch.cat([angles, angles], dim=1) # Shape: (context_length, head_dim)
- # Precompute sine and cosine
- cos = torch.cos(angles)
- sin = torch.sin(angles)
- return cos, sin
- def apply_rope(x, cos, sin):
- # x: (batch_size, num_heads, seq_len, head_dim)
- batch_size, num_heads, seq_len, head_dim = x.shape
- assert head_dim % 2 == 0, "Head dimension must be even"
- # Split x into first half and second half
- x1 = x[..., : head_dim // 2] # First half
- x2 = x[..., head_dim // 2:] # Second half
- # Adjust sin and cos shapes
- cos = cos[:seq_len, :].unsqueeze(0).unsqueeze(0) # Shape: (1, 1, seq_len, head_dim)
- sin = sin[:seq_len, :].unsqueeze(0).unsqueeze(0)
- # Apply the rotary transformation
- rotated = torch.cat((-x2, x1), dim=-1)
- x_rotated = (x * cos) + (rotated * sin)
- # It's ok to use lower-precision after applying cos and sin rotation
- return x_rotated.to(dtype=x.dtype)
- ##########################################
- # Tokenizer
- ##########################################
- class Llama3Tokenizer:
- """Thin wrapper around tiktoken that keeps track of Llama-3 special IDs."""
- def __init__(self, model_path):
- if not os.path.isfile(model_path):
- raise FileNotFoundError(model_path)
- mergeable = load_tiktoken_bpe(model_path)
- # hard-coded from Meta's tokenizer.json
- self.special = {
- "<|begin_of_text|>": 128000,
- "<|end_of_text|>": 128001,
- "<|start_header_id|>": 128006,
- "<|end_header_id|>": 128007,
- "<|eot_id|>": 128009,
- }
- self.special.update({f"<|reserved_{i}|>": 128002 + i
- for i in range(256)
- if 128002 + i not in self.special.values()})
- self.model = tiktoken.Encoding(
- name=Path(model_path).name,
- pat_str=r"(?i:'s|'t|'re|'ve|'m|'ll|'d)"
- r"|[^\r\n\p{L}\p{N}]?\p{L}+"
- r"|\p{N}{1,3}"
- r"| ?[^\s\p{L}\p{N}]+[\r\n]*"
- r"|\s*[\r\n]+"
- r"|\s+(?!\S)"
- r"|\s+",
- mergeable_ranks=mergeable,
- special_tokens=self.special,
- )
- def encode(self, text, bos=False, eos=False, **kwargs):
- ids = ([self.special["<|begin_of_text|>"]] if bos else []) \
- + self.model.encode(text)
- if eos:
- ids.append(self.special["<|end_of_text|>"])
- return ids
- def decode(self, ids):
- return self.model.decode(ids)
- class ChatFormat:
- def __init__(self, tokenizer: Llama3Tokenizer, *,
- default_system="You are a helpful assistant."):
- self.tok = tokenizer
- self.default_system = default_system
- def _header(self, role):
- """Encode <|start_header_id|>role<|end_header_id|>\n\n"""
- return (
- [self.tok.special["<|start_header_id|>"]]
- + self.tok.encode(role)
- + [self.tok.special["<|end_header_id|>"]]
- + self.tok.encode("\n\n")
- )
- def encode(self, user_message, system_message=None, allowed_special=None):
- sys_msg = system_message if system_message is not None else self.default_system
- ids = [self.tok.special["<|begin_of_text|>"]]
- # system
- ids += self._header("system")
- ids += self.tok.encode(sys_msg, allowed_special=allowed_special)
- ids += [self.tok.special["<|eot_id|>"]]
- # user
- ids += self._header("user")
- ids += self.tok.encode(user_message)
- ids += [self.tok.special["<|eot_id|>"]]
- # assistant header (no content yet)
- ids += self._header("assistant")
- return ids
- def decode(self, ids):
- return self.tok.decode(ids)
- def clean_text(text, header_end="assistant<|end_header_id|>\n\n"):
- # Find the index of the first occurrence of "<|end_header_id|>"
- index = text.find(header_end)
- if index != -1:
- # Return the substring starting after "<|end_header_id|>"
- return text[index + len(header_end):].strip() # Strip removes leading/trailing whitespace
- else:
- # If the token is not found, return the original text
- return text
- ######################################################################
- # Llama 3 fast (alternative code geared towards efficiency)
- ######################################################################
- class GroupedQueryAttentionFast(nn.Module):
- """
- Drop-in replacement for GroupedQueryAttention but using PyTorch's
- scaled_dot_product_attention, which uses FlashAttention if run
- on an Ampere GPU (like A100) or newer and uses float16/bfloat16 or lower.
- """
- def __init__(self, d_in, d_out, num_heads, num_kv_groups, dtype=None):
- super().__init__()
- assert d_out % num_heads == 0, "d_out must be divisible by num_heads"
- assert num_heads % num_kv_groups == 0, "num_heads must be divisible by num_kv_groups"
- self.d_out = d_out
- self.num_heads = num_heads
- self.head_dim = d_out // num_heads
- self.num_kv_groups = num_kv_groups
- self.group_size = num_heads // num_kv_groups
- self.W_key = nn.Linear(d_in, num_kv_groups * self.head_dim, bias=False, dtype=dtype)
- self.W_value = nn.Linear(d_in, num_kv_groups * self.head_dim, bias=False, dtype=dtype)
- self.W_query = nn.Linear(d_in, d_out, bias=False, dtype=dtype)
- self.out_proj = nn.Linear(d_out, d_out, bias=False, dtype=dtype)
- def forward(self, x, cos, sin):
- b, num_tokens, _ = x.shape
- # Project to queries, keys, values
- q = self.W_query(x).view(b, num_tokens, self.num_heads, self.head_dim).transpose(1, 2)
- k = self.W_key(x).view(b, num_tokens, self.num_kv_groups, self.head_dim).transpose(1, 2)
- v = self.W_value(x).view(b, num_tokens, self.num_kv_groups, self.head_dim).transpose(1, 2)
- # Apply Rotary Positional Embedding
- q = apply_rope(q, cos, sin)
- k = apply_rope(k, cos, sin)
- # Expand key/value groups to full head count
- k = k.repeat_interleave(self.group_size, dim=1)
- v = v.repeat_interleave(self.group_size, dim=1)
- # Efficient scaled dot-product attention
- attn_output = torch.nn.functional.scaled_dot_product_attention(
- q, k, v,
- is_causal=True # Enables Flash/FlexAttention kernels
- )
- # Combine heads and project
- attn_output = attn_output.transpose(1, 2).reshape(b, num_tokens, self.d_out)
- return self.out_proj(attn_output)
- class TransformerBlockFast(nn.Module):
- """
- Same as original TransformerBlock but uses
- GroupedQueryAttentionFast instead of GroupedQueryAttention.
- """
- def __init__(self, cfg):
- super().__init__()
- self.att = GroupedQueryAttentionFast(
- d_in=cfg["emb_dim"],
- d_out=cfg["emb_dim"],
- num_heads=cfg["n_heads"],
- num_kv_groups=cfg["n_kv_groups"],
- dtype=cfg["dtype"]
- )
- self.ff = FeedForward(cfg)
- self.norm1 = nn.RMSNorm(cfg["emb_dim"], eps=1e-5, dtype=cfg["dtype"])
- self.norm2 = nn.RMSNorm(cfg["emb_dim"], eps=1e-5, dtype=cfg["dtype"])
- def forward(self, x, cos, sin):
- # Shortcut connection for attention block
- shortcut = x
- x = self.norm1(x)
- x = self.att(x, cos, sin) # Shape [batch_size, num_tokens, emb_size]
- x = x + shortcut # Add the original input back
- # Shortcut connection for feed-forward block
- shortcut = x
- x = self.norm2(x)
- x = self.ff(x)
- x = x + shortcut # Add the original input back
- return x
- class Llama3ModelFast(nn.Module):
- """
- Same as original Llama3Model but uses TransformerBlockFast
- instead of TransformerBlock, which in turn uses
- GroupedQueryAttentionFast instead of GroupedQueryAttention.
- """
- def __init__(self, cfg):
- super().__init__()
- # Main model parameters
- self.tok_emb = nn.Embedding(cfg["vocab_size"], cfg["emb_dim"], dtype=cfg["dtype"])
- self.trf_blocks = nn.ModuleList( # ModuleList since Sequential can only accept one input, and we need `x, cos, sin`
- [TransformerBlockFast(cfg) for _ in range(cfg["n_layers"])]
- )
- self.final_norm = nn.RMSNorm(cfg["emb_dim"], eps=1e-5, dtype=cfg["dtype"])
- self.out_head = nn.Linear(cfg["emb_dim"], cfg["vocab_size"], bias=False, dtype=cfg["dtype"])
- cos, sin = compute_rope_params(
- head_dim=cfg["emb_dim"] // cfg["n_heads"],
- theta_base=cfg["rope_base"],
- context_length=cfg["context_length"],
- freq_config=cfg["rope_freq"]
- )
- self.register_buffer("cos", cos, persistent=False)
- self.register_buffer("sin", sin, persistent=False)
- self.cfg = cfg
- def forward(self, in_idx):
- tok_embeds = self.tok_emb(in_idx)
- x = tok_embeds
- for block in self.trf_blocks:
- x = block(x, self.cos, self.sin)
- x = self.final_norm(x)
- logits = self.out_head(x.to(self.cfg["dtype"]))
- return logits
- def assign(left, right, tensor_name="unknown"):
- if left.shape != right.shape:
- raise ValueError(f"Shape mismatch in tensor '{tensor_name}'. Left: {left.shape}, Right: {right.shape}")
- if isinstance(right, torch.Tensor):
- return torch.nn.Parameter(right.clone().detach())
- else:
- return torch.nn.Parameter(torch.tensor(right))
- def load_weights_into_llama(model, param_config, params):
- model.tok_emb.weight = assign(model.tok_emb.weight, params["model.embed_tokens.weight"], "model.embed_tokens.weight")
- for l in range(param_config["n_layers"]):
- # Load attention weights
- model.trf_blocks[l].att.W_query.weight = assign(
- model.trf_blocks[l].att.W_query.weight,
- params[f"model.layers.{l}.self_attn.q_proj.weight"],
- f"model.layers.{l}.self_attn.q_proj.weight"
- )
- model.trf_blocks[l].att.W_key.weight = assign(
- model.trf_blocks[l].att.W_key.weight,
- params[f"model.layers.{l}.self_attn.k_proj.weight"],
- f"model.layers.{l}.self_attn.k_proj.weight"
- )
- model.trf_blocks[l].att.W_value.weight = assign(
- model.trf_blocks[l].att.W_value.weight,
- params[f"model.layers.{l}.self_attn.v_proj.weight"],
- f"model.layers.{l}.self_attn.v_proj.weight"
- )
- model.trf_blocks[l].att.out_proj.weight = assign(
- model.trf_blocks[l].att.out_proj.weight,
- params[f"model.layers.{l}.self_attn.o_proj.weight"],
- f"model.layers.{l}.self_attn.o_proj.weight"
- )
- model.trf_blocks[l].norm1.weight = assign(
- model.trf_blocks[l].norm1.weight,
- params[f"model.layers.{l}.input_layernorm.weight"],
- f"model.layers.{l}.input_layernorm.weight"
- )
- # Load FeedForward weights
- model.trf_blocks[l].ff.fc1.weight = assign(
- model.trf_blocks[l].ff.fc1.weight,
- params[f"model.layers.{l}.mlp.gate_proj.weight"],
- f"model.layers.{l}.mlp.gate_proj.weight"
- )
- model.trf_blocks[l].ff.fc2.weight = assign(
- model.trf_blocks[l].ff.fc2.weight,
- params[f"model.layers.{l}.mlp.up_proj.weight"],
- f"model.layers.{l}.mlp.up_proj.weight"
- )
- model.trf_blocks[l].ff.fc3.weight = assign(
- model.trf_blocks[l].ff.fc3.weight,
- params[f"model.layers.{l}.mlp.down_proj.weight"],
- f"model.layers.{l}.mlp.down_proj.weight"
- )
- model.trf_blocks[l].norm2.weight = assign(
- model.trf_blocks[l].norm2.weight,
- params[f"model.layers.{l}.post_attention_layernorm.weight"],
- f"model.layers.{l}.post_attention_layernorm.weight"
- )
- # Load output layer weights
- model.final_norm.weight = assign(model.final_norm.weight, params["model.norm.weight"], "model.norm.weight")
- if "lm_head.weight" in params.keys():
- model.out_head.weight = assign(model.out_head.weight, params["lm_head.weight"], "lm_head.weight")
- else:
- model.out_head.weight = assign(model.out_head.weight, params["model.embed_tokens.weight"], "model.embed_tokens.weight")
- print("Model uses weight tying.")
|