|
|
@@ -11,6 +11,7 @@ from torch.utils.data import Dataset, DataLoader
|
|
|
|
|
|
# NEW imports:
|
|
|
import os
|
|
|
+import platform
|
|
|
import torch.multiprocessing as mp
|
|
|
from torch.utils.data.distributed import DistributedSampler
|
|
|
from torch.nn.parallel import DistributedDataParallel as DDP
|
|
|
@@ -30,11 +31,19 @@ def ddp_setup(rank, world_size):
|
|
|
os.environ["MASTER_ADDR"] = "localhost"
|
|
|
# any free port on the machine
|
|
|
os.environ["MASTER_PORT"] = "12345"
|
|
|
+ if platform.system() == "Windows":
|
|
|
+ # Disable libuv because PyTorch for Windows isn't built with support
|
|
|
+ os.environ["USE_LIBUV"] = "0"
|
|
|
|
|
|
# initialize process group
|
|
|
- # Windows users may have to use "gloo" instead of "nccl" as backend
|
|
|
- # nccl: NVIDIA Collective Communication Library
|
|
|
- init_process_group(backend="nccl", rank=rank, world_size=world_size)
|
|
|
+ if platform.system() == "Windows":
|
|
|
+ # Windows users may have to use "gloo" instead of "nccl" as backend
|
|
|
+ # gloo: Facebook Collective Communication Library
|
|
|
+ init_process_group(backend="gloo", rank=rank, world_size=world_size)
|
|
|
+ else:
|
|
|
+ # nccl: NVIDIA Collective Communication Library
|
|
|
+ init_process_group(backend="nccl", rank=rank, world_size=world_size)
|
|
|
+
|
|
|
torch.cuda.set_device(rank)
|
|
|
|
|
|
|