MixtrainDocs
from mixtrain import MixModel, sandbox

MixModel is the base class for creating models that run in Mixtrain.

Basic Structure

from mixtrain import MixModel

class MyModel(MixModel):
    def setup(self):
        # Initialize resources (called once)
        self.model = self._load_model()

    def run(self, text: str, temperature: float = 0.7):
        """Process a single request.

        Args:
            text: Text to process
            temperature: Sampling temperature
        """
        return {"result": self.model.process(text, temperature)}

    def run_batch(self, batch):
        # Process batch (optional, for efficiency)
        return [self.run(**item) for item in batch]

Lifecycle Methods

MethodRequiredDescription
setup()NoInitialize resources. Can receive inputs it declares in signature. Called once.
run()YesProcess a single request. Inputs defined in method signature.
run_batch(batch)NoProcess multiple requests. Default calls run() for each.

setup()

def setup(self):
    """Initialize resources. Called once when the model container starts."""
    self.model = load_my_model()
    self.tokenizer = load_tokenizer()

Or with inputs:

def setup(self, model_path: str):
    """Initialize with a specific model path."""
    self.model = load_model(model_path)

run()

def run(self, prompt: str, max_tokens: int = 512) -> dict:
    """Process a single request.

    Args:
        prompt: Text prompt to process
        max_tokens: Maximum output tokens

    Returns:
        dict with output data
    """
    result = self.model.generate(prompt, max_tokens=max_tokens)
    return {"text": result}

run_batch()

def run_batch(self, batch: list[dict]) -> list[dict]:
    """Process multiple requests efficiently.

    Args:
        batch: list of input dicts

    Returns:
        list of output dicts
    """
    return [self.run(**item) for item in batch]

Defining Inputs

Define inputs as parameters in the run() method signature:

from mixtrain import MixModel
from typing import Literal

class MyModel(MixModel):
    def run(
        self,
        prompt: str,  # Required input
        temperature: float = 0.7,  # Optional with default
        style: Literal["normal", "fancy"] = "normal",  # Dropdown options in UI
    ):
        """Generate text.

        Args:
            prompt: Text prompt to process
            temperature: Sampling temperature (0.0 to 1.0)
            style: Output style
        """
        return {"result": self.model.generate(prompt, temperature, style)}

Calling models

Both styles work:

model = MyModel()

# Keyword arguments
result = model.run(prompt="hello", temperature=0.8)

# Dict input
result = model.run({"prompt": "hello", "temperature": 0.8})

Sandbox Configuration

Configure the runtime environment using the sandbox() function:

from mixtrain import MixModel, sandbox

class GPUModel(MixModel):
    _sandbox = sandbox(
        image="nvcr.io/nvidia/pytorch:24.01-py3",
        gpu="T4",
        memory=8192,
        timeout=300,
    )
ParameterTypeDescription
imagestrDocker image
gpustrGPU type: "T4", "A10G", "L4", "A100", "H100", "H200", "B200"
cpuintNumber of CPU cores
memoryintMemory in MB
timeoutintTimeout in seconds
idle_timeoutintIdle timeout in seconds
cloudstrCloud provider preference
regionstrRegion preference
ephemeral_diskintEphemeral disk size in MB
block_networkboolBlock network access
mixtrain_versionstrMixtrain SDK version (e.g., "0.1.23", ">=0.1.20")

Accessing MixClient

Every MixModel has a built-in self.mix client:

class MyModel(MixModel):
    def run(self, prompt: str):
        # Access other Mixtrain resources
        secret = self.mix.get_secret("api_key")
        dataset = self.mix.get_dataset("my-dataset")
        return {"result": "..."}

Complete Example

from mixtrain import MixModel, sandbox
import torch

class ImageClassifier(MixModel):
    _sandbox = sandbox(
        image="nvcr.io/nvidia/pytorch:24.01-py3",
        gpu="T4",
        memory=4096,
        timeout=60,
    )

    def setup(self):
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.model = None  # Loaded on first run

    def run(
        self,
        image_url: str,
        model_path: str,
        threshold: float = 0.5
    ):
        """Classify an image.

        Args:
            image_url: Image URL to classify
            model_path: Path to model weights
            threshold: Confidence threshold
        """
        if self.model is None:
            self.model = torch.load(model_path).to(self.device)
            self.model.eval()

        image = self._download_and_preprocess(image_url)

        with torch.no_grad():
            output = self.model(image.to(self.device))
            probs = torch.softmax(output, dim=1)
            conf, pred = probs.max(1)

        return {
            "label": self.classes[pred.item()],
            "confidence": conf.item()
        }

    def run_batch(self, batch):
        # More efficient batch processing
        images = [self._download_and_preprocess(b["image_url"]) for b in batch]
        batch_tensor = torch.stack(images).to(self.device)

        with torch.no_grad():
            outputs = self.model(batch_tensor)
            probs = torch.softmax(outputs, dim=1)
            confs, preds = probs.max(1)

        return [
            {"label": self.classes[p.item()], "confidence": c.item()}
            for p, c in zip(preds, confs)
        ]

On this page