MixtrainDocs
from mixtrain import Dataset

Constructor

Dataset(name: str)

Creates a reference to an existing dataset on the platform. This is a lazy operation - no API call is made until you access data.

ParameterTypeDescription
namestrDataset name
dataset = Dataset("training-data")

Properties

PropertyTypeDescription
namestrDataset name
descriptionstrDataset description
row_countintNumber of rows
metadatadictFull metadata dictionary (cached)

Iteration

Row iteration

for row in dataset:
    print(row)  # {"col1": value, "col2": value}

Streams rows without loading the full dataset into memory.

Batch iteration

dataset.to_batches(size: int = 32) -> Iterator[dict[str, list]]

Yields batches as columnar dicts. Respects batch size regardless of underlying storage format.

for batch in dataset.to_batches(size=64):
    print(batch)  # {"col1": [v1, v2, ...], "col2": [v1, v2, ...]}

Creating Datasets

From files (persists to platform)

Dataset.from_file(
    name: str,
    file_path: str,
    description: str = None,
    column_types: dict = None
) -> Dataset
ParameterTypeDescription
namestrDataset name
file_pathstrPath to data file
descriptionstrOptional description
column_typesdictColumn type mappings for rich UI rendering

Supported formats: .parquet, .csv, .tsv

dataset = Dataset.from_file(
    name="training-data",
    file_path="data.parquet",
    description="Training dataset"
)

# With column types for multimodal data
from mixtrain import Image, Video, Embedding

dataset = Dataset.from_file(
    name="multimodal-data",
    file_path="data.csv",
    column_types={
        "image_url": Image,
        "video_url": Video,
        "embedding": Embedding
    }
)

From in-memory data

# From Python dict
ds = Dataset.from_dict({"col1": [1, 2, 3], "col2": ["a", "b", "c"]})

# From pandas DataFrame
ds = Dataset.from_pandas(df)

# From Arrow table
ds = Dataset.from_arrow(table)

# From HuggingFace datasets
ds = Dataset.from_huggingface("imdb", split="train")

# From PyTorch dataset
ds = Dataset.from_torch(torch_dataset)

These create in-memory datasets. Use .save() to persist to the platform.

Saving Datasets

save()

dataset.save(
    name: str = None,
    overwrite: bool = False,
    description: str = None,
    column_types: dict = None
) -> Dataset

Save dataset to platform.

# Save in-memory dataset
ds = Dataset.from_dict({"x": [1, 2, 3]})
ds.save("my-dataset", description="My dataset")

# Save transformed dataset
ds = Dataset("source-data").shuffle(42).filter(lambda x: x["label"] == 1)
ds.save("filtered-data")

append_to()

dataset.append_to(name: str) -> Dataset

Append rows to an existing dataset.

Export Methods

to_arrow()

dataset.to_arrow() -> pyarrow.Table

Get dataset as Arrow table (lazy-loaded, cached).

to_pandas()

dataset.to_pandas() -> pandas.DataFrame

Convert to pandas DataFrame.

to_table()

dataset.to_table() -> pyarrow.Table

Alias for to_arrow().

to_tensors()

dataset.to_tensors() -> dict[str, Tensor | list]

Convert to dict of PyTorch tensors. Uses zero-copy for numeric columns.

tensors = dataset.to_tensors()
print(tensors["label"])  # tensor([0, 1, 0, 1, ...])

to_torch()

dataset.to_torch(batch_size: int = None) -> DataLoader

Get a PyTorch DataLoader with zero-copy tensor conversion.

# Unbatched - yields individual rows
loader = dataset.to_torch()
for row in loader:
    print(row)  # {"col1": value, "col2": value}

# Batched - yields dicts of tensors
loader = dataset.to_torch(batch_size=32)
for batch in loader:
    print(batch["features"].shape)  # torch.Size([32, ...])

to_huggingface()

dataset.to_huggingface() -> datasets.Dataset

Convert to HuggingFace Dataset.

Transformations

All transformations return a new Dataset (immutable).

shuffle()

dataset.shuffle(seed: int = None) -> Dataset

Randomly shuffle rows.

sample()

dataset.sample(n: int, seed: int = None) -> Dataset

Random sample of n rows.

select()

dataset.select(indices: list[int]) -> Dataset

Select rows by indices.

cols()

dataset.cols(columns: list[str]) -> Dataset

Select columns.

ds.cols(["text", "label"])
dataset.head(n: int = 5) -> Dataset

First n rows.

filter()

dataset.filter(fn: Callable[[dict], bool]) -> Dataset

Filter rows with a Python function.

positive = dataset.filter(lambda x: x["label"] == 1)

map()

dataset.map(fn: Callable, batched: bool = False) -> Dataset

Apply function to rows or batches.

# Row-by-row
ds.map(lambda x: {**x, "text_len": len(x["text"])})

# Batched (faster for vectorized operations)
ds.map(lambda batch: {**batch, "doubled": [v * 2 for v in batch["value"]]}, batched=True)

join()

dataset.join(
    other: Dataset,
    keys: str | list[str],
    right_keys: str | list[str] = None,
    join_type: str = "inner"
) -> Dataset

Join with another dataset.

ParameterTypeDescription
otherDatasetRight table to join
keysstr | list[str]Column(s) to join on from left table
right_keysstr | list[str]Column(s) from right table (defaults to keys)
join_typestr"inner", "left outer", "right outer", "full outer"
joined = users.join(orders, keys="user_id")

train_test_split()

dataset.train_test_split(test_size: float = 0.2, seed: int = None) -> dict[str, Dataset]

Split into train and test sets.

splits = dataset.train_test_split(test_size=0.2, seed=42)
train_ds = splits["train"]
test_ds = splits["test"]

SQL Queries

query()

dataset.query(sql: str) -> Dataset

Execute SQL query via DuckDB. The dataset is available as data in the query.

filtered = dataset.query("SELECT * FROM data WHERE score > 0.8")
stats = dataset.query("SELECT label, COUNT(*) as cnt FROM data GROUP BY label")

query_multiple()

Dataset.query_multiple(datasets: dict[str, Dataset], sql: str) -> Dataset

Query across multiple datasets.

result = Dataset.query_multiple({
    "users": Dataset("users"),
    "orders": Dataset("orders"),
}, "SELECT * FROM users u JOIN orders o ON u.id = o.user_id")

Metadata

set_column_types()

dataset.set_column_types(column_types: dict) -> None

Update column types for rich UI rendering.

from mixtrain import Image, Audio

dataset.set_column_types({
    "image_url": Image,
    "audio_url": Audio
})

update_metadata()

dataset.update_metadata(description: str = None, column_types: dict = None) -> dict

versions()

dataset.versions() -> list[dict]

List available versions/snapshots.

delete()

dataset.delete() -> None

Delete the dataset.

refresh()

dataset.refresh() -> None

Clear cached data.

Class Methods

Dataset.exists()

Check if a dataset exists.

Dataset.exists(name: str) -> bool
ParameterTypeDescription
namestrDataset name to check

Returns: bool - True if the dataset exists, False otherwise

if not Dataset.exists("my-dataset"):
    Dataset.from_pandas(df).save("my-dataset")

Helper Functions

list_datasets()

from mixtrain import list_datasets

datasets = list_datasets()
for ds in datasets:
    print(f"{ds.name}: {ds.row_count} rows")

get_dataset()

from mixtrain import get_dataset

dataset = get_dataset("my-dataset")

On this page