MixtrainDocsBlog
from mixtrain import Dataset

Constructor

Dataset(name: str, *, version: int | None = None)

Creates a reference to an existing dataset on the platform. This is a lazy reference, no API call is made until you access data or schema.

ParameterTypeDescription
namestrDataset name
versionint | NoneOptional dataset version (v1, v2, ...). None reads the latest version.
dataset = Dataset("training-data")
baseline = Dataset("training-data", version=3)

Properties

PropertyTypeDescription
namestrDataset name
schemapyarrow.SchemaColumn schema
descriptionstrDataset description
column_typesdict[str, str]Semantic column types (image, video, ...)
metadatadictPlatform metadata (schema, column types, statistics, snapshots)

Versions

Mixtrain datasets are versioned on each update.

latest = Dataset("training-data")
v3 = Dataset("training-data", version=3)

list_versions()

dataset.list_versions() -> list[DatasetVersion]

List dataset versions, newest first. Each entry includes version, snapshot_id, operation, added_records, total_records, and timestamp_ms when available.

for version in Dataset("photos").list_versions():
    print(version.version, version.operation)

added_since()

dataset.added_since(
    from_version: int | None = None,
    to_version: int | None = None,
) -> Iterator[pyarrow.RecordBatch]

Iterate over rows added between two versions. from_version is exclusive and to_version is inclusive. None starts from the earliest version or ends at the latest version.

for batch in Dataset("photos").added_since(from_version=4, to_version=7):
    process(batch)

This is append-oriented in the current implementation and is the mechanism used by dataset-triggered routines.

Iteration

Row iteration

for row in dataset:
    print(row)  # {"col1": value, "col2": value}

Streams rows without loading the full dataset into memory. dataset.rows(load_files=...) is the explicit form; file-backed columns arrive as Image/Video/File handles by default (load_files="auto"), as raw records with load_files=False, or as loaded content with load_files="bytes" / "path" / a per-column dict. See load_files().

batches()

dataset.batches(
    size: int = None,
    format: str = "dict",      # "dict" | "arrow"
    load_files: bool | str | dict = "auto",
) -> Iterator[dict[str, list]] | Iterator[pyarrow.RecordBatch]

Streams batches. size=None keeps the plan's natural chunk sizes. format="dict" yields columnar dicts (file columns as handles under "auto"); format="arrow" yields RecordBatches (file columns stay native record structs under "auto"). The explicit load_files values (False, "bytes", "path", per-column dict) mean the same thing in every format.

for batch in dataset.batches(64):
    print(batch)  # {"col1": [v1, v2, ...], "col2": [v1, v2, ...]}

for batch in dataset.batches(64, load_files="bytes"):
    send(batch["image"])  # list[bytes], fetched in parallel

Creating Datasets

From files

Dataset.from_file(file_path: str) -> Dataset

Load a local data file into an in-memory dataset. Use .save() to persist it to Mixtrain.

Supported formats: .csv, .parquet, .jsonl (all stream in batches)

dataset = Dataset.from_file("data.parquet")
dataset.save("training-data", description="Training dataset")

From in-memory data

# From Python dict
ds = Dataset.from_dict({"col1": [1, 2, 3], "col2": ["a", "b", "c"]})

# From pandas DataFrame
ds = Dataset.from_pandas(df)

# From Arrow table
ds = Dataset.from_arrow(table)

# From a re-iterable batch source
ds = Dataset.from_batches(batch_factory, schema=schema)

# From HuggingFace datasets (streams by default)
ds = Dataset.from_huggingface("imdb", split="train")

# From PyTorch dataset (schema= is required)
ds = Dataset.from_torch(
    torch_dataset,
    schema=pa.schema([("data", pa.float32()), ("label", pa.int64())]),
)

These create in-memory or streaming datasets. Use .save() to persist to the platform.

Saving Datasets

save()

dataset.save(
    name: str,
    description: str = None,
    column_types: dict | str | None = "auto",
    copy_files: bool | str = "auto",
    stats: bool = True,
) -> Dataset

Save dataset to the platform. Files referenced by file-backed columns are handled automatically: local files are uploaded so they can be viewed and used from Mixtrain, while remote URLs are kept as references by default.

ParameterTypeDescription
namestrDataset name to create on the platform
descriptionstrOptional description
column_typesdict | str | None"auto" (default) infers types for untyped columns. A dict sets exactly the named columns (no inference of others). None disables type detection. To set some columns and infer the rest, chain with_column_types() before save().
copy_filesbool | str"auto" (default) uploads local files and keeps remote URLs as references. True also copies remote files into workspace storage. False writes the table as-is without touching file contents.
statsboolRecord intrinsic file stats (image width/height, video/audio duration_seconds, plus fields like fps, num_frames, sample_rate, and channels) per media column during upload. Queryable via sql() struct access, e.g. image.width >= 512.
# Auto-detects column types and uploads local files (default)
ds = Dataset.from_dict({"img": ["cat.jpg", "dog.png"], "label": [0, 1]})
ds.save("my-dataset")  # img detected as Image, files uploaded

# Explicit types for exactly these columns (no inference of others)
from mixtrain import Image
ds.save("my-dataset", column_types={"photo": Image})

# Pin some columns and let save() infer the rest: chain with_column_types()
ds.with_column_types({"photo": Image}).save("my-dataset")

# Copy remote URLs into workspace storage too
ds.save("my-dataset", copy_files=True)

append_to()

dataset.append_to(
    name: str,
    copy_files: bool | str = "auto",
    stats: bool = True,
) -> Dataset

Append rows to an existing dataset. File columns follow the target table's column types and the same copy_files contract as save(): local files are uploaded before rows are written.

Export Methods

collect()

dataset.collect(max_bytes: int = None, warn: bool = True) -> pyarrow.Table

Execute the plan and materialize the result as an Arrow table. Warns on large results; pass max_bytes to fail instead of materializing more than a budget.

to_arrow() / to_pandas()

dataset.to_arrow() -> pyarrow.Table
dataset.to_pandas() -> pandas.DataFrame

Materialize as an Arrow table or pandas DataFrame (equivalent to collect()).

to_tensors()

dataset.to_tensors() -> dict[str, Tensor | list]

Convert to dict of PyTorch tensors. Uses zero-copy for numeric columns.

to_torch()

dataset.to_torch(
    batch_size: int = None,
    num_workers: int = 0,
    rank: int = 0,
    world_size: int = 1,
    drop_last: bool = False,
    load_files: bool | str | dict = "auto",
    prefetch_batches: int = 2,
    fetch_workers: int = 8,
    fetch_timeout: float = 30.0,
    max_inflight_bytes: int = 256_000_000,
    **dataloader_kwargs,
) -> DataLoader

Get a PyTorch DataLoader. Data is sharded automatically across DataLoader workers and distributed ranks (rank/world_size) so each worker reads a disjoint partition and fetches only its own media files. (Each worker still runs its own table scan; call cache() first so many workers share one local mirror of a large platform table.) File-backed columns are loaded automatically (load_files= accepts the same values as load_files(); False disables); batches are pipelined with background prefetch. fetch_workers controls parallel fetches, fetch_timeout sets the per-request timeout, and MIXTRAIN_CACHE_DIR controls the local blob cache location. Extra kwargs (e.g. collate_fn, pin_memory) pass through to torch.utils.data.DataLoader. Shuffle with .shuffle() before calling to_torch() and use loader.set_epoch(epoch) to reshuffle windowed shuffles per epoch.

# Batched - yields dicts of tensors (numeric columns are zero-copy)
loader = dataset.to_torch(batch_size=32, num_workers=4)
for batch in loader:
    print(batch["features"].shape)  # torch.Size([32, ...])

# Media: images arrive as bytes — convert in collate_fn (runs in workers)
import io
from PIL import Image as PILImage

def collate(batch):
    batch["image"] = [PILImage.open(io.BytesIO(b)) for b in batch["image"]]
    return batch

loader = dataset.to_torch(batch_size=64, num_workers=4, collate_fn=collate)

to_huggingface()

dataset.to_huggingface() -> datasets.Dataset

Convert to HuggingFace Dataset.

Transformations

All transformations are lazy and return a new Dataset (immutable).

shuffle()

dataset.shuffle(seed: int = None, global_: bool = True, window: int = 10_000) -> Dataset

Randomly shuffle rows. global_=True is a full shuffle (may buffer in DuckDB); global_=False is a bounded-memory streaming window shuffle. Without a seed, a stable seed is generated so re-iteration is consistent.

sample()

dataset.sample(n: int, seed: int = None) -> Dataset

Random sample of n rows (streaming reservoir sample).

select()

dataset.select(columns: list[str]) -> Dataset

Keep only the named columns. Pushes the projection down to the storage scan when possible.

take()

dataset.take(indices: list[int]) -> Dataset

Keep only the rows at the given indices.

head() / slice()

dataset.head(n: int = 5) -> Dataset
dataset.slice(start: int, stop: int) -> Dataset

First n rows / a contiguous row range. Limits push down to the storage scan, so Dataset("big").head(5) reads only a few rows.

filter()

dataset.filter(predicate: str | Callable[[dict], bool]) -> Dataset

Filter rows with a SQL-like expression string or a Python function. Expressions are evaluated vectorized and push down into the Iceberg scan when the filter sits directly on a platform table — prefer them for large datasets; functions run per row in Python.

recent = dataset.filter("score > 0.8 and label != None")
positive = dataset.filter(lambda x: x["label"] == 1)

map() / map_batches()

dataset.map(fn: Callable[[dict], dict], *, schema: pyarrow.Schema) -> Dataset
dataset.map_batches(fn: Callable[[RecordBatch], RecordBatch], *, schema: pyarrow.Schema) -> Dataset

Apply a function per row, or per Arrow batch (faster for vectorized work). The output schema is required so the pipeline stays lazy without executing your function early.

import pyarrow as pa

ds = dataset.map(
    lambda x: {"text": x["text"], "text_len": len(x["text"])},
    schema=pa.schema([("text", pa.string()), ("text_len", pa.int64())]),
)

join()

dataset.join(other: Dataset, keys: str, join_type: str = "inner") -> Dataset

Join with another dataset

ParameterTypeDescription
otherDatasetRight table to join
keysstrColumn to join on
join_typestr"inner", "left outer", "right outer", "full outer"
joined = users.join(orders, keys="user_id")

train_test_split()

dataset.train_test_split(*, test_size: float = 0.2, seed: int = None) -> dict[str, Dataset]

Split into deterministic, non-overlapping train and test sets.

splits = dataset.train_test_split(test_size=0.2, seed=42)
train_ds = splits["train"]
test_ds = splits["test"]

shard()

dataset.shard(index: int, count: int) -> Dataset

Restrict the dataset to one of count disjoint partitions (for manual distributed reading; to_torch() does this automatically).

File Methods

load_files()

dataset.load_files(
    to: bool | str | dict = "auto",
    columns: list[str] = None,
    workers: int = 8,
    max_inflight_bytes: int = 256_000_000,
    on_error: str = "raise",
    fetch_timeout: float = 30.0,
) -> Dataset

Load file-backed columns locally. Files are downloaded in parallel and cached (in $MIXTRAIN_CACHE_DIR), so re-iteration does not re-download them.

to controls what each file-backed column becomes:

ValueColumn valueUse for
"auto" (default)bytes for image/file, local path for video/audio/3d/mcap/rrdmost cases
"bytes"raw bytesGPU decode, sending to APIs
"path"local file path (str, read-only)decoders that want seekable files
Falseraw media recordmanual control

Pass a dict for per-column values: load_files({"image": "path"}) — a dict is an exact spec — only the columns it names are loaded; other media columns are left as records (lazy handles when consumed as Python objects). on_error="null" maps failed fetches to None instead of raising.

with_column_types()

dataset.with_column_types(column_types: str | dict = "auto") -> Dataset

Set column types (image, video, audio, ...) on the schema. "auto" infers from a data sample; pass a dict for explicit mappings. save() applies this automatically, so you only need it to override inference before saving or calling load_files().

SQL Queries

sql()

dataset.sql(sql: str, table_name: str = "data", seed: int = None) -> Dataset

Execute SQL via DuckDB. The dataset is registered as data by default, or pass table_name to choose another alias. query() is an alias.

filtered = dataset.sql("SELECT * FROM data WHERE score > 0.8")
stats = dataset.sql("SELECT label, COUNT(*) as cnt FROM t GROUP BY label", table_name="t")

Inspection

explain()

dataset.explain() -> str

Render the lazy query plan.

print(Dataset("d").filter("x > 1").head(5).explain())

cache()

dataset.cache(path: str = None, load_files: bool = True) -> Dataset

Materialize the dataset locally so iteration is fast and offline-safe.

With no arguments, platform table scans are mirrored to local Parquet under $MIXTRAIN_CACHE_DIR/datasets/<name>/<snapshot_id> — reused across runs and processes, and invalidated automatically when the table gets a new snapshot. Transforms in the plan re-run against the local mirror. With path=, this dataset's full output is materialized at an explicit location and reused on later calls (the caller owns invalidation). load_files=True also fetches file-backed column contents into the local file cache.

ds = Dataset("photo-dataset").cache()   # table + files local; train offline
prep = ds.filter("image.width >= 512").cache(path="/data/prep-v2")

Metadata

Dataset metadata includes version history when the platform returns it. Snapshot entries include:

FieldDescription
sequence_numberUser-facing dataset version (v1, v2, ...)
snapshot_idIceberg snapshot ID
operationCommit operation, such as append
added_recordsRecords added by the commit when available
total_recordsTotal records after the commit when available
timestamp_msCommit timestamp

set_column_types()

dataset.set_column_types(column_types: dict) -> None

Update column types for rich UI rendering.

from mixtrain import Image, Audio

dataset.set_column_types({"image_url": Image, "audio_url": Audio})

update()

dataset.update(*, description: str = None, column_types: dict = None) -> dict

Update platform metadata.

delete()

dataset.delete() -> None

Delete the dataset.

Class Methods

Dataset.exists()

Check if a dataset exists.

Dataset.exists(name: str) -> bool
if not Dataset.exists("my-dataset"):
    Dataset.from_pandas(df).save("my-dataset")

Helper Functions

list_datasets()

from mixtrain import list_datasets

datasets = list_datasets()
for ds in datasets:
    print(ds.name)

get_dataset()

from mixtrain import get_dataset

dataset = get_dataset("my-dataset")

On this page