MixtrainDocsBlog

PyTorch Training

Train a PyTorch model on a Mixtrain dataset with GPU support and checkpointing.

Overview

This workflow demonstrates:

  • Loading data from a Mixtrain Dataset with Dataset.to_torch()
  • Training on a GPU with checkpointing
  • Training metrics logging

Workflow

from mixtrain import Dataset, MixFlow, Sandbox
import torch
import torch.nn as nn


class PyTorchTraining(MixFlow):
    _sandbox = Sandbox(
        image="nvcr.io/nvidia/pytorch:24.01-py3",
        gpu="T4",
        memory=8192,
        timeout=1800,
    )

    def run(
        self,
        dataset: str = "my-training-data",
        batch_size: int = 64,
        epochs: int = 10,
        learning_rate: float = 0.01,
        seed: int = 42,
    ):
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

        # Load dataset and split
        ds = Dataset(dataset)
        splits = ds.shuffle(seed).train_test_split(test_size=0.2)
        train_loader = splits["train"].to_torch(batch_size=batch_size)
        test_loader = splits["test"].to_torch(batch_size=1000)

        # Your model
        model = YourModel().to(device)  # any nn.Module
        optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)

        # Training loop
        for epoch in range(epochs):
            model.train()
            for batch in train_loader:
                images = batch["image"].to(device)
                labels = batch["label"].to(device)
                optimizer.zero_grad()
                loss = torch.nn.functional.cross_entropy(model(images), labels)
                loss.backward()
                optimizer.step()

Running

Create the workflow

mixtrain workflow create simple_pytorch_training.py \
  --name pytorch-training \
  --description "Simple PyTorch training workflow"

Run training

mixtrain workflow run pytorch-training

Custom parameters

mixtrain workflow run pytorch-training \
  --input '{"dataset": "my-new-training-data", "epochs": 20, "batch_size": 128, "learning_rate": 0.001}'

Sandbox Configuration

ResourceValue
GPU1x T4 (16GB VRAM)
Memory8GB RAM
Timeout30 minutes

Next Steps

On this page