previous_chapters.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371
  1. # Copyright (c) Sebastian Raschka under Apache License 2.0 (see LICENSE.txt).
  2. # Source for "Build a Large Language Model From Scratch"
  3. # - https://www.manning.com/books/build-a-large-language-model-from-scratch
  4. # Code: https://github.com/rasbt/LLMs-from-scratch
  5. #
  6. # This file collects all the relevant code that we covered thus far
  7. # throughout Chapters 2-5.
  8. import json
  9. import os
  10. import urllib
  11. import numpy as np
  12. import tensorflow as tf
  13. import torch
  14. import torch.nn as nn
  15. from tqdm import tqdm
  16. #####################################
  17. # Chapter 3
  18. #####################################
  19. class MultiHeadAttention(nn.Module):
  20. def __init__(self, d_in, d_out, context_length, dropout, num_heads, qkv_bias=False):
  21. super().__init__()
  22. assert d_out % num_heads == 0, "d_out must be divisible by n_heads"
  23. self.d_out = d_out
  24. self.num_heads = num_heads
  25. self.head_dim = d_out // num_heads # Reduce the projection dim to match desired output dim
  26. self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)
  27. self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias)
  28. self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)
  29. self.out_proj = nn.Linear(d_out, d_out) # Linear layer to combine head outputs
  30. self.dropout = nn.Dropout(dropout)
  31. self.register_buffer('mask', torch.triu(torch.ones(context_length, context_length), diagonal=1))
  32. def forward(self, x):
  33. b, num_tokens, d_in = x.shape
  34. keys = self.W_key(x) # Shape: (b, num_tokens, d_out)
  35. queries = self.W_query(x)
  36. values = self.W_value(x)
  37. # We implicitly split the matrix by adding a `num_heads` dimension
  38. # Unroll last dim: (b, num_tokens, d_out) -> (b, num_tokens, num_heads, head_dim)
  39. keys = keys.view(b, num_tokens, self.num_heads, self.head_dim)
  40. values = values.view(b, num_tokens, self.num_heads, self.head_dim)
  41. queries = queries.view(b, num_tokens, self.num_heads, self.head_dim)
  42. # Transpose: (b, num_tokens, num_heads, head_dim) -> (b, num_heads, num_tokens, head_dim)
  43. keys = keys.transpose(1, 2)
  44. queries = queries.transpose(1, 2)
  45. values = values.transpose(1, 2)
  46. # Compute scaled dot-product attention (aka self-attention) with a causal mask
  47. attn_scores = queries @ keys.transpose(2, 3) # Dot product for each head
  48. # Original mask truncated to the number of tokens and converted to boolean
  49. mask_bool = self.mask.bool()[:num_tokens, :num_tokens]
  50. # Use the mask to fill attention scores
  51. attn_scores.masked_fill_(mask_bool, -torch.inf)
  52. attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=-1)
  53. attn_weights = self.dropout(attn_weights)
  54. # Shape: (b, num_tokens, num_heads, head_dim)
  55. context_vec = (attn_weights @ values).transpose(1, 2)
  56. # Combine heads, where self.d_out = self.num_heads * self.head_dim
  57. context_vec = context_vec.reshape(b, num_tokens, self.d_out)
  58. context_vec = self.out_proj(context_vec) # optional projection
  59. return context_vec
  60. #####################################
  61. # Chapter 4
  62. #####################################
  63. class LayerNorm(nn.Module):
  64. def __init__(self, emb_dim):
  65. super().__init__()
  66. self.eps = 1e-5
  67. self.scale = nn.Parameter(torch.ones(emb_dim))
  68. self.shift = nn.Parameter(torch.zeros(emb_dim))
  69. def forward(self, x):
  70. mean = x.mean(dim=-1, keepdim=True)
  71. var = x.var(dim=-1, keepdim=True, unbiased=False)
  72. norm_x = (x - mean) / torch.sqrt(var + self.eps)
  73. return self.scale * norm_x + self.shift
  74. class GELU(nn.Module):
  75. def __init__(self):
  76. super().__init__()
  77. def forward(self, x):
  78. return 0.5 * x * (1 + torch.tanh(
  79. torch.sqrt(torch.tensor(2.0 / torch.pi)) *
  80. (x + 0.044715 * torch.pow(x, 3))
  81. ))
  82. class FeedForward(nn.Module):
  83. def __init__(self, cfg):
  84. super().__init__()
  85. self.layers = nn.Sequential(
  86. nn.Linear(cfg["emb_dim"], 4 * cfg["emb_dim"]),
  87. GELU(),
  88. nn.Linear(4 * cfg["emb_dim"], cfg["emb_dim"]),
  89. )
  90. def forward(self, x):
  91. return self.layers(x)
  92. class TransformerBlock(nn.Module):
  93. def __init__(self, cfg):
  94. super().__init__()
  95. self.att = MultiHeadAttention(
  96. d_in=cfg["emb_dim"],
  97. d_out=cfg["emb_dim"],
  98. context_length=cfg["context_length"],
  99. num_heads=cfg["n_heads"],
  100. dropout=cfg["drop_rate"],
  101. qkv_bias=cfg["qkv_bias"])
  102. self.ff = FeedForward(cfg)
  103. self.norm1 = LayerNorm(cfg["emb_dim"])
  104. self.norm2 = LayerNorm(cfg["emb_dim"])
  105. self.drop_shortcut = nn.Dropout(cfg["drop_rate"])
  106. def forward(self, x):
  107. # Shortcut connection for attention block
  108. shortcut = x
  109. x = self.norm1(x)
  110. x = self.att(x) # Shape [batch_size, num_tokens, emb_size]
  111. x = self.drop_shortcut(x)
  112. x = x + shortcut # Add the original input back
  113. # Shortcut connection for feed-forward block
  114. shortcut = x
  115. x = self.norm2(x)
  116. x = self.ff(x)
  117. x = self.drop_shortcut(x)
  118. x = x + shortcut # Add the original input back
  119. return x
  120. class GPTModel(nn.Module):
  121. def __init__(self, cfg):
  122. super().__init__()
  123. self.tok_emb = nn.Embedding(cfg["vocab_size"], cfg["emb_dim"])
  124. self.pos_emb = nn.Embedding(cfg["context_length"], cfg["emb_dim"])
  125. self.drop_emb = nn.Dropout(cfg["drop_rate"])
  126. self.trf_blocks = nn.Sequential(
  127. *[TransformerBlock(cfg) for _ in range(cfg["n_layers"])])
  128. self.final_norm = LayerNorm(cfg["emb_dim"])
  129. self.out_head = nn.Linear(cfg["emb_dim"], cfg["vocab_size"], bias=False)
  130. def forward(self, in_idx):
  131. batch_size, seq_len = in_idx.shape
  132. tok_embeds = self.tok_emb(in_idx)
  133. pos_embeds = self.pos_emb(torch.arange(seq_len, device=in_idx.device))
  134. x = tok_embeds + pos_embeds # Shape [batch_size, num_tokens, emb_size]
  135. x = self.drop_emb(x)
  136. x = self.trf_blocks(x)
  137. x = self.final_norm(x)
  138. logits = self.out_head(x)
  139. return logits
  140. #####################################
  141. # Chapter 5
  142. #####################################
  143. def text_to_token_ids(text, tokenizer):
  144. encoded = tokenizer.encode(text)
  145. encoded_tensor = torch.tensor(encoded).unsqueeze(0) # add batch dimension
  146. return encoded_tensor
  147. def token_ids_to_text(token_ids, tokenizer):
  148. flat = token_ids.squeeze(0) # remove batch dimension
  149. return tokenizer.decode(flat.tolist())
  150. def download_and_load_gpt2(model_size, models_dir):
  151. # Validate model size
  152. allowed_sizes = ("124M", "355M", "774M", "1558M")
  153. if model_size not in allowed_sizes:
  154. raise ValueError(f"Model size not in {allowed_sizes}")
  155. # Define paths
  156. model_dir = os.path.join(models_dir, model_size)
  157. base_url = "https://openaipublic.blob.core.windows.net/gpt-2/models"
  158. filenames = [
  159. "checkpoint", "encoder.json", "hparams.json",
  160. "model.ckpt.data-00000-of-00001", "model.ckpt.index",
  161. "model.ckpt.meta", "vocab.bpe"
  162. ]
  163. # Download files
  164. os.makedirs(model_dir, exist_ok=True)
  165. for filename in filenames:
  166. file_url = os.path.join(base_url, model_size, filename)
  167. file_path = os.path.join(model_dir, filename)
  168. download_file(file_url, file_path)
  169. # Load settings and params
  170. tf_ckpt_path = tf.train.latest_checkpoint(model_dir)
  171. settings = json.load(open(os.path.join(model_dir, "hparams.json")))
  172. params = load_gpt2_params_from_tf_ckpt(tf_ckpt_path, settings)
  173. return settings, params
  174. def download_file(url, destination):
  175. # Send a GET request to download the file
  176. with urllib.request.urlopen(url) as response:
  177. # Get the total file size from headers, defaulting to 0 if not present
  178. file_size = int(response.headers.get("Content-Length", 0))
  179. # Check if file exists and has the same size
  180. if os.path.exists(destination):
  181. file_size_local = os.path.getsize(destination)
  182. if file_size == file_size_local:
  183. print(f"File already exists and is up-to-date: {destination}")
  184. return
  185. # Define the block size for reading the file
  186. block_size = 1024 # 1 Kilobyte
  187. # Initialize the progress bar with total file size
  188. progress_bar_description = os.path.basename(url) # Extract filename from URL
  189. with tqdm(total=file_size, unit="iB", unit_scale=True, desc=progress_bar_description) as progress_bar:
  190. # Open the destination file in binary write mode
  191. with open(destination, "wb") as file:
  192. # Read the file in chunks and write to destination
  193. while True:
  194. chunk = response.read(block_size)
  195. if not chunk:
  196. break
  197. file.write(chunk)
  198. progress_bar.update(len(chunk)) # Update progress bar
  199. def load_gpt2_params_from_tf_ckpt(ckpt_path, settings):
  200. # Initialize parameters dictionary with empty blocks for each layer
  201. params = {"blocks": [{} for _ in range(settings["n_layer"])]}
  202. # Iterate over each variable in the checkpoint
  203. for name, _ in tf.train.list_variables(ckpt_path):
  204. # Load the variable and remove singleton dimensions
  205. variable_array = np.squeeze(tf.train.load_variable(ckpt_path, name))
  206. # Process the variable name to extract relevant parts
  207. variable_name_parts = name.split("/")[1:] # Skip the 'model/' prefix
  208. # Identify the target dictionary for the variable
  209. target_dict = params
  210. if variable_name_parts[0].startswith("h"):
  211. layer_number = int(variable_name_parts[0][1:])
  212. target_dict = params["blocks"][layer_number]
  213. # Recursively access or create nested dictionaries
  214. for key in variable_name_parts[1:-1]:
  215. target_dict = target_dict.setdefault(key, {})
  216. # Assign the variable array to the last key
  217. last_key = variable_name_parts[-1]
  218. target_dict[last_key] = variable_array
  219. return params
  220. def assign(left, right):
  221. if left.shape != right.shape:
  222. raise ValueError(f"Shape mismatch. Left: {left.shape}, Right: {right.shape}")
  223. return torch.nn.Parameter(torch.tensor(right))
  224. def load_weights_into_gpt(gpt, params):
  225. gpt.pos_emb.weight = assign(gpt.pos_emb.weight, params['wpe'])
  226. gpt.tok_emb.weight = assign(gpt.tok_emb.weight, params['wte'])
  227. for b in range(len(params["blocks"])):
  228. q_w, k_w, v_w = np.split(
  229. (params["blocks"][b]["attn"]["c_attn"])["w"], 3, axis=-1)
  230. gpt.trf_blocks[b].att.W_query.weight = assign(
  231. gpt.trf_blocks[b].att.W_query.weight, q_w.T)
  232. gpt.trf_blocks[b].att.W_key.weight = assign(
  233. gpt.trf_blocks[b].att.W_key.weight, k_w.T)
  234. gpt.trf_blocks[b].att.W_value.weight = assign(
  235. gpt.trf_blocks[b].att.W_value.weight, v_w.T)
  236. q_b, k_b, v_b = np.split(
  237. (params["blocks"][b]["attn"]["c_attn"])["b"], 3, axis=-1)
  238. gpt.trf_blocks[b].att.W_query.bias = assign(
  239. gpt.trf_blocks[b].att.W_query.bias, q_b)
  240. gpt.trf_blocks[b].att.W_key.bias = assign(
  241. gpt.trf_blocks[b].att.W_key.bias, k_b)
  242. gpt.trf_blocks[b].att.W_value.bias = assign(
  243. gpt.trf_blocks[b].att.W_value.bias, v_b)
  244. gpt.trf_blocks[b].att.out_proj.weight = assign(
  245. gpt.trf_blocks[b].att.out_proj.weight,
  246. params["blocks"][b]["attn"]["c_proj"]["w"].T)
  247. gpt.trf_blocks[b].att.out_proj.bias = assign(
  248. gpt.trf_blocks[b].att.out_proj.bias,
  249. params["blocks"][b]["attn"]["c_proj"]["b"])
  250. gpt.trf_blocks[b].ff.layers[0].weight = assign(
  251. gpt.trf_blocks[b].ff.layers[0].weight,
  252. params["blocks"][b]["mlp"]["c_fc"]["w"].T)
  253. gpt.trf_blocks[b].ff.layers[0].bias = assign(
  254. gpt.trf_blocks[b].ff.layers[0].bias,
  255. params["blocks"][b]["mlp"]["c_fc"]["b"])
  256. gpt.trf_blocks[b].ff.layers[2].weight = assign(
  257. gpt.trf_blocks[b].ff.layers[2].weight,
  258. params["blocks"][b]["mlp"]["c_proj"]["w"].T)
  259. gpt.trf_blocks[b].ff.layers[2].bias = assign(
  260. gpt.trf_blocks[b].ff.layers[2].bias,
  261. params["blocks"][b]["mlp"]["c_proj"]["b"])
  262. gpt.trf_blocks[b].norm1.scale = assign(
  263. gpt.trf_blocks[b].norm1.scale,
  264. params["blocks"][b]["ln_1"]["g"])
  265. gpt.trf_blocks[b].norm1.shift = assign(
  266. gpt.trf_blocks[b].norm1.shift,
  267. params["blocks"][b]["ln_1"]["b"])
  268. gpt.trf_blocks[b].norm2.scale = assign(
  269. gpt.trf_blocks[b].norm2.scale,
  270. params["blocks"][b]["ln_2"]["g"])
  271. gpt.trf_blocks[b].norm2.shift = assign(
  272. gpt.trf_blocks[b].norm2.shift,
  273. params["blocks"][b]["ln_2"]["b"])
  274. gpt.final_norm.scale = assign(gpt.final_norm.scale, params["g"])
  275. gpt.final_norm.shift = assign(gpt.final_norm.shift, params["b"])
  276. gpt.out_head.weight = assign(gpt.out_head.weight, params["wte"])
  277. #####################################
  278. # Chapter 6
  279. #####################################
  280. def classify_review(text, model, tokenizer, device, max_length=None, pad_token_id=50256):
  281. model.eval()
  282. # Prepare inputs to the model
  283. input_ids = tokenizer.encode(text)
  284. supported_context_length = model.pos_emb.weight.shape[0]
  285. # Truncate sequences if they too long
  286. input_ids = input_ids[:min(max_length, supported_context_length)]
  287. # Pad sequences to the longest sequence
  288. input_ids += [pad_token_id] * (max_length - len(input_ids))
  289. input_tensor = torch.tensor(input_ids, device=device).unsqueeze(0) # add batch dimension
  290. # Model inference
  291. with torch.no_grad():
  292. logits = model(input_tensor.to(device))[:, -1, :] # Logits of the last output token
  293. predicted_label = torch.argmax(logits, dim=-1).item()
  294. # Return the classified result
  295. return "spam" if predicted_label == 1 else "not spam"