Kaynağa Gözat

Improve DDP on Windows (#376)

* Update DDP-script.py for Windows

* Windows handling

---------

Co-authored-by: Nathan Brown <nathan@nkbrown.us>
Sebastian Raschka 1 yıl önce
ebeveyn
işleme
4caafddb93
1 değiştirilmiş dosya ile 12 ekleme ve 3 silme
  1. 12 3
      appendix-A/01_main-chapter-code/DDP-script.py

+ 12 - 3
appendix-A/01_main-chapter-code/DDP-script.py

@@ -11,6 +11,7 @@ from torch.utils.data import Dataset, DataLoader
 
 # NEW imports:
 import os
+import platform
 import torch.multiprocessing as mp
 from torch.utils.data.distributed import DistributedSampler
 from torch.nn.parallel import DistributedDataParallel as DDP
@@ -30,11 +31,19 @@ def ddp_setup(rank, world_size):
     os.environ["MASTER_ADDR"] = "localhost"
     # any free port on the machine
     os.environ["MASTER_PORT"] = "12345"
+    if platform.system() == "Windows":
+        # Disable libuv because PyTorch for Windows isn't built with support
+        os.environ["USE_LIBUV"] = "0"
 
     # initialize process group
-    # Windows users may have to use "gloo" instead of "nccl" as backend
-    # nccl: NVIDIA Collective Communication Library
-    init_process_group(backend="nccl", rank=rank, world_size=world_size)
+    if platform.system() == "Windows":
+        # Windows users may have to use "gloo" instead of "nccl" as backend
+        # gloo: Facebook Collective Communication Library
+        init_process_group(backend="gloo", rank=rank, world_size=world_size)
+    else:
+        # nccl: NVIDIA Collective Communication Library
+        init_process_group(backend="nccl", rank=rank, world_size=world_size)
+
     torch.cuda.set_device(rank)