from mixtrain import MixModel, sandboxMixModel is the base class for creating models that run in Mixtrain.
Basic Structure
from mixtrain import MixModel
class MyModel(MixModel):
def setup(self):
# Initialize resources (called once)
self.model = self._load_model()
def run(self, text: str, temperature: float = 0.7):
"""Process a single request.
Args:
text: Text to process
temperature: Sampling temperature
"""
return {"result": self.model.process(text, temperature)}
def run_batch(self, batch):
# Process batch (optional, for efficiency)
return [self.run(**item) for item in batch]Lifecycle Methods
| Method | Required | Description |
|---|---|---|
setup() | No | Initialize resources. Can receive inputs it declares in signature. Called once. |
run() | Yes | Process a single request. Inputs defined in method signature. |
run_batch(batch) | No | Process multiple requests. Default calls run() for each. |
setup()
def setup(self):
"""Initialize resources. Called once when the model container starts."""
self.model = load_my_model()
self.tokenizer = load_tokenizer()Or with inputs:
def setup(self, model_path: str):
"""Initialize with a specific model path."""
self.model = load_model(model_path)run()
def run(self, prompt: str, max_tokens: int = 512) -> dict:
"""Process a single request.
Args:
prompt: Text prompt to process
max_tokens: Maximum output tokens
Returns:
dict with output data
"""
result = self.model.generate(prompt, max_tokens=max_tokens)
return {"text": result}run_batch()
def run_batch(self, batch: list[dict]) -> list[dict]:
"""Process multiple requests efficiently.
Args:
batch: list of input dicts
Returns:
list of output dicts
"""
return [self.run(**item) for item in batch]Defining Inputs
Define inputs as parameters in the run() method signature:
from mixtrain import MixModel
from typing import Literal
class MyModel(MixModel):
def run(
self,
prompt: str, # Required input
temperature: float = 0.7, # Optional with default
style: Literal["normal", "fancy"] = "normal", # Dropdown options in UI
):
"""Generate text.
Args:
prompt: Text prompt to process
temperature: Sampling temperature (0.0 to 1.0)
style: Output style
"""
return {"result": self.model.generate(prompt, temperature, style)}Calling models
Both styles work:
model = MyModel()
# Keyword arguments
result = model.run(prompt="hello", temperature=0.8)
# Dict input
result = model.run({"prompt": "hello", "temperature": 0.8})Sandbox Configuration
Configure the runtime environment using the sandbox() function:
from mixtrain import MixModel, sandbox
class GPUModel(MixModel):
_sandbox = sandbox(
image="nvcr.io/nvidia/pytorch:24.01-py3",
gpu="T4",
memory=8192,
timeout=300,
)| Parameter | Type | Description |
|---|---|---|
image | str | Docker image |
gpu | str | GPU type: "T4", "A10G", "L4", "A100", "H100", "H200", "B200" |
cpu | int | Number of CPU cores |
memory | int | Memory in MB |
timeout | int | Timeout in seconds |
idle_timeout | int | Idle timeout in seconds |
cloud | str | Cloud provider preference |
region | str | Region preference |
ephemeral_disk | int | Ephemeral disk size in MB |
block_network | bool | Block network access |
mixtrain_version | str | Mixtrain SDK version (e.g., "0.1.23", ">=0.1.20") |
Accessing MixClient
Every MixModel has a built-in self.mix client:
class MyModel(MixModel):
def run(self, prompt: str):
# Access other Mixtrain resources
secret = self.mix.get_secret("api_key")
dataset = self.mix.get_dataset("my-dataset")
return {"result": "..."}Complete Example
from mixtrain import MixModel, sandbox
import torch
class ImageClassifier(MixModel):
_sandbox = sandbox(
image="nvcr.io/nvidia/pytorch:24.01-py3",
gpu="T4",
memory=4096,
timeout=60,
)
def setup(self):
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.model = None # Loaded on first run
def run(
self,
image_url: str,
model_path: str,
threshold: float = 0.5
):
"""Classify an image.
Args:
image_url: Image URL to classify
model_path: Path to model weights
threshold: Confidence threshold
"""
if self.model is None:
self.model = torch.load(model_path).to(self.device)
self.model.eval()
image = self._download_and_preprocess(image_url)
with torch.no_grad():
output = self.model(image.to(self.device))
probs = torch.softmax(output, dim=1)
conf, pred = probs.max(1)
return {
"label": self.classes[pred.item()],
"confidence": conf.item()
}
def run_batch(self, batch):
# More efficient batch processing
images = [self._download_and_preprocess(b["image_url"]) for b in batch]
batch_tensor = torch.stack(images).to(self.device)
with torch.no_grad():
outputs = self.model(batch_tensor)
probs = torch.softmax(outputs, dim=1)
confs, preds = probs.max(1)
return [
{"label": self.classes[p.item()], "confidence": c.item()}
for p, c in zip(preds, confs)
]