PyTorch Training
Train a simple Convolutional Neural Network on the MNIST dataset.
Overview
This workflow demonstrates:
- Loading data from Mixtrain Dataset or torchvision
- Using
Dataset.to_torch()for zero-copy PyTorch integration - Training and evaluation
- Model checkpointing
- Training metrics logging
Configuration
from mixtrain import Dataset, MixFlow, sandbox
class SimplePyTorchTraining(MixFlow):
# Sandbox configuration
_sandbox = sandbox(
image="nvcr.io/nvidia/pytorch:24.01-py3",
gpu="T4",
memory=8192,
timeout=1800,
)
def run(
self,
dataset: Dataset | None = None, # Uses torchvision MNIST if not provided
batch_size: int = 64,
epochs: int = 10,
learning_rate: float = 0.01,
seed: int = 42,
):
import torch
# ... training logicLoading Data
From Mixtrain Dataset
Use Dataset.to_torch() for zero-copy tensor conversion:
from mixtrain import Dataset
# Load dataset and split
ds = Dataset("my-training-data")
splits = ds.shuffle(42).train_test_split(test_size=0.2)
# Get PyTorch DataLoaders with zero-copy tensors
train_loader = splits["train"].to_torch(batch_size=64)
test_loader = splits["test"].to_torch(batch_size=1000)
# Training loop - batches are dicts of tensors
for batch in train_loader:
images = batch["image"].to(device) # tensor
labels = batch["label"].to(device) # tensor
# ... training stepFrom torchvision
Standard PyTorch approach:
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
])
train_dataset = datasets.MNIST("./data", train=True, download=True, transform=transform)
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)Model Architecture
import torch.nn as nn
import torch.nn.functional as F
class SimpleCNN(nn.Module):
def __init__(self):
super(SimpleCNN, self).__init__()
self.conv1 = nn.Conv2d(1, 32, 3, 1)
self.conv2 = nn.Conv2d(32, 64, 3, 1)
self.dropout1 = nn.Dropout(0.25)
self.dropout2 = nn.Dropout(0.5)
self.fc1 = nn.Linear(9216, 128)
self.fc2 = nn.Linear(128, 10)
def forward(self, x):
x = F.relu(self.conv1(x))
x = F.relu(self.conv2(x))
x = F.max_pool2d(x, 2)
x = self.dropout1(x)
x = torch.flatten(x, 1)
x = F.relu(self.fc1(x))
x = self.dropout2(x)
return F.log_softmax(self.fc2(x), dim=1)Running
Create the workflow
mixtrain workflow create simple_pytorch_training.py \
--name mnist-training \
--description "Simple CNN training on MNIST"Run with torchvision MNIST
mixtrain workflow run mnist-trainingRun with Mixtrain Dataset
mixtrain workflow run mnist-training \
--input '{"dataset": "my-mnist-data", "epochs": 20}'Custom parameters
mixtrain workflow run mnist-training \
--input '{"epochs": 20, "batch_size": 128, "learning_rate": 0.001}'Example Sandbox Configuration
| Resource | Value |
|---|---|
| GPU | 1x T4 (16GB VRAM) |
| Memory | 8GB RAM |
| Time | ~10-15 minutes for 10 epochs |
| Cost | ~$0.50/hour |
Next Steps
- Distributed Training - Scale to multiple GPUs
- Datasets Guide - Dataset SDK documentation
- Workflows Guide - Full workflow documentation