mlxdf 0.1.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mlxdf-0.1.0/PKG-INFO +148 -0
- mlxdf-0.1.0/README.md +134 -0
- mlxdf-0.1.0/pyproject.toml +38 -0
- mlxdf-0.1.0/src/mlxdf/__init__.py +24 -0
- mlxdf-0.1.0/src/mlxdf/compute/__init__.py +0 -0
- mlxdf-0.1.0/src/mlxdf/compute/compiler.py +215 -0
- mlxdf-0.1.0/src/mlxdf/core/__init__.py +5 -0
- mlxdf-0.1.0/src/mlxdf/core/categorical.py +250 -0
- mlxdf-0.1.0/src/mlxdf/core/dataframe.py +319 -0
- mlxdf-0.1.0/src/mlxdf/core/series.py +308 -0
- mlxdf-0.1.0/src/mlxdf/io/__init__.py +22 -0
- mlxdf-0.1.0/src/mlxdf/io/arrow.py +230 -0
- mlxdf-0.1.0/src/mlxdf/io/dlpack.py +158 -0
- mlxdf-0.1.0/src/mlxdf/io/parquet.py +82 -0
- mlxdf-0.1.0/src/mlxdf/ops/__init__.py +0 -0
- mlxdf-0.1.0/src/mlxdf/ops/arithmetic.py +0 -0
- mlxdf-0.1.0/src/mlxdf/ops/boolean.py +0 -0
- mlxdf-0.1.0/src/mlxdf/ops/groupby.py +244 -0
- mlxdf-0.1.0/src/mlxdf/ops/join.py +373 -0
- mlxdf-0.1.0/src/mlxdf/py.typed +0 -0
mlxdf-0.1.0/PKG-INFO
ADDED
|
@@ -0,0 +1,148 @@
|
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
|
+
Name: mlxdf
|
|
3
|
+
Version: 0.1.0
|
|
4
|
+
Summary: GPU-accelerated DataFrame library for Apple Silicon, built on MLX
|
|
5
|
+
Author: Mocus Zhang
|
|
6
|
+
Author-email: Mocus Zhang <mocusez@outlook.com>
|
|
7
|
+
License-Expression: MIT
|
|
8
|
+
Requires-Dist: mlx>=0.31.1
|
|
9
|
+
Requires-Dist: numpy>=1.24.0
|
|
10
|
+
Requires-Dist: pyarrow>=12.0 ; extra == 'arrow'
|
|
11
|
+
Requires-Python: >=3.11
|
|
12
|
+
Provides-Extra: arrow
|
|
13
|
+
Description-Content-Type: text/markdown
|
|
14
|
+
|
|
15
|
+
# MLX-DF
|
|
16
|
+
|
|
17
|
+
GPU-accelerated DataFrame library for Apple Silicon, built on [MLX](https://github.com/ml-explore/mlx).
|
|
18
|
+
|
|
19
|
+
MLX-DF brings cuDF-style GPU DataFrame operations to Mac, exploiting Apple's unified memory for zero-copy CPU/GPU data sharing. The API mirrors Pandas for easy migration.
|
|
20
|
+
|
|
21
|
+
> [!WARNING]
|
|
22
|
+
> MLX-DF currently supports only Apple Silicon devices (M-series chips).
|
|
23
|
+
|
|
24
|
+
## Installation
|
|
25
|
+
|
|
26
|
+
```bash
|
|
27
|
+
pip install mlxdf
|
|
28
|
+
```
|
|
29
|
+
|
|
30
|
+
Using uv:
|
|
31
|
+
|
|
32
|
+
```bash
|
|
33
|
+
uv add mlxdf
|
|
34
|
+
```
|
|
35
|
+
|
|
36
|
+
With PyArrow/Parquet support:
|
|
37
|
+
|
|
38
|
+
```bash
|
|
39
|
+
pip install mlxdf[arrow]
|
|
40
|
+
```
|
|
41
|
+
|
|
42
|
+
Using uv:
|
|
43
|
+
|
|
44
|
+
```bash
|
|
45
|
+
uv add "mlxdf[arrow]"
|
|
46
|
+
```
|
|
47
|
+
|
|
48
|
+
From source:
|
|
49
|
+
|
|
50
|
+
```bash
|
|
51
|
+
uv sync
|
|
52
|
+
uv build
|
|
53
|
+
```
|
|
54
|
+
|
|
55
|
+
## Quick Start
|
|
56
|
+
|
|
57
|
+
```python
|
|
58
|
+
from mlxdf import MlxDataFrame, merge, read_parquet
|
|
59
|
+
|
|
60
|
+
# Create a DataFrame (string columns auto-detected as CategoricalSeries)
|
|
61
|
+
df = MlxDataFrame({
|
|
62
|
+
"product_id": [1.0, 2.0, 1.0, 3.0, 2.0],
|
|
63
|
+
"quantity": [5.0, 3.0, 2.0, 7.0, 1.0],
|
|
64
|
+
"category": ["A", "B", "A", "C", "B"],
|
|
65
|
+
})
|
|
66
|
+
|
|
67
|
+
# Filter
|
|
68
|
+
high_qty = df[df["quantity"] > 2.0]
|
|
69
|
+
|
|
70
|
+
# Computed columns
|
|
71
|
+
df["double_qty"] = df["quantity"] * 2
|
|
72
|
+
|
|
73
|
+
# GroupBy aggregation
|
|
74
|
+
result = df.groupby("category")["quantity"].sum()
|
|
75
|
+
result.show()
|
|
76
|
+
|
|
77
|
+
# Join two DataFrames
|
|
78
|
+
prices = MlxDataFrame({
|
|
79
|
+
"product_id": [1.0, 2.0, 3.0],
|
|
80
|
+
"price": [10.0, 25.0, 15.0],
|
|
81
|
+
})
|
|
82
|
+
joined = df.merge(prices, on="product_id", how="inner")
|
|
83
|
+
|
|
84
|
+
# Parquet I/O (requires mlx-df[arrow])
|
|
85
|
+
df.to_parquet("output.parquet")
|
|
86
|
+
df2 = read_parquet("output.parquet", columns=["product_id", "quantity"])
|
|
87
|
+
```
|
|
88
|
+
|
|
89
|
+
## Features
|
|
90
|
+
|
|
91
|
+
- **MlxSeries** — Column with boolean null mask, vectorized arithmetic, comparisons, and aggregations
|
|
92
|
+
- **CategoricalSeries** — Dictionary-encoded string column (55× faster filtering vs Pandas)
|
|
93
|
+
- **MlxDataFrame** — Dict-like table with column access, boolean filtering, head/tail/slicing
|
|
94
|
+
- **GroupBy** — Bincount/sort-based groupby with sum/mean/count/max/min aggregations
|
|
95
|
+
- **Join** — Hash-index join supporting inner/left/right/outer (4× faster vs Pandas at 200M rows)
|
|
96
|
+
- **Pandas Interop** — `to_pandas()` / `from_pandas()` with automatic type conversion
|
|
97
|
+
- **PyArrow & Parquet** — Read/write Parquet with column pruning and predicate pushdown
|
|
98
|
+
- **JIT Compilation** — `compile_fn` for fused GPU kernel execution
|
|
99
|
+
|
|
100
|
+
## Development
|
|
101
|
+
|
|
102
|
+
### Setup
|
|
103
|
+
|
|
104
|
+
```bash
|
|
105
|
+
uv sync
|
|
106
|
+
```
|
|
107
|
+
|
|
108
|
+
### Running Tests
|
|
109
|
+
|
|
110
|
+
```bash
|
|
111
|
+
# Run all unit tests (benchmarks are excluded by default)
|
|
112
|
+
uv run pytest
|
|
113
|
+
|
|
114
|
+
# Run a specific test file
|
|
115
|
+
uv run pytest tests/test_series.py
|
|
116
|
+
|
|
117
|
+
# Run a specific test case
|
|
118
|
+
uv run pytest tests/test_series.py::TestArithmetic::test_add_series -v
|
|
119
|
+
|
|
120
|
+
# Run with verbose output
|
|
121
|
+
uv run pytest -v
|
|
122
|
+
|
|
123
|
+
# Run and stop on first failure
|
|
124
|
+
uv run pytest -x
|
|
125
|
+
```
|
|
126
|
+
|
|
127
|
+
### Benchmarks
|
|
128
|
+
|
|
129
|
+
Benchmarks are integrated into pytest via the `bench` marker, defaulting to deselected so they don't slow down regular test runs.
|
|
130
|
+
|
|
131
|
+
```bash
|
|
132
|
+
# Run all benchmarks
|
|
133
|
+
uv run pytest -m bench
|
|
134
|
+
|
|
135
|
+
# Run a specific benchmark
|
|
136
|
+
uv run pytest -m bench -k parquet
|
|
137
|
+
uv run pytest -m bench -k tpch
|
|
138
|
+
uv run pytest -m bench -k categorical
|
|
139
|
+
uv run pytest -m bench -k compile
|
|
140
|
+
|
|
141
|
+
# Run both tests and benchmarks together
|
|
142
|
+
uv run pytest -m ""
|
|
143
|
+
|
|
144
|
+
# Run benchmark scripts directly (also works)
|
|
145
|
+
uv run python benchmarks/bench_vs_pandas.py
|
|
146
|
+
```
|
|
147
|
+
|
|
148
|
+
Available benchmarks: `bench_vs_pandas`, `bench_categorical`, `bench_parquet`, `bench_compile_df`, `bench_tpch_q1/q3/q4/q6/q18/q19`。
|
mlxdf-0.1.0/README.md
ADDED
|
@@ -0,0 +1,134 @@
|
|
|
1
|
+
# MLX-DF
|
|
2
|
+
|
|
3
|
+
GPU-accelerated DataFrame library for Apple Silicon, built on [MLX](https://github.com/ml-explore/mlx).
|
|
4
|
+
|
|
5
|
+
MLX-DF brings cuDF-style GPU DataFrame operations to Mac, exploiting Apple's unified memory for zero-copy CPU/GPU data sharing. The API mirrors Pandas for easy migration.
|
|
6
|
+
|
|
7
|
+
> [!WARNING]
|
|
8
|
+
> MLX-DF currently supports only Apple Silicon devices (M-series chips).
|
|
9
|
+
|
|
10
|
+
## Installation
|
|
11
|
+
|
|
12
|
+
```bash
|
|
13
|
+
pip install mlxdf
|
|
14
|
+
```
|
|
15
|
+
|
|
16
|
+
Using uv:
|
|
17
|
+
|
|
18
|
+
```bash
|
|
19
|
+
uv add mlxdf
|
|
20
|
+
```
|
|
21
|
+
|
|
22
|
+
With PyArrow/Parquet support:
|
|
23
|
+
|
|
24
|
+
```bash
|
|
25
|
+
pip install mlxdf[arrow]
|
|
26
|
+
```
|
|
27
|
+
|
|
28
|
+
Using uv:
|
|
29
|
+
|
|
30
|
+
```bash
|
|
31
|
+
uv add "mlxdf[arrow]"
|
|
32
|
+
```
|
|
33
|
+
|
|
34
|
+
From source:
|
|
35
|
+
|
|
36
|
+
```bash
|
|
37
|
+
uv sync
|
|
38
|
+
uv build
|
|
39
|
+
```
|
|
40
|
+
|
|
41
|
+
## Quick Start
|
|
42
|
+
|
|
43
|
+
```python
|
|
44
|
+
from mlxdf import MlxDataFrame, merge, read_parquet
|
|
45
|
+
|
|
46
|
+
# Create a DataFrame (string columns auto-detected as CategoricalSeries)
|
|
47
|
+
df = MlxDataFrame({
|
|
48
|
+
"product_id": [1.0, 2.0, 1.0, 3.0, 2.0],
|
|
49
|
+
"quantity": [5.0, 3.0, 2.0, 7.0, 1.0],
|
|
50
|
+
"category": ["A", "B", "A", "C", "B"],
|
|
51
|
+
})
|
|
52
|
+
|
|
53
|
+
# Filter
|
|
54
|
+
high_qty = df[df["quantity"] > 2.0]
|
|
55
|
+
|
|
56
|
+
# Computed columns
|
|
57
|
+
df["double_qty"] = df["quantity"] * 2
|
|
58
|
+
|
|
59
|
+
# GroupBy aggregation
|
|
60
|
+
result = df.groupby("category")["quantity"].sum()
|
|
61
|
+
result.show()
|
|
62
|
+
|
|
63
|
+
# Join two DataFrames
|
|
64
|
+
prices = MlxDataFrame({
|
|
65
|
+
"product_id": [1.0, 2.0, 3.0],
|
|
66
|
+
"price": [10.0, 25.0, 15.0],
|
|
67
|
+
})
|
|
68
|
+
joined = df.merge(prices, on="product_id", how="inner")
|
|
69
|
+
|
|
70
|
+
# Parquet I/O (requires mlx-df[arrow])
|
|
71
|
+
df.to_parquet("output.parquet")
|
|
72
|
+
df2 = read_parquet("output.parquet", columns=["product_id", "quantity"])
|
|
73
|
+
```
|
|
74
|
+
|
|
75
|
+
## Features
|
|
76
|
+
|
|
77
|
+
- **MlxSeries** — Column with boolean null mask, vectorized arithmetic, comparisons, and aggregations
|
|
78
|
+
- **CategoricalSeries** — Dictionary-encoded string column (55× faster filtering vs Pandas)
|
|
79
|
+
- **MlxDataFrame** — Dict-like table with column access, boolean filtering, head/tail/slicing
|
|
80
|
+
- **GroupBy** — Bincount/sort-based groupby with sum/mean/count/max/min aggregations
|
|
81
|
+
- **Join** — Hash-index join supporting inner/left/right/outer (4× faster vs Pandas at 200M rows)
|
|
82
|
+
- **Pandas Interop** — `to_pandas()` / `from_pandas()` with automatic type conversion
|
|
83
|
+
- **PyArrow & Parquet** — Read/write Parquet with column pruning and predicate pushdown
|
|
84
|
+
- **JIT Compilation** — `compile_fn` for fused GPU kernel execution
|
|
85
|
+
|
|
86
|
+
## Development
|
|
87
|
+
|
|
88
|
+
### Setup
|
|
89
|
+
|
|
90
|
+
```bash
|
|
91
|
+
uv sync
|
|
92
|
+
```
|
|
93
|
+
|
|
94
|
+
### Running Tests
|
|
95
|
+
|
|
96
|
+
```bash
|
|
97
|
+
# Run all unit tests (benchmarks are excluded by default)
|
|
98
|
+
uv run pytest
|
|
99
|
+
|
|
100
|
+
# Run a specific test file
|
|
101
|
+
uv run pytest tests/test_series.py
|
|
102
|
+
|
|
103
|
+
# Run a specific test case
|
|
104
|
+
uv run pytest tests/test_series.py::TestArithmetic::test_add_series -v
|
|
105
|
+
|
|
106
|
+
# Run with verbose output
|
|
107
|
+
uv run pytest -v
|
|
108
|
+
|
|
109
|
+
# Run and stop on first failure
|
|
110
|
+
uv run pytest -x
|
|
111
|
+
```
|
|
112
|
+
|
|
113
|
+
### Benchmarks
|
|
114
|
+
|
|
115
|
+
Benchmarks are integrated into pytest via the `bench` marker, defaulting to deselected so they don't slow down regular test runs.
|
|
116
|
+
|
|
117
|
+
```bash
|
|
118
|
+
# Run all benchmarks
|
|
119
|
+
uv run pytest -m bench
|
|
120
|
+
|
|
121
|
+
# Run a specific benchmark
|
|
122
|
+
uv run pytest -m bench -k parquet
|
|
123
|
+
uv run pytest -m bench -k tpch
|
|
124
|
+
uv run pytest -m bench -k categorical
|
|
125
|
+
uv run pytest -m bench -k compile
|
|
126
|
+
|
|
127
|
+
# Run both tests and benchmarks together
|
|
128
|
+
uv run pytest -m ""
|
|
129
|
+
|
|
130
|
+
# Run benchmark scripts directly (also works)
|
|
131
|
+
uv run python benchmarks/bench_vs_pandas.py
|
|
132
|
+
```
|
|
133
|
+
|
|
134
|
+
Available benchmarks: `bench_vs_pandas`, `bench_categorical`, `bench_parquet`, `bench_compile_df`, `bench_tpch_q1/q3/q4/q6/q18/q19`。
|
|
@@ -0,0 +1,38 @@
|
|
|
1
|
+
[project]
|
|
2
|
+
name = "mlxdf"
|
|
3
|
+
version = "0.1.0"
|
|
4
|
+
description = "GPU-accelerated DataFrame library for Apple Silicon, built on MLX"
|
|
5
|
+
readme = "README.md"
|
|
6
|
+
license = "MIT"
|
|
7
|
+
authors = [
|
|
8
|
+
{ name = "Mocus Zhang", email = "mocusez@outlook.com" }
|
|
9
|
+
]
|
|
10
|
+
requires-python = ">=3.11"
|
|
11
|
+
dependencies = [
|
|
12
|
+
"mlx>=0.31.1",
|
|
13
|
+
"numpy>=1.24.0",
|
|
14
|
+
]
|
|
15
|
+
|
|
16
|
+
[project.optional-dependencies]
|
|
17
|
+
arrow = ["pyarrow>=12.0"]
|
|
18
|
+
|
|
19
|
+
[dependency-groups]
|
|
20
|
+
dev = [
|
|
21
|
+
"pytest>=8.0",
|
|
22
|
+
"pandas>=2.0",
|
|
23
|
+
"pyarrow>=12.0",
|
|
24
|
+
]
|
|
25
|
+
bench = [
|
|
26
|
+
"polars==1.39.3",
|
|
27
|
+
"duckdb==1.5.0",
|
|
28
|
+
]
|
|
29
|
+
|
|
30
|
+
[tool.pytest.ini_options]
|
|
31
|
+
addopts = "-m 'not bench'"
|
|
32
|
+
markers = [
|
|
33
|
+
"bench: benchmark tests (deselected by default, run with: uv run pytest -m bench)",
|
|
34
|
+
]
|
|
35
|
+
|
|
36
|
+
[build-system]
|
|
37
|
+
requires = ["uv_build>=0.9.18,<0.10.0"]
|
|
38
|
+
build-backend = "uv_build"
|
|
@@ -0,0 +1,24 @@
|
|
|
1
|
+
"""MLX-DF: GPU-accelerated DataFrame library for Apple Silicon."""
|
|
2
|
+
|
|
3
|
+
from mlxdf.core.series import MlxSeries
|
|
4
|
+
from mlxdf.core.categorical import CategoricalSeries
|
|
5
|
+
from mlxdf.core.dataframe import MlxDataFrame
|
|
6
|
+
from mlxdf.ops.join import merge
|
|
7
|
+
from mlxdf.compute.compiler import compile_fn, compile_df
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
def read_parquet(path, *, columns=None, filters=None):
|
|
11
|
+
"""Read a Parquet file into an MlxDataFrame (requires PyArrow)."""
|
|
12
|
+
from mlxdf.io.parquet import read_parquet as _read_parquet
|
|
13
|
+
return _read_parquet(path, columns=columns, filters=filters)
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
__all__ = [
|
|
17
|
+
"MlxSeries",
|
|
18
|
+
"CategoricalSeries",
|
|
19
|
+
"MlxDataFrame",
|
|
20
|
+
"merge",
|
|
21
|
+
"compile_fn",
|
|
22
|
+
"compile_df",
|
|
23
|
+
"read_parquet",
|
|
24
|
+
]
|
|
File without changes
|
|
@@ -0,0 +1,215 @@
|
|
|
1
|
+
"""Compute engine helpers: mx.compile wrapper and forced materialization."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from functools import wraps
|
|
6
|
+
|
|
7
|
+
import mlx.core as mx
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
def compile_fn(fn):
|
|
11
|
+
"""Wrap a function with mx.compile for fused GPU execution."""
|
|
12
|
+
return mx.compile(fn)
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
# ---------------------------------------------------------------------------
|
|
16
|
+
# Lightweight proxies for compile_df — these mimic MlxSeries / MlxDataFrame
|
|
17
|
+
# using only raw mx.array values so mx.compile can trace them.
|
|
18
|
+
# ---------------------------------------------------------------------------
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
class _SeriesProxy:
|
|
22
|
+
"""Proxy for MlxSeries inside an mx.compile context.
|
|
23
|
+
|
|
24
|
+
Stores raw ``data`` and ``mask`` arrays and implements arithmetic
|
|
25
|
+
with null-mask propagation identical to MlxSeries._binary_op.
|
|
26
|
+
"""
|
|
27
|
+
|
|
28
|
+
__slots__ = ("data", "mask", "name")
|
|
29
|
+
|
|
30
|
+
def __init__(self, data: mx.array, mask: mx.array, name: str | None = None):
|
|
31
|
+
self.data = data
|
|
32
|
+
self.mask = mask
|
|
33
|
+
self.name = name
|
|
34
|
+
|
|
35
|
+
# -- arithmetic (mask = AND of both masks) --
|
|
36
|
+
|
|
37
|
+
def _binary_op(self, other, op_fn) -> _SeriesProxy:
|
|
38
|
+
if isinstance(other, _SeriesProxy):
|
|
39
|
+
return _SeriesProxy(
|
|
40
|
+
op_fn(self.data, other.data),
|
|
41
|
+
mx.logical_and(self.mask, other.mask),
|
|
42
|
+
)
|
|
43
|
+
scalar = mx.array(other, dtype=self.data.dtype)
|
|
44
|
+
return _SeriesProxy(op_fn(self.data, scalar), self.mask)
|
|
45
|
+
|
|
46
|
+
def __add__(self, other):
|
|
47
|
+
return self._binary_op(other, mx.add)
|
|
48
|
+
|
|
49
|
+
def __radd__(self, other):
|
|
50
|
+
return self._binary_op(other, mx.add)
|
|
51
|
+
|
|
52
|
+
def __sub__(self, other):
|
|
53
|
+
return self._binary_op(other, mx.subtract)
|
|
54
|
+
|
|
55
|
+
def __rsub__(self, other):
|
|
56
|
+
if isinstance(other, _SeriesProxy):
|
|
57
|
+
return other._binary_op(self, mx.subtract)
|
|
58
|
+
return _SeriesProxy(
|
|
59
|
+
mx.subtract(mx.array(other, dtype=self.data.dtype), self.data),
|
|
60
|
+
self.mask,
|
|
61
|
+
)
|
|
62
|
+
|
|
63
|
+
def __mul__(self, other):
|
|
64
|
+
return self._binary_op(other, mx.multiply)
|
|
65
|
+
|
|
66
|
+
def __rmul__(self, other):
|
|
67
|
+
return self._binary_op(other, mx.multiply)
|
|
68
|
+
|
|
69
|
+
def __truediv__(self, other):
|
|
70
|
+
return self._binary_op(other, mx.divide)
|
|
71
|
+
|
|
72
|
+
def __rtruediv__(self, other):
|
|
73
|
+
if isinstance(other, _SeriesProxy):
|
|
74
|
+
return other._binary_op(self, mx.divide)
|
|
75
|
+
return _SeriesProxy(
|
|
76
|
+
mx.divide(mx.array(other, dtype=self.data.dtype), self.data),
|
|
77
|
+
self.mask,
|
|
78
|
+
)
|
|
79
|
+
|
|
80
|
+
def __neg__(self):
|
|
81
|
+
return _SeriesProxy(mx.negative(self.data), self.mask, name=self.name)
|
|
82
|
+
|
|
83
|
+
# -- comparison ops --
|
|
84
|
+
|
|
85
|
+
def _cmp_op(self, other, op_fn) -> _SeriesProxy:
|
|
86
|
+
if isinstance(other, _SeriesProxy):
|
|
87
|
+
return _SeriesProxy(
|
|
88
|
+
op_fn(self.data, other.data),
|
|
89
|
+
mx.logical_and(self.mask, other.mask),
|
|
90
|
+
)
|
|
91
|
+
return _SeriesProxy(op_fn(self.data, mx.array(other)), self.mask)
|
|
92
|
+
|
|
93
|
+
def __gt__(self, other):
|
|
94
|
+
return self._cmp_op(other, mx.greater)
|
|
95
|
+
|
|
96
|
+
def __ge__(self, other):
|
|
97
|
+
return self._cmp_op(other, mx.greater_equal)
|
|
98
|
+
|
|
99
|
+
def __lt__(self, other):
|
|
100
|
+
return self._cmp_op(other, mx.less)
|
|
101
|
+
|
|
102
|
+
def __le__(self, other):
|
|
103
|
+
return self._cmp_op(other, mx.less_equal)
|
|
104
|
+
|
|
105
|
+
def __eq__(self, other):
|
|
106
|
+
return self._cmp_op(other, mx.equal)
|
|
107
|
+
|
|
108
|
+
def __ne__(self, other):
|
|
109
|
+
return self._cmp_op(other, mx.not_equal)
|
|
110
|
+
|
|
111
|
+
# -- null helpers --
|
|
112
|
+
|
|
113
|
+
def fillna(self, value) -> _SeriesProxy:
|
|
114
|
+
fill = mx.array(value, dtype=self.data.dtype) if not isinstance(value, mx.array) else value
|
|
115
|
+
filled = mx.where(self.mask, self.data, fill)
|
|
116
|
+
return _SeriesProxy(filled, mx.ones(self.data.shape, dtype=mx.bool_), name=self.name)
|
|
117
|
+
|
|
118
|
+
def isna(self) -> _SeriesProxy:
|
|
119
|
+
return _SeriesProxy(mx.logical_not(self.mask), mx.ones(self.mask.shape, dtype=mx.bool_), name=self.name)
|
|
120
|
+
|
|
121
|
+
def notna(self) -> _SeriesProxy:
|
|
122
|
+
return _SeriesProxy(self.mask, mx.ones(self.mask.shape, dtype=mx.bool_), name=self.name)
|
|
123
|
+
|
|
124
|
+
|
|
125
|
+
class _DataFrameProxy:
|
|
126
|
+
"""Proxy for MlxDataFrame inside an mx.compile context.
|
|
127
|
+
|
|
128
|
+
Stores columns as ``_SeriesProxy`` instances and supports
|
|
129
|
+
``[]`` get / set so user code reads like normal DataFrame code.
|
|
130
|
+
"""
|
|
131
|
+
|
|
132
|
+
__slots__ = ("_cols",)
|
|
133
|
+
|
|
134
|
+
def __init__(self, data_dict: dict[str, mx.array], mask_dict: dict[str, mx.array]):
|
|
135
|
+
self._cols: dict[str, _SeriesProxy] = {
|
|
136
|
+
k: _SeriesProxy(data_dict[k], mask_dict[k], name=k) for k in data_dict
|
|
137
|
+
}
|
|
138
|
+
|
|
139
|
+
def __getitem__(self, key: str) -> _SeriesProxy:
|
|
140
|
+
return self._cols[key]
|
|
141
|
+
|
|
142
|
+
def __setitem__(self, key: str, value: _SeriesProxy):
|
|
143
|
+
if isinstance(value, _SeriesProxy):
|
|
144
|
+
value.name = key
|
|
145
|
+
self._cols[key] = value
|
|
146
|
+
else:
|
|
147
|
+
raise TypeError(
|
|
148
|
+
f"compile_df only supports _SeriesProxy column assignment, got {type(value)}"
|
|
149
|
+
)
|
|
150
|
+
|
|
151
|
+
@property
|
|
152
|
+
def columns(self) -> list[str]:
|
|
153
|
+
return list(self._cols.keys())
|
|
154
|
+
|
|
155
|
+
def _to_dicts(self) -> tuple[dict[str, mx.array], dict[str, mx.array]]:
|
|
156
|
+
data = {k: v.data for k, v in self._cols.items()}
|
|
157
|
+
masks = {k: v.mask for k, v in self._cols.items()}
|
|
158
|
+
return data, masks
|
|
159
|
+
|
|
160
|
+
|
|
161
|
+
# ---------------------------------------------------------------------------
|
|
162
|
+
# compile_df — DataFrame-aware JIT compilation
|
|
163
|
+
# ---------------------------------------------------------------------------
|
|
164
|
+
|
|
165
|
+
|
|
166
|
+
def compile_df(fn):
|
|
167
|
+
"""JIT-compile a function that operates on :class:`MlxDataFrame`.
|
|
168
|
+
|
|
169
|
+
The decorated function receives a lightweight proxy that supports the
|
|
170
|
+
same column access (``df['col']``), arithmetic, and ``fillna`` as
|
|
171
|
+
:class:`MlxDataFrame` / :class:`MlxSeries`. Under the hood, the
|
|
172
|
+
proxy passes raw ``mx.array`` values through ``mx.compile`` so that
|
|
173
|
+
all GPU operations are fused into a single kernel.
|
|
174
|
+
|
|
175
|
+
Example::
|
|
176
|
+
|
|
177
|
+
@compile_df
|
|
178
|
+
def compute(df):
|
|
179
|
+
df['wap'] = (df['bid_price'] * df['ask_size']
|
|
180
|
+
+ df['ask_price'] * df['bid_size']) / \\
|
|
181
|
+
(df['bid_size'] + df['ask_size'])
|
|
182
|
+
df['wap'] = df['wap'].fillna(0.0)
|
|
183
|
+
return df
|
|
184
|
+
|
|
185
|
+
result = compute(my_dataframe) # MlxDataFrame in, MlxDataFrame out
|
|
186
|
+
"""
|
|
187
|
+
from mlxdf.core.dataframe import MlxDataFrame
|
|
188
|
+
from mlxdf.core.series import MlxSeries
|
|
189
|
+
|
|
190
|
+
def _inner(data_dict, mask_dict):
|
|
191
|
+
proxy = _DataFrameProxy(data_dict, mask_dict)
|
|
192
|
+
result = fn(proxy)
|
|
193
|
+
return result._to_dicts()
|
|
194
|
+
|
|
195
|
+
compiled = mx.compile(_inner)
|
|
196
|
+
|
|
197
|
+
@wraps(fn)
|
|
198
|
+
def wrapper(df: MlxDataFrame) -> MlxDataFrame:
|
|
199
|
+
# Flatten: DataFrame → dict-of-arrays (only numeric MlxSeries)
|
|
200
|
+
data_dict: dict[str, mx.array] = {}
|
|
201
|
+
mask_dict: dict[str, mx.array] = {}
|
|
202
|
+
for name, col in df._columns.items():
|
|
203
|
+
if isinstance(col, MlxSeries):
|
|
204
|
+
data_dict[name] = col.data
|
|
205
|
+
mask_dict[name] = col.mask
|
|
206
|
+
|
|
207
|
+
out_data, out_mask = compiled(data_dict, mask_dict)
|
|
208
|
+
|
|
209
|
+
# Rebuild MlxDataFrame from output arrays
|
|
210
|
+
new_cols: dict[str, MlxSeries] = {}
|
|
211
|
+
for name in out_data:
|
|
212
|
+
new_cols[name] = MlxSeries(out_data[name], mask=out_mask[name], name=name)
|
|
213
|
+
return MlxDataFrame(new_cols)
|
|
214
|
+
|
|
215
|
+
return wrapper
|