MixtrainDocsBlog

Distributed Training

Train models across multiple GPUs and nodes using PyTorch DistributedDataParallel (DDP).

Overview

This workflow demonstrates:

  • PyTorch DistributedDataParallel (DDP)
  • Multi-node, multi-GPU training
  • Distributed data sampling
  • Synchronized gradient updates
  • Metrics aggregation

Configuration

from mixtrain import MixFlow, Sandbox
from typing import Literal


class MultiNodeTraining(MixFlow):
    # Multi-node sandbox configuration
    _sandbox = Sandbox(
        image="nvcr.io/nvidia/pytorch:24.01-py3",
        gpu="A100",
        memory=40960,
        num_nodes=4,
        timeout=7200,
    )

    def run(
        self,
        dataset: str = "my-training-data",
        batch_size: int = 128,
        epochs: int = 100,
        learning_rate: float = 0.1,
        momentum: float = 0.9,
        weight_decay: float = 1e-4,
        backend: Literal["nccl", "gloo"] = "nccl",
    ):
        # Training logic here...
        pass

Key Concepts

Process Group Initialization

Each process (GPU) has a unique rank and communicates via a backend:

import torch.distributed as dist

dist.init_process_group(
    backend="nccl",
    init_method="env://",
    world_size=world_size,
    rank=rank
)

Distributed Data Sampling

DistributedSampler partitions the dataset across GPUs:

from torch.utils.data.distributed import DistributedSampler

sampler = DistributedSampler(dataset, num_replicas=world_size, rank=rank)
loader = DataLoader(dataset, sampler=sampler, batch_size=batch_size)

Model Wrapping

Wrap your model with DDP:

from torch.nn.parallel import DistributedDataParallel as DDP

model = DDP(model, device_ids=[local_rank])

Multi-GPU (Single Node)

For multi-GPU training on a single node, use gpu_per_node to request multiple GPUs:

import os
from mixtrain import MixFlow, Sandbox

class MultiGPUTraining(MixFlow):
    _sandbox = Sandbox(
        image="nvcr.io/nvidia/pytorch:25.12-py3",
        gpu="A100",
        gpu_per_node=4,   # 4 GPUs on one node
        memory=40960,
        timeout=3600,
    )

    def run(self):
        import torch.multiprocessing as mp

        gpu_per_node = int(os.environ.get("GPU_PER_NODE", "1"))
        mp.spawn(self._train_worker, nprocs=gpu_per_node, args=(gpu_per_node,))

    @staticmethod
    def _train_worker(local_rank, gpu_per_node):
        import torch
        import torch.distributed as dist

        node_rank = int(os.environ.get("RANK", "0"))
        num_nodes = int(os.environ.get("WORLD_SIZE", "1"))

        # Promote node-level env vars to per-GPU-process level
        os.environ["RANK"] = str(node_rank * gpu_per_node + local_rank)
        os.environ["LOCAL_RANK"] = str(local_rank)
        os.environ["WORLD_SIZE"] = str(num_nodes * gpu_per_node)

        torch.cuda.set_device(local_rank)
        dist.init_process_group("nccl")
        # ... training loop with DDP

The platform sets GPU_PER_NODE automatically. You handle multi-process spawning using standard tools (torch.multiprocessing.spawn, torchrun, etc.).

Environment Variables

These are set automatically by Mixtrain:

VariableDescription
RANKGlobal rank (0 to world_size-1)
WORLD_SIZETotal number of processes
LOCAL_RANKRank within the node
MASTER_ADDRMaster node IP
MASTER_PORTCommunication port
GPU_PER_NODENumber of GPUs per node (set when GPU is specified)

Running

mixtrain workflow create multinode_training.py \
  --name distributed-training \
  --description "Multi-node distributed training"

mixtrain workflow run distributed-training

Example Sandbox Configuration

ResourceValue
GPUs4x A100 (40GB each)
Memory40GB RAM per node
Time~1-2 hours for 100 epochs
Effective batch128 × 4 GPUs × 4 nodes = 2048

Troubleshooting

Training hangs

  • Check network connectivity between nodes
  • Verify environment variables
  • Try backend="gloo" instead of "nccl"

CUDA OOM

  • Reduce batch_size
  • Enable gradient checkpointing
  • Use mixed precision training

Next Steps

On this page