previous_chapters.py 20 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541
  1. # Copyright (c) Sebastian Raschka under Apache License 2.0 (see LICENSE.txt).
  2. # Source for "Build a Large Language Model From Scratch"
  3. # - https://www.manning.com/books/build-a-large-language-model-from-scratch
  4. # Code: https://github.com/rasbt/LLMs-from-scratch
  5. #
  6. # This file collects all the relevant code that we covered thus far
  7. # throughout Chapters 2-6.
  8. # This file can be run as a standalone script.
  9. import os
  10. from pathlib import Path
  11. import urllib
  12. import zipfile
  13. import matplotlib.pyplot as plt
  14. import numpy as np
  15. import pandas as pd
  16. import tiktoken
  17. import torch
  18. import torch.nn as nn
  19. from torch.utils.data import Dataset, DataLoader
  20. #####################################
  21. # Chapter 2
  22. #####################################
  23. class GPTDatasetV1(Dataset):
  24. def __init__(self, txt, tokenizer, max_length, stride):
  25. self.input_ids = []
  26. self.target_ids = []
  27. # Tokenize the entire text
  28. token_ids = tokenizer.encode(txt, allowed_special={"<|endoftext|>"})
  29. # Use a sliding window to chunk the book into overlapping sequences of max_length
  30. for i in range(0, len(token_ids) - max_length, stride):
  31. input_chunk = token_ids[i:i + max_length]
  32. target_chunk = token_ids[i + 1: i + max_length + 1]
  33. self.input_ids.append(torch.tensor(input_chunk))
  34. self.target_ids.append(torch.tensor(target_chunk))
  35. def __len__(self):
  36. return len(self.input_ids)
  37. def __getitem__(self, idx):
  38. return self.input_ids[idx], self.target_ids[idx]
  39. def create_dataloader_v1(txt, batch_size=4, max_length=256,
  40. stride=128, shuffle=True, drop_last=True):
  41. # Initialize the tokenizer
  42. tokenizer = tiktoken.get_encoding("gpt2")
  43. # Create dataset
  44. dataset = GPTDatasetV1(txt, tokenizer, max_length, stride)
  45. # Create dataloader
  46. dataloader = DataLoader(
  47. dataset, batch_size=batch_size, shuffle=shuffle, drop_last=drop_last)
  48. return dataloader
  49. #####################################
  50. # Chapter 3
  51. #####################################
  52. class MultiHeadAttention(nn.Module):
  53. def __init__(self, d_in, d_out, context_length, dropout, num_heads, qkv_bias=False):
  54. super().__init__()
  55. assert d_out % num_heads == 0, "d_out must be divisible by n_heads"
  56. self.d_out = d_out
  57. self.num_heads = num_heads
  58. self.head_dim = d_out // num_heads # Reduce the projection dim to match desired output dim
  59. self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)
  60. self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias)
  61. self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)
  62. self.out_proj = nn.Linear(d_out, d_out) # Linear layer to combine head outputs
  63. self.dropout = nn.Dropout(dropout)
  64. self.register_buffer('mask', torch.triu(torch.ones(context_length, context_length), diagonal=1))
  65. def forward(self, x):
  66. b, num_tokens, d_in = x.shape
  67. keys = self.W_key(x) # Shape: (b, num_tokens, d_out)
  68. queries = self.W_query(x)
  69. values = self.W_value(x)
  70. # We implicitly split the matrix by adding a `num_heads` dimension
  71. # Unroll last dim: (b, num_tokens, d_out) -> (b, num_tokens, num_heads, head_dim)
  72. keys = keys.view(b, num_tokens, self.num_heads, self.head_dim)
  73. values = values.view(b, num_tokens, self.num_heads, self.head_dim)
  74. queries = queries.view(b, num_tokens, self.num_heads, self.head_dim)
  75. # Transpose: (b, num_tokens, num_heads, head_dim) -> (b, num_heads, num_tokens, head_dim)
  76. keys = keys.transpose(1, 2)
  77. queries = queries.transpose(1, 2)
  78. values = values.transpose(1, 2)
  79. # Compute scaled dot-product attention (aka self-attention) with a causal mask
  80. attn_scores = queries @ keys.transpose(2, 3) # Dot product for each head
  81. # Original mask truncated to the number of tokens and converted to boolean
  82. mask_bool = self.mask.bool()[:num_tokens, :num_tokens]
  83. # Use the mask to fill attention scores
  84. attn_scores.masked_fill_(mask_bool, -torch.inf)
  85. attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=-1)
  86. attn_weights = self.dropout(attn_weights)
  87. # Shape: (b, num_tokens, num_heads, head_dim)
  88. context_vec = (attn_weights @ values).transpose(1, 2)
  89. # Combine heads, where self.d_out = self.num_heads * self.head_dim
  90. context_vec = context_vec.reshape(b, num_tokens, self.d_out)
  91. context_vec = self.out_proj(context_vec) # optional projection
  92. return context_vec
  93. #####################################
  94. # Chapter 4
  95. #####################################
  96. class LayerNorm(nn.Module):
  97. def __init__(self, emb_dim):
  98. super().__init__()
  99. self.eps = 1e-5
  100. self.scale = nn.Parameter(torch.ones(emb_dim))
  101. self.shift = nn.Parameter(torch.zeros(emb_dim))
  102. def forward(self, x):
  103. mean = x.mean(dim=-1, keepdim=True)
  104. var = x.var(dim=-1, keepdim=True, unbiased=False)
  105. norm_x = (x - mean) / torch.sqrt(var + self.eps)
  106. return self.scale * norm_x + self.shift
  107. class GELU(nn.Module):
  108. def __init__(self):
  109. super().__init__()
  110. def forward(self, x):
  111. return 0.5 * x * (1 + torch.tanh(
  112. torch.sqrt(torch.tensor(2.0 / torch.pi)) *
  113. (x + 0.044715 * torch.pow(x, 3))
  114. ))
  115. class FeedForward(nn.Module):
  116. def __init__(self, cfg):
  117. super().__init__()
  118. self.layers = nn.Sequential(
  119. nn.Linear(cfg["emb_dim"], 4 * cfg["emb_dim"]),
  120. GELU(),
  121. nn.Linear(4 * cfg["emb_dim"], cfg["emb_dim"]),
  122. )
  123. def forward(self, x):
  124. return self.layers(x)
  125. class TransformerBlock(nn.Module):
  126. def __init__(self, cfg):
  127. super().__init__()
  128. self.att = MultiHeadAttention(
  129. d_in=cfg["emb_dim"],
  130. d_out=cfg["emb_dim"],
  131. context_length=cfg["context_length"],
  132. num_heads=cfg["n_heads"],
  133. dropout=cfg["drop_rate"],
  134. qkv_bias=cfg["qkv_bias"])
  135. self.ff = FeedForward(cfg)
  136. self.norm1 = LayerNorm(cfg["emb_dim"])
  137. self.norm2 = LayerNorm(cfg["emb_dim"])
  138. self.drop_resid = nn.Dropout(cfg["drop_rate"])
  139. def forward(self, x):
  140. # Shortcut connection for attention block
  141. shortcut = x
  142. x = self.norm1(x)
  143. x = self.att(x) # Shape [batch_size, num_tokens, emb_size]
  144. x = self.drop_resid(x)
  145. x = x + shortcut # Add the original input back
  146. # Shortcut connection for feed-forward block
  147. shortcut = x
  148. x = self.norm2(x)
  149. x = self.ff(x)
  150. x = self.drop_resid(x)
  151. x = x + shortcut # Add the original input back
  152. return x
  153. class GPTModel(nn.Module):
  154. def __init__(self, cfg):
  155. super().__init__()
  156. self.tok_emb = nn.Embedding(cfg["vocab_size"], cfg["emb_dim"])
  157. self.pos_emb = nn.Embedding(cfg["context_length"], cfg["emb_dim"])
  158. self.drop_emb = nn.Dropout(cfg["drop_rate"])
  159. self.trf_blocks = nn.Sequential(
  160. *[TransformerBlock(cfg) for _ in range(cfg["n_layers"])])
  161. self.final_norm = LayerNorm(cfg["emb_dim"])
  162. self.out_head = nn.Linear(cfg["emb_dim"], cfg["vocab_size"], bias=False)
  163. def forward(self, in_idx):
  164. batch_size, seq_len = in_idx.shape
  165. tok_embeds = self.tok_emb(in_idx)
  166. pos_embeds = self.pos_emb(torch.arange(seq_len, device=in_idx.device))
  167. x = tok_embeds + pos_embeds # Shape [batch_size, num_tokens, emb_size]
  168. x = self.drop_emb(x)
  169. x = self.trf_blocks(x)
  170. x = self.final_norm(x)
  171. logits = self.out_head(x)
  172. return logits
  173. def generate_text_simple(model, idx, max_new_tokens, context_size):
  174. # idx is (B, T) array of indices in the current context
  175. for _ in range(max_new_tokens):
  176. # Crop current context if it exceeds the supported context size
  177. # E.g., if LLM supports only 5 tokens, and the context size is 10
  178. # then only the last 5 tokens are used as context
  179. idx_cond = idx[:, -context_size:]
  180. # Get the predictions
  181. with torch.no_grad():
  182. logits = model(idx_cond)
  183. # Focus only on the last time step
  184. # (batch, n_token, vocab_size) becomes (batch, vocab_size)
  185. logits = logits[:, -1, :]
  186. # Get the idx of the vocab entry with the highest logits value
  187. idx_next = torch.argmax(logits, dim=-1, keepdim=True) # (batch, 1)
  188. # Append sampled index to the running sequence
  189. idx = torch.cat((idx, idx_next), dim=1) # (batch, n_tokens+1)
  190. return idx
  191. #####################################
  192. # Chapter 5
  193. #####################################
  194. def assign(left, right):
  195. if left.shape != right.shape:
  196. raise ValueError(f"Shape mismatch. Left: {left.shape}, Right: {right.shape}")
  197. return torch.nn.Parameter(torch.tensor(right))
  198. def load_weights_into_gpt(gpt, params):
  199. gpt.pos_emb.weight = assign(gpt.pos_emb.weight, params['wpe'])
  200. gpt.tok_emb.weight = assign(gpt.tok_emb.weight, params['wte'])
  201. for b in range(len(params["blocks"])):
  202. q_w, k_w, v_w = np.split(
  203. (params["blocks"][b]["attn"]["c_attn"])["w"], 3, axis=-1)
  204. gpt.trf_blocks[b].att.W_query.weight = assign(
  205. gpt.trf_blocks[b].att.W_query.weight, q_w.T)
  206. gpt.trf_blocks[b].att.W_key.weight = assign(
  207. gpt.trf_blocks[b].att.W_key.weight, k_w.T)
  208. gpt.trf_blocks[b].att.W_value.weight = assign(
  209. gpt.trf_blocks[b].att.W_value.weight, v_w.T)
  210. q_b, k_b, v_b = np.split(
  211. (params["blocks"][b]["attn"]["c_attn"])["b"], 3, axis=-1)
  212. gpt.trf_blocks[b].att.W_query.bias = assign(
  213. gpt.trf_blocks[b].att.W_query.bias, q_b)
  214. gpt.trf_blocks[b].att.W_key.bias = assign(
  215. gpt.trf_blocks[b].att.W_key.bias, k_b)
  216. gpt.trf_blocks[b].att.W_value.bias = assign(
  217. gpt.trf_blocks[b].att.W_value.bias, v_b)
  218. gpt.trf_blocks[b].att.out_proj.weight = assign(
  219. gpt.trf_blocks[b].att.out_proj.weight,
  220. params["blocks"][b]["attn"]["c_proj"]["w"].T)
  221. gpt.trf_blocks[b].att.out_proj.bias = assign(
  222. gpt.trf_blocks[b].att.out_proj.bias,
  223. params["blocks"][b]["attn"]["c_proj"]["b"])
  224. gpt.trf_blocks[b].ff.layers[0].weight = assign(
  225. gpt.trf_blocks[b].ff.layers[0].weight,
  226. params["blocks"][b]["mlp"]["c_fc"]["w"].T)
  227. gpt.trf_blocks[b].ff.layers[0].bias = assign(
  228. gpt.trf_blocks[b].ff.layers[0].bias,
  229. params["blocks"][b]["mlp"]["c_fc"]["b"])
  230. gpt.trf_blocks[b].ff.layers[2].weight = assign(
  231. gpt.trf_blocks[b].ff.layers[2].weight,
  232. params["blocks"][b]["mlp"]["c_proj"]["w"].T)
  233. gpt.trf_blocks[b].ff.layers[2].bias = assign(
  234. gpt.trf_blocks[b].ff.layers[2].bias,
  235. params["blocks"][b]["mlp"]["c_proj"]["b"])
  236. gpt.trf_blocks[b].norm1.scale = assign(
  237. gpt.trf_blocks[b].norm1.scale,
  238. params["blocks"][b]["ln_1"]["g"])
  239. gpt.trf_blocks[b].norm1.shift = assign(
  240. gpt.trf_blocks[b].norm1.shift,
  241. params["blocks"][b]["ln_1"]["b"])
  242. gpt.trf_blocks[b].norm2.scale = assign(
  243. gpt.trf_blocks[b].norm2.scale,
  244. params["blocks"][b]["ln_2"]["g"])
  245. gpt.trf_blocks[b].norm2.shift = assign(
  246. gpt.trf_blocks[b].norm2.shift,
  247. params["blocks"][b]["ln_2"]["b"])
  248. gpt.final_norm.scale = assign(gpt.final_norm.scale, params["g"])
  249. gpt.final_norm.shift = assign(gpt.final_norm.shift, params["b"])
  250. gpt.out_head.weight = assign(gpt.out_head.weight, params["wte"])
  251. def text_to_token_ids(text, tokenizer):
  252. encoded = tokenizer.encode(text, allowed_special={'<|endoftext|>'})
  253. encoded_tensor = torch.tensor(encoded).unsqueeze(0) # add batch dimension
  254. return encoded_tensor
  255. def token_ids_to_text(token_ids, tokenizer):
  256. flat = token_ids.squeeze(0) # remove batch dimension
  257. return tokenizer.decode(flat.tolist())
  258. def calc_loss_loader(data_loader, model, device, num_batches=None):
  259. total_loss = 0.
  260. if len(data_loader) == 0:
  261. return float("nan")
  262. elif num_batches is None:
  263. num_batches = len(data_loader)
  264. else:
  265. # Reduce the number of batches to match the total number of batches in the data loader
  266. # if num_batches exceeds the number of batches in the data loader
  267. num_batches = min(num_batches, len(data_loader))
  268. for i, (input_batch, target_batch) in enumerate(data_loader):
  269. if i < num_batches:
  270. loss = calc_loss_batch(input_batch, target_batch, model, device)
  271. total_loss += loss.item()
  272. else:
  273. break
  274. return total_loss / num_batches
  275. def evaluate_model(model, train_loader, val_loader, device, eval_iter):
  276. model.eval()
  277. with torch.no_grad():
  278. train_loss = calc_loss_loader(train_loader, model, device, num_batches=eval_iter)
  279. val_loss = calc_loss_loader(val_loader, model, device, num_batches=eval_iter)
  280. model.train()
  281. return train_loss, val_loss
  282. #####################################
  283. # Chapter 6
  284. #####################################
  285. def download_and_unzip_spam_data(url, zip_path, extracted_path, data_file_path):
  286. if data_file_path.exists():
  287. print(f"{data_file_path} already exists. Skipping download and extraction.")
  288. return
  289. # Downloading the file
  290. with urllib.request.urlopen(url) as response:
  291. with open(zip_path, "wb") as out_file:
  292. out_file.write(response.read())
  293. # Unzipping the file
  294. with zipfile.ZipFile(zip_path, "r") as zip_ref:
  295. zip_ref.extractall(extracted_path)
  296. # Add .tsv file extension
  297. original_file_path = Path(extracted_path) / "SMSSpamCollection"
  298. os.rename(original_file_path, data_file_path)
  299. print(f"File downloaded and saved as {data_file_path}")
  300. def create_balanced_dataset(df):
  301. # Count the instances of "spam"
  302. num_spam = df[df["Label"] == "spam"].shape[0]
  303. # Randomly sample "ham' instances to match the number of 'spam' instances
  304. ham_subset = df[df["Label"] == "ham"].sample(num_spam, random_state=123)
  305. # Combine ham "subset" with "spam"
  306. balanced_df = pd.concat([ham_subset, df[df["Label"] == "spam"]])
  307. return balanced_df
  308. def random_split(df, train_frac, validation_frac):
  309. # Shuffle the entire DataFrame
  310. df = df.sample(frac=1, random_state=123).reset_index(drop=True)
  311. # Calculate split indices
  312. train_end = int(len(df) * train_frac)
  313. validation_end = train_end + int(len(df) * validation_frac)
  314. # Split the DataFrame
  315. train_df = df[:train_end]
  316. validation_df = df[train_end:validation_end]
  317. test_df = df[validation_end:]
  318. return train_df, validation_df, test_df
  319. class SpamDataset(Dataset):
  320. def __init__(self, csv_file, tokenizer, max_length=None, pad_token_id=50256):
  321. self.data = pd.read_csv(csv_file)
  322. # Pre-tokenize texts
  323. self.encoded_texts = [
  324. tokenizer.encode(text) for text in self.data["Text"]
  325. ]
  326. if max_length is None:
  327. self.max_length = self._longest_encoded_length()
  328. else:
  329. self.max_length = max_length
  330. # Truncate sequences if they are longer than max_length
  331. self.encoded_texts = [
  332. encoded_text[:self.max_length]
  333. for encoded_text in self.encoded_texts
  334. ]
  335. # Pad sequences to the longest sequence
  336. self.encoded_texts = [
  337. encoded_text + [pad_token_id] * (self.max_length - len(encoded_text))
  338. for encoded_text in self.encoded_texts
  339. ]
  340. def __getitem__(self, index):
  341. encoded = self.encoded_texts[index]
  342. label = self.data.iloc[index]["Label"]
  343. return torch.tensor(encoded, dtype=torch.long), torch.tensor(label, dtype=torch.long)
  344. def __len__(self):
  345. return len(self.data)
  346. def _longest_encoded_length(self):
  347. max_length = 0
  348. for encoded_text in self.encoded_texts:
  349. encoded_length = len(encoded_text)
  350. if encoded_length > max_length:
  351. max_length = encoded_length
  352. return max_length
  353. @torch.no_grad() # Disable gradient tracking for efficiency
  354. def calc_accuracy_loader(data_loader, model, device, num_batches=None):
  355. model.eval()
  356. correct_predictions, num_examples = 0, 0
  357. if num_batches is None:
  358. num_batches = len(data_loader)
  359. else:
  360. num_batches = min(num_batches, len(data_loader))
  361. for i, (input_batch, target_batch) in enumerate(data_loader):
  362. if i < num_batches:
  363. input_batch, target_batch = input_batch.to(device), target_batch.to(device)
  364. logits = model(input_batch)[:, -1, :] # Logits of last output token
  365. predicted_labels = torch.argmax(logits, dim=-1)
  366. num_examples += predicted_labels.shape[0]
  367. correct_predictions += (predicted_labels == target_batch).sum().item()
  368. else:
  369. break
  370. return correct_predictions / num_examples
  371. def calc_loss_batch(input_batch, target_batch, model, device):
  372. input_batch, target_batch = input_batch.to(device), target_batch.to(device)
  373. logits = model(input_batch)[:, -1, :] # Logits of last output token
  374. loss = torch.nn.functional.cross_entropy(logits, target_batch)
  375. return loss
  376. # Overall the same as `train_model_simple` in chapter 5
  377. def train_classifier_simple(model, train_loader, val_loader, optimizer, device, num_epochs,
  378. eval_freq, eval_iter, tokenizer):
  379. # Initialize lists to track losses and tokens seen
  380. train_losses, val_losses, train_accs, val_accs = [], [], [], []
  381. examples_seen, global_step = 0, -1
  382. # Main training loop
  383. for epoch in range(num_epochs):
  384. model.train() # Set model to training mode
  385. for input_batch, target_batch in train_loader:
  386. optimizer.zero_grad() # Reset loss gradients from previous epoch
  387. loss = calc_loss_batch(input_batch, target_batch, model, device)
  388. loss.backward() # Calculate loss gradients
  389. optimizer.step() # Update model weights using loss gradients
  390. examples_seen += input_batch.shape[0] # New: track examples instead of tokens
  391. global_step += 1
  392. # Optional evaluation step
  393. if global_step % eval_freq == 0:
  394. train_loss, val_loss = evaluate_model(
  395. model, train_loader, val_loader, device, eval_iter)
  396. train_losses.append(train_loss)
  397. val_losses.append(val_loss)
  398. print(f"Ep {epoch+1} (Step {global_step:06d}): "
  399. f"Train loss {train_loss:.3f}, Val loss {val_loss:.3f}")
  400. # Calculate accuracy after each epoch
  401. train_accuracy = calc_accuracy_loader(train_loader, model, device, num_batches=eval_iter)
  402. val_accuracy = calc_accuracy_loader(val_loader, model, device, num_batches=eval_iter)
  403. print(f"Training accuracy: {train_accuracy*100:.2f}% | ", end="")
  404. print(f"Validation accuracy: {val_accuracy*100:.2f}%")
  405. train_accs.append(train_accuracy)
  406. val_accs.append(val_accuracy)
  407. return train_losses, val_losses, train_accs, val_accs, examples_seen
  408. def plot_values(epochs_seen, examples_seen, train_values, val_values, label="loss"):
  409. fig, ax1 = plt.subplots(figsize=(5, 3))
  410. # Plot training and validation loss against epochs
  411. ax1.plot(epochs_seen, train_values, label=f"Training {label}")
  412. ax1.plot(epochs_seen, val_values, linestyle="-.", label=f"Validation {label}")
  413. ax1.set_xlabel("Epochs")
  414. ax1.set_ylabel(label.capitalize())
  415. ax1.legend()
  416. # Create a second x-axis for tokens seen
  417. ax2 = ax1.twiny() # Create a second x-axis that shares the same y-axis
  418. ax2.plot(examples_seen, train_values, alpha=0) # Invisible plot for aligning ticks
  419. ax2.set_xlabel("Examples seen")
  420. fig.tight_layout() # Adjust layout to make room
  421. plt.savefig(f"{label}-plot.pdf")
  422. plt.show()