from mixtrain import DatasetConstructor
Dataset(name: str, *, version: int | None = None)Creates a reference to an existing dataset on the platform. This is a lazy reference, no API call is made until you access data or schema.
| Parameter | Type | Description |
|---|---|---|
name | str | Dataset name |
version | int | None | Optional dataset version (v1, v2, ...). None reads the latest version. |
dataset = Dataset("training-data")
baseline = Dataset("training-data", version=3)Properties
| Property | Type | Description |
|---|---|---|
name | str | Dataset name |
schema | pyarrow.Schema | Column schema |
description | str | Dataset description |
column_types | dict[str, str] | Semantic column types (image, video, ...) |
metadata | dict | Platform metadata (schema, column types, statistics, snapshots) |
Versions
Mixtrain datasets are versioned on each update.
latest = Dataset("training-data")
v3 = Dataset("training-data", version=3)list_versions()
dataset.list_versions() -> list[DatasetVersion]List dataset versions, newest first. Each entry includes version, snapshot_id, operation, added_records, total_records, and timestamp_ms when available.
for version in Dataset("photos").list_versions():
print(version.version, version.operation)added_since()
dataset.added_since(
from_version: int | None = None,
to_version: int | None = None,
) -> Iterator[pyarrow.RecordBatch]Iterate over rows added between two versions. from_version is exclusive and to_version is inclusive. None starts from the earliest version or ends at the latest version.
for batch in Dataset("photos").added_since(from_version=4, to_version=7):
process(batch)This is append-oriented in the current implementation and is the mechanism used by dataset-triggered routines.
Iteration
Row iteration
for row in dataset:
print(row) # {"col1": value, "col2": value}Streams rows without loading the full dataset into memory. dataset.rows(load_files=...)
is the explicit form; file-backed columns arrive as Image/Video/File handles by
default (load_files="auto"), as raw records with load_files=False, or as loaded
content with load_files="bytes" / "path" / a per-column dict. See load_files().
batches()
dataset.batches(
size: int = None,
format: str = "dict", # "dict" | "arrow"
load_files: bool | str | dict = "auto",
) -> Iterator[dict[str, list]] | Iterator[pyarrow.RecordBatch]Streams batches. size=None keeps the plan's natural chunk sizes. format="dict"
yields columnar dicts (file columns as handles under "auto"); format="arrow" yields
RecordBatches (file columns stay native record structs under "auto"). The explicit
load_files values (False, "bytes", "path", per-column dict) mean the same thing
in every format.
for batch in dataset.batches(64):
print(batch) # {"col1": [v1, v2, ...], "col2": [v1, v2, ...]}
for batch in dataset.batches(64, load_files="bytes"):
send(batch["image"]) # list[bytes], fetched in parallelCreating Datasets
From files
Dataset.from_file(file_path: str) -> DatasetLoad a local data file into an in-memory dataset. Use .save() to persist
it to Mixtrain.
Supported formats: .csv, .parquet, .jsonl (all stream in batches)
dataset = Dataset.from_file("data.parquet")
dataset.save("training-data", description="Training dataset")From in-memory data
# From Python dict
ds = Dataset.from_dict({"col1": [1, 2, 3], "col2": ["a", "b", "c"]})
# From pandas DataFrame
ds = Dataset.from_pandas(df)
# From Arrow table
ds = Dataset.from_arrow(table)
# From a re-iterable batch source
ds = Dataset.from_batches(batch_factory, schema=schema)
# From HuggingFace datasets (streams by default)
ds = Dataset.from_huggingface("imdb", split="train")
# From PyTorch dataset (schema= is required)
ds = Dataset.from_torch(
torch_dataset,
schema=pa.schema([("data", pa.float32()), ("label", pa.int64())]),
)These create in-memory or streaming datasets. Use .save() to persist to
the platform.
Saving Datasets
save()
dataset.save(
name: str,
description: str = None,
column_types: dict | str | None = "auto",
copy_files: bool | str = "auto",
stats: bool = True,
) -> DatasetSave dataset to the platform. Files referenced by file-backed columns are handled automatically: local files are uploaded so they can be viewed and used from Mixtrain, while remote URLs are kept as references by default.
| Parameter | Type | Description |
|---|---|---|
name | str | Dataset name to create on the platform |
description | str | Optional description |
column_types | dict | str | None | "auto" (default) infers types for untyped columns. A dict sets exactly the named columns (no inference of others). None disables type detection. To set some columns and infer the rest, chain with_column_types() before save(). |
copy_files | bool | str | "auto" (default) uploads local files and keeps remote URLs as references. True also copies remote files into workspace storage. False writes the table as-is without touching file contents. |
stats | bool | Record intrinsic file stats (image width/height, video/audio duration_seconds, plus fields like fps, num_frames, sample_rate, and channels) per media column during upload. Queryable via sql() struct access, e.g. image.width >= 512. |
# Auto-detects column types and uploads local files (default)
ds = Dataset.from_dict({"img": ["cat.jpg", "dog.png"], "label": [0, 1]})
ds.save("my-dataset") # img detected as Image, files uploaded
# Explicit types for exactly these columns (no inference of others)
from mixtrain import Image
ds.save("my-dataset", column_types={"photo": Image})
# Pin some columns and let save() infer the rest: chain with_column_types()
ds.with_column_types({"photo": Image}).save("my-dataset")
# Copy remote URLs into workspace storage too
ds.save("my-dataset", copy_files=True)append_to()
dataset.append_to(
name: str,
copy_files: bool | str = "auto",
stats: bool = True,
) -> DatasetAppend rows to an existing dataset. File columns follow the target table's column
types and the same copy_files contract as save(): local files are uploaded
before rows are written.
Export Methods
collect()
dataset.collect(max_bytes: int = None, warn: bool = True) -> pyarrow.TableExecute the plan and materialize the result as an Arrow table. Warns on large results;
pass max_bytes to fail instead of materializing more than a budget.
to_arrow() / to_pandas()
dataset.to_arrow() -> pyarrow.Table
dataset.to_pandas() -> pandas.DataFrameMaterialize as an Arrow table or pandas DataFrame (equivalent to collect()).
to_tensors()
dataset.to_tensors() -> dict[str, Tensor | list]Convert to dict of PyTorch tensors. Uses zero-copy for numeric columns.
to_torch()
dataset.to_torch(
batch_size: int = None,
num_workers: int = 0,
rank: int = 0,
world_size: int = 1,
drop_last: bool = False,
load_files: bool | str | dict = "auto",
prefetch_batches: int = 2,
fetch_workers: int = 8,
fetch_timeout: float = 30.0,
max_inflight_bytes: int = 256_000_000,
**dataloader_kwargs,
) -> DataLoaderGet a PyTorch DataLoader. Data is sharded automatically across DataLoader workers and
distributed ranks (rank/world_size) so each worker reads a disjoint partition and
fetches only its own media files. (Each worker still runs its own table scan; call
cache() first so many workers share one local mirror of a large platform table.)
File-backed columns are loaded automatically (load_files= accepts the same values
as load_files(); False disables); batches are pipelined with background prefetch.
fetch_workers controls parallel fetches, fetch_timeout sets the
per-request timeout, and MIXTRAIN_CACHE_DIR controls the local blob cache location.
Extra kwargs (e.g. collate_fn, pin_memory) pass through to
torch.utils.data.DataLoader. Shuffle with .shuffle() before calling to_torch() and
use loader.set_epoch(epoch) to reshuffle windowed shuffles per epoch.
# Batched - yields dicts of tensors (numeric columns are zero-copy)
loader = dataset.to_torch(batch_size=32, num_workers=4)
for batch in loader:
print(batch["features"].shape) # torch.Size([32, ...])
# Media: images arrive as bytes — convert in collate_fn (runs in workers)
import io
from PIL import Image as PILImage
def collate(batch):
batch["image"] = [PILImage.open(io.BytesIO(b)) for b in batch["image"]]
return batch
loader = dataset.to_torch(batch_size=64, num_workers=4, collate_fn=collate)to_huggingface()
dataset.to_huggingface() -> datasets.DatasetConvert to HuggingFace Dataset.
Transformations
All transformations are lazy and return a new Dataset (immutable).
shuffle()
dataset.shuffle(seed: int = None, global_: bool = True, window: int = 10_000) -> DatasetRandomly shuffle rows. global_=True is a full shuffle (may buffer in DuckDB);
global_=False is a bounded-memory streaming window shuffle. Without a seed, a stable
seed is generated so re-iteration is consistent.
sample()
dataset.sample(n: int, seed: int = None) -> DatasetRandom sample of n rows (streaming reservoir sample).
select()
dataset.select(columns: list[str]) -> DatasetKeep only the named columns. Pushes the projection down to the storage scan when possible.
take()
dataset.take(indices: list[int]) -> DatasetKeep only the rows at the given indices.
head() / slice()
dataset.head(n: int = 5) -> Dataset
dataset.slice(start: int, stop: int) -> DatasetFirst n rows / a contiguous row range. Limits push down to the storage scan, so
Dataset("big").head(5) reads only a few rows.
filter()
dataset.filter(predicate: str | Callable[[dict], bool]) -> DatasetFilter rows with a SQL-like expression string or a Python function. Expressions are evaluated vectorized and push down into the Iceberg scan when the filter sits directly on a platform table — prefer them for large datasets; functions run per row in Python.
recent = dataset.filter("score > 0.8 and label != None")
positive = dataset.filter(lambda x: x["label"] == 1)map() / map_batches()
dataset.map(fn: Callable[[dict], dict], *, schema: pyarrow.Schema) -> Dataset
dataset.map_batches(fn: Callable[[RecordBatch], RecordBatch], *, schema: pyarrow.Schema) -> DatasetApply a function per row, or per Arrow batch (faster for vectorized work). The output
schema is required so the pipeline stays lazy without executing your function early.
import pyarrow as pa
ds = dataset.map(
lambda x: {"text": x["text"], "text_len": len(x["text"])},
schema=pa.schema([("text", pa.string()), ("text_len", pa.int64())]),
)join()
dataset.join(other: Dataset, keys: str, join_type: str = "inner") -> DatasetJoin with another dataset
| Parameter | Type | Description |
|---|---|---|
other | Dataset | Right table to join |
keys | str | Column to join on |
join_type | str | "inner", "left outer", "right outer", "full outer" |
joined = users.join(orders, keys="user_id")train_test_split()
dataset.train_test_split(*, test_size: float = 0.2, seed: int = None) -> dict[str, Dataset]Split into deterministic, non-overlapping train and test sets.
splits = dataset.train_test_split(test_size=0.2, seed=42)
train_ds = splits["train"]
test_ds = splits["test"]shard()
dataset.shard(index: int, count: int) -> DatasetRestrict the dataset to one of count disjoint partitions (for manual distributed
reading; to_torch() does this automatically).
File Methods
load_files()
dataset.load_files(
to: bool | str | dict = "auto",
columns: list[str] = None,
workers: int = 8,
max_inflight_bytes: int = 256_000_000,
on_error: str = "raise",
fetch_timeout: float = 30.0,
) -> DatasetLoad file-backed columns locally. Files are downloaded in parallel and cached
(in $MIXTRAIN_CACHE_DIR), so re-iteration does not
re-download them.
to controls what each file-backed column becomes:
| Value | Column value | Use for |
|---|---|---|
"auto" (default) | bytes for image/file, local path for video/audio/3d/mcap/rrd | most cases |
"bytes" | raw bytes | GPU decode, sending to APIs |
"path" | local file path (str, read-only) | decoders that want seekable files |
False | raw media record | manual control |
Pass a dict for per-column values: load_files({"image": "path"}) — a dict is an
exact spec — only the columns it names are loaded; other media columns are left as
records (lazy handles when consumed as Python objects). on_error="null" maps failed
fetches to None instead of raising.
with_column_types()
dataset.with_column_types(column_types: str | dict = "auto") -> DatasetSet column types (image, video, audio, ...) on the schema. "auto" infers from a data
sample; pass a dict for explicit mappings. save() applies this automatically, so
you only need it to override inference before saving or calling load_files().
SQL Queries
sql()
dataset.sql(sql: str, table_name: str = "data", seed: int = None) -> DatasetExecute SQL via DuckDB. The dataset is registered as data by default, or pass
table_name to choose another alias. query() is an alias.
filtered = dataset.sql("SELECT * FROM data WHERE score > 0.8")
stats = dataset.sql("SELECT label, COUNT(*) as cnt FROM t GROUP BY label", table_name="t")Inspection
explain()
dataset.explain() -> strRender the lazy query plan.
print(Dataset("d").filter("x > 1").head(5).explain())cache()
dataset.cache(path: str = None, load_files: bool = True) -> DatasetMaterialize the dataset locally so iteration is fast and offline-safe.
With no arguments, platform table scans are mirrored to local Parquet under
$MIXTRAIN_CACHE_DIR/datasets/<name>/<snapshot_id> — reused across runs and
processes, and invalidated automatically when the table gets a new snapshot.
Transforms in the plan re-run against the local mirror. With path=, this
dataset's full output is materialized at an explicit location and reused on
later calls (the caller owns invalidation). load_files=True also fetches
file-backed column contents into the local file cache.
ds = Dataset("photo-dataset").cache() # table + files local; train offline
prep = ds.filter("image.width >= 512").cache(path="/data/prep-v2")Metadata
Dataset metadata includes version history when the platform returns it. Snapshot entries include:
| Field | Description |
|---|---|
sequence_number | User-facing dataset version (v1, v2, ...) |
snapshot_id | Iceberg snapshot ID |
operation | Commit operation, such as append |
added_records | Records added by the commit when available |
total_records | Total records after the commit when available |
timestamp_ms | Commit timestamp |
set_column_types()
dataset.set_column_types(column_types: dict) -> NoneUpdate column types for rich UI rendering.
from mixtrain import Image, Audio
dataset.set_column_types({"image_url": Image, "audio_url": Audio})update()
dataset.update(*, description: str = None, column_types: dict = None) -> dictUpdate platform metadata.
delete()
dataset.delete() -> NoneDelete the dataset.
Class Methods
Dataset.exists()
Check if a dataset exists.
Dataset.exists(name: str) -> boolif not Dataset.exists("my-dataset"):
Dataset.from_pandas(df).save("my-dataset")Helper Functions
list_datasets()
from mixtrain import list_datasets
datasets = list_datasets()
for ds in datasets:
print(ds.name)get_dataset()
from mixtrain import get_dataset
dataset = get_dataset("my-dataset")