Distributed Training
Train models across multiple GPUs and nodes using PyTorch DistributedDataParallel (DDP).
Overview
This workflow demonstrates:
- PyTorch DistributedDataParallel (DDP)
- Multi-node, multi-GPU training
- ResNet on CIFAR-10
- Distributed data sampling
- Synchronized gradient updates
- Metrics aggregation
Configuration
from mixtrain import MixFlow, sandbox
from typing import Literal
class MultiNodeTraining(MixFlow):
# Multi-node sandbox configuration
_sandbox = sandbox(
image="nvcr.io/nvidia/pytorch:24.01-py3",
gpu="A100",
memory=40960,
num_nodes=4,
timeout=7200,
)
def run(
self,
model_name: Literal["resnet18", "resnet50", "resnet101"] = "resnet50",
num_classes: int = 10,
batch_size: int = 128,
epochs: int = 100,
learning_rate: float = 0.1,
momentum: float = 0.9,
weight_decay: float = 1e-4,
backend: Literal["nccl", "gloo"] = "nccl",
):
"""Execute distributed training.
Args:
model_name: Model architecture
num_classes: Number of output classes
batch_size: Per-GPU batch size
epochs: Training epochs
learning_rate: Initial learning rate
momentum: SGD momentum
weight_decay: Weight decay
backend: Distributed backend
"""
# Training logic here...
passKey Concepts
Process Group Initialization
Each process (GPU) has a unique rank and communicates via a backend:
import torch.distributed as dist
dist.init_process_group(
backend="nccl",
init_method="env://",
world_size=world_size,
rank=rank
)Distributed Data Sampling
DistributedSampler partitions the dataset across GPUs:
from torch.utils.data.distributed import DistributedSampler
sampler = DistributedSampler(dataset, num_replicas=world_size, rank=rank)
loader = DataLoader(dataset, sampler=sampler, batch_size=batch_size)Model Wrapping
Wrap your model with DDP:
from torch.nn.parallel import DistributedDataParallel as DDP
model = DDP(model, device_ids=[local_rank])Environment Variables
These are set automatically by Mixtrain:
| Variable | Description |
|---|---|
RANK | Global rank (0 to world_size-1) |
WORLD_SIZE | Total number of processes |
LOCAL_RANK | Rank within the node |
MASTER_ADDR | Master node IP |
MASTER_PORT | Communication port |
Running
mixtrain workflow create multinode_pytorch_training.py \
--name resnet-distributed \
--description "Multi-node distributed ResNet training"
mixtrain workflow run resnet-distributedExample Sandbox Configuration
| Resource | Value |
|---|---|
| GPUs | 4x A100 (40GB each) |
| Memory | 40GB RAM per node |
| Time | ~1-2 hours for 100 epochs |
| Effective batch | 128 × 4 GPUs × 4 nodes = 2048 |
Troubleshooting
Training hangs
- Check network connectivity between nodes
- Verify environment variables
- Try
backend="gloo"instead of"nccl"
CUDA OOM
- Reduce
batch_size - Enable gradient checkpointing
- Use mixed precision training
Next Steps
- Workflows Guide - Full documentation
- PyTorch DDP Tutorial