00_orig.py 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533
  1. # Copyright (c) Sebastian Raschka under Apache License 2.0 (see LICENSE.txt).
  2. # Source for "Build a Large Language Model From Scratch"
  3. # - https://www.manning.com/books/build-a-large-language-model-from-scratch
  4. # Code: https://github.com/rasbt/LLMs-from-scratch
  5. import os
  6. import time
  7. import urllib.request
  8. import matplotlib.pyplot as plt
  9. import torch
  10. import torch.nn as nn
  11. from torch.utils.data import Dataset, DataLoader
  12. import tiktoken
  13. #####################################
  14. # Chapter 2
  15. #####################################
  16. class GPTDatasetV1(Dataset):
  17. def __init__(self, txt, tokenizer, max_length, stride):
  18. self.input_ids = []
  19. self.target_ids = []
  20. # Tokenize the entire text
  21. token_ids = tokenizer.encode(txt, allowed_special={"<|endoftext|>"})
  22. # Use a sliding window to chunk the book into overlapping sequences of max_length
  23. for i in range(0, len(token_ids) - max_length, stride):
  24. input_chunk = token_ids[i:i + max_length]
  25. target_chunk = token_ids[i + 1: i + max_length + 1]
  26. self.input_ids.append(torch.tensor(input_chunk))
  27. self.target_ids.append(torch.tensor(target_chunk))
  28. def __len__(self):
  29. return len(self.input_ids)
  30. def __getitem__(self, idx):
  31. return self.input_ids[idx], self.target_ids[idx]
  32. def create_dataloader_v1(txt, batch_size=4, max_length=256,
  33. stride=128, shuffle=True, drop_last=True, num_workers=0):
  34. # Initialize the tokenizer
  35. tokenizer = tiktoken.get_encoding("gpt2")
  36. # Create dataset
  37. dataset = GPTDatasetV1(txt, tokenizer, max_length, stride)
  38. # Create dataloader
  39. dataloader = DataLoader(
  40. dataset, batch_size=batch_size, shuffle=shuffle, drop_last=drop_last, num_workers=num_workers)
  41. return dataloader
  42. #####################################
  43. # Chapter 3
  44. #####################################
  45. class MultiHeadAttention(nn.Module):
  46. def __init__(self, d_in, d_out, context_length, dropout, num_heads, qkv_bias=False):
  47. super().__init__()
  48. assert d_out % num_heads == 0, "d_out must be divisible by n_heads"
  49. self.d_out = d_out
  50. self.num_heads = num_heads
  51. self.head_dim = d_out // num_heads # Reduce the projection dim to match desired output dim
  52. self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)
  53. self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias)
  54. self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)
  55. self.out_proj = nn.Linear(d_out, d_out) # Linear layer to combine head outputs
  56. self.dropout = nn.Dropout(dropout)
  57. self.register_buffer("mask", torch.triu(torch.ones(context_length, context_length), diagonal=1))
  58. def forward(self, x):
  59. b, num_tokens, d_in = x.shape
  60. keys = self.W_key(x) # Shape: (b, num_tokens, d_out)
  61. queries = self.W_query(x)
  62. values = self.W_value(x)
  63. # We implicitly split the matrix by adding a `num_heads` dimension
  64. # Unroll last dim: (b, num_tokens, d_out) -> (b, num_tokens, num_heads, head_dim)
  65. keys = keys.view(b, num_tokens, self.num_heads, self.head_dim)
  66. values = values.view(b, num_tokens, self.num_heads, self.head_dim)
  67. queries = queries.view(b, num_tokens, self.num_heads, self.head_dim)
  68. # Transpose: (b, num_tokens, num_heads, head_dim) -> (b, num_heads, num_tokens, head_dim)
  69. keys = keys.transpose(1, 2)
  70. queries = queries.transpose(1, 2)
  71. values = values.transpose(1, 2)
  72. # Compute scaled dot-product attention (aka self-attention) with a causal mask
  73. attn_scores = queries @ keys.transpose(2, 3) # Dot product for each head
  74. # Original mask truncated to the number of tokens and converted to boolean
  75. mask_bool = self.mask.bool()[:num_tokens, :num_tokens]
  76. # Use the mask to fill attention scores
  77. attn_scores.masked_fill_(mask_bool, -torch.inf)
  78. attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=-1)
  79. attn_weights = self.dropout(attn_weights)
  80. # Shape: (b, num_tokens, num_heads, head_dim)
  81. context_vec = (attn_weights @ values).transpose(1, 2)
  82. # Combine heads, where self.d_out = self.num_heads * self.head_dim
  83. context_vec = context_vec.reshape(b, num_tokens, self.d_out)
  84. context_vec = self.out_proj(context_vec) # optional projection
  85. return context_vec
  86. #####################################
  87. # Chapter 4
  88. #####################################
  89. class LayerNorm(nn.Module):
  90. def __init__(self, emb_dim):
  91. super().__init__()
  92. self.eps = 1e-5
  93. self.scale = nn.Parameter(torch.ones(emb_dim))
  94. self.shift = nn.Parameter(torch.zeros(emb_dim))
  95. def forward(self, x):
  96. mean = x.mean(dim=-1, keepdim=True)
  97. var = x.var(dim=-1, keepdim=True, unbiased=False)
  98. norm_x = (x - mean) / torch.sqrt(var + self.eps)
  99. return self.scale * norm_x + self.shift
  100. class GELU(nn.Module):
  101. def __init__(self):
  102. super().__init__()
  103. def forward(self, x):
  104. return 0.5 * x * (1 + torch.tanh(
  105. torch.sqrt(torch.tensor(2.0 / torch.pi)) *
  106. (x + 0.044715 * torch.pow(x, 3))
  107. ))
  108. class FeedForward(nn.Module):
  109. def __init__(self, cfg):
  110. super().__init__()
  111. self.layers = nn.Sequential(
  112. nn.Linear(cfg["emb_dim"], 4 * cfg["emb_dim"]),
  113. GELU(),
  114. nn.Linear(4 * cfg["emb_dim"], cfg["emb_dim"]),
  115. )
  116. def forward(self, x):
  117. return self.layers(x)
  118. class TransformerBlock(nn.Module):
  119. def __init__(self, cfg):
  120. super().__init__()
  121. self.att = MultiHeadAttention(
  122. d_in=cfg["emb_dim"],
  123. d_out=cfg["emb_dim"],
  124. context_length=cfg["context_length"],
  125. num_heads=cfg["n_heads"],
  126. dropout=cfg["drop_rate"],
  127. qkv_bias=cfg["qkv_bias"])
  128. self.ff = FeedForward(cfg)
  129. self.norm1 = LayerNorm(cfg["emb_dim"])
  130. self.norm2 = LayerNorm(cfg["emb_dim"])
  131. self.drop_shortcut = nn.Dropout(cfg["drop_rate"])
  132. def forward(self, x):
  133. # Shortcut connection for attention block
  134. shortcut = x
  135. x = self.norm1(x)
  136. x = self.att(x) # Shape [batch_size, num_tokens, emb_size]
  137. x = self.drop_shortcut(x)
  138. x = x + shortcut # Add the original input back
  139. # Shortcut connection for feed-forward block
  140. shortcut = x
  141. x = self.norm2(x)
  142. x = self.ff(x)
  143. x = self.drop_shortcut(x)
  144. x = x + shortcut # Add the original input back
  145. return x
  146. class GPTModel(nn.Module):
  147. def __init__(self, cfg):
  148. super().__init__()
  149. self.tok_emb = nn.Embedding(cfg["vocab_size"], cfg["emb_dim"])
  150. self.pos_emb = nn.Embedding(cfg["context_length"], cfg["emb_dim"])
  151. self.drop_emb = nn.Dropout(cfg["drop_rate"])
  152. self.trf_blocks = nn.Sequential(
  153. *[TransformerBlock(cfg) for _ in range(cfg["n_layers"])])
  154. self.final_norm = LayerNorm(cfg["emb_dim"])
  155. self.out_head = nn.Linear(cfg["emb_dim"], cfg["vocab_size"], bias=False)
  156. def forward(self, in_idx):
  157. batch_size, seq_len = in_idx.shape
  158. tok_embeds = self.tok_emb(in_idx)
  159. pos_embeds = self.pos_emb(torch.arange(seq_len, device=in_idx.device))
  160. x = tok_embeds + pos_embeds # Shape [batch_size, num_tokens, emb_size]
  161. x = self.drop_emb(x)
  162. x = self.trf_blocks(x)
  163. x = self.final_norm(x)
  164. logits = self.out_head(x)
  165. return logits
  166. def generate_text_simple(model, idx, max_new_tokens, context_size):
  167. # idx is (B, T) array of indices in the current context
  168. for _ in range(max_new_tokens):
  169. # Crop current context if it exceeds the supported context size
  170. # E.g., if LLM supports only 5 tokens, and the context size is 10
  171. # then only the last 5 tokens are used as context
  172. idx_cond = idx[:, -context_size:]
  173. # Get the predictions
  174. with torch.no_grad():
  175. logits = model(idx_cond)
  176. # Focus only on the last time step
  177. # (batch, n_token, vocab_size) becomes (batch, vocab_size)
  178. logits = logits[:, -1, :]
  179. # Get the idx of the vocab entry with the highest logits value
  180. idx_next = torch.argmax(logits, dim=-1, keepdim=True) # (batch, 1)
  181. # Append sampled index to the running sequence
  182. idx = torch.cat((idx, idx_next), dim=1) # (batch, n_tokens+1)
  183. return idx
  184. #####################################
  185. # Chapter 5
  186. #####################################
  187. def text_to_token_ids(text, tokenizer):
  188. encoded = tokenizer.encode(text)
  189. encoded_tensor = torch.tensor(encoded).unsqueeze(0) # add batch dimension
  190. return encoded_tensor
  191. def token_ids_to_text(token_ids, tokenizer):
  192. flat = token_ids.squeeze(0) # remove batch dimension
  193. return tokenizer.decode(flat.tolist())
  194. def calc_loss_batch(input_batch, target_batch, model, device):
  195. input_batch, target_batch = input_batch.to(device), target_batch.to(device)
  196. logits = model(input_batch)
  197. loss = torch.nn.functional.cross_entropy(logits.flatten(0, 1), target_batch.flatten())
  198. return loss
  199. def calc_loss_loader(data_loader, model, device, num_batches=None):
  200. total_loss = 0.
  201. if len(data_loader) == 0:
  202. return float("nan")
  203. elif num_batches is None:
  204. num_batches = len(data_loader)
  205. else:
  206. num_batches = min(num_batches, len(data_loader))
  207. for i, (input_batch, target_batch) in enumerate(data_loader):
  208. if i < num_batches:
  209. loss = calc_loss_batch(input_batch, target_batch, model, device)
  210. total_loss += loss.item()
  211. else:
  212. break
  213. return total_loss / num_batches
  214. def evaluate_model(model, train_loader, val_loader, device, eval_iter):
  215. model.eval()
  216. with torch.no_grad():
  217. train_loss = calc_loss_loader(train_loader, model, device, num_batches=eval_iter)
  218. val_loss = calc_loss_loader(val_loader, model, device, num_batches=eval_iter)
  219. model.train()
  220. return train_loss, val_loss
  221. def generate_and_print_sample(model, tokenizer, device, start_context):
  222. model.eval()
  223. context_size = model.pos_emb.weight.shape[0]
  224. encoded = text_to_token_ids(start_context, tokenizer).to(device)
  225. with torch.no_grad():
  226. token_ids = generate_text_simple(
  227. model=model, idx=encoded,
  228. max_new_tokens=50, context_size=context_size
  229. )
  230. decoded_text = token_ids_to_text(token_ids, tokenizer)
  231. print(decoded_text.replace("\n", " ")) # Compact print format
  232. model.train()
  233. def train_model_simple_with_timing(model, train_loader, val_loader, optimizer, device,
  234. num_epochs, eval_freq, eval_iter, start_context, tokenizer):
  235. train_losses, val_losses, track_tokens = [], [], []
  236. total_tokens, global_step, last_tokens = 0, -1, 0
  237. # Variables for cumulative average tokens/sec
  238. cumulative_tokens, cumulative_time = 0.0, 0.0
  239. # CUDA-specific timing setup
  240. use_cuda = device.type == "cuda"
  241. if use_cuda:
  242. t_start = torch.cuda.Event(enable_timing=True)
  243. t_end = torch.cuda.Event(enable_timing=True)
  244. torch.cuda.synchronize() # Ensure all prior CUDA operations are done
  245. t_start.record() # Start the timer for the first interval
  246. else:
  247. t0 = time.time() # Start the timer for the first interval
  248. # Main training loop
  249. for epoch in range(num_epochs):
  250. model.train()
  251. for inp_batch, tgt_batch in train_loader:
  252. optimizer.zero_grad()
  253. global_step += 1
  254. # Forward and backward pass
  255. loss = calc_loss_batch(inp_batch, tgt_batch, model, device)
  256. loss.backward()
  257. optimizer.step()
  258. total_tokens += inp_batch.numel()
  259. # At evaluation intervals, measure elapsed time and tokens per second
  260. if global_step % eval_freq == 0:
  261. # End timing for the current interval
  262. if use_cuda:
  263. t_end.record()
  264. torch.cuda.synchronize() # Wait for all CUDA ops to complete.
  265. elapsed = t_start.elapsed_time(t_end) / 1000 # Convert ms to seconds
  266. t_start.record() # Reset timer for the next interval
  267. else:
  268. elapsed = time.time() - t0
  269. t0 = time.time() # Reset timer for the next interval
  270. # Calculate tokens processed in this interval
  271. tokens_interval = total_tokens - last_tokens
  272. last_tokens = total_tokens
  273. tps = tokens_interval / elapsed if elapsed > 0 else 0 # Tokens per second
  274. # Update cumulative counters (skip the first evaluation interval)
  275. if global_step: # This is False only when global_step == 0 (first evaluation)
  276. cumulative_tokens += tokens_interval
  277. cumulative_time += elapsed
  278. # Compute cumulative average tokens/sec (excluding the first interval)
  279. avg_tps = cumulative_tokens / cumulative_time if cumulative_time > 0 else 0
  280. # Evaluate model performance (this may add overhead)
  281. train_loss, val_loss = evaluate_model(model, train_loader, val_loader, device, eval_iter)
  282. train_losses.append(train_loss)
  283. val_losses.append(val_loss)
  284. track_tokens.append(total_tokens)
  285. print(f"Ep {epoch+1}, Step {global_step:06d}, "
  286. f"Train: {train_loss:.3f}, Val: {val_loss:.3f}, "
  287. f"Step tok/sec: {round(tps)}, Avg tok/sec: {round(avg_tps)}")
  288. generate_and_print_sample(model, tokenizer, device, start_context)
  289. # Memory stats
  290. if torch.cuda.is_available():
  291. device = torch.cuda.current_device()
  292. allocated = torch.cuda.memory_allocated(device) / 1024**3 # Convert to GB
  293. reserved = torch.cuda.memory_reserved(device) / 1024**3 # Convert to GB
  294. print(f"\nAllocated memory: {allocated:.4f} GB")
  295. print(f"Reserved memory: {reserved:.4f} GB\n")
  296. return train_losses, val_losses, track_tokens
  297. def plot_losses(epochs_seen, tokens_seen, train_losses, val_losses):
  298. fig, ax1 = plt.subplots()
  299. # Plot training and validation loss against epochs
  300. ax1.plot(epochs_seen, train_losses, label="Training loss")
  301. ax1.plot(epochs_seen, val_losses, linestyle="-.", label="Validation loss")
  302. ax1.set_xlabel("Epochs")
  303. ax1.set_ylabel("Loss")
  304. ax1.legend(loc="upper right")
  305. # Create a second x-axis for tokens seen
  306. ax2 = ax1.twiny() # Create a second x-axis that shares the same y-axis
  307. ax2.plot(tokens_seen, train_losses, alpha=0) # Invisible plot for aligning ticks
  308. ax2.set_xlabel("Tokens seen")
  309. fig.tight_layout() # Adjust layout to make room
  310. # plt.show()
  311. #####################################
  312. # Main function calls
  313. #####################################
  314. def main(gpt_config, settings):
  315. torch.manual_seed(123)
  316. device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
  317. print(f"PyTorch version: {torch.__version__}")
  318. print(f"Using {device}")
  319. if torch.cuda.is_available():
  320. print(f"CUDA version: {torch.version.cuda}")
  321. print()
  322. ##############################
  323. # Download data if necessary
  324. ##############################
  325. file_path = "middlemarch.txt"
  326. url = "https://www.gutenberg.org/cache/epub/145/pg145.txt"
  327. if not os.path.exists(file_path):
  328. with urllib.request.urlopen(url) as response:
  329. text_data = response.read().decode('utf-8')
  330. with open(file_path, "w", encoding="utf-8") as file:
  331. file.write(text_data)
  332. else:
  333. with open(file_path, "r", encoding="utf-8") as file:
  334. text_data = file.read()
  335. ##############################
  336. # Initialize model
  337. ##############################
  338. model = GPTModel(gpt_config)
  339. model.to(device) # no assignment model = model.to(device) necessary for nn.Module classes
  340. optimizer = torch.optim.AdamW(
  341. model.parameters(), lr=settings["learning_rate"], weight_decay=settings["weight_decay"]
  342. )
  343. ##############################
  344. # Set up dataloaders
  345. ##############################
  346. # Train/validation ratio
  347. train_ratio = 0.90
  348. split_idx = int(train_ratio * len(text_data))
  349. train_loader = create_dataloader_v1(
  350. text_data[:split_idx],
  351. batch_size=settings["batch_size"],
  352. max_length=gpt_config["context_length"],
  353. stride=gpt_config["context_length"],
  354. drop_last=True,
  355. shuffle=True,
  356. num_workers=4
  357. )
  358. val_loader = create_dataloader_v1(
  359. text_data[split_idx:],
  360. batch_size=settings["batch_size"],
  361. max_length=gpt_config["context_length"],
  362. stride=gpt_config["context_length"],
  363. drop_last=False,
  364. shuffle=False,
  365. num_workers=4
  366. )
  367. ##############################
  368. # Train model
  369. ##############################
  370. tokenizer = tiktoken.get_encoding("gpt2")
  371. train_losses, val_losses, tokens_seen = train_model_simple_with_timing(
  372. model=model,
  373. train_loader=train_loader,
  374. val_loader=val_loader,
  375. optimizer=optimizer,
  376. device=device,
  377. num_epochs=settings["num_epochs"],
  378. eval_freq=15,
  379. eval_iter=1,
  380. start_context="Every effort moves you",
  381. tokenizer=tokenizer
  382. )
  383. return train_losses, val_losses, tokens_seen, model
  384. if __name__ == "__main__":
  385. GPT_CONFIG_124M = {
  386. "vocab_size": 50257, # Vocabulary size
  387. "context_length": 1024, # Input tokens per training example
  388. "emb_dim": 768, # Embedding dimension
  389. "n_heads": 12, # Number of attention heads
  390. "n_layers": 12, # Number of layers
  391. "drop_rate": 0.1, # Dropout rate
  392. "qkv_bias": False # Query-key-value bias
  393. }
  394. OTHER_SETTINGS = {
  395. "learning_rate": 5e-4,
  396. "num_epochs": 15,
  397. "batch_size": 8,
  398. "weight_decay": 0.1
  399. }
  400. ###########################
  401. # Initiate training
  402. ###########################
  403. train_losses, val_losses, tokens_seen, model = main(GPT_CONFIG_124M, OTHER_SETTINGS)
  404. ###########################
  405. # After training
  406. ###########################
  407. # Plot results
  408. epochs_tensor = torch.linspace(0, OTHER_SETTINGS["num_epochs"], len(train_losses))
  409. plot_losses(epochs_tensor, tokens_seen, train_losses, val_losses)
  410. plt.savefig("loss.pdf")
  411. # Save and load model
  412. # torch.save(model.state_dict(), "model.pth")
  413. # model = GPTModel(GPT_CONFIG_124M)
  414. # model.load_state_dict(torch.load("model.pth", weights_only=True))