DDP-script.py 5.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191
  1. # Copyright (c) Sebastian Raschka under Apache License 2.0 (see LICENSE.txt).
  2. # Source for "Build a Large Language Model From Scratch"
  3. # - https://www.manning.com/books/build-a-large-language-model-from-scratch
  4. # Code: https://github.com/rasbt/LLMs-from-scratch
  5. # Appendix A: Introduction to PyTorch (Part 3)
  6. import torch
  7. import torch.nn.functional as F
  8. from torch.utils.data import Dataset, DataLoader
  9. # NEW imports:
  10. import os
  11. import platform
  12. import torch.multiprocessing as mp
  13. from torch.utils.data.distributed import DistributedSampler
  14. from torch.nn.parallel import DistributedDataParallel as DDP
  15. from torch.distributed import init_process_group, destroy_process_group
  16. # NEW: function to initialize a distributed process group (1 process / GPU)
  17. # this allows communication among processes
  18. def ddp_setup(rank, world_size):
  19. """
  20. Arguments:
  21. rank: a unique process ID
  22. world_size: total number of processes in the group
  23. """
  24. # rank of machine running rank:0 process
  25. # here, we assume all GPUs are on the same machine
  26. os.environ["MASTER_ADDR"] = "localhost"
  27. # any free port on the machine
  28. os.environ["MASTER_PORT"] = "12345"
  29. if platform.system() == "Windows":
  30. # Disable libuv because PyTorch for Windows isn't built with support
  31. os.environ["USE_LIBUV"] = "0"
  32. # initialize process group
  33. if platform.system() == "Windows":
  34. # Windows users may have to use "gloo" instead of "nccl" as backend
  35. # gloo: Facebook Collective Communication Library
  36. init_process_group(backend="gloo", rank=rank, world_size=world_size)
  37. else:
  38. # nccl: NVIDIA Collective Communication Library
  39. init_process_group(backend="nccl", rank=rank, world_size=world_size)
  40. torch.cuda.set_device(rank)
  41. class ToyDataset(Dataset):
  42. def __init__(self, X, y):
  43. self.features = X
  44. self.labels = y
  45. def __getitem__(self, index):
  46. one_x = self.features[index]
  47. one_y = self.labels[index]
  48. return one_x, one_y
  49. def __len__(self):
  50. return self.labels.shape[0]
  51. class NeuralNetwork(torch.nn.Module):
  52. def __init__(self, num_inputs, num_outputs):
  53. super().__init__()
  54. self.layers = torch.nn.Sequential(
  55. # 1st hidden layer
  56. torch.nn.Linear(num_inputs, 30),
  57. torch.nn.ReLU(),
  58. # 2nd hidden layer
  59. torch.nn.Linear(30, 20),
  60. torch.nn.ReLU(),
  61. # output layer
  62. torch.nn.Linear(20, num_outputs),
  63. )
  64. def forward(self, x):
  65. logits = self.layers(x)
  66. return logits
  67. def prepare_dataset():
  68. X_train = torch.tensor([
  69. [-1.2, 3.1],
  70. [-0.9, 2.9],
  71. [-0.5, 2.6],
  72. [2.3, -1.1],
  73. [2.7, -1.5]
  74. ])
  75. y_train = torch.tensor([0, 0, 0, 1, 1])
  76. X_test = torch.tensor([
  77. [-0.8, 2.8],
  78. [2.6, -1.6],
  79. ])
  80. y_test = torch.tensor([0, 1])
  81. train_ds = ToyDataset(X_train, y_train)
  82. test_ds = ToyDataset(X_test, y_test)
  83. train_loader = DataLoader(
  84. dataset=train_ds,
  85. batch_size=2,
  86. shuffle=False, # NEW: False because of DistributedSampler below
  87. pin_memory=True,
  88. drop_last=True,
  89. # NEW: chunk batches across GPUs without overlapping samples:
  90. sampler=DistributedSampler(train_ds) # NEW
  91. )
  92. test_loader = DataLoader(
  93. dataset=test_ds,
  94. batch_size=2,
  95. shuffle=False,
  96. )
  97. return train_loader, test_loader
  98. # NEW: wrapper
  99. def main(rank, world_size, num_epochs):
  100. ddp_setup(rank, world_size) # NEW: initialize process groups
  101. train_loader, test_loader = prepare_dataset()
  102. model = NeuralNetwork(num_inputs=2, num_outputs=2)
  103. model.to(rank)
  104. optimizer = torch.optim.SGD(model.parameters(), lr=0.5)
  105. model = DDP(model, device_ids=[rank]) # NEW: wrap model with DDP
  106. # the core model is now accessible as model.module
  107. for epoch in range(num_epochs):
  108. model.train()
  109. for features, labels in train_loader:
  110. features, labels = features.to(rank), labels.to(rank) # New: use rank
  111. logits = model(features)
  112. loss = F.cross_entropy(logits, labels) # Loss function
  113. optimizer.zero_grad()
  114. loss.backward()
  115. optimizer.step()
  116. # LOGGING
  117. print(f"[GPU{rank}] Epoch: {epoch+1:03d}/{num_epochs:03d}"
  118. f" | Batchsize {labels.shape[0]:03d}"
  119. f" | Train/Val Loss: {loss:.2f}")
  120. model.eval()
  121. train_acc = compute_accuracy(model, train_loader, device=rank)
  122. print(f"[GPU{rank}] Training accuracy", train_acc)
  123. test_acc = compute_accuracy(model, test_loader, device=rank)
  124. print(f"[GPU{rank}] Test accuracy", test_acc)
  125. destroy_process_group() # NEW: cleanly exit distributed mode
  126. def compute_accuracy(model, dataloader, device):
  127. model = model.eval()
  128. correct = 0.0
  129. total_examples = 0
  130. for idx, (features, labels) in enumerate(dataloader):
  131. features, labels = features.to(device), labels.to(device)
  132. with torch.no_grad():
  133. logits = model(features)
  134. predictions = torch.argmax(logits, dim=1)
  135. compare = labels == predictions
  136. correct += torch.sum(compare)
  137. total_examples += len(compare)
  138. return (correct / total_examples).item()
  139. if __name__ == "__main__":
  140. print("PyTorch version:", torch.__version__)
  141. print("CUDA available:", torch.cuda.is_available())
  142. print("Number of GPUs available:", torch.cuda.device_count())
  143. torch.manual_seed(123)
  144. # NEW: spawn new processes
  145. # note that spawn will automatically pass the rank
  146. num_epochs = 3
  147. world_size = torch.cuda.device_count()
  148. mp.spawn(main, args=(world_size, num_epochs), nprocs=world_size)
  149. # nprocs=world_size spawns one process per GPU