MixtrainDocs
from mixtrain import MixFlow, sandbox

MixFlow is the base class for building reusable ML workflows that run in Mixtrain's sandboxed environment.

Basic Structure

from mixtrain import MixFlow

class MyWorkflow(MixFlow):
    def setup(self):
        # Initialize resources (called once)
        print("Initializing...")

    def run(self, batch_size: int = 32):
        """Main workflow logic.

        Args:
            batch_size: Batch size for processing
        """
        print(f"Processing with batch_size={batch_size}")
        return {"status": "completed"}

    def cleanup(self):
        # Release resources (called after run)
        print("Cleaning up...")

Lifecycle Methods

MethodRequiredDescription
setup()NoInitialize resources. Can receive inputs it declares in signature. Called before run().
run()YesExecute main workflow logic. Inputs defined in method signature.
cleanup()NoRelease resources, save artifacts. Called after run().

setup()

def setup(self):
    """Initialize resources before the main workflow runs."""
    self.data = self._load_data()
    self.model = self._load_model()

Or with inputs:

def setup(self, dataset_name: str):
    """Initialize with a specific dataset."""
    self.data = self.mix.get_dataset(dataset_name)

run()

def run(self, batch_size: int = 32, epochs: int = 10):
    """Execute main workflow logic.

    Args:
        batch_size: Batch size for processing
        epochs: Number of training epochs
    """
    for epoch in range(epochs):
        for batch in self.data.batches(batch_size):
            result = self.model.process(batch)
            self._save_result(result)
    return {"status": "completed"}

cleanup()

def cleanup(self):
    """Release resources after workflow completes."""
    self._save_final_results()
    print("Workflow complete!")

Defining Inputs

Define inputs as parameters in the run() method signature:

from mixtrain import MixFlow
from typing import Literal

class MyWorkflow(MixFlow):
    def run(
        self,
        input_path: str,  # Required input (no default)
        batch_size: int = 32,  # Optional with default
        use_gpu: bool = True,
        mode: Literal["train", "eval"] = "train",  # Dropdown options in UI
    ):
        """Process data pipeline.

        Args:
            input_path: Path to input data (required)
            batch_size: Batch size for processing
            use_gpu: Whether to use GPU acceleration
            mode: Processing mode
        """
        return {"status": "completed"}

Calling workflows

Both styles work:

workflow = MyWorkflow()

# Keyword arguments
result = workflow.run(input_path="/data", batch_size=64)

# Dict input
result = workflow.run({"input_path": "/data", "batch_size": 64})

Sandbox Configuration

Configure the runtime environment using the sandbox() function:

from mixtrain import MixFlow, sandbox

class GPUWorkflow(MixFlow):
    _sandbox = sandbox(
        image="nvcr.io/nvidia/pytorch:24.01-py3",
        gpu="T4",
        memory=8192,
        timeout=1800,
    )
ParameterTypeDescription
imagestrDocker image
gpustrGPU type: "T4", "A10G", "L4", "A100", "H100", "H200", "B200"
cpuintNumber of CPU cores
memoryintMemory in MB
timeoutintTimeout in seconds
idle_timeoutintIdle timeout in seconds
num_nodesintNumber of nodes for distributed training
cloudstrCloud provider preference
regionstrRegion preference
ephemeral_diskintEphemeral disk size in MB
block_networkboolBlock network access
mixtrain_versionstrMixtrain SDK version (e.g., "0.1.23", ">=0.1.20")

Accessing MixClient

Every MixFlow has a built-in self.mix client:

class MyWorkflow(MixFlow):
    def run(self, model_name: str):
        # Access Mixtrain resources
        secret = self.mix.get_secret("api_key")
        dataset = self.mix.get_dataset("my-dataset")
        result = self.mix.run_model(model_name, {"prompt": "hello"})
        return {"result": result}

Complete Example

from mixtrain import MixFlow, sandbox
import torch

class TrainingWorkflow(MixFlow):
    _sandbox = sandbox(
        image="nvcr.io/nvidia/pytorch:24.01-py3",
        gpu="T4",
        memory=8192,
        timeout=3600,
    )

    def setup(self):
        torch.manual_seed(42)
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    def run(
        self,
        dataset_name: str,
        batch_size: int = 64,
        epochs: int = 10,
        learning_rate: float = 0.01,
    ):
        """Train a model on a dataset.

        Args:
            dataset_name: Dataset to train on
            batch_size: Batch size
            epochs: Training epochs
            learning_rate: Learning rate
        """
        dataset = self.mix.get_dataset(dataset_name)
        data = self._prepare_data(dataset)

        model = self._create_model().to(self.device)
        optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate)
        criterion = torch.nn.CrossEntropyLoss()

        for epoch in range(epochs):
            total_loss = 0
            for batch in data:
                optimizer.zero_grad()
                outputs = model(batch["inputs"].to(self.device))
                loss = criterion(outputs, batch["labels"].to(self.device))
                loss.backward()
                optimizer.step()
                total_loss += loss.item()

            print(f"Epoch {epoch + 1}/{epochs}, Loss: {total_loss / len(data):.4f}")

        torch.save(model.state_dict(), "model.pt")
        return {"status": "complete"}

    def cleanup(self):
        print("Training complete! Model saved.")

On this page