from mixtrain import MixFlow, sandboxMixFlow is the base class for building reusable ML workflows that run in Mixtrain's sandboxed environment.
Basic Structure
from mixtrain import MixFlow
class MyWorkflow(MixFlow):
def setup(self):
# Initialize resources (called once)
print("Initializing...")
def run(self, batch_size: int = 32):
"""Main workflow logic.
Args:
batch_size: Batch size for processing
"""
print(f"Processing with batch_size={batch_size}")
return {"status": "completed"}
def cleanup(self):
# Release resources (called after run)
print("Cleaning up...")Lifecycle Methods
| Method | Required | Description |
|---|---|---|
setup() | No | Initialize resources. Can receive inputs it declares in signature. Called before run(). |
run() | Yes | Execute main workflow logic. Inputs defined in method signature. |
cleanup() | No | Release resources, save artifacts. Called after run(). |
setup()
def setup(self):
"""Initialize resources before the main workflow runs."""
self.data = self._load_data()
self.model = self._load_model()Or with inputs:
def setup(self, dataset_name: str):
"""Initialize with a specific dataset."""
self.data = self.mix.get_dataset(dataset_name)run()
def run(self, batch_size: int = 32, epochs: int = 10):
"""Execute main workflow logic.
Args:
batch_size: Batch size for processing
epochs: Number of training epochs
"""
for epoch in range(epochs):
for batch in self.data.batches(batch_size):
result = self.model.process(batch)
self._save_result(result)
return {"status": "completed"}cleanup()
def cleanup(self):
"""Release resources after workflow completes."""
self._save_final_results()
print("Workflow complete!")Defining Inputs
Define inputs as parameters in the run() method signature:
from mixtrain import MixFlow
from typing import Literal
class MyWorkflow(MixFlow):
def run(
self,
input_path: str, # Required input (no default)
batch_size: int = 32, # Optional with default
use_gpu: bool = True,
mode: Literal["train", "eval"] = "train", # Dropdown options in UI
):
"""Process data pipeline.
Args:
input_path: Path to input data (required)
batch_size: Batch size for processing
use_gpu: Whether to use GPU acceleration
mode: Processing mode
"""
return {"status": "completed"}Calling workflows
Both styles work:
workflow = MyWorkflow()
# Keyword arguments
result = workflow.run(input_path="/data", batch_size=64)
# Dict input
result = workflow.run({"input_path": "/data", "batch_size": 64})Sandbox Configuration
Configure the runtime environment using the sandbox() function:
from mixtrain import MixFlow, sandbox
class GPUWorkflow(MixFlow):
_sandbox = sandbox(
image="nvcr.io/nvidia/pytorch:24.01-py3",
gpu="T4",
memory=8192,
timeout=1800,
)| Parameter | Type | Description |
|---|---|---|
image | str | Docker image |
gpu | str | GPU type: "T4", "A10G", "L4", "A100", "H100", "H200", "B200" |
cpu | int | Number of CPU cores |
memory | int | Memory in MB |
timeout | int | Timeout in seconds |
idle_timeout | int | Idle timeout in seconds |
num_nodes | int | Number of nodes for distributed training |
cloud | str | Cloud provider preference |
region | str | Region preference |
ephemeral_disk | int | Ephemeral disk size in MB |
block_network | bool | Block network access |
mixtrain_version | str | Mixtrain SDK version (e.g., "0.1.23", ">=0.1.20") |
Accessing MixClient
Every MixFlow has a built-in self.mix client:
class MyWorkflow(MixFlow):
def run(self, model_name: str):
# Access Mixtrain resources
secret = self.mix.get_secret("api_key")
dataset = self.mix.get_dataset("my-dataset")
result = self.mix.run_model(model_name, {"prompt": "hello"})
return {"result": result}Complete Example
from mixtrain import MixFlow, sandbox
import torch
class TrainingWorkflow(MixFlow):
_sandbox = sandbox(
image="nvcr.io/nvidia/pytorch:24.01-py3",
gpu="T4",
memory=8192,
timeout=3600,
)
def setup(self):
torch.manual_seed(42)
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
def run(
self,
dataset_name: str,
batch_size: int = 64,
epochs: int = 10,
learning_rate: float = 0.01,
):
"""Train a model on a dataset.
Args:
dataset_name: Dataset to train on
batch_size: Batch size
epochs: Training epochs
learning_rate: Learning rate
"""
dataset = self.mix.get_dataset(dataset_name)
data = self._prepare_data(dataset)
model = self._create_model().to(self.device)
optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate)
criterion = torch.nn.CrossEntropyLoss()
for epoch in range(epochs):
total_loss = 0
for batch in data:
optimizer.zero_grad()
outputs = model(batch["inputs"].to(self.device))
loss = criterion(outputs, batch["labels"].to(self.device))
loss.backward()
optimizer.step()
total_loss += loss.item()
print(f"Epoch {epoch + 1}/{epochs}, Loss: {total_loss / len(data):.4f}")
torch.save(model.state_dict(), "model.pt")
return {"status": "complete"}
def cleanup(self):
print("Training complete! Model saved.")