previous_chapters.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454
  1. # Copyright (c) Sebastian Raschka under Apache License 2.0 (see LICENSE.txt).
  2. # Source for "Build a Large Language Model From Scratch"
  3. # - https://www.manning.com/books/build-a-large-language-model-from-scratch
  4. # Code: https://github.com/rasbt/LLMs-from-scratch
  5. #
  6. # This file collects all the relevant code that we covered thus far
  7. # throughout Chapters 2-5.
  8. import json
  9. import os
  10. import urllib
  11. import numpy as np
  12. import tensorflow as tf
  13. import tiktoken
  14. import torch
  15. import torch.nn as nn
  16. from torch.utils.data import Dataset, DataLoader
  17. from tqdm import tqdm
  18. #####################################
  19. # Chapter 2
  20. #####################################
  21. class GPTDatasetV1(Dataset):
  22. def __init__(self, txt, tokenizer, max_length, stride):
  23. self.input_ids = []
  24. self.target_ids = []
  25. # Tokenize the entire text
  26. token_ids = tokenizer.encode(txt, allowed_special={"<|endoftext|>"})
  27. # Use a sliding window to chunk the book into overlapping sequences of max_length
  28. for i in range(0, len(token_ids) - max_length, stride):
  29. input_chunk = token_ids[i:i + max_length]
  30. target_chunk = token_ids[i + 1: i + max_length + 1]
  31. self.input_ids.append(torch.tensor(input_chunk))
  32. self.target_ids.append(torch.tensor(target_chunk))
  33. def __len__(self):
  34. return len(self.input_ids)
  35. def __getitem__(self, idx):
  36. return self.input_ids[idx], self.target_ids[idx]
  37. def create_dataloader_v1(txt, batch_size=4, max_length=256,
  38. stride=128, shuffle=True, drop_last=True, num_workers=0):
  39. # Initialize the tokenizer
  40. tokenizer = tiktoken.get_encoding("gpt2")
  41. # Create dataset
  42. dataset = GPTDatasetV1(txt, tokenizer, max_length, stride)
  43. # Create dataloader
  44. dataloader = DataLoader(
  45. dataset, batch_size=batch_size, shuffle=shuffle, drop_last=drop_last, num_workers=num_workers)
  46. return dataloader
  47. #####################################
  48. # Chapter 3
  49. #####################################
  50. class MultiHeadAttention(nn.Module):
  51. def __init__(self, d_in, d_out, context_length, dropout, num_heads, qkv_bias=False):
  52. super().__init__()
  53. assert d_out % num_heads == 0, "d_out must be divisible by n_heads"
  54. self.d_out = d_out
  55. self.num_heads = num_heads
  56. self.head_dim = d_out // num_heads # Reduce the projection dim to match desired output dim
  57. self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)
  58. self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias)
  59. self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)
  60. self.out_proj = nn.Linear(d_out, d_out) # Linear layer to combine head outputs
  61. self.dropout = nn.Dropout(dropout)
  62. self.register_buffer('mask', torch.triu(torch.ones(context_length, context_length), diagonal=1))
  63. def forward(self, x):
  64. b, num_tokens, d_in = x.shape
  65. keys = self.W_key(x) # Shape: (b, num_tokens, d_out)
  66. queries = self.W_query(x)
  67. values = self.W_value(x)
  68. # We implicitly split the matrix by adding a `num_heads` dimension
  69. # Unroll last dim: (b, num_tokens, d_out) -> (b, num_tokens, num_heads, head_dim)
  70. keys = keys.view(b, num_tokens, self.num_heads, self.head_dim)
  71. values = values.view(b, num_tokens, self.num_heads, self.head_dim)
  72. queries = queries.view(b, num_tokens, self.num_heads, self.head_dim)
  73. # Transpose: (b, num_tokens, num_heads, head_dim) -> (b, num_heads, num_tokens, head_dim)
  74. keys = keys.transpose(1, 2)
  75. queries = queries.transpose(1, 2)
  76. values = values.transpose(1, 2)
  77. # Compute scaled dot-product attention (aka self-attention) with a causal mask
  78. attn_scores = queries @ keys.transpose(2, 3) # Dot product for each head
  79. # Original mask truncated to the number of tokens and converted to boolean
  80. mask_bool = self.mask.bool()[:num_tokens, :num_tokens]
  81. # Use the mask to fill attention scores
  82. attn_scores.masked_fill_(mask_bool, -torch.inf)
  83. attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=-1)
  84. attn_weights = self.dropout(attn_weights)
  85. # Shape: (b, num_tokens, num_heads, head_dim)
  86. context_vec = (attn_weights @ values).transpose(1, 2)
  87. # Combine heads, where self.d_out = self.num_heads * self.head_dim
  88. context_vec = context_vec.reshape(b, num_tokens, self.d_out)
  89. context_vec = self.out_proj(context_vec) # optional projection
  90. return context_vec
  91. #####################################
  92. # Chapter 4
  93. #####################################
  94. class LayerNorm(nn.Module):
  95. def __init__(self, emb_dim):
  96. super().__init__()
  97. self.eps = 1e-5
  98. self.scale = nn.Parameter(torch.ones(emb_dim))
  99. self.shift = nn.Parameter(torch.zeros(emb_dim))
  100. def forward(self, x):
  101. mean = x.mean(dim=-1, keepdim=True)
  102. var = x.var(dim=-1, keepdim=True, unbiased=False)
  103. norm_x = (x - mean) / torch.sqrt(var + self.eps)
  104. return self.scale * norm_x + self.shift
  105. class GELU(nn.Module):
  106. def __init__(self):
  107. super().__init__()
  108. def forward(self, x):
  109. return 0.5 * x * (1 + torch.tanh(
  110. torch.sqrt(torch.tensor(2.0 / torch.pi)) *
  111. (x + 0.044715 * torch.pow(x, 3))
  112. ))
  113. class FeedForward(nn.Module):
  114. def __init__(self, cfg):
  115. super().__init__()
  116. self.layers = nn.Sequential(
  117. nn.Linear(cfg["emb_dim"], 4 * cfg["emb_dim"]),
  118. GELU(),
  119. nn.Linear(4 * cfg["emb_dim"], cfg["emb_dim"]),
  120. )
  121. def forward(self, x):
  122. return self.layers(x)
  123. class TransformerBlock(nn.Module):
  124. def __init__(self, cfg):
  125. super().__init__()
  126. self.att = MultiHeadAttention(
  127. d_in=cfg["emb_dim"],
  128. d_out=cfg["emb_dim"],
  129. context_length=cfg["context_length"],
  130. num_heads=cfg["n_heads"],
  131. dropout=cfg["drop_rate"],
  132. qkv_bias=cfg["qkv_bias"])
  133. self.ff = FeedForward(cfg)
  134. self.norm1 = LayerNorm(cfg["emb_dim"])
  135. self.norm2 = LayerNorm(cfg["emb_dim"])
  136. self.drop_shortcut = nn.Dropout(cfg["drop_rate"])
  137. def forward(self, x):
  138. # Shortcut connection for attention block
  139. shortcut = x
  140. x = self.norm1(x)
  141. x = self.att(x) # Shape [batch_size, num_tokens, emb_size]
  142. x = self.drop_shortcut(x)
  143. x = x + shortcut # Add the original input back
  144. # Shortcut connection for feed-forward block
  145. shortcut = x
  146. x = self.norm2(x)
  147. x = self.ff(x)
  148. x = self.drop_shortcut(x)
  149. x = x + shortcut # Add the original input back
  150. return x
  151. class GPTModel(nn.Module):
  152. def __init__(self, cfg):
  153. super().__init__()
  154. self.tok_emb = nn.Embedding(cfg["vocab_size"], cfg["emb_dim"])
  155. self.pos_emb = nn.Embedding(cfg["context_length"], cfg["emb_dim"])
  156. self.drop_emb = nn.Dropout(cfg["drop_rate"])
  157. self.trf_blocks = nn.Sequential(
  158. *[TransformerBlock(cfg) for _ in range(cfg["n_layers"])])
  159. self.final_norm = LayerNorm(cfg["emb_dim"])
  160. self.out_head = nn.Linear(cfg["emb_dim"], cfg["vocab_size"], bias=False)
  161. def forward(self, in_idx):
  162. batch_size, seq_len = in_idx.shape
  163. tok_embeds = self.tok_emb(in_idx)
  164. pos_embeds = self.pos_emb(torch.arange(seq_len, device=in_idx.device))
  165. x = tok_embeds + pos_embeds # Shape [batch_size, num_tokens, emb_size]
  166. x = self.drop_emb(x)
  167. x = self.trf_blocks(x)
  168. x = self.final_norm(x)
  169. logits = self.out_head(x)
  170. return logits
  171. def generate_text_simple(model, idx, max_new_tokens, context_size):
  172. # idx is (B, T) array of indices in the current context
  173. for _ in range(max_new_tokens):
  174. # Crop current context if it exceeds the supported context size
  175. # E.g., if LLM supports only 5 tokens, and the context size is 10
  176. # then only the last 5 tokens are used as context
  177. idx_cond = idx[:, -context_size:]
  178. # Get the predictions
  179. with torch.no_grad():
  180. logits = model(idx_cond)
  181. # Focus only on the last time step
  182. # (batch, n_token, vocab_size) becomes (batch, vocab_size)
  183. logits = logits[:, -1, :]
  184. # Get the idx of the vocab entry with the highest logits value
  185. idx_next = torch.argmax(logits, dim=-1, keepdim=True) # (batch, 1)
  186. # Append sampled index to the running sequence
  187. idx = torch.cat((idx, idx_next), dim=1) # (batch, n_tokens+1)
  188. return idx
  189. #####################################
  190. # Chapter 5
  191. #####################################
  192. def text_to_token_ids(text, tokenizer):
  193. encoded = tokenizer.encode(text)
  194. encoded_tensor = torch.tensor(encoded).unsqueeze(0) # add batch dimension
  195. return encoded_tensor
  196. def token_ids_to_text(token_ids, tokenizer):
  197. flat = token_ids.squeeze(0) # remove batch dimension
  198. return tokenizer.decode(flat.tolist())
  199. def download_and_load_gpt2(model_size, models_dir):
  200. # Validate model size
  201. allowed_sizes = ("124M", "355M", "774M", "1558M")
  202. if model_size not in allowed_sizes:
  203. raise ValueError(f"Model size not in {allowed_sizes}")
  204. # Define paths
  205. model_dir = os.path.join(models_dir, model_size)
  206. base_url = "https://openaipublic.blob.core.windows.net/gpt-2/models"
  207. filenames = [
  208. "checkpoint", "encoder.json", "hparams.json",
  209. "model.ckpt.data-00000-of-00001", "model.ckpt.index",
  210. "model.ckpt.meta", "vocab.bpe"
  211. ]
  212. # Download files
  213. os.makedirs(model_dir, exist_ok=True)
  214. for filename in filenames:
  215. file_url = os.path.join(base_url, model_size, filename)
  216. file_path = os.path.join(model_dir, filename)
  217. download_file(file_url, file_path)
  218. # Load settings and params
  219. tf_ckpt_path = tf.train.latest_checkpoint(model_dir)
  220. settings = json.load(open(os.path.join(model_dir, "hparams.json")))
  221. params = load_gpt2_params_from_tf_ckpt(tf_ckpt_path, settings)
  222. return settings, params
  223. def download_file(url, destination):
  224. # Send a GET request to download the file
  225. with urllib.request.urlopen(url) as response:
  226. # Get the total file size from headers, defaulting to 0 if not present
  227. file_size = int(response.headers.get("Content-Length", 0))
  228. # Check if file exists and has the same size
  229. if os.path.exists(destination):
  230. file_size_local = os.path.getsize(destination)
  231. if file_size == file_size_local:
  232. print(f"File already exists and is up-to-date: {destination}")
  233. return
  234. # Define the block size for reading the file
  235. block_size = 1024 # 1 Kilobyte
  236. # Initialize the progress bar with total file size
  237. progress_bar_description = os.path.basename(url) # Extract filename from URL
  238. with tqdm(total=file_size, unit="iB", unit_scale=True, desc=progress_bar_description) as progress_bar:
  239. # Open the destination file in binary write mode
  240. with open(destination, "wb") as file:
  241. # Read the file in chunks and write to destination
  242. while True:
  243. chunk = response.read(block_size)
  244. if not chunk:
  245. break
  246. file.write(chunk)
  247. progress_bar.update(len(chunk)) # Update progress bar
  248. def load_gpt2_params_from_tf_ckpt(ckpt_path, settings):
  249. # Initialize parameters dictionary with empty blocks for each layer
  250. params = {"blocks": [{} for _ in range(settings["n_layer"])]}
  251. # Iterate over each variable in the checkpoint
  252. for name, _ in tf.train.list_variables(ckpt_path):
  253. # Load the variable and remove singleton dimensions
  254. variable_array = np.squeeze(tf.train.load_variable(ckpt_path, name))
  255. # Process the variable name to extract relevant parts
  256. variable_name_parts = name.split("/")[1:] # Skip the 'model/' prefix
  257. # Identify the target dictionary for the variable
  258. target_dict = params
  259. if variable_name_parts[0].startswith("h"):
  260. layer_number = int(variable_name_parts[0][1:])
  261. target_dict = params["blocks"][layer_number]
  262. # Recursively access or create nested dictionaries
  263. for key in variable_name_parts[1:-1]:
  264. target_dict = target_dict.setdefault(key, {})
  265. # Assign the variable array to the last key
  266. last_key = variable_name_parts[-1]
  267. target_dict[last_key] = variable_array
  268. return params
  269. def assign(left, right):
  270. if left.shape != right.shape:
  271. raise ValueError(f"Shape mismatch. Left: {left.shape}, Right: {right.shape}")
  272. return torch.nn.Parameter(torch.tensor(right))
  273. def load_weights_into_gpt(gpt, params):
  274. gpt.pos_emb.weight = assign(gpt.pos_emb.weight, params['wpe'])
  275. gpt.tok_emb.weight = assign(gpt.tok_emb.weight, params['wte'])
  276. for b in range(len(params["blocks"])):
  277. q_w, k_w, v_w = np.split(
  278. (params["blocks"][b]["attn"]["c_attn"])["w"], 3, axis=-1)
  279. gpt.trf_blocks[b].att.W_query.weight = assign(
  280. gpt.trf_blocks[b].att.W_query.weight, q_w.T)
  281. gpt.trf_blocks[b].att.W_key.weight = assign(
  282. gpt.trf_blocks[b].att.W_key.weight, k_w.T)
  283. gpt.trf_blocks[b].att.W_value.weight = assign(
  284. gpt.trf_blocks[b].att.W_value.weight, v_w.T)
  285. q_b, k_b, v_b = np.split(
  286. (params["blocks"][b]["attn"]["c_attn"])["b"], 3, axis=-1)
  287. gpt.trf_blocks[b].att.W_query.bias = assign(
  288. gpt.trf_blocks[b].att.W_query.bias, q_b)
  289. gpt.trf_blocks[b].att.W_key.bias = assign(
  290. gpt.trf_blocks[b].att.W_key.bias, k_b)
  291. gpt.trf_blocks[b].att.W_value.bias = assign(
  292. gpt.trf_blocks[b].att.W_value.bias, v_b)
  293. gpt.trf_blocks[b].att.out_proj.weight = assign(
  294. gpt.trf_blocks[b].att.out_proj.weight,
  295. params["blocks"][b]["attn"]["c_proj"]["w"].T)
  296. gpt.trf_blocks[b].att.out_proj.bias = assign(
  297. gpt.trf_blocks[b].att.out_proj.bias,
  298. params["blocks"][b]["attn"]["c_proj"]["b"])
  299. gpt.trf_blocks[b].ff.layers[0].weight = assign(
  300. gpt.trf_blocks[b].ff.layers[0].weight,
  301. params["blocks"][b]["mlp"]["c_fc"]["w"].T)
  302. gpt.trf_blocks[b].ff.layers[0].bias = assign(
  303. gpt.trf_blocks[b].ff.layers[0].bias,
  304. params["blocks"][b]["mlp"]["c_fc"]["b"])
  305. gpt.trf_blocks[b].ff.layers[2].weight = assign(
  306. gpt.trf_blocks[b].ff.layers[2].weight,
  307. params["blocks"][b]["mlp"]["c_proj"]["w"].T)
  308. gpt.trf_blocks[b].ff.layers[2].bias = assign(
  309. gpt.trf_blocks[b].ff.layers[2].bias,
  310. params["blocks"][b]["mlp"]["c_proj"]["b"])
  311. gpt.trf_blocks[b].norm1.scale = assign(
  312. gpt.trf_blocks[b].norm1.scale,
  313. params["blocks"][b]["ln_1"]["g"])
  314. gpt.trf_blocks[b].norm1.shift = assign(
  315. gpt.trf_blocks[b].norm1.shift,
  316. params["blocks"][b]["ln_1"]["b"])
  317. gpt.trf_blocks[b].norm2.scale = assign(
  318. gpt.trf_blocks[b].norm2.scale,
  319. params["blocks"][b]["ln_2"]["g"])
  320. gpt.trf_blocks[b].norm2.shift = assign(
  321. gpt.trf_blocks[b].norm2.shift,
  322. params["blocks"][b]["ln_2"]["b"])
  323. gpt.final_norm.scale = assign(gpt.final_norm.scale, params["g"])
  324. gpt.final_norm.shift = assign(gpt.final_norm.shift, params["b"])
  325. gpt.out_head.weight = assign(gpt.out_head.weight, params["wte"])
  326. def generate(model, idx, max_new_tokens, context_size, temperature=0.0, top_k=None, eos_id=None):
  327. # For-loop is the same as before: Get logits, and only focus on last time step
  328. for _ in range(max_new_tokens):
  329. idx_cond = idx[:, -context_size:]
  330. with torch.no_grad():
  331. logits = model(idx_cond)
  332. logits = logits[:, -1, :]
  333. # New: Filter logits with top_k sampling
  334. if top_k is not None:
  335. # Keep only top_k values
  336. top_logits, _ = torch.topk(logits, top_k)
  337. min_val = top_logits[:, -1]
  338. logits = torch.where(logits < min_val, torch.tensor(float('-inf')).to(logits.device), logits)
  339. # New: Apply temperature scaling
  340. if temperature > 0.0:
  341. logits = logits / temperature
  342. # Apply softmax to get probabilities
  343. probs = torch.softmax(logits, dim=-1) # (batch_size, context_len)
  344. # Sample from the distribution
  345. idx_next = torch.multinomial(probs, num_samples=1) # (batch_size, 1)
  346. # Otherwise same as before: get idx of the vocab entry with the highest logits value
  347. else:
  348. idx_next = torch.argmax(logits, dim=-1, keepdim=True) # (batch_size, 1)
  349. if idx_next == eos_id: # Stop generating early if end-of-sequence token is encountered and eos_id is specified
  350. break
  351. # Same as before: append sampled index to the running sequence
  352. idx = torch.cat((idx, idx_next), dim=1) # (batch_size, num_tokens+1)
  353. return idx