llama3.py 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499
  1. # Copyright (c) Sebastian Raschka under Apache License 2.0 (see LICENSE.txt).
  2. # Source for "Build a Large Language Model From Scratch"
  3. # - https://www.manning.com/books/build-a-large-language-model-from-scratch
  4. # Code: https://github.com/rasbt/LLMs-from-scratch
  5. import os
  6. from pathlib import Path
  7. import torch
  8. import torch.nn as nn
  9. import tiktoken
  10. from tiktoken.load import load_tiktoken_bpe
  11. LLAMA32_CONFIG_1B = {
  12. "vocab_size": 128_256, # Vocabulary size
  13. "context_length": 131_072, # Context length that was used to train the model
  14. "emb_dim": 2048, # Embedding dimension
  15. "n_heads": 32, # Number of attention heads
  16. "n_layers": 16, # Number of layers
  17. "hidden_dim": 8192, # Size of the intermediate dimension in FeedForward
  18. "n_kv_groups": 8, # Key-Value groups for grouped-query attention
  19. "rope_base": 500_000.0, # The base in RoPE's "theta"
  20. "dtype": torch.bfloat16, # Lower-precision dtype to reduce memory usage
  21. "rope_freq": { # RoPE frequency scaling
  22. "factor": 32.0,
  23. "low_freq_factor": 1.0,
  24. "high_freq_factor": 4.0,
  25. "original_context_length": 8192,
  26. }
  27. }
  28. LLAMA32_CONFIG_3B = {
  29. "vocab_size": 128_256, # Vocabulary size
  30. "context_length": 131_072, # Context length that was used to train the model
  31. "emb_dim": 3072, # Embedding dimension
  32. "n_heads": 24, # Number of attention heads
  33. "n_layers": 28, # Number of layers
  34. "hidden_dim": 8192, # Size of the intermediate dimension in FeedForward
  35. "n_kv_groups": 8, # Key-Value groups for grouped-query attention
  36. "rope_base": 500_000.0, # The base in RoPE's "theta"
  37. "dtype": torch.bfloat16, # Lower-precision dtype to reduce memory usage
  38. "rope_freq": { # RoPE frequency scaling
  39. "factor": 32.0,
  40. "low_freq_factor": 1.0,
  41. "high_freq_factor": 4.0,
  42. "original_context_length": 8192,
  43. }
  44. }
  45. class Llama3Model(nn.Module):
  46. def __init__(self, cfg):
  47. super().__init__()
  48. # Main model parameters
  49. self.tok_emb = nn.Embedding(cfg["vocab_size"], cfg["emb_dim"], dtype=cfg["dtype"])
  50. self.trf_blocks = nn.ModuleList( # ModuleList since Sequential can only accept one input, and we need `x, mask, cos, sin`
  51. [TransformerBlock(cfg) for _ in range(cfg["n_layers"])]
  52. )
  53. self.final_norm = nn.RMSNorm(cfg["emb_dim"], eps=1e-5, dtype=cfg["dtype"])
  54. self.out_head = nn.Linear(cfg["emb_dim"], cfg["vocab_size"], bias=False, dtype=cfg["dtype"])
  55. # Reusuable utilities
  56. cos, sin = compute_rope_params(
  57. head_dim=cfg["emb_dim"] // cfg["n_heads"],
  58. theta_base=cfg["rope_base"],
  59. context_length=cfg["context_length"],
  60. freq_config=cfg["rope_freq"]
  61. )
  62. self.register_buffer("cos", cos, persistent=False)
  63. self.register_buffer("sin", sin, persistent=False)
  64. self.cfg = cfg
  65. def forward(self, in_idx):
  66. tok_embeds = self.tok_emb(in_idx)
  67. x = tok_embeds
  68. num_tokens = x.shape[1]
  69. mask = torch.triu(torch.ones(num_tokens, num_tokens, device=x.device, dtype=torch.bool), diagonal=1)
  70. for block in self.trf_blocks:
  71. x = block(x, mask, self.cos, self.sin)
  72. x = self.final_norm(x)
  73. logits = self.out_head(x.to(self.cfg["dtype"]))
  74. return logits
  75. class TransformerBlock(nn.Module):
  76. def __init__(self, cfg):
  77. super().__init__()
  78. self.att = GroupedQueryAttention(
  79. d_in=cfg["emb_dim"],
  80. d_out=cfg["emb_dim"],
  81. num_heads=cfg["n_heads"],
  82. num_kv_groups=cfg["n_kv_groups"],
  83. dtype=cfg["dtype"]
  84. )
  85. self.ff = FeedForward(cfg)
  86. self.norm1 = nn.RMSNorm(cfg["emb_dim"], eps=1e-5, dtype=cfg["dtype"])
  87. self.norm2 = nn.RMSNorm(cfg["emb_dim"], eps=1e-5, dtype=cfg["dtype"])
  88. def forward(self, x, mask, cos, sin):
  89. # Shortcut connection for attention block
  90. shortcut = x
  91. x = self.norm1(x)
  92. x = self.att(x, mask, cos, sin) # Shape [batch_size, num_tokens, emb_size]
  93. x = x + shortcut # Add the original input back
  94. # Shortcut connection for feed-forward block
  95. shortcut = x
  96. x = self.norm2(x)
  97. x = self.ff(x)
  98. x = x + shortcut # Add the original input back
  99. return x
  100. class FeedForward(nn.Module):
  101. def __init__(self, cfg):
  102. super().__init__()
  103. self.fc1 = nn.Linear(cfg["emb_dim"], cfg["hidden_dim"], dtype=cfg["dtype"], bias=False)
  104. self.fc2 = nn.Linear(cfg["emb_dim"], cfg["hidden_dim"], dtype=cfg["dtype"], bias=False)
  105. self.fc3 = nn.Linear(cfg["hidden_dim"], cfg["emb_dim"], dtype=cfg["dtype"], bias=False)
  106. def forward(self, x):
  107. x_fc1 = self.fc1(x)
  108. x_fc2 = self.fc2(x)
  109. x = nn.functional.silu(x_fc1) * x_fc2
  110. return self.fc3(x)
  111. class GroupedQueryAttention(nn.Module):
  112. def __init__(
  113. self, d_in, d_out, num_heads, num_kv_groups, dtype=None
  114. ):
  115. super().__init__()
  116. assert d_out % num_heads == 0, "d_out must be divisible by num_heads"
  117. assert num_heads % num_kv_groups == 0, "num_heads must be divisible by num_kv_groups"
  118. self.d_out = d_out
  119. self.num_heads = num_heads
  120. self.head_dim = d_out // num_heads
  121. self.W_key = nn.Linear(d_in, num_kv_groups * self.head_dim, bias=False, dtype=dtype)
  122. self.W_value = nn.Linear(d_in, num_kv_groups * self.head_dim, bias=False, dtype=dtype)
  123. self.num_kv_groups = num_kv_groups
  124. self.group_size = num_heads // num_kv_groups
  125. self.W_query = nn.Linear(d_in, d_out, bias=False, dtype=dtype)
  126. self.out_proj = nn.Linear(d_out, d_out, bias=False, dtype=dtype)
  127. def forward(self, x, mask, cos, sin):
  128. b, num_tokens, d_in = x.shape
  129. queries = self.W_query(x) # Shape: (b, num_tokens, d_out)
  130. keys = self.W_key(x) # Shape: (b, num_tokens, num_kv_groups * head_dim)
  131. values = self.W_value(x) # Shape: (b, num_tokens, num_kv_groups * head_dim)
  132. # Reshape queries, keys, and values
  133. queries = queries.view(b, num_tokens, self.num_heads, self.head_dim)
  134. keys = keys.view(b, num_tokens, self.num_kv_groups, self.head_dim)
  135. values = values.view(b, num_tokens, self.num_kv_groups, self.head_dim)
  136. # Transpose keys, values, and queries
  137. keys = keys.transpose(1, 2) # Shape: (b, num_kv_groups, num_tokens, head_dim)
  138. values = values.transpose(1, 2) # Shape: (b, num_kv_groups, num_tokens, head_dim)
  139. queries = queries.transpose(1, 2) # Shape: (b, num_heads, num_tokens, head_dim)
  140. # Apply RoPE
  141. keys = apply_rope(keys, cos, sin)
  142. queries = apply_rope(queries, cos, sin)
  143. # Expand keys and values to match the number of heads
  144. # Shape: (b, num_heads, num_tokens, head_dim)
  145. keys = keys.repeat_interleave(self.group_size, dim=1) # Shape: (b, num_heads, num_tokens, head_dim)
  146. values = values.repeat_interleave(self.group_size, dim=1) # Shape: (b, num_heads, num_tokens, head_dim)
  147. # For example, before repeat_interleave along dim=1 (query groups):
  148. # [K1, K2]
  149. # After repeat_interleave (each query group is repeated group_size times):
  150. # [K1, K1, K2, K2]
  151. # If we used regular repeat instead of repeat_interleave, we'd get:
  152. # [K1, K2, K1, K2]
  153. # Compute scaled dot-product attention (aka self-attention) with a causal mask
  154. # Shape: (b, num_heads, num_tokens, num_tokens)
  155. attn_scores = queries @ keys.transpose(2, 3) # Dot product for each head
  156. # Use the mask to fill attention scores
  157. attn_scores = attn_scores.masked_fill(mask[:num_tokens, :num_tokens], -torch.inf)
  158. attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=-1)
  159. assert keys.shape[-1] == self.head_dim
  160. # Shape: (b, num_tokens, num_heads, head_dim)
  161. context_vec = (attn_weights @ values).transpose(1, 2)
  162. # Combine heads, where self.d_out = self.num_heads * self.head_dim
  163. context_vec = context_vec.reshape(b, num_tokens, self.d_out)
  164. context_vec = self.out_proj(context_vec) # optional projection
  165. return context_vec
  166. def compute_rope_params(head_dim, theta_base=10_000, context_length=4096, freq_config=None, dtype=torch.float32):
  167. assert head_dim % 2 == 0, "Embedding dimension must be even"
  168. # Compute the inverse frequencies
  169. inv_freq = 1.0 / (theta_base ** (torch.arange(0, head_dim, 2, dtype=dtype)[: (head_dim // 2)].float() / head_dim))
  170. # Frequency adjustments
  171. if freq_config is not None:
  172. low_freq_wavelen = freq_config["original_context_length"] / freq_config["low_freq_factor"]
  173. high_freq_wavelen = freq_config["original_context_length"] / freq_config["high_freq_factor"]
  174. wavelen = 2 * torch.pi / inv_freq
  175. inv_freq_llama = torch.where(
  176. wavelen > low_freq_wavelen, inv_freq / freq_config["factor"], inv_freq
  177. )
  178. smooth_factor = (freq_config["original_context_length"] / wavelen - freq_config["low_freq_factor"]) / (
  179. freq_config["high_freq_factor"] - freq_config["low_freq_factor"]
  180. )
  181. smoothed_inv_freq = (
  182. (1 - smooth_factor) * (inv_freq / freq_config["factor"]) + smooth_factor * inv_freq
  183. )
  184. is_medium_freq = (wavelen <= low_freq_wavelen) & (wavelen >= high_freq_wavelen)
  185. inv_freq_llama = torch.where(is_medium_freq, smoothed_inv_freq, inv_freq_llama)
  186. inv_freq = inv_freq_llama
  187. # Generate position indices
  188. positions = torch.arange(context_length, dtype=dtype)
  189. # Compute the angles
  190. angles = positions[:, None] * inv_freq[None, :] # Shape: (context_length, head_dim // 2)
  191. # Expand angles to match the head_dim
  192. angles = torch.cat([angles, angles], dim=1) # Shape: (context_length, head_dim)
  193. # Precompute sine and cosine
  194. cos = torch.cos(angles)
  195. sin = torch.sin(angles)
  196. return cos, sin
  197. def apply_rope(x, cos, sin):
  198. # x: (batch_size, num_heads, seq_len, head_dim)
  199. batch_size, num_heads, seq_len, head_dim = x.shape
  200. assert head_dim % 2 == 0, "Head dimension must be even"
  201. # Split x into first half and second half
  202. x1 = x[..., : head_dim // 2] # First half
  203. x2 = x[..., head_dim // 2:] # Second half
  204. # Adjust sin and cos shapes
  205. cos = cos[:seq_len, :].unsqueeze(0).unsqueeze(0) # Shape: (1, 1, seq_len, head_dim)
  206. sin = sin[:seq_len, :].unsqueeze(0).unsqueeze(0)
  207. # Apply the rotary transformation
  208. rotated = torch.cat((-x2, x1), dim=-1)
  209. x_rotated = (x * cos) + (rotated * sin)
  210. # It's ok to use lower-precision after applying cos and sin rotation
  211. return x_rotated.to(dtype=x.dtype)
  212. ##########################################
  213. # Tokenizer
  214. ##########################################
  215. class Llama3Tokenizer:
  216. """Thin wrapper around tiktoken that keeps track of Llama-3 special IDs."""
  217. def __init__(self, model_path):
  218. if not os.path.isfile(model_path):
  219. raise FileNotFoundError(model_path)
  220. mergeable = load_tiktoken_bpe(model_path)
  221. # hard-coded from Meta's tokenizer.json
  222. self.special = {
  223. "<|begin_of_text|>": 128000,
  224. "<|end_of_text|>": 128001,
  225. "<|start_header_id|>": 128006,
  226. "<|end_header_id|>": 128007,
  227. "<|eot_id|>": 128009,
  228. }
  229. self.special.update({f"<|reserved_{i}|>": 128002 + i
  230. for i in range(256)
  231. if 128002 + i not in self.special.values()})
  232. self.model = tiktoken.Encoding(
  233. name=Path(model_path).name,
  234. pat_str=r"(?i:'s|'t|'re|'ve|'m|'ll|'d)"
  235. r"|[^\r\n\p{L}\p{N}]?\p{L}+"
  236. r"|\p{N}{1,3}"
  237. r"| ?[^\s\p{L}\p{N}]+[\r\n]*"
  238. r"|\s*[\r\n]+"
  239. r"|\s+(?!\S)"
  240. r"|\s+",
  241. mergeable_ranks=mergeable,
  242. special_tokens=self.special,
  243. )
  244. def encode(self, text, bos=False, eos=False):
  245. ids = ([self.special["<|begin_of_text|>"]] if bos else []) \
  246. + self.model.encode(text)
  247. if eos:
  248. ids.append(self.special["<|end_of_text|>"])
  249. return ids
  250. def decode(self, ids):
  251. return self.model.decode(ids)
  252. class ChatFormat:
  253. def __init__(self, tokenizer: Llama3Tokenizer, *,
  254. default_system="You are a helpful assistant."):
  255. self.tok = tokenizer
  256. self.default_system = default_system
  257. def _header(self, role):
  258. """Encode <|start_header_id|>role<|end_header_id|>\n\n"""
  259. return (
  260. [self.tok.special["<|start_header_id|>"]]
  261. + self.tok.encode(role)
  262. + [self.tok.special["<|end_header_id|>"]]
  263. + self.tok.encode("\n\n")
  264. )
  265. def encode(self, user_message, system_message=None, allowed_special=None):
  266. sys_msg = system_message if system_message is not None else self.default_system
  267. ids = [self.tok.special["<|begin_of_text|>"]]
  268. # system
  269. ids += self._header("system")
  270. ids += self.tok.encode(sys_msg, allowed_special=allowed_special)
  271. ids += [self.tok.special["<|eot_id|>"]]
  272. # user
  273. ids += self._header("user")
  274. ids += self.tok.encode(user_message)
  275. ids += [self.tok.special["<|eot_id|>"]]
  276. # assistant header (no content yet)
  277. ids += self._header("assistant")
  278. return ids
  279. def decode(self, ids):
  280. return self.tok.decode(ids)
  281. def clean_text(text, header_end="assistant<|end_header_id|>\n\n"):
  282. # Find the index of the first occurrence of "<|end_header_id|>"
  283. index = text.find(header_end)
  284. if index != -1:
  285. # Return the substring starting after "<|end_header_id|>"
  286. return text[index + len(header_end):].strip() # Strip removes leading/trailing whitespace
  287. else:
  288. # If the token is not found, return the original text
  289. return text
  290. ######################################################################
  291. # Llama 3 fast (alternative code geared towards efficiency)
  292. ######################################################################
  293. class GroupedQueryAttentionFast(nn.Module):
  294. """
  295. Drop-in replacement for GroupedQueryAttention but using PyTorch's
  296. scaled_dot_product_attention, which uses FlashAttention if run
  297. on an Ampere GPU (like A100) or newer and uses float16/bfloat16 or lower.
  298. """
  299. def __init__(self, d_in, d_out, num_heads, num_kv_groups, dtype=None):
  300. super().__init__()
  301. assert d_out % num_heads == 0, "d_out must be divisible by num_heads"
  302. assert num_heads % num_kv_groups == 0, "num_heads must be divisible by num_kv_groups"
  303. self.d_out = d_out
  304. self.num_heads = num_heads
  305. self.head_dim = d_out // num_heads
  306. self.num_kv_groups = num_kv_groups
  307. self.group_size = num_heads // num_kv_groups
  308. self.W_key = nn.Linear(d_in, num_kv_groups * self.head_dim, bias=False, dtype=dtype)
  309. self.W_value = nn.Linear(d_in, num_kv_groups * self.head_dim, bias=False, dtype=dtype)
  310. self.W_query = nn.Linear(d_in, d_out, bias=False, dtype=dtype)
  311. self.out_proj = nn.Linear(d_out, d_out, bias=False, dtype=dtype)
  312. def forward(self, x, cos, sin):
  313. b, num_tokens, _ = x.shape
  314. # Project to queries, keys, values
  315. q = self.W_query(x).view(b, num_tokens, self.num_heads, self.head_dim).transpose(1, 2)
  316. k = self.W_key(x).view(b, num_tokens, self.num_kv_groups, self.head_dim).transpose(1, 2)
  317. v = self.W_value(x).view(b, num_tokens, self.num_kv_groups, self.head_dim).transpose(1, 2)
  318. # Apply Rotary Positional Embedding
  319. q = apply_rope(q, cos, sin)
  320. k = apply_rope(k, cos, sin)
  321. # Expand key/value groups to full head count
  322. k = k.repeat_interleave(self.group_size, dim=1)
  323. v = v.repeat_interleave(self.group_size, dim=1)
  324. # Efficient scaled dot-product attention
  325. attn_output = torch.nn.functional.scaled_dot_product_attention(
  326. q, k, v,
  327. is_causal=True # Enables Flash/FlexAttention kernels
  328. )
  329. # Combine heads and project
  330. attn_output = attn_output.transpose(1, 2).reshape(b, num_tokens, self.d_out)
  331. return self.out_proj(attn_output)
  332. class TransformerBlockFast(nn.Module):
  333. """
  334. Same as original TransformerBlock but uses
  335. GroupedQueryAttentionFast instead of GroupedQueryAttention.
  336. """
  337. def __init__(self, cfg):
  338. super().__init__()
  339. self.att = GroupedQueryAttentionFast(
  340. d_in=cfg["emb_dim"],
  341. d_out=cfg["emb_dim"],
  342. num_heads=cfg["n_heads"],
  343. num_kv_groups=cfg["n_kv_groups"],
  344. dtype=cfg["dtype"]
  345. )
  346. self.ff = FeedForward(cfg)
  347. self.norm1 = nn.RMSNorm(cfg["emb_dim"], eps=1e-5, dtype=cfg["dtype"])
  348. self.norm2 = nn.RMSNorm(cfg["emb_dim"], eps=1e-5, dtype=cfg["dtype"])
  349. def forward(self, x, cos, sin):
  350. # Shortcut connection for attention block
  351. shortcut = x
  352. x = self.norm1(x)
  353. x = self.att(x, cos, sin) # Shape [batch_size, num_tokens, emb_size]
  354. x = x + shortcut # Add the original input back
  355. # Shortcut connection for feed-forward block
  356. shortcut = x
  357. x = self.norm2(x)
  358. x = self.ff(x)
  359. x = x + shortcut # Add the original input back
  360. return x
  361. class Llama3ModelFast(nn.Module):
  362. """
  363. Same as original Llama3Model but uses TransformerBlockFast
  364. instead of TransformerBlock, which in turn uses
  365. GroupedQueryAttentionFast instead of GroupedQueryAttention.
  366. """
  367. def __init__(self, cfg):
  368. super().__init__()
  369. # Main model parameters
  370. self.tok_emb = nn.Embedding(cfg["vocab_size"], cfg["emb_dim"], dtype=cfg["dtype"])
  371. self.trf_blocks = nn.ModuleList( # ModuleList since Sequential can only accept one input, and we need `x, cos, sin`
  372. [TransformerBlockFast(cfg) for _ in range(cfg["n_layers"])]
  373. )
  374. self.final_norm = nn.RMSNorm(cfg["emb_dim"], eps=1e-5, dtype=cfg["dtype"])
  375. self.out_head = nn.Linear(cfg["emb_dim"], cfg["vocab_size"], bias=False, dtype=cfg["dtype"])
  376. cos, sin = compute_rope_params(
  377. head_dim=cfg["emb_dim"] // cfg["n_heads"],
  378. theta_base=cfg["rope_base"],
  379. context_length=cfg["context_length"],
  380. freq_config=cfg["rope_freq"]
  381. )
  382. self.register_buffer("cos", cos, persistent=False)
  383. self.register_buffer("sin", sin, persistent=False)
  384. self.cfg = cfg
  385. def forward(self, in_idx):
  386. tok_embeds = self.tok_emb(in_idx)
  387. x = tok_embeds
  388. for block in self.trf_blocks:
  389. x = block(x, self.cos, self.sin)
  390. x = self.final_norm(x)
  391. logits = self.out_head(x.to(self.cfg["dtype"]))
  392. return logits