previous_chapters.py 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470
  1. # Copyright (c) Sebastian Raschka under Apache License 2.0 (see LICENSE.txt).
  2. # Source for "Build a Large Language Model From Scratch"
  3. # - https://www.manning.com/books/build-a-large-language-model-from-scratch
  4. # Code: https://github.com/rasbt/LLMs-from-scratch
  5. #
  6. # This file collects all the relevant code that we covered thus far
  7. # throughout Chapters 2-6.
  8. # This file can be run as a standalone script.
  9. import matplotlib.pyplot as plt
  10. from matplotlib.ticker import MaxNLocator
  11. import numpy as np
  12. import tiktoken
  13. import torch
  14. import torch.nn as nn
  15. from torch.utils.data import Dataset, DataLoader
  16. #####################################
  17. # Chapter 2
  18. #####################################
  19. class GPTDatasetV1(Dataset):
  20. def __init__(self, txt, tokenizer, max_length, stride):
  21. self.tokenizer = tokenizer
  22. self.input_ids = []
  23. self.target_ids = []
  24. # Tokenize the entire text
  25. token_ids = tokenizer.encode(txt, allowed_special={"<|endoftext|>"})
  26. # Use a sliding window to chunk the book into overlapping sequences of max_length
  27. for i in range(0, len(token_ids) - max_length, stride):
  28. input_chunk = token_ids[i:i + max_length]
  29. target_chunk = token_ids[i + 1: i + max_length + 1]
  30. self.input_ids.append(torch.tensor(input_chunk))
  31. self.target_ids.append(torch.tensor(target_chunk))
  32. def __len__(self):
  33. return len(self.input_ids)
  34. def __getitem__(self, idx):
  35. return self.input_ids[idx], self.target_ids[idx]
  36. def create_dataloader_v1(txt, batch_size=4, max_length=256,
  37. stride=128, shuffle=True, drop_last=True, num_workers=0):
  38. # Initialize the tokenizer
  39. tokenizer = tiktoken.get_encoding("gpt2")
  40. # Create dataset
  41. dataset = GPTDatasetV1(txt, tokenizer, max_length, stride)
  42. # Create dataloader
  43. dataloader = DataLoader(
  44. dataset, batch_size=batch_size, shuffle=shuffle, drop_last=drop_last, num_workers=num_workers)
  45. return dataloader
  46. #####################################
  47. # Chapter 3
  48. #####################################
  49. class MultiHeadAttention(nn.Module):
  50. def __init__(self, d_in, d_out, context_length, dropout, num_heads, qkv_bias=False):
  51. super().__init__()
  52. assert d_out % num_heads == 0, "d_out must be divisible by n_heads"
  53. self.d_out = d_out
  54. self.num_heads = num_heads
  55. self.head_dim = d_out // num_heads # Reduce the projection dim to match desired output dim
  56. self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)
  57. self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias)
  58. self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)
  59. self.out_proj = nn.Linear(d_out, d_out) # Linear layer to combine head outputs
  60. self.dropout = nn.Dropout(dropout)
  61. self.register_buffer('mask', torch.triu(torch.ones(context_length, context_length), diagonal=1))
  62. def forward(self, x):
  63. b, num_tokens, d_in = x.shape
  64. keys = self.W_key(x) # Shape: (b, num_tokens, d_out)
  65. queries = self.W_query(x)
  66. values = self.W_value(x)
  67. # We implicitly split the matrix by adding a `num_heads` dimension
  68. # Unroll last dim: (b, num_tokens, d_out) -> (b, num_tokens, num_heads, head_dim)
  69. keys = keys.view(b, num_tokens, self.num_heads, self.head_dim)
  70. values = values.view(b, num_tokens, self.num_heads, self.head_dim)
  71. queries = queries.view(b, num_tokens, self.num_heads, self.head_dim)
  72. # Transpose: (b, num_tokens, num_heads, head_dim) -> (b, num_heads, num_tokens, head_dim)
  73. keys = keys.transpose(1, 2)
  74. queries = queries.transpose(1, 2)
  75. values = values.transpose(1, 2)
  76. # Compute scaled dot-product attention (aka self-attention) with a causal mask
  77. attn_scores = queries @ keys.transpose(2, 3) # Dot product for each head
  78. # Original mask truncated to the number of tokens and converted to boolean
  79. mask_bool = self.mask.bool()[:num_tokens, :num_tokens]
  80. # Use the mask to fill attention scores
  81. attn_scores.masked_fill_(mask_bool, -torch.inf)
  82. attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=-1)
  83. attn_weights = self.dropout(attn_weights)
  84. # Shape: (b, num_tokens, num_heads, head_dim)
  85. context_vec = (attn_weights @ values).transpose(1, 2)
  86. # Combine heads, where self.d_out = self.num_heads * self.head_dim
  87. context_vec = context_vec.reshape(b, num_tokens, self.d_out)
  88. context_vec = self.out_proj(context_vec) # optional projection
  89. return context_vec
  90. #####################################
  91. # Chapter 4
  92. #####################################
  93. class LayerNorm(nn.Module):
  94. def __init__(self, emb_dim):
  95. super().__init__()
  96. self.eps = 1e-5
  97. self.scale = nn.Parameter(torch.ones(emb_dim))
  98. self.shift = nn.Parameter(torch.zeros(emb_dim))
  99. def forward(self, x):
  100. mean = x.mean(dim=-1, keepdim=True)
  101. var = x.var(dim=-1, keepdim=True, unbiased=False)
  102. norm_x = (x - mean) / torch.sqrt(var + self.eps)
  103. return self.scale * norm_x + self.shift
  104. class GELU(nn.Module):
  105. def __init__(self):
  106. super().__init__()
  107. def forward(self, x):
  108. return 0.5 * x * (1 + torch.tanh(
  109. torch.sqrt(torch.tensor(2.0 / torch.pi)) *
  110. (x + 0.044715 * torch.pow(x, 3))
  111. ))
  112. class FeedForward(nn.Module):
  113. def __init__(self, cfg):
  114. super().__init__()
  115. self.layers = nn.Sequential(
  116. nn.Linear(cfg["emb_dim"], 4 * cfg["emb_dim"]),
  117. GELU(),
  118. nn.Linear(4 * cfg["emb_dim"], cfg["emb_dim"]),
  119. )
  120. def forward(self, x):
  121. return self.layers(x)
  122. class TransformerBlock(nn.Module):
  123. def __init__(self, cfg):
  124. super().__init__()
  125. self.att = MultiHeadAttention(
  126. d_in=cfg["emb_dim"],
  127. d_out=cfg["emb_dim"],
  128. context_length=cfg["context_length"],
  129. num_heads=cfg["n_heads"],
  130. dropout=cfg["drop_rate"],
  131. qkv_bias=cfg["qkv_bias"])
  132. self.ff = FeedForward(cfg)
  133. self.norm1 = LayerNorm(cfg["emb_dim"])
  134. self.norm2 = LayerNorm(cfg["emb_dim"])
  135. self.drop_resid = nn.Dropout(cfg["drop_rate"])
  136. def forward(self, x):
  137. # Shortcut connection for attention block
  138. shortcut = x
  139. x = self.norm1(x)
  140. x = self.att(x) # Shape [batch_size, num_tokens, emb_size]
  141. x = self.drop_resid(x)
  142. x = x + shortcut # Add the original input back
  143. # Shortcut connection for feed-forward block
  144. shortcut = x
  145. x = self.norm2(x)
  146. x = self.ff(x)
  147. x = self.drop_resid(x)
  148. x = x + shortcut # Add the original input back
  149. return x
  150. class GPTModel(nn.Module):
  151. def __init__(self, cfg):
  152. super().__init__()
  153. self.tok_emb = nn.Embedding(cfg["vocab_size"], cfg["emb_dim"])
  154. self.pos_emb = nn.Embedding(cfg["context_length"], cfg["emb_dim"])
  155. self.drop_emb = nn.Dropout(cfg["drop_rate"])
  156. self.trf_blocks = nn.Sequential(
  157. *[TransformerBlock(cfg) for _ in range(cfg["n_layers"])])
  158. self.final_norm = LayerNorm(cfg["emb_dim"])
  159. self.out_head = nn.Linear(cfg["emb_dim"], cfg["vocab_size"], bias=False)
  160. def forward(self, in_idx):
  161. batch_size, seq_len = in_idx.shape
  162. tok_embeds = self.tok_emb(in_idx)
  163. pos_embeds = self.pos_emb(torch.arange(seq_len, device=in_idx.device))
  164. x = tok_embeds + pos_embeds # Shape [batch_size, num_tokens, emb_size]
  165. x = self.drop_emb(x)
  166. x = self.trf_blocks(x)
  167. x = self.final_norm(x)
  168. logits = self.out_head(x)
  169. return logits
  170. def generate_text_simple(model, idx, max_new_tokens, context_size):
  171. # idx is (B, T) array of indices in the current context
  172. for _ in range(max_new_tokens):
  173. # Crop current context if it exceeds the supported context size
  174. # E.g., if LLM supports only 5 tokens, and the context size is 10
  175. # then only the last 5 tokens are used as context
  176. idx_cond = idx[:, -context_size:]
  177. # Get the predictions
  178. with torch.no_grad():
  179. logits = model(idx_cond)
  180. # Focus only on the last time step
  181. # (batch, n_token, vocab_size) becomes (batch, vocab_size)
  182. logits = logits[:, -1, :]
  183. # Get the idx of the vocab entry with the highest logits value
  184. idx_next = torch.argmax(logits, dim=-1, keepdim=True) # (batch, 1)
  185. # Append sampled index to the running sequence
  186. idx = torch.cat((idx, idx_next), dim=1) # (batch, n_tokens+1)
  187. return idx
  188. #####################################
  189. # Chapter 5
  190. #####################################
  191. def generate(model, idx, max_new_tokens, context_size, temperature=0.0, top_k=None, eos_id=None):
  192. # For-loop is the same as before: Get logits, and only focus on last time step
  193. for _ in range(max_new_tokens):
  194. idx_cond = idx[:, -context_size:]
  195. with torch.no_grad():
  196. logits = model(idx_cond)
  197. logits = logits[:, -1, :]
  198. # New: Filter logits with top_k sampling
  199. if top_k is not None:
  200. # Keep only top_k values
  201. top_logits, _ = torch.topk(logits, top_k)
  202. min_val = top_logits[:, -1]
  203. logits = torch.where(logits < min_val, torch.tensor(float('-inf')).to(logits.device), logits)
  204. # New: Apply temperature scaling
  205. if temperature > 0.0:
  206. logits = logits / temperature
  207. # Apply softmax to get probabilities
  208. probs = torch.softmax(logits, dim=-1) # (batch_size, context_len)
  209. # Sample from the distribution
  210. idx_next = torch.multinomial(probs, num_samples=1) # (batch_size, 1)
  211. # Otherwise same as before: get idx of the vocab entry with the highest logits value
  212. else:
  213. idx_next = torch.argmax(logits, dim=-1, keepdim=True) # (batch_size, 1)
  214. if idx_next == eos_id: # Stop generating early if end-of-sequence token is encountered and eos_id is specified
  215. break
  216. # Same as before: append sampled index to the running sequence
  217. idx = torch.cat((idx, idx_next), dim=1) # (batch_size, num_tokens+1)
  218. return idx
  219. def train_model_simple(model, train_loader, val_loader, optimizer, device, num_epochs,
  220. eval_freq, eval_iter, start_context, tokenizer):
  221. # Initialize lists to track losses and tokens seen
  222. train_losses, val_losses, track_tokens_seen = [], [], []
  223. tokens_seen, global_step = 0, -1
  224. # Main training loop
  225. for epoch in range(num_epochs):
  226. model.train() # Set model to training mode
  227. for input_batch, target_batch in train_loader:
  228. optimizer.zero_grad() # Reset loss gradients from previous batch iteration
  229. loss = calc_loss_batch(input_batch, target_batch, model, device)
  230. loss.backward() # Calculate loss gradients
  231. optimizer.step() # Update model weights using loss gradients
  232. tokens_seen += input_batch.numel()
  233. global_step += 1
  234. # Optional evaluation step
  235. if global_step % eval_freq == 0:
  236. train_loss, val_loss = evaluate_model(
  237. model, train_loader, val_loader, device, eval_iter)
  238. train_losses.append(train_loss)
  239. val_losses.append(val_loss)
  240. track_tokens_seen.append(tokens_seen)
  241. print(f"Ep {epoch+1} (Step {global_step:06d}): "
  242. f"Train loss {train_loss:.3f}, Val loss {val_loss:.3f}")
  243. # Print a sample text after each epoch
  244. generate_and_print_sample(
  245. model, tokenizer, device, start_context
  246. )
  247. return train_losses, val_losses, track_tokens_seen
  248. def evaluate_model(model, train_loader, val_loader, device, eval_iter):
  249. model.eval()
  250. with torch.no_grad():
  251. train_loss = calc_loss_loader(train_loader, model, device, num_batches=eval_iter)
  252. val_loss = calc_loss_loader(val_loader, model, device, num_batches=eval_iter)
  253. model.train()
  254. return train_loss, val_loss
  255. def generate_and_print_sample(model, tokenizer, device, start_context):
  256. model.eval()
  257. context_size = model.pos_emb.weight.shape[0]
  258. encoded = text_to_token_ids(start_context, tokenizer).to(device)
  259. with torch.no_grad():
  260. token_ids = generate_text_simple(
  261. model=model, idx=encoded,
  262. max_new_tokens=50, context_size=context_size
  263. )
  264. decoded_text = token_ids_to_text(token_ids, tokenizer)
  265. print(decoded_text.replace("\n", " ")) # Compact print format
  266. model.train()
  267. def assign(left, right):
  268. if left.shape != right.shape:
  269. raise ValueError(f"Shape mismatch. Left: {left.shape}, Right: {right.shape}")
  270. return torch.nn.Parameter(torch.tensor(right))
  271. def load_weights_into_gpt(gpt, params):
  272. gpt.pos_emb.weight = assign(gpt.pos_emb.weight, params['wpe'])
  273. gpt.tok_emb.weight = assign(gpt.tok_emb.weight, params['wte'])
  274. for b in range(len(params["blocks"])):
  275. q_w, k_w, v_w = np.split(
  276. (params["blocks"][b]["attn"]["c_attn"])["w"], 3, axis=-1)
  277. gpt.trf_blocks[b].att.W_query.weight = assign(
  278. gpt.trf_blocks[b].att.W_query.weight, q_w.T)
  279. gpt.trf_blocks[b].att.W_key.weight = assign(
  280. gpt.trf_blocks[b].att.W_key.weight, k_w.T)
  281. gpt.trf_blocks[b].att.W_value.weight = assign(
  282. gpt.trf_blocks[b].att.W_value.weight, v_w.T)
  283. q_b, k_b, v_b = np.split(
  284. (params["blocks"][b]["attn"]["c_attn"])["b"], 3, axis=-1)
  285. gpt.trf_blocks[b].att.W_query.bias = assign(
  286. gpt.trf_blocks[b].att.W_query.bias, q_b)
  287. gpt.trf_blocks[b].att.W_key.bias = assign(
  288. gpt.trf_blocks[b].att.W_key.bias, k_b)
  289. gpt.trf_blocks[b].att.W_value.bias = assign(
  290. gpt.trf_blocks[b].att.W_value.bias, v_b)
  291. gpt.trf_blocks[b].att.out_proj.weight = assign(
  292. gpt.trf_blocks[b].att.out_proj.weight,
  293. params["blocks"][b]["attn"]["c_proj"]["w"].T)
  294. gpt.trf_blocks[b].att.out_proj.bias = assign(
  295. gpt.trf_blocks[b].att.out_proj.bias,
  296. params["blocks"][b]["attn"]["c_proj"]["b"])
  297. gpt.trf_blocks[b].ff.layers[0].weight = assign(
  298. gpt.trf_blocks[b].ff.layers[0].weight,
  299. params["blocks"][b]["mlp"]["c_fc"]["w"].T)
  300. gpt.trf_blocks[b].ff.layers[0].bias = assign(
  301. gpt.trf_blocks[b].ff.layers[0].bias,
  302. params["blocks"][b]["mlp"]["c_fc"]["b"])
  303. gpt.trf_blocks[b].ff.layers[2].weight = assign(
  304. gpt.trf_blocks[b].ff.layers[2].weight,
  305. params["blocks"][b]["mlp"]["c_proj"]["w"].T)
  306. gpt.trf_blocks[b].ff.layers[2].bias = assign(
  307. gpt.trf_blocks[b].ff.layers[2].bias,
  308. params["blocks"][b]["mlp"]["c_proj"]["b"])
  309. gpt.trf_blocks[b].norm1.scale = assign(
  310. gpt.trf_blocks[b].norm1.scale,
  311. params["blocks"][b]["ln_1"]["g"])
  312. gpt.trf_blocks[b].norm1.shift = assign(
  313. gpt.trf_blocks[b].norm1.shift,
  314. params["blocks"][b]["ln_1"]["b"])
  315. gpt.trf_blocks[b].norm2.scale = assign(
  316. gpt.trf_blocks[b].norm2.scale,
  317. params["blocks"][b]["ln_2"]["g"])
  318. gpt.trf_blocks[b].norm2.shift = assign(
  319. gpt.trf_blocks[b].norm2.shift,
  320. params["blocks"][b]["ln_2"]["b"])
  321. gpt.final_norm.scale = assign(gpt.final_norm.scale, params["g"])
  322. gpt.final_norm.shift = assign(gpt.final_norm.shift, params["b"])
  323. gpt.out_head.weight = assign(gpt.out_head.weight, params["wte"])
  324. def text_to_token_ids(text, tokenizer):
  325. encoded = tokenizer.encode(text, allowed_special={"<|endoftext|>"})
  326. encoded_tensor = torch.tensor(encoded).unsqueeze(0) # add batch dimension
  327. return encoded_tensor
  328. def token_ids_to_text(token_ids, tokenizer):
  329. flat = token_ids.squeeze(0) # remove batch dimension
  330. return tokenizer.decode(flat.tolist())
  331. def calc_loss_batch(input_batch, target_batch, model, device):
  332. input_batch, target_batch = input_batch.to(device), target_batch.to(device)
  333. logits = model(input_batch)
  334. loss = torch.nn.functional.cross_entropy(logits.flatten(0, 1), target_batch.flatten())
  335. return loss
  336. def calc_loss_loader(data_loader, model, device, num_batches=None):
  337. total_loss = 0.
  338. if len(data_loader) == 0:
  339. return float("nan")
  340. elif num_batches is None:
  341. num_batches = len(data_loader)
  342. else:
  343. # Reduce the number of batches to match the total number of batches in the data loader
  344. # if num_batches exceeds the number of batches in the data loader
  345. num_batches = min(num_batches, len(data_loader))
  346. for i, (input_batch, target_batch) in enumerate(data_loader):
  347. if i < num_batches:
  348. loss = calc_loss_batch(input_batch, target_batch, model, device)
  349. total_loss += loss.item()
  350. else:
  351. break
  352. return total_loss / num_batches
  353. def plot_losses(epochs_seen, tokens_seen, train_losses, val_losses):
  354. fig, ax1 = plt.subplots(figsize=(5, 3))
  355. # Plot training and validation loss against epochs
  356. ax1.plot(epochs_seen, train_losses, label="Training loss")
  357. ax1.plot(epochs_seen, val_losses, linestyle="-.", label="Validation loss")
  358. ax1.set_xlabel("Epochs")
  359. ax1.set_ylabel("Loss")
  360. ax1.legend(loc="upper right")
  361. ax1.xaxis.set_major_locator(MaxNLocator(integer=True)) # only show integer labels on x-axis
  362. # Create a second x-axis for tokens seen
  363. ax2 = ax1.twiny() # Create a second x-axis that shares the same y-axis
  364. ax2.plot(tokens_seen, train_losses, alpha=0) # Invisible plot for aligning ticks
  365. ax2.set_xlabel("Tokens seen")
  366. fig.tight_layout() # Adjust layout to make room
  367. plt.savefig("loss-plot.pdf")
  368. plt.show()