MixtrainDocs

Distributed Training

Train models across multiple GPUs and nodes using PyTorch DistributedDataParallel (DDP).

Overview

This workflow demonstrates:

  • PyTorch DistributedDataParallel (DDP)
  • Multi-node, multi-GPU training
  • ResNet on CIFAR-10
  • Distributed data sampling
  • Synchronized gradient updates
  • Metrics aggregation

Configuration

from mixtrain import MixFlow, sandbox
from typing import Literal


class MultiNodeTraining(MixFlow):
    # Multi-node sandbox configuration
    _sandbox = sandbox(
        image="nvcr.io/nvidia/pytorch:24.01-py3",
        gpu="A100",
        memory=40960,
        num_nodes=4,
        timeout=7200,
    )

    def run(
        self,
        model_name: Literal["resnet18", "resnet50", "resnet101"] = "resnet50",
        num_classes: int = 10,
        batch_size: int = 128,
        epochs: int = 100,
        learning_rate: float = 0.1,
        momentum: float = 0.9,
        weight_decay: float = 1e-4,
        backend: Literal["nccl", "gloo"] = "nccl",
    ):
        """Execute distributed training.

        Args:
            model_name: Model architecture
            num_classes: Number of output classes
            batch_size: Per-GPU batch size
            epochs: Training epochs
            learning_rate: Initial learning rate
            momentum: SGD momentum
            weight_decay: Weight decay
            backend: Distributed backend
        """
        # Training logic here...
        pass

Key Concepts

Process Group Initialization

Each process (GPU) has a unique rank and communicates via a backend:

import torch.distributed as dist

dist.init_process_group(
    backend="nccl",
    init_method="env://",
    world_size=world_size,
    rank=rank
)

Distributed Data Sampling

DistributedSampler partitions the dataset across GPUs:

from torch.utils.data.distributed import DistributedSampler

sampler = DistributedSampler(dataset, num_replicas=world_size, rank=rank)
loader = DataLoader(dataset, sampler=sampler, batch_size=batch_size)

Model Wrapping

Wrap your model with DDP:

from torch.nn.parallel import DistributedDataParallel as DDP

model = DDP(model, device_ids=[local_rank])

Environment Variables

These are set automatically by Mixtrain:

VariableDescription
RANKGlobal rank (0 to world_size-1)
WORLD_SIZETotal number of processes
LOCAL_RANKRank within the node
MASTER_ADDRMaster node IP
MASTER_PORTCommunication port

Running

mixtrain workflow create multinode_pytorch_training.py \
  --name resnet-distributed \
  --description "Multi-node distributed ResNet training"

mixtrain workflow run resnet-distributed

Example Sandbox Configuration

ResourceValue
GPUs4x A100 (40GB each)
Memory40GB RAM per node
Time~1-2 hours for 100 epochs
Effective batch128 × 4 GPUs × 4 nodes = 2048

Troubleshooting

Training hangs

  • Check network connectivity between nodes
  • Verify environment variables
  • Try backend="gloo" instead of "nccl"

CUDA OOM

  • Reduce batch_size
  • Enable gradient checkpointing
  • Use mixed precision training

Next Steps

On this page