02_opt_multi_gpu_ddp.py 21 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606
  1. # Copyright (c) Sebastian Raschka under Apache License 2.0 (see LICENSE.txt).
  2. # Source for "Build a Large Language Model From Scratch"
  3. # - https://www.manning.com/books/build-a-large-language-model-from-scratch
  4. # Code: https://github.com/rasbt/LLMs-from-scratch
  5. import os
  6. import time
  7. import urllib.request
  8. import matplotlib.pyplot as plt
  9. import torch
  10. import torch.nn as nn
  11. from torch.utils.data import Dataset, DataLoader
  12. import tiktoken
  13. # NEW imports (see Appendix A):
  14. import platform
  15. from torch.utils.data.distributed import DistributedSampler
  16. from torch.nn.parallel import DistributedDataParallel as DDP
  17. from torch.distributed import init_process_group, destroy_process_group
  18. # NEW: function to initialize a distributed process group (1 process / GPU)
  19. # this allows communication among processes
  20. # (see Appendix A):
  21. def ddp_setup(rank, world_size):
  22. """
  23. Arguments:
  24. rank: a unique process ID
  25. world_size: total number of processes in the group
  26. """
  27. # Only set MASTER_ADDR and MASTER_PORT if not already defined by torchrun
  28. if "MASTER_ADDR" not in os.environ:
  29. os.environ["MASTER_ADDR"] = "localhost"
  30. if "MASTER_PORT" not in os.environ:
  31. os.environ["MASTER_PORT"] = "12345"
  32. # initialize process group
  33. if platform.system() == "Windows":
  34. # Disable libuv because PyTorch for Windows isn't built with support
  35. os.environ["USE_LIBUV"] = "0"
  36. # Windows users may have to use "gloo" instead of "nccl" as backend
  37. # gloo: Facebook Collective Communication Library
  38. init_process_group(backend="gloo", rank=rank, world_size=world_size)
  39. else:
  40. # nccl: NVIDIA Collective Communication Library
  41. init_process_group(backend="nccl", rank=rank, world_size=world_size)
  42. torch.cuda.set_device(rank)
  43. #####################################
  44. # Chapter 2
  45. #####################################
  46. class GPTDatasetV1(Dataset):
  47. def __init__(self, txt, tokenizer, max_length, stride):
  48. self.input_ids = []
  49. self.target_ids = []
  50. # Tokenize the entire text
  51. token_ids = tokenizer.encode(txt, allowed_special={"<|endoftext|>"})
  52. # Use a sliding window to chunk the book into overlapping sequences of max_length
  53. for i in range(0, len(token_ids) - max_length, stride):
  54. input_chunk = token_ids[i:i + max_length]
  55. target_chunk = token_ids[i + 1: i + max_length + 1]
  56. self.input_ids.append(torch.tensor(input_chunk))
  57. self.target_ids.append(torch.tensor(target_chunk))
  58. def __len__(self):
  59. return len(self.input_ids)
  60. def __getitem__(self, idx):
  61. return self.input_ids[idx], self.target_ids[idx]
  62. # NEW: Modify to set shuffle=False and use a sampler
  63. # (See Appendix A):
  64. def create_dataloader_v1(txt, batch_size=4, max_length=256,
  65. stride=128, drop_last=True, num_workers=0):
  66. # Initialize the tokenizer
  67. tokenizer = tiktoken.get_encoding("gpt2")
  68. # Create dataset
  69. dataset = GPTDatasetV1(txt, tokenizer, max_length, stride)
  70. # Create dataloader
  71. dataloader = DataLoader(
  72. dataset=dataset,
  73. batch_size=batch_size,
  74. shuffle=False, # NEW: False because of DistributedSampler below
  75. drop_last=drop_last,
  76. num_workers=num_workers,
  77. pin_memory=True,
  78. # NEW: chunk batches across GPUs without overlapping samples:
  79. sampler=DistributedSampler(dataset) # NEW
  80. )
  81. return dataloader
  82. #####################################
  83. # Chapter 3
  84. #####################################
  85. class PyTorchMultiHeadAttention(nn.Module):
  86. def __init__(self, d_in, d_out, num_heads, dropout=0.0, qkv_bias=False):
  87. super().__init__()
  88. assert d_out % num_heads == 0, "d_out is indivisible by num_heads"
  89. self.num_heads = num_heads
  90. self.head_dim = d_out // num_heads
  91. self.d_out = d_out
  92. self.qkv = nn.Linear(d_in, 3 * d_out, bias=qkv_bias)
  93. self.proj = nn.Linear(d_out, d_out)
  94. self.dropout = dropout
  95. def forward(self, x):
  96. batch_size, num_tokens, embed_dim = x.shape
  97. # (b, num_tokens, embed_dim) --> (b, num_tokens, 3 * embed_dim)
  98. qkv = self.qkv(x)
  99. # (b, num_tokens, 3 * embed_dim) --> (b, num_tokens, 3, num_heads, head_dim)
  100. qkv = qkv.view(batch_size, num_tokens, 3, self.num_heads, self.head_dim)
  101. # (b, num_tokens, 3, num_heads, head_dim) --> (3, b, num_heads, num_tokens, head_dim)
  102. qkv = qkv.permute(2, 0, 3, 1, 4)
  103. # (3, b, num_heads, num_tokens, head_dim) -> 3 times (b, num_heads, num_tokens, head_dim)
  104. queries, keys, values = qkv
  105. use_dropout = 0. if not self.training else self.dropout
  106. context_vec = nn.functional.scaled_dot_product_attention(
  107. queries, keys, values, attn_mask=None, dropout_p=use_dropout, is_causal=True)
  108. # Combine heads, where self.d_out = self.num_heads * self.head_dim
  109. context_vec = context_vec.transpose(1, 2).contiguous().view(batch_size, num_tokens, self.d_out)
  110. context_vec = self.proj(context_vec)
  111. return context_vec
  112. #####################################
  113. # Chapter 4
  114. #####################################
  115. class FeedForward(nn.Module):
  116. def __init__(self, cfg):
  117. super().__init__()
  118. self.layers = nn.Sequential(
  119. nn.Linear(cfg["emb_dim"], 4 * cfg["emb_dim"]),
  120. nn.GELU(approximate="tanh"),
  121. nn.Linear(4 * cfg["emb_dim"], cfg["emb_dim"]),
  122. )
  123. def forward(self, x):
  124. return self.layers(x)
  125. class TransformerBlock(nn.Module):
  126. def __init__(self, cfg):
  127. super().__init__()
  128. self.att = PyTorchMultiHeadAttention(
  129. d_in=cfg["emb_dim"],
  130. d_out=cfg["emb_dim"],
  131. num_heads=cfg["n_heads"],
  132. dropout=cfg["drop_rate"],
  133. qkv_bias=cfg["qkv_bias"])
  134. self.ff = FeedForward(cfg)
  135. self.norm1 = nn.LayerNorm(cfg["emb_dim"])
  136. self.norm2 = nn.LayerNorm(cfg["emb_dim"])
  137. self.drop_shortcut = nn.Dropout(cfg["drop_rate"])
  138. def forward(self, x):
  139. # Shortcut connection for attention block
  140. shortcut = x
  141. x = self.norm1(x)
  142. x = self.att(x) # Shape [batch_size, num_tokens, emb_size]
  143. x = self.drop_shortcut(x)
  144. x = x + shortcut # Add the original input back
  145. # Shortcut connection for feed-forward block
  146. shortcut = x
  147. x = self.norm2(x)
  148. x = self.ff(x)
  149. x = self.drop_shortcut(x)
  150. x = x + shortcut # Add the original input back
  151. return x
  152. class GPTModel(nn.Module):
  153. def __init__(self, cfg):
  154. super().__init__()
  155. self.tok_emb = nn.Embedding(cfg["vocab_size"], cfg["emb_dim"])
  156. self.pos_emb = nn.Embedding(cfg["context_length"], cfg["emb_dim"])
  157. self.drop_emb = nn.Dropout(cfg["drop_rate"])
  158. self.trf_blocks = nn.Sequential(
  159. *[TransformerBlock(cfg) for _ in range(cfg["n_layers"])])
  160. self.final_norm = nn.LayerNorm(cfg["emb_dim"])
  161. self.out_head = nn.Linear(cfg["emb_dim"], cfg["vocab_size"], bias=False)
  162. def forward(self, in_idx):
  163. batch_size, seq_len = in_idx.shape
  164. tok_embeds = self.tok_emb(in_idx)
  165. pos_embeds = self.pos_emb(torch.arange(seq_len, device=in_idx.device))
  166. x = tok_embeds + pos_embeds # Shape [batch_size, num_tokens, emb_size]
  167. x = self.drop_emb(x)
  168. x = self.trf_blocks(x)
  169. x = self.final_norm(x)
  170. logits = self.out_head(x)
  171. return logits
  172. def generate_text_simple(model, idx, max_new_tokens, context_size):
  173. # idx is (B, T) array of indices in the current context
  174. for _ in range(max_new_tokens):
  175. # Crop current context if it exceeds the supported context size
  176. # E.g., if LLM supports only 5 tokens, and the context size is 10
  177. # then only the last 5 tokens are used as context
  178. idx_cond = idx[:, -context_size:]
  179. # Get the predictions
  180. with torch.no_grad():
  181. logits = model(idx_cond)
  182. # Focus only on the last time step
  183. # (batch, n_token, vocab_size) becomes (batch, vocab_size)
  184. logits = logits[:, -1, :]
  185. # Get the idx of the vocab entry with the highest logits value
  186. idx_next = torch.argmax(logits, dim=-1, keepdim=True) # (batch, 1)
  187. # Append sampled index to the running sequence
  188. idx = torch.cat((idx, idx_next), dim=1) # (batch, n_tokens+1)
  189. return idx
  190. #####################################
  191. # Chapter 5
  192. #####################################
  193. def text_to_token_ids(text, tokenizer):
  194. encoded = tokenizer.encode(text)
  195. encoded_tensor = torch.tensor(encoded).unsqueeze(0) # add batch dimension
  196. return encoded_tensor
  197. def token_ids_to_text(token_ids, tokenizer):
  198. flat = token_ids.squeeze(0) # remove batch dimension
  199. return tokenizer.decode(flat.tolist())
  200. def calc_loss_batch(input_batch, target_batch, model, device):
  201. input_batch, target_batch = input_batch.to(device), target_batch.to(device)
  202. logits = model(input_batch)
  203. loss = torch.nn.functional.cross_entropy(logits.flatten(0, 1), target_batch.flatten())
  204. return loss
  205. def calc_loss_loader(data_loader, model, device, num_batches=None):
  206. total_loss = 0.
  207. if len(data_loader) == 0:
  208. return float("nan")
  209. elif num_batches is None:
  210. num_batches = len(data_loader)
  211. else:
  212. num_batches = min(num_batches, len(data_loader))
  213. for i, (input_batch, target_batch) in enumerate(data_loader):
  214. if i < num_batches:
  215. loss = calc_loss_batch(input_batch, target_batch, model, device)
  216. total_loss += loss.item()
  217. else:
  218. break
  219. return total_loss / num_batches
  220. def evaluate_model(model, train_loader, val_loader, device, eval_iter):
  221. model.eval()
  222. with torch.no_grad():
  223. train_loss = calc_loss_loader(train_loader, model, device, num_batches=eval_iter)
  224. val_loss = calc_loss_loader(val_loader, model, device, num_batches=eval_iter)
  225. model.train()
  226. return train_loss, val_loss
  227. def generate_and_print_sample(model, device, start_context):
  228. model.eval()
  229. # NEW: Modify for DDP
  230. context_size = model.module.pos_emb.weight.shape[0] if isinstance(model, DDP) else model.pos_emb.weight.shape[0]
  231. encoded = text_to_token_ids(start_context, tiktoken.get_encoding("gpt2")).to(device)
  232. with torch.no_grad():
  233. token_ids = generate_text_simple(
  234. model=model, idx=encoded,
  235. max_new_tokens=50, context_size=context_size
  236. )
  237. decoded_text = token_ids_to_text(token_ids, tiktoken.get_encoding("gpt2"))
  238. print(decoded_text.replace("\n", " ")) # Compact print format
  239. model.train()
  240. def train_model_simple_with_timing(model, train_loader, val_loader, optimizer, device,
  241. num_epochs, eval_freq, eval_iter, start_context):
  242. train_losses, val_losses, track_tokens = [], [], []
  243. total_tokens, global_step, last_tokens = 0, -1, 0
  244. # NEW: Determine the current rank (default to 0 if not distributed)
  245. rank = torch.distributed.get_rank() if torch.distributed.is_initialized() else 0
  246. # world_size = torch.distributed.get_world_size() if torch.distributed.is_initialized() else 1
  247. # Variables for cumulative average tokens/sec
  248. cumulative_tokens, cumulative_time = 0.0, 0.0
  249. # CUDA-specific timing setup
  250. use_cuda = device.type == "cuda"
  251. if use_cuda:
  252. t_start = torch.cuda.Event(enable_timing=True)
  253. t_end = torch.cuda.Event(enable_timing=True)
  254. torch.cuda.synchronize() # Ensure all prior CUDA operations are done
  255. t_start.record() # Start the timer for the first interval
  256. else:
  257. t0 = time.time() # Start the timer for the first interval
  258. # Main training loop
  259. for epoch in range(num_epochs):
  260. # NEW: set epoch for DistributedSampler so each process gets a unique shuffle order
  261. if isinstance(train_loader.sampler, DistributedSampler):
  262. train_loader.sampler.set_epoch(epoch)
  263. model.train()
  264. for inp_batch, tgt_batch in train_loader:
  265. optimizer.zero_grad()
  266. global_step += 1
  267. # Forward and backward pass
  268. loss = calc_loss_batch(inp_batch, tgt_batch, model, device)
  269. loss.backward()
  270. optimizer.step()
  271. total_tokens += inp_batch.numel()
  272. # At evaluation intervals, measure elapsed time and tokens per second
  273. if global_step % eval_freq == 0:
  274. # End timing for the current interval
  275. if use_cuda:
  276. t_end.record()
  277. torch.cuda.synchronize() # Wait for all CUDA ops to complete.
  278. elapsed = t_start.elapsed_time(t_end) / 1000 # Convert ms to seconds
  279. t_start.record() # Reset timer for the next interval
  280. else:
  281. elapsed = time.time() - t0
  282. t0 = time.time() # Reset timer for the next interval
  283. # Calculate local tokens processed during this interval
  284. local_interval = total_tokens - last_tokens
  285. last_tokens = total_tokens
  286. # Aggregate the tokens processed over all devices
  287. local_tensor = torch.tensor([local_interval], device=device, dtype=torch.float)
  288. global_tensor = local_tensor.clone()
  289. torch.distributed.all_reduce(global_tensor, op=torch.distributed.ReduceOp.SUM)
  290. global_interval = global_tensor.item()
  291. # Global tokens per second for this interval
  292. global_tps = global_interval / elapsed if elapsed > 0 else 0
  293. # Update cumulative tokens (local) and aggregate globally
  294. cumulative_tokens += local_interval
  295. local_cum_tensor = torch.tensor([cumulative_tokens], device=device, dtype=torch.float)
  296. global_cum_tensor = local_cum_tensor.clone()
  297. torch.distributed.all_reduce(global_cum_tensor, op=torch.distributed.ReduceOp.SUM)
  298. global_cumulative_tokens = global_cum_tensor.item()
  299. cumulative_time += elapsed
  300. global_avg_tps = global_cumulative_tokens / cumulative_time if cumulative_time > 0 else 0
  301. # Evaluate model performance (this may add overhead)
  302. train_loss, val_loss = evaluate_model(model, train_loader, val_loader, device, eval_iter)
  303. train_losses.append(train_loss)
  304. val_losses.append(val_loss)
  305. track_tokens.append(total_tokens)
  306. # NEW: Only print logs once per GPU (choosing the rank 0 GPU)
  307. if rank == 0:
  308. print(f"Ep {epoch+1}, Step {global_step:06d}, "
  309. f"Train: {train_loss:.3f}, Val: {val_loss:.3f}, "
  310. f"Step tok/sec: {round(global_tps)}, Global avg tok/sec: {round(global_avg_tps)}")
  311. # NEW Only rank 0 prints the generated sample and memory usage stats
  312. if rank == 0 and epoch % 5 == 0:
  313. generate_and_print_sample(model, device, start_context)
  314. # Memory stats
  315. if torch.cuda.is_available():
  316. current_device = torch.cuda.current_device()
  317. allocated = torch.cuda.memory_allocated(current_device) / 1024**3 # Convert to GB
  318. reserved = torch.cuda.memory_reserved(current_device) / 1024**3 # Convert to GB
  319. print(f"\nAllocated memory: {allocated:.4f} GB")
  320. print(f"Reserved memory: {reserved:.4f} GB\n")
  321. return train_losses, val_losses, track_tokens
  322. def plot_losses(epochs_seen, tokens_seen, train_losses, val_losses):
  323. fig, ax1 = plt.subplots()
  324. # Plot training and validation loss against epochs
  325. ax1.plot(epochs_seen, train_losses, label="Training loss")
  326. ax1.plot(epochs_seen, val_losses, linestyle="-.", label="Validation loss")
  327. ax1.set_xlabel("Epochs")
  328. ax1.set_ylabel("Loss")
  329. ax1.legend(loc="upper right")
  330. # Create a second x-axis for tokens seen
  331. ax2 = ax1.twiny() # Create a second x-axis that shares the same y-axis
  332. ax2.plot(tokens_seen, train_losses, alpha=0) # Invisible plot for aligning ticks
  333. ax2.set_xlabel("Tokens seen")
  334. fig.tight_layout() # Adjust layout to make room
  335. # plt.show()
  336. #####################################
  337. # Main function calls
  338. #####################################
  339. # NEW: Add rank and world_size
  340. def main(gpt_config, settings, rank, world_size):
  341. ddp_setup(rank, world_size) # NEW: initialize process groups
  342. device = torch.device("cuda", rank)
  343. torch.manual_seed(123)
  344. # NEW: Print info only on 1 GPU
  345. if rank == 0:
  346. print(f"PyTorch version: {torch.__version__}")
  347. if torch.cuda.is_available():
  348. print(f"CUDA version: {torch.version.cuda}")
  349. capability = torch.cuda.get_device_capability()
  350. if capability[0] >= 7: # Volta (7.0+), Turing (7.5+), Ampere (8.0+), Hopper (9.0+)
  351. torch.set_float32_matmul_precision("high")
  352. print("Uses tensor cores")
  353. else:
  354. print("Tensor cores not supported on this GPU. Using default precision.")
  355. print()
  356. ##############################
  357. # Download data if necessary
  358. ##############################
  359. file_path = "middlemarch.txt"
  360. url = "https://www.gutenberg.org/cache/epub/145/pg145.txt"
  361. # NEW: Only download 1 time
  362. if rank == 0:
  363. if not os.path.exists(file_path):
  364. with urllib.request.urlopen(url) as response:
  365. text_data = response.read().decode('utf-8')
  366. with open(file_path, "w", encoding="utf-8") as file:
  367. file.write(text_data)
  368. # NEW: All processes wait until rank 0 is done, using the GPU index.
  369. torch.distributed.barrier(device_ids=[device.index])
  370. with open(file_path, "r", encoding="utf-8") as file:
  371. text_data = file.read()
  372. ##############################
  373. # Initialize model
  374. ##############################
  375. model = GPTModel(gpt_config)
  376. model = torch.compile(model)
  377. model = model.to(device)
  378. model = model.to(torch.bfloat16)
  379. # NEW: Wrap model with DDP
  380. model = DDP(model, device_ids=[rank])
  381. optimizer = torch.optim.AdamW(
  382. model.parameters(), lr=settings["learning_rate"], weight_decay=settings["weight_decay"],
  383. fused=True
  384. )
  385. ##############################
  386. # Set up dataloaders
  387. ##############################
  388. # Train/validation ratio
  389. train_ratio = 0.90
  390. split_idx = int(train_ratio * len(text_data))
  391. train_loader = create_dataloader_v1(
  392. text_data[:split_idx],
  393. batch_size=settings["batch_size"],
  394. max_length=gpt_config["context_length"],
  395. stride=gpt_config["context_length"],
  396. drop_last=True,
  397. num_workers=4
  398. )
  399. val_loader = create_dataloader_v1(
  400. text_data[split_idx:],
  401. batch_size=settings["batch_size"],
  402. max_length=gpt_config["context_length"],
  403. stride=gpt_config["context_length"],
  404. drop_last=False,
  405. num_workers=4
  406. )
  407. ##############################
  408. # Train model
  409. ##############################
  410. train_losses, val_losses, tokens_seen = train_model_simple_with_timing(
  411. model=model,
  412. train_loader=train_loader,
  413. val_loader=val_loader,
  414. optimizer=optimizer,
  415. device=device,
  416. num_epochs=settings["num_epochs"],
  417. eval_freq=5,
  418. eval_iter=1,
  419. start_context="Every effort moves you",
  420. )
  421. # NEW: Clean up distributed processes
  422. destroy_process_group()
  423. return train_losses, val_losses, tokens_seen, model
  424. if __name__ == "__main__":
  425. # NEW: Extract rank and world size from environment variables
  426. if "WORLD_SIZE" in os.environ:
  427. world_size = int(os.environ["WORLD_SIZE"])
  428. else:
  429. world_size = 1
  430. if "LOCAL_RANK" in os.environ:
  431. rank = int(os.environ["LOCAL_RANK"])
  432. elif "RANK" in os.environ:
  433. rank = int(os.environ["RANK"])
  434. else:
  435. rank = 0
  436. GPT_CONFIG_124M = {
  437. "vocab_size": 50304, # Vocabulary size
  438. "context_length": 1024, # Input tokens per training example
  439. "emb_dim": 768, # Embedding dimension
  440. "n_heads": 12, # Number of attention heads
  441. "n_layers": 12, # Number of layers
  442. "drop_rate": 0.1, # Dropout rate
  443. "qkv_bias": False # Query-key-value bias
  444. }
  445. OTHER_SETTINGS = {
  446. "learning_rate": 5e-4, # * world_size, # NEW: Increase learning rate to account for multiple GPUs
  447. "num_epochs": 50,
  448. "batch_size": 32,
  449. "weight_decay": 0.1
  450. }
  451. ###########################
  452. # Initiate training
  453. ###########################
  454. train_losses, val_losses, tokens_seen, model = main(
  455. GPT_CONFIG_124M, OTHER_SETTINGS,
  456. rank, world_size # NEW
  457. )
  458. ###########################
  459. # After training
  460. ###########################
  461. # NEW: Only create 1 plot
  462. if rank == 0:
  463. # Plot results
  464. epochs_tensor = torch.linspace(0, OTHER_SETTINGS["num_epochs"], len(train_losses))
  465. plot_losses(epochs_tensor, tokens_seen, train_losses, val_losses)
  466. plt.savefig("loss.pdf")
  467. # Save and load model
  468. #
  469. # compiled = hasattr(model, "_orig_mod")
  470. # if compiled:
  471. # torch.save(model._orig_mod.state_dict(), "model.pth")
  472. # else:
  473. # torch.save(model.state_dict(), "model.pth")
  474. #
  475. # model = GPTModel(GPT_CONFIG_124M)
  476. # model.load_state_dict(torch.load("model.pth", weights_only=True))