mplang-nightly 0.1.dev164__py3-none-any.whl → 0.1.dev166__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mplang/core/expr/evaluator.py +55 -15
- mplang/kernels/__init__.py +28 -0
- mplang/kernels/builtin.py +91 -56
- mplang/kernels/crypto.py +39 -30
- mplang/kernels/mock_tee.py +10 -8
- mplang/kernels/phe.py +238 -39
- mplang/kernels/spu.py +134 -45
- mplang/kernels/sql_duckdb.py +8 -13
- mplang/kernels/stablehlo.py +15 -9
- mplang/kernels/value.py +626 -0
- mplang/protos/v1alpha1/mpir_pb2.pyi +71 -21
- mplang/protos/v1alpha1/value_pb2.py +34 -0
- mplang/protos/v1alpha1/value_pb2.pyi +169 -0
- mplang/runtime/client.py +19 -8
- mplang/runtime/communicator.py +11 -4
- mplang/runtime/driver.py +16 -1
- mplang/runtime/link_comm.py +26 -79
- mplang/runtime/server.py +30 -29
- mplang/runtime/session.py +9 -0
- mplang/runtime/simulation.py +4 -5
- mplang/simp/__init__.py +1 -1
- {mplang_nightly-0.1.dev164.dist-info → mplang_nightly-0.1.dev166.dist-info}/METADATA +1 -1
- {mplang_nightly-0.1.dev164.dist-info → mplang_nightly-0.1.dev166.dist-info}/RECORD +26 -23
- {mplang_nightly-0.1.dev164.dist-info → mplang_nightly-0.1.dev166.dist-info}/WHEEL +0 -0
- {mplang_nightly-0.1.dev164.dist-info → mplang_nightly-0.1.dev166.dist-info}/entry_points.txt +0 -0
- {mplang_nightly-0.1.dev164.dist-info → mplang_nightly-0.1.dev166.dist-info}/licenses/LICENSE +0 -0
mplang/kernels/sql_duckdb.py
CHANGED
@@ -14,16 +14,14 @@
|
|
14
14
|
|
15
15
|
from __future__ import annotations
|
16
16
|
|
17
|
-
from typing import Any
|
18
|
-
|
19
17
|
from mplang.core.pfunc import PFunction
|
20
18
|
from mplang.kernels.base import kernel_def
|
19
|
+
from mplang.kernels.value import TableValue
|
21
20
|
|
22
21
|
|
23
22
|
@kernel_def("duckdb.run_sql")
|
24
|
-
def _duckdb_sql(pfunc: PFunction, *args:
|
23
|
+
def _duckdb_sql(pfunc: PFunction, *args: TableValue) -> TableValue:
|
25
24
|
import duckdb
|
26
|
-
import pandas as pd
|
27
25
|
|
28
26
|
# TODO: maybe we could translate the sql to duckdb dialect
|
29
27
|
# instead of raising an exception
|
@@ -36,12 +34,9 @@ def _duckdb_sql(pfunc: PFunction, *args: Any) -> Any:
|
|
36
34
|
if in_names is None:
|
37
35
|
raise ValueError("duckdb sql missing in_names attr")
|
38
36
|
for arg, name in zip(args, in_names, strict=True):
|
39
|
-
|
40
|
-
|
41
|
-
|
42
|
-
|
43
|
-
|
44
|
-
|
45
|
-
conn.register(name, df)
|
46
|
-
res_df = conn.execute(pfunc.fn_text).fetchdf()
|
47
|
-
return res_df
|
37
|
+
# Use Arrow directly for zero-copy data transfer
|
38
|
+
arrow_table = arg.to_arrow()
|
39
|
+
conn.register(name, arrow_table)
|
40
|
+
# Fetch result as Arrow table for consistency
|
41
|
+
res_arrow = conn.execute(pfunc.fn_text).fetch_arrow_table()
|
42
|
+
return TableValue(res_arrow)
|
mplang/kernels/stablehlo.py
CHANGED
@@ -18,11 +18,13 @@ from typing import Any
|
|
18
18
|
|
19
19
|
import jax
|
20
20
|
import jax.numpy as jnp
|
21
|
+
import numpy as np
|
21
22
|
from jax._src import xla_bridge
|
22
23
|
from jax.lib import xla_client as xc
|
23
24
|
|
24
25
|
from mplang.core.pfunc import PFunction
|
25
26
|
from mplang.kernels.base import cur_kctx, kernel_def
|
27
|
+
from mplang.kernels.value import TensorValue
|
26
28
|
|
27
29
|
|
28
30
|
@kernel_def("mlir.stablehlo")
|
@@ -61,13 +63,17 @@ def _stablehlo_exec(pfunc: PFunction, *args: Any) -> Any:
|
|
61
63
|
# Filter out arguments that were eliminated by JAX during compilation
|
62
64
|
runtime_args = tuple(args[i] for i in keep_indices)
|
63
65
|
|
64
|
-
|
65
|
-
for arg in runtime_args:
|
66
|
-
if
|
67
|
-
|
68
|
-
|
69
|
-
|
70
|
-
|
66
|
+
tensor_args: list[TensorValue] = []
|
67
|
+
for idx, arg in enumerate(runtime_args):
|
68
|
+
if not isinstance(arg, TensorValue):
|
69
|
+
raise TypeError(
|
70
|
+
f"StableHLO kernel expects TensorValue inputs, got {type(arg).__name__} at position {idx}"
|
71
|
+
)
|
72
|
+
tensor_args.append(arg)
|
73
|
+
|
74
|
+
jax_args = [
|
75
|
+
jax.device_put(jnp.asarray(tensor.to_numpy())) for tensor in tensor_args
|
76
|
+
]
|
71
77
|
|
72
78
|
try:
|
73
79
|
result = compiled.execute_sharded(jax_args)
|
@@ -75,9 +81,9 @@ def _stablehlo_exec(pfunc: PFunction, *args: Any) -> Any:
|
|
75
81
|
flat: list[Any] = []
|
76
82
|
for lst in arrays:
|
77
83
|
if isinstance(lst, list) and len(lst) == 1:
|
78
|
-
flat.append(
|
84
|
+
flat.append(TensorValue(np.asarray(lst[0])))
|
79
85
|
else:
|
80
|
-
flat.extend(
|
86
|
+
flat.extend(TensorValue(np.asarray(a)) for a in lst)
|
81
87
|
return tuple(flat)
|
82
88
|
except Exception as e: # pragma: no cover
|
83
89
|
raise RuntimeError(f"StableHLO execute failed: {e}") from e
|