DDP-script-torchrun.py 6.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220
  1. # Copyright (c) Sebastian Raschka under Apache License 2.0 (see LICENSE.txt).
  2. # Source for "Build a Large Language Model From Scratch"
  3. # - https://www.manning.com/books/build-a-large-language-model-from-scratch
  4. # Code: https://github.com/rasbt/LLMs-from-scratch
  5. # Appendix A: Introduction to PyTorch (Part 3)
  6. import torch
  7. import torch.nn.functional as F
  8. from torch.utils.data import Dataset, DataLoader
  9. # NEW imports:
  10. import os
  11. import platform
  12. from torch.utils.data.distributed import DistributedSampler
  13. from torch.nn.parallel import DistributedDataParallel as DDP
  14. from torch.distributed import init_process_group, destroy_process_group
  15. # NEW: function to initialize a distributed process group (1 process / GPU)
  16. # this allows communication among processes
  17. def ddp_setup(rank, world_size):
  18. """
  19. Arguments:
  20. rank: a unique process ID
  21. world_size: total number of processes in the group
  22. """
  23. # Only set MASTER_ADDR and MASTER_PORT if not already defined by torchrun
  24. if "MASTER_ADDR" not in os.environ:
  25. os.environ["MASTER_ADDR"] = "localhost"
  26. if "MASTER_PORT" not in os.environ:
  27. os.environ["MASTER_PORT"] = "12345"
  28. # initialize process group
  29. if platform.system() == "Windows":
  30. # Disable libuv because PyTorch for Windows isn't built with support
  31. os.environ["USE_LIBUV"] = "0"
  32. # Windows users may have to use "gloo" instead of "nccl" as backend
  33. # gloo: Facebook Collective Communication Library
  34. init_process_group(backend="gloo", rank=rank, world_size=world_size)
  35. else:
  36. # nccl: NVIDIA Collective Communication Library
  37. init_process_group(backend="nccl", rank=rank, world_size=world_size)
  38. torch.cuda.set_device(rank)
  39. class ToyDataset(Dataset):
  40. def __init__(self, X, y):
  41. self.features = X
  42. self.labels = y
  43. def __getitem__(self, index):
  44. one_x = self.features[index]
  45. one_y = self.labels[index]
  46. return one_x, one_y
  47. def __len__(self):
  48. return self.labels.shape[0]
  49. class NeuralNetwork(torch.nn.Module):
  50. def __init__(self, num_inputs, num_outputs):
  51. super().__init__()
  52. self.layers = torch.nn.Sequential(
  53. # 1st hidden layer
  54. torch.nn.Linear(num_inputs, 30),
  55. torch.nn.ReLU(),
  56. # 2nd hidden layer
  57. torch.nn.Linear(30, 20),
  58. torch.nn.ReLU(),
  59. # output layer
  60. torch.nn.Linear(20, num_outputs),
  61. )
  62. def forward(self, x):
  63. logits = self.layers(x)
  64. return logits
  65. def prepare_dataset():
  66. X_train = torch.tensor([
  67. [-1.2, 3.1],
  68. [-0.9, 2.9],
  69. [-0.5, 2.6],
  70. [2.3, -1.1],
  71. [2.7, -1.5]
  72. ])
  73. y_train = torch.tensor([0, 0, 0, 1, 1])
  74. X_test = torch.tensor([
  75. [-0.8, 2.8],
  76. [2.6, -1.6],
  77. ])
  78. y_test = torch.tensor([0, 1])
  79. # Uncomment these lines to increase the dataset size to run this script on up to 8 GPUs:
  80. # factor = 4
  81. # X_train = torch.cat([X_train + torch.randn_like(X_train) * 0.1 for _ in range(factor)])
  82. # y_train = y_train.repeat(factor)
  83. # X_test = torch.cat([X_test + torch.randn_like(X_test) * 0.1 for _ in range(factor)])
  84. # y_test = y_test.repeat(factor)
  85. train_ds = ToyDataset(X_train, y_train)
  86. test_ds = ToyDataset(X_test, y_test)
  87. train_loader = DataLoader(
  88. dataset=train_ds,
  89. batch_size=2,
  90. shuffle=False, # NEW: False because of DistributedSampler below
  91. pin_memory=True,
  92. drop_last=True,
  93. # NEW: chunk batches across GPUs without overlapping samples:
  94. sampler=DistributedSampler(train_ds) # NEW
  95. )
  96. test_loader = DataLoader(
  97. dataset=test_ds,
  98. batch_size=2,
  99. shuffle=False,
  100. )
  101. return train_loader, test_loader
  102. # NEW: wrapper
  103. def main(rank, world_size, num_epochs):
  104. ddp_setup(rank, world_size) # NEW: initialize process groups
  105. train_loader, test_loader = prepare_dataset()
  106. model = NeuralNetwork(num_inputs=2, num_outputs=2)
  107. model.to(rank)
  108. optimizer = torch.optim.SGD(model.parameters(), lr=0.5)
  109. model = DDP(model, device_ids=[rank]) # NEW: wrap model with DDP
  110. # the core model is now accessible as model.module
  111. for epoch in range(num_epochs):
  112. # NEW: Set sampler to ensure each epoch has a different shuffle order
  113. train_loader.sampler.set_epoch(epoch)
  114. model.train()
  115. for features, labels in train_loader:
  116. features, labels = features.to(rank), labels.to(rank) # New: use rank
  117. logits = model(features)
  118. loss = F.cross_entropy(logits, labels) # Loss function
  119. optimizer.zero_grad()
  120. loss.backward()
  121. optimizer.step()
  122. # LOGGING
  123. print(f"[GPU{rank}] Epoch: {epoch+1:03d}/{num_epochs:03d}"
  124. f" | Batchsize {labels.shape[0]:03d}"
  125. f" | Train/Val Loss: {loss:.2f}")
  126. model.eval()
  127. try:
  128. train_acc = compute_accuracy(model, train_loader, device=rank)
  129. print(f"[GPU{rank}] Training accuracy", train_acc)
  130. test_acc = compute_accuracy(model, test_loader, device=rank)
  131. print(f"[GPU{rank}] Test accuracy", test_acc)
  132. ####################################################
  133. # NEW (not in the book):
  134. except ZeroDivisionError as e:
  135. raise ZeroDivisionError(
  136. f"{e}\n\nThis script is designed for 2 GPUs. You can run it as:\n"
  137. "torchrun --nproc_per_node=2 DDP-script-torchrun.py\n"
  138. f"Or, to run it on {torch.cuda.device_count()} GPUs, uncomment the code on lines 103 to 107."
  139. )
  140. ####################################################
  141. destroy_process_group() # NEW: cleanly exit distributed mode
  142. def compute_accuracy(model, dataloader, device):
  143. model = model.eval()
  144. correct = 0.0
  145. total_examples = 0
  146. for idx, (features, labels) in enumerate(dataloader):
  147. features, labels = features.to(device), labels.to(device)
  148. with torch.no_grad():
  149. logits = model(features)
  150. predictions = torch.argmax(logits, dim=1)
  151. compare = labels == predictions
  152. correct += torch.sum(compare)
  153. total_examples += len(compare)
  154. return (correct / total_examples).item()
  155. if __name__ == "__main__":
  156. # NEW: Use environment variables set by torchrun if available, otherwise default to single-process.
  157. if "WORLD_SIZE" in os.environ:
  158. world_size = int(os.environ["WORLD_SIZE"])
  159. else:
  160. world_size = 1
  161. if "LOCAL_RANK" in os.environ:
  162. rank = int(os.environ["LOCAL_RANK"])
  163. elif "RANK" in os.environ:
  164. rank = int(os.environ["RANK"])
  165. else:
  166. rank = 0
  167. # Only print on rank 0 to avoid duplicate prints from each GPU process
  168. if rank == 0:
  169. print("PyTorch version:", torch.__version__)
  170. print("CUDA available:", torch.cuda.is_available())
  171. print("Number of GPUs available:", torch.cuda.device_count())
  172. torch.manual_seed(123)
  173. num_epochs = 3
  174. main(rank, world_size, num_epochs)