Mixtrain provides dataset management with rich multimodal support.
Quick Start
from mixtrain import Dataset
# Create a dataset from local metadata. Local image files are uploaded on save.
photos = Dataset.from_dict({
"image": ["images/cat.jpg", "images/dog.jpg"],
"caption": ["a cat on a sofa", "a dog in a park"],
"label": [0, 1],
})
photos.save("pet-photos")
# Read it back, filter it, and stream batches for training.
dataset = Dataset("pet-photos")
cleaned_dataset = dataset.filter("image.width >= 512")
for batch in cleaned_dataset.to_torch():
images = batch["image"]
captions = batch["caption"]This dataset will be available on Mixtrain where you can view images, inspect columns, and query the data. Typical usage in training:
from mixtrain import Dataset
# Load and prepare data
ds = Dataset("training-data")
splits = ds.shuffle(42).train_test_split(test_size=0.2) # returns {"train": ..., "test": ...}
# Get PyTorch DataLoaders
train_loader = splits["train"].to_torch(batch_size=32)
val_loader = splits["test"].to_torch(batch_size=32)
# Training loop
for batch in train_loader:
inputs = batch["features"] # tensor
labels = batch["labels"] # tensor
# ... training stepDataset Versions
Every update to a dataset creates a new version. By default, the latest version is used on all reads.
from mixtrain import Dataset
latest = Dataset("training-data")Read a specific version of the dataset:
baseline = Dataset("training-data", version=3)
for batch in baseline.batches(64):
train(batch)You can also use the list_versions() method to list all versions of the dataset. You can brwose and query all the versions in the web UI.
Incremental reads
Use added_since() to stream rows added between two versions. from_version is exclusive and to_version is inclusive:
ds = Dataset("photos")
for batch in ds.added_since(from_version=7, to_version=10):
embed(batch)to_version=None means the latest version. from_version=None starts at the earliest version. Dataset-triggered routines use the same mechanism to process only newly appended rows.
Creating Datasets
From files
Load data from local files, then save it to Mixtrain:
from mixtrain import Dataset
dataset = Dataset.from_file("data.parquet")
dataset.save("training-data", description="Training dataset")Supported formats: .csv, .parquet, .jsonl
For multimodal CSV/Parquet files, put local paths or remote URLs in file-backed columns and set column types when needed:
from mixtrain import Dataset, Image
dataset = Dataset.from_file("captions.csv")
dataset.save("image-captions", column_types={"image": Image})From in-memory data
Create datasets from various Python sources:
from mixtrain import Dataset
# From Python dict
ds = Dataset.from_dict({"x": [1, 2, 3], "y": ["a", "b", "c"]})
# From pandas DataFrame
ds = Dataset.from_pandas(df)
# From Arrow table
ds = Dataset.from_arrow(table)
# From HuggingFace datasets
ds = Dataset.from_huggingface("imdb", split="train")
# Save to platform
ds.save("my-dataset", description="My dataset")Column Types
Column types control rich rendering in the UI (images, video players, audio, etc.).
Auto-detection (default) — When you call save(), types are automatically inferred from data content by inspecting URLs, file extensions, and value patterns:
from mixtrain import Dataset
# image_url detected as Image, video_url as Video, etc.
ds = Dataset.from_file("data.csv")
ds.save("my-data")Explicit types — You can pass a dict to set exactly the column types you want.
from mixtrain import Dataset, Image, Video, Audio, Embedding, Tensor
ds = Dataset.from_file("data.csv")
ds.save("my-data", column_types={
"photo": Image,
"sound": Audio,
"action": Tensor, # N-dimensional arrays (e.g. robot action sequences)
# any other columns are left untyped
})Explicitly set some column types, infer the rest — chain with_column_types({...}), then save() (default "auto") fills the remaining columns:
ds.with_column_types({"photo": Image}).save("my-data") # photo set as Image, others inferredDisable auto-detection — Pass None to skip type inference entirely:
ds.save("my-data", column_types=None)Supported column types are documented in the Types reference. Browse datasets in the web UI to view and explore your data visually.
Working with File-Backed Columns
File-backed columns (File, Image, Video, Audio, 3D models, robotics logs) are
stored as references that Mixtrain can render in the UI and use in workflows.
load_files() downloads the referenced files locally with caching and parallel
prefetching:
ds = Dataset("photo-dataset")
# "auto": images and files arrive as bytes, video/audio & other file-backed columns arrive as local file paths
loaded = ds.load_files()
for row in loaded.head(10):
image_bytes = row["image"]
video_path = row["video"]
# Explicitly set the mode for each column
loaded = ds.load_files(to="path") # everything as local paths
loaded = ds.load_files(to={"image": "path", "video": "bytes"}) # image as path, video as bytesWhat each file-backed column becomes:
| Value | You get | Use for |
|---|---|---|
"auto" (default) | bytes for images/files, local paths for the rest | most cases |
"bytes" | raw bytes | GPU JPEG decode, sending to APIs, exports |
"path" | local file path (str) | video/audio decoders that want seekable files |
False | the raw file reference record | full manual control |
File-backed columns are fetched in parallel and cached locally in read-only mode.
The cache lives at $MIXTRAIN_CACHE_DIR (defaults to /data/cache on Mixtrain). You can convert bytes/paths to PIL images, frames, tensors, etc. in your code (see PyTorch Integration).
Files on save
save() handles file-backed columns automatically: local files are uploaded to Mixtrain also records useful media stats where available — image
width/height, video/audio duration_seconds, and type-specific fields like
fps, num_frames, sample_rate, and channels. Stats can be queried without
downloading the files:
big = ds.sql("SELECT * FROM t WHERE image.width >= 512", table_name="t")ds = Dataset.from_file("captions.csv") # has a local image path column
ds.save("photo-dataset") # files uploaded, types inferred
ds.save("photo-dataset", copy_files=True) # also copy remote URLs into storage
ds.save("photo-dataset", copy_files=False) # write the table as-is, don't touch files
ds.save("photo-dataset", stats=False) # skip stat probingIterating Over Data
Row-by-row
Stream rows without loading the full dataset:
for row in dataset:
print(row["text"], row["label"])Batched
Get batches as columnar dicts:
for batch in dataset.batches(32): # list of 32 text, label pairs
texts = batch["text"]
labels = batch["label"]Caching locally
cache() materializes a dataset locally for fast, offline-safe access.
ds = Dataset("photo-dataset").cache() # cache the dataset locally at /data/cache/
prep = ds.filter("score > 0.8").cache(path="/data/prep-v2") PyTorch Integration
DataLoader
Get a PyTorch DataLoader with zero-copy tensor conversion:
# Unbatched
loader = dataset.to_torch()
for row in loader:
print(row)
# Batched with tensors
loader = dataset.to_torch(batch_size=32)
for batch in loader:
features = batch["features"] # torch.Tensor
labels = batch["labels"] # torch.TensorWorkers and distributed training
to_torch() shards data automatically: each DataLoader worker (and each DDP rank, via
rank/world_size) reads a disjoint partition — no duplicated samples, and each worker
only fetches the files for its own partition:
loader = dataset.to_torch(batch_size=32, num_workers=4)
# Distributed (DDP): pass your rank and world size
loader = dataset.to_torch(batch_size=32, num_workers=4, rank=rank, world_size=world_size)
loader.set_epoch(epoch) # reshuffles windowed shuffle per epochFile-backed datasets
File-backed columns are loaded automatically. You can convert them in a
collate_fn, which runs inside the worker processes:
import io
from PIL import Image
def collate(batch):
batch["image"] = [Image.open(io.BytesIO(b)) for b in batch["image"]]
return batch
loader = dataset.to_torch(batch_size=64, num_workers=4, collate_fn=collate)Direct tensor conversion
tensors = dataset.to_tensors()
print(tensors["label"]) # tensor([0, 1, 0, 1, ...])Transformations
All transformations are lazily evaluated:
import pyarrow as pa
ds = Dataset("training-data")
# Shuffle and sample
shuffled = ds.shuffle(seed=42)
windowed = ds.shuffle(global_=False, window=10_000) # bounded-memory streaming shuffle
sample = ds.sample(100, seed=42)
# Filter with a Python function or a SQL-like expression
positive = ds.filter(lambda x: x["label"] == 1)
recent = ds.filter("score > 0.8 and label != None")
# Map rows (declare the output schema so the pipeline stays lazy)
with_length = ds.map(
lambda x: {"text": x["text"], "text_len": len(x["text"])},
schema=pa.schema([("text", pa.string()), ("text_len", pa.int64())]),
)
# Select columns and rows
subset = ds.select(["text", "label"]).head(100)
# Chain operations — nothing executes until you consume
processed = ds.shuffle(42).filter(lambda x: x["score"] > 0.8).head(1000)
processed.save("processed-data")SQL Queries
The dataset is available as data by default (override with table_name=):
dataset = Dataset("training-data")
# Filter with SQL
filtered = dataset.sql("SELECT * FROM data WHERE score > 0.8")
# Aggregations, with a custom table name
stats = dataset.sql("SELECT label, COUNT(*) as cnt FROM t GROUP BY label", table_name="t")Joining Datasets
videos = Dataset("videos")
captions = Dataset("captions")
# Inner join
joined = videos.join(captions, keys="video_id")
# Left outer join
joined = videos.join(captions, keys="video_id", join_type="left outer")Next Steps
- Dataset API Reference - Complete SDK documentation
- CLI Reference - Command-line interface