MixtrainDocs

PyTorch Training

Train a simple Convolutional Neural Network on the MNIST dataset.

Overview

This workflow demonstrates:

  • Loading data from Mixtrain Dataset or torchvision
  • Using Dataset.to_torch() for zero-copy PyTorch integration
  • Training and evaluation
  • Model checkpointing
  • Training metrics logging

Configuration

from mixtrain import Dataset, MixFlow, sandbox


class SimplePyTorchTraining(MixFlow):
    # Sandbox configuration
    _sandbox = sandbox(
        image="nvcr.io/nvidia/pytorch:24.01-py3",
        gpu="T4",
        memory=8192,
        timeout=1800,
    )

    def run(
        self,
        dataset: Dataset | None = None,  # Uses torchvision MNIST if not provided
        batch_size: int = 64,
        epochs: int = 10,
        learning_rate: float = 0.01,
        seed: int = 42,
    ):
        import torch
        # ... training logic

Loading Data

From Mixtrain Dataset

Use Dataset.to_torch() for zero-copy tensor conversion:

from mixtrain import Dataset

# Load dataset and split
ds = Dataset("my-training-data")
splits = ds.shuffle(42).train_test_split(test_size=0.2)

# Get PyTorch DataLoaders with zero-copy tensors
train_loader = splits["train"].to_torch(batch_size=64)
test_loader = splits["test"].to_torch(batch_size=1000)

# Training loop - batches are dicts of tensors
for batch in train_loader:
    images = batch["image"].to(device)  # tensor
    labels = batch["label"].to(device)  # tensor
    # ... training step

From torchvision

Standard PyTorch approach:

from torchvision import datasets, transforms
from torch.utils.data import DataLoader

transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])

train_dataset = datasets.MNIST("./data", train=True, download=True, transform=transform)
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)

Model Architecture

import torch.nn as nn
import torch.nn.functional as F

class SimpleCNN(nn.Module):
    def __init__(self):
        super(SimpleCNN, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, 3, 1)
        self.conv2 = nn.Conv2d(32, 64, 3, 1)
        self.dropout1 = nn.Dropout(0.25)
        self.dropout2 = nn.Dropout(0.5)
        self.fc1 = nn.Linear(9216, 128)
        self.fc2 = nn.Linear(128, 10)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x))
        x = F.max_pool2d(x, 2)
        x = self.dropout1(x)
        x = torch.flatten(x, 1)
        x = F.relu(self.fc1(x))
        x = self.dropout2(x)
        return F.log_softmax(self.fc2(x), dim=1)

Running

Create the workflow

mixtrain workflow create simple_pytorch_training.py \
  --name mnist-training \
  --description "Simple CNN training on MNIST"

Run with torchvision MNIST

mixtrain workflow run mnist-training

Run with Mixtrain Dataset

mixtrain workflow run mnist-training \
  --input '{"dataset": "my-mnist-data", "epochs": 20}'

Custom parameters

mixtrain workflow run mnist-training \
  --input '{"epochs": 20, "batch_size": 128, "learning_rate": 0.001}'

Example Sandbox Configuration

ResourceValue
GPU1x T4 (16GB VRAM)
Memory8GB RAM
Time~10-15 minutes for 10 epochs
Cost~$0.50/hour

Next Steps

On this page