01_opt_single_gpu.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513
  1. # Copyright (c) Sebastian Raschka under Apache License 2.0 (see LICENSE.txt).
  2. # Source for "Build a Large Language Model From Scratch"
  3. # - https://www.manning.com/books/build-a-large-language-model-from-scratch
  4. # Code: https://github.com/rasbt/LLMs-from-scratch
  5. import os
  6. import time
  7. import urllib.request
  8. import matplotlib.pyplot as plt
  9. import torch
  10. import torch.nn as nn
  11. from torch.utils.data import Dataset, DataLoader
  12. import tiktoken
  13. #####################################
  14. # Chapter 2
  15. #####################################
  16. class GPTDatasetV1(Dataset):
  17. def __init__(self, txt, tokenizer, max_length, stride):
  18. self.input_ids = []
  19. self.target_ids = []
  20. # Tokenize the entire text
  21. token_ids = tokenizer.encode(txt, allowed_special={"<|endoftext|>"})
  22. # Use a sliding window to chunk the book into overlapping sequences of max_length
  23. for i in range(0, len(token_ids) - max_length, stride):
  24. input_chunk = token_ids[i:i + max_length]
  25. target_chunk = token_ids[i + 1: i + max_length + 1]
  26. self.input_ids.append(torch.tensor(input_chunk))
  27. self.target_ids.append(torch.tensor(target_chunk))
  28. def __len__(self):
  29. return len(self.input_ids)
  30. def __getitem__(self, idx):
  31. return self.input_ids[idx], self.target_ids[idx]
  32. def create_dataloader_v1(txt, batch_size=4, max_length=256,
  33. stride=128, shuffle=True, drop_last=True, num_workers=0):
  34. # Initialize the tokenizer
  35. tokenizer = tiktoken.get_encoding("gpt2")
  36. # Create dataset
  37. dataset = GPTDatasetV1(txt, tokenizer, max_length, stride)
  38. # Create dataloader
  39. dataloader = DataLoader(
  40. dataset, batch_size=batch_size, shuffle=shuffle, drop_last=drop_last, num_workers=num_workers,
  41. pin_memory=True
  42. )
  43. return dataloader
  44. #####################################
  45. # Chapter 3
  46. #####################################
  47. class PyTorchMultiHeadAttention(nn.Module):
  48. def __init__(self, d_in, d_out, num_heads, dropout=0.0, qkv_bias=False):
  49. super().__init__()
  50. assert d_out % num_heads == 0, "d_out is indivisible by num_heads"
  51. self.num_heads = num_heads
  52. self.head_dim = d_out // num_heads
  53. self.d_out = d_out
  54. self.qkv = nn.Linear(d_in, 3 * d_out, bias=qkv_bias)
  55. self.proj = nn.Linear(d_out, d_out)
  56. self.dropout = dropout
  57. def forward(self, x):
  58. batch_size, num_tokens, embed_dim = x.shape
  59. # (b, num_tokens, embed_dim) --> (b, num_tokens, 3 * embed_dim)
  60. qkv = self.qkv(x)
  61. # (b, num_tokens, 3 * embed_dim) --> (b, num_tokens, 3, num_heads, head_dim)
  62. qkv = qkv.view(batch_size, num_tokens, 3, self.num_heads, self.head_dim)
  63. # (b, num_tokens, 3, num_heads, head_dim) --> (3, b, num_heads, num_tokens, head_dim)
  64. qkv = qkv.permute(2, 0, 3, 1, 4)
  65. # (3, b, num_heads, num_tokens, head_dim) -> 3 times (b, num_heads, num_tokens, head_dim)
  66. queries, keys, values = qkv
  67. use_dropout = 0. if not self.training else self.dropout
  68. context_vec = nn.functional.scaled_dot_product_attention(
  69. queries, keys, values, attn_mask=None, dropout_p=use_dropout, is_causal=True)
  70. # Combine heads, where self.d_out = self.num_heads * self.head_dim
  71. context_vec = context_vec.transpose(1, 2).contiguous().view(batch_size, num_tokens, self.d_out)
  72. context_vec = self.proj(context_vec)
  73. return context_vec
  74. #####################################
  75. # Chapter 4
  76. #####################################
  77. class FeedForward(nn.Module):
  78. def __init__(self, cfg):
  79. super().__init__()
  80. self.layers = nn.Sequential(
  81. nn.Linear(cfg["emb_dim"], 4 * cfg["emb_dim"]),
  82. nn.GELU(approximate="tanh"),
  83. nn.Linear(4 * cfg["emb_dim"], cfg["emb_dim"]),
  84. )
  85. def forward(self, x):
  86. return self.layers(x)
  87. class TransformerBlock(nn.Module):
  88. def __init__(self, cfg):
  89. super().__init__()
  90. self.att = PyTorchMultiHeadAttention(
  91. d_in=cfg["emb_dim"],
  92. d_out=cfg["emb_dim"],
  93. num_heads=cfg["n_heads"],
  94. dropout=cfg["drop_rate"],
  95. qkv_bias=cfg["qkv_bias"])
  96. self.ff = FeedForward(cfg)
  97. self.norm1 = nn.LayerNorm(cfg["emb_dim"])
  98. self.norm2 = nn.LayerNorm(cfg["emb_dim"])
  99. self.drop_shortcut = nn.Dropout(cfg["drop_rate"])
  100. def forward(self, x):
  101. # Shortcut connection for attention block
  102. shortcut = x
  103. x = self.norm1(x)
  104. x = self.att(x) # Shape [batch_size, num_tokens, emb_size]
  105. x = self.drop_shortcut(x)
  106. x = x + shortcut # Add the original input back
  107. # Shortcut connection for feed-forward block
  108. shortcut = x
  109. x = self.norm2(x)
  110. x = self.ff(x)
  111. x = self.drop_shortcut(x)
  112. x = x + shortcut # Add the original input back
  113. return x
  114. class GPTModel(nn.Module):
  115. def __init__(self, cfg):
  116. super().__init__()
  117. self.tok_emb = nn.Embedding(cfg["vocab_size"], cfg["emb_dim"])
  118. self.pos_emb = nn.Embedding(cfg["context_length"], cfg["emb_dim"])
  119. self.drop_emb = nn.Dropout(cfg["drop_rate"])
  120. self.trf_blocks = nn.Sequential(
  121. *[TransformerBlock(cfg) for _ in range(cfg["n_layers"])])
  122. self.final_norm = nn.LayerNorm(cfg["emb_dim"])
  123. self.out_head = nn.Linear(cfg["emb_dim"], cfg["vocab_size"], bias=False)
  124. def forward(self, in_idx):
  125. batch_size, seq_len = in_idx.shape
  126. tok_embeds = self.tok_emb(in_idx)
  127. pos_embeds = self.pos_emb(torch.arange(seq_len, device=in_idx.device))
  128. x = tok_embeds + pos_embeds # Shape [batch_size, num_tokens, emb_size]
  129. x = self.drop_emb(x)
  130. x = self.trf_blocks(x)
  131. x = self.final_norm(x)
  132. logits = self.out_head(x)
  133. return logits
  134. def generate_text_simple(model, idx, max_new_tokens, context_size):
  135. # idx is (B, T) array of indices in the current context
  136. for _ in range(max_new_tokens):
  137. # Crop current context if it exceeds the supported context size
  138. # E.g., if LLM supports only 5 tokens, and the context size is 10
  139. # then only the last 5 tokens are used as context
  140. idx_cond = idx[:, -context_size:]
  141. # Get the predictions
  142. with torch.no_grad():
  143. logits = model(idx_cond)
  144. # Focus only on the last time step
  145. # (batch, n_token, vocab_size) becomes (batch, vocab_size)
  146. logits = logits[:, -1, :]
  147. # Get the idx of the vocab entry with the highest logits value
  148. idx_next = torch.argmax(logits, dim=-1, keepdim=True) # (batch, 1)
  149. # Append sampled index to the running sequence
  150. idx = torch.cat((idx, idx_next), dim=1) # (batch, n_tokens+1)
  151. return idx
  152. #####################################
  153. # Chapter 5
  154. #####################################
  155. def text_to_token_ids(text, tokenizer):
  156. encoded = tokenizer.encode(text)
  157. encoded_tensor = torch.tensor(encoded).unsqueeze(0) # add batch dimension
  158. return encoded_tensor
  159. def token_ids_to_text(token_ids, tokenizer):
  160. flat = token_ids.squeeze(0) # remove batch dimension
  161. return tokenizer.decode(flat.tolist())
  162. def calc_loss_batch(input_batch, target_batch, model, device):
  163. input_batch, target_batch = input_batch.to(device), target_batch.to(device)
  164. logits = model(input_batch)
  165. loss = torch.nn.functional.cross_entropy(logits.flatten(0, 1), target_batch.flatten())
  166. return loss
  167. def calc_loss_loader(data_loader, model, device, num_batches=None):
  168. total_loss = 0.
  169. if len(data_loader) == 0:
  170. return float("nan")
  171. elif num_batches is None:
  172. num_batches = len(data_loader)
  173. else:
  174. num_batches = min(num_batches, len(data_loader))
  175. for i, (input_batch, target_batch) in enumerate(data_loader):
  176. if i < num_batches:
  177. loss = calc_loss_batch(input_batch, target_batch, model, device)
  178. total_loss += loss.item()
  179. else:
  180. break
  181. return total_loss / num_batches
  182. def evaluate_model(model, train_loader, val_loader, device, eval_iter):
  183. model.eval()
  184. with torch.no_grad():
  185. train_loss = calc_loss_loader(train_loader, model, device, num_batches=eval_iter)
  186. val_loss = calc_loss_loader(val_loader, model, device, num_batches=eval_iter)
  187. model.train()
  188. return train_loss, val_loss
  189. def generate_and_print_sample(model, tokenizer, device, start_context):
  190. model.eval()
  191. context_size = model.pos_emb.weight.shape[0]
  192. encoded = text_to_token_ids(start_context, tokenizer).to(device)
  193. with torch.no_grad():
  194. token_ids = generate_text_simple(
  195. model=model, idx=encoded,
  196. max_new_tokens=50, context_size=context_size
  197. )
  198. decoded_text = token_ids_to_text(token_ids, tokenizer)
  199. print(decoded_text.replace("\n", " ")) # Compact print format
  200. model.train()
  201. def train_model_simple_with_timing(model, train_loader, val_loader, optimizer, device,
  202. num_epochs, eval_freq, eval_iter, start_context, tokenizer):
  203. train_losses, val_losses, track_tokens = [], [], []
  204. total_tokens, global_step, last_tokens = 0, -1, 0
  205. # Variables for cumulative average tokens/sec
  206. cumulative_tokens, cumulative_time = 0.0, 0.0
  207. # CUDA-specific timing setup
  208. use_cuda = device.type == "cuda"
  209. if use_cuda:
  210. t_start = torch.cuda.Event(enable_timing=True)
  211. t_end = torch.cuda.Event(enable_timing=True)
  212. torch.cuda.synchronize() # Ensure all prior CUDA operations are done
  213. t_start.record() # Start the timer for the first interval
  214. else:
  215. t0 = time.time() # Start the timer for the first interval
  216. # Main training loop
  217. for epoch in range(num_epochs):
  218. model.train()
  219. for inp_batch, tgt_batch in train_loader:
  220. optimizer.zero_grad()
  221. global_step += 1
  222. # Forward and backward pass
  223. loss = calc_loss_batch(inp_batch, tgt_batch, model, device)
  224. loss.backward()
  225. optimizer.step()
  226. total_tokens += inp_batch.numel()
  227. # At evaluation intervals, measure elapsed time and tokens per second
  228. if global_step % eval_freq == 0:
  229. # End timing for the current interval
  230. if use_cuda:
  231. t_end.record()
  232. torch.cuda.synchronize() # Wait for all CUDA ops to complete.
  233. elapsed = t_start.elapsed_time(t_end) / 1000 # Convert ms to seconds
  234. t_start.record() # Reset timer for the next interval
  235. else:
  236. elapsed = time.time() - t0
  237. t0 = time.time() # Reset timer for the next interval
  238. # Calculate tokens processed in this interval
  239. tokens_interval = total_tokens - last_tokens
  240. last_tokens = total_tokens
  241. tps = tokens_interval / elapsed if elapsed > 0 else 0 # Tokens per second
  242. # Update cumulative counters (skip the first evaluation interval)
  243. if global_step: # This is False only when global_step == 0 (first evaluation)
  244. cumulative_tokens += tokens_interval
  245. cumulative_time += elapsed
  246. # Compute cumulative average tokens/sec (excluding the first interval)
  247. avg_tps = cumulative_tokens / cumulative_time if cumulative_time > 0 else 0
  248. # Evaluate model performance (this may add overhead)
  249. train_loss, val_loss = evaluate_model(model, train_loader, val_loader, device, eval_iter)
  250. train_losses.append(train_loss)
  251. val_losses.append(val_loss)
  252. track_tokens.append(total_tokens)
  253. print(f"Ep {epoch+1}, Step {global_step:06d}, "
  254. f"Train: {train_loss:.3f}, Val: {val_loss:.3f}, "
  255. f"Step tok/sec: {round(tps)}, Avg tok/sec: {round(avg_tps)}")
  256. generate_and_print_sample(model, tokenizer, device, start_context)
  257. # Memory stats
  258. if torch.cuda.is_available():
  259. device = torch.cuda.current_device()
  260. allocated = torch.cuda.memory_allocated(device) / 1024**3 # Convert to GB
  261. reserved = torch.cuda.memory_reserved(device) / 1024**3 # Convert to GB
  262. print(f"\nAllocated memory: {allocated:.4f} GB")
  263. print(f"Reserved memory: {reserved:.4f} GB\n")
  264. return train_losses, val_losses, track_tokens
  265. def plot_losses(epochs_seen, tokens_seen, train_losses, val_losses):
  266. fig, ax1 = plt.subplots()
  267. # Plot training and validation loss against epochs
  268. ax1.plot(epochs_seen, train_losses, label="Training loss")
  269. ax1.plot(epochs_seen, val_losses, linestyle="-.", label="Validation loss")
  270. ax1.set_xlabel("Epochs")
  271. ax1.set_ylabel("Loss")
  272. ax1.legend(loc="upper right")
  273. # Create a second x-axis for tokens seen
  274. ax2 = ax1.twiny() # Create a second x-axis that shares the same y-axis
  275. ax2.plot(tokens_seen, train_losses, alpha=0) # Invisible plot for aligning ticks
  276. ax2.set_xlabel("Tokens seen")
  277. fig.tight_layout() # Adjust layout to make room
  278. # plt.show()
  279. #####################################
  280. # Main function calls
  281. #####################################
  282. def main(gpt_config, settings):
  283. torch.manual_seed(123)
  284. device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
  285. print(f"PyTorch version: {torch.__version__}")
  286. print(f"Using {device}")
  287. if torch.cuda.is_available():
  288. print(f"CUDA version: {torch.version.cuda}")
  289. capability = torch.cuda.get_device_capability()
  290. if capability[0] >= 7: # Volta (7.0+), Turing (7.5+), Ampere (8.0+), Hopper (9.0+)
  291. torch.set_float32_matmul_precision("high")
  292. print("Uses tensor cores")
  293. else:
  294. print("Tensor cores not supported on this GPU. Using default precision.")
  295. print(f"Uses tensor cores: {torch.cuda.is_available()}")
  296. print()
  297. ##############################
  298. # Download data if necessary
  299. ##############################
  300. file_path = "middlemarch.txt"
  301. url = "https://www.gutenberg.org/cache/epub/145/pg145.txt"
  302. if not os.path.exists(file_path):
  303. with urllib.request.urlopen(url) as response:
  304. text_data = response.read().decode('utf-8')
  305. with open(file_path, "w", encoding="utf-8") as file:
  306. file.write(text_data)
  307. else:
  308. with open(file_path, "r", encoding="utf-8") as file:
  309. text_data = file.read()
  310. ##############################
  311. # Initialize model
  312. ##############################
  313. model = GPTModel(gpt_config)
  314. model = torch.compile(model)
  315. model.to(device).to(torch.bfloat16)
  316. optimizer = torch.optim.AdamW(
  317. model.parameters(), lr=settings["learning_rate"], weight_decay=settings["weight_decay"],
  318. fused=True
  319. )
  320. ##############################
  321. # Set up dataloaders
  322. ##############################
  323. # Train/validation ratio
  324. train_ratio = 0.90
  325. split_idx = int(train_ratio * len(text_data))
  326. train_loader = create_dataloader_v1(
  327. text_data[:split_idx],
  328. batch_size=settings["batch_size"],
  329. max_length=gpt_config["context_length"],
  330. stride=gpt_config["context_length"],
  331. drop_last=True,
  332. shuffle=True,
  333. num_workers=4
  334. )
  335. val_loader = create_dataloader_v1(
  336. text_data[split_idx:],
  337. batch_size=settings["batch_size"],
  338. max_length=gpt_config["context_length"],
  339. stride=gpt_config["context_length"],
  340. drop_last=False,
  341. shuffle=False,
  342. num_workers=4
  343. )
  344. ##############################
  345. # Train model
  346. ##############################
  347. tokenizer = tiktoken.get_encoding("gpt2")
  348. train_losses, val_losses, tokens_seen = train_model_simple_with_timing(
  349. model=model,
  350. train_loader=train_loader,
  351. val_loader=val_loader,
  352. optimizer=optimizer,
  353. device=device,
  354. num_epochs=settings["num_epochs"],
  355. eval_freq=10,
  356. eval_iter=1,
  357. start_context="Every effort moves you",
  358. tokenizer=tokenizer
  359. )
  360. return train_losses, val_losses, tokens_seen, model
  361. if __name__ == "__main__":
  362. GPT_CONFIG_124M = {
  363. "vocab_size": 50304, # Vocabulary size
  364. "context_length": 1024, # Input tokens per training example
  365. "emb_dim": 768, # Embedding dimension
  366. "n_heads": 12, # Number of attention heads
  367. "n_layers": 12, # Number of layers
  368. "drop_rate": 0.1, # Dropout rate
  369. "qkv_bias": False # Query-key-value bias
  370. }
  371. OTHER_SETTINGS = {
  372. "learning_rate": 5e-4,
  373. "num_epochs": 15,
  374. "batch_size": 32,
  375. "weight_decay": 0.1
  376. }
  377. ###########################
  378. # Initiate training
  379. ###########################
  380. train_losses, val_losses, tokens_seen, model = main(GPT_CONFIG_124M, OTHER_SETTINGS)
  381. ###########################
  382. # After training
  383. ###########################
  384. # Plot results
  385. epochs_tensor = torch.linspace(0, OTHER_SETTINGS["num_epochs"], len(train_losses))
  386. plot_losses(epochs_tensor, tokens_seen, train_losses, val_losses)
  387. plt.savefig("loss.pdf")
  388. # Save and load model
  389. #
  390. # compiled = hasattr(model, "_orig_mod")
  391. # if compiled:
  392. # torch.save(model._orig_mod.state_dict(), "model.pth")
  393. # else:
  394. # torch.save(model.state_dict(), "model.pth")
  395. #
  396. # model = GPTModel(GPT_CONFIG_124M)
  397. # model.load_state_dict(torch.load("model.pth", weights_only=True))