mplang-nightly 0.1.dev167__py3-none-any.whl → 0.1.dev169__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mplang/core/expr/evaluator.py +3 -2
- mplang/device.py +6 -8
- mplang/ops/__init__.py +2 -13
- mplang/ops/ibis_cc.py +3 -3
- mplang/ops/jax_cc.py +7 -7
- mplang/ops/{sql.py → sql_cc.py} +5 -4
- mplang/simp/__init__.py +2 -2
- {mplang_nightly-0.1.dev167.dist-info → mplang_nightly-0.1.dev169.dist-info}/METADATA +1 -1
- {mplang_nightly-0.1.dev167.dist-info → mplang_nightly-0.1.dev169.dist-info}/RECORD +12 -12
- {mplang_nightly-0.1.dev167.dist-info → mplang_nightly-0.1.dev169.dist-info}/WHEEL +0 -0
- {mplang_nightly-0.1.dev167.dist-info → mplang_nightly-0.1.dev169.dist-info}/entry_points.txt +0 -0
- {mplang_nightly-0.1.dev167.dist-info → mplang_nightly-0.1.dev169.dist-info}/licenses/LICENSE +0 -0
mplang/core/expr/evaluator.py
CHANGED
@@ -27,8 +27,6 @@ from __future__ import annotations
|
|
27
27
|
from dataclasses import dataclass
|
28
28
|
from typing import Any, Protocol
|
29
29
|
|
30
|
-
import numpy as np
|
31
|
-
|
32
30
|
from mplang.core.comm import ICommunicator
|
33
31
|
from mplang.core.expr.ast import (
|
34
32
|
AccessExpr,
|
@@ -234,11 +232,14 @@ class EvalSemantic:
|
|
234
232
|
"""Convert Value payloads to numpy/python equivalents when possible."""
|
235
233
|
if value is None:
|
236
234
|
return None
|
235
|
+
|
237
236
|
if isinstance(value, Value):
|
238
237
|
# Try to_numpy first for broader compatibility
|
239
238
|
to_numpy = getattr(value, "to_numpy", None)
|
240
239
|
if callable(to_numpy):
|
241
240
|
arr = to_numpy()
|
241
|
+
import numpy as np
|
242
|
+
|
242
243
|
if isinstance(arr, np.ndarray):
|
243
244
|
if arr.size == 1:
|
244
245
|
return arr.item()
|
mplang/device.py
CHANGED
@@ -37,8 +37,8 @@ from mplang.core.context_mgr import cur_ctx
|
|
37
37
|
from mplang.core.tensor import TensorType
|
38
38
|
from mplang.ops import builtin, crypto, ibis_cc, jax_cc, tee
|
39
39
|
from mplang.ops.base import FeOperation
|
40
|
-
from mplang.ops.ibis_cc import
|
41
|
-
from mplang.ops.jax_cc import
|
40
|
+
from mplang.ops.ibis_cc import IbisRunner
|
41
|
+
from mplang.ops.jax_cc import JaxRunner
|
42
42
|
from mplang.simp import mpi, smpc
|
43
43
|
|
44
44
|
# Automatic transfer between devices when parameter is not on the target device.
|
@@ -82,7 +82,7 @@ _is_mpobj = lambda x: isinstance(x, MPObject)
|
|
82
82
|
def _device_run_spu(
|
83
83
|
dev_info: Device, op: FeOperation, *args: Any, **kwargs: Any
|
84
84
|
) -> Any:
|
85
|
-
if not isinstance(op,
|
85
|
+
if not isinstance(op, JaxRunner):
|
86
86
|
raise ValueError("SPU device only supports JAX frontend.")
|
87
87
|
fn, *aargs = args
|
88
88
|
var = smpc.srun(fn)(*aargs, **kwargs)
|
@@ -92,7 +92,7 @@ def _device_run_spu(
|
|
92
92
|
def _device_run_tee(
|
93
93
|
dev_info: Device, op: FeOperation, *args: Any, **kwargs: Any
|
94
94
|
) -> Any:
|
95
|
-
if not isinstance(op,
|
95
|
+
if not isinstance(op, JaxRunner) and not isinstance(op, IbisRunner):
|
96
96
|
raise ValueError("TEE device only supports JAX and Ibis frontend.")
|
97
97
|
assert len(dev_info.members) == 1
|
98
98
|
rank = dev_info.members[0].rank
|
@@ -159,11 +159,9 @@ def device(dev_id: str, *, fe_type: str = "jax") -> Callable:
|
|
159
159
|
return _device_run(dev_id, fn, *args, **kwargs)
|
160
160
|
else:
|
161
161
|
if fe_type == "jax":
|
162
|
-
return _device_run(dev_id, jax_cc.
|
162
|
+
return _device_run(dev_id, jax_cc.run_jax, fn, *args, **kwargs)
|
163
163
|
elif fe_type == "ibis":
|
164
|
-
return _device_run(
|
165
|
-
dev_id, ibis_cc.ibis_compile, fn, *args, **kwargs
|
166
|
-
)
|
164
|
+
return _device_run(dev_id, ibis_cc.run_ibis, fn, *args, **kwargs)
|
167
165
|
else:
|
168
166
|
raise ValueError(f"Unsupported frontend type: {fe_type}")
|
169
167
|
|
mplang/ops/__init__.py
CHANGED
@@ -19,28 +19,17 @@ This module contains compilers that transform high-level functions into
|
|
19
19
|
portable, serializable intermediate representations.
|
20
20
|
"""
|
21
21
|
|
22
|
-
from mplang.ops import builtin
|
23
|
-
from mplang.ops import crypto as crypto
|
24
|
-
from mplang.ops import ibis_cc as ibis_cc
|
25
|
-
from mplang.ops import jax_cc as jax_cc
|
26
|
-
from mplang.ops import phe as phe
|
27
|
-
from mplang.ops import spu as spu
|
28
|
-
from mplang.ops import tee as tee
|
22
|
+
from mplang.ops import builtin, crypto, ibis_cc, jax_cc, phe, spu, sql_cc, tee
|
29
23
|
from mplang.ops.base import FeOperation as FeOperation
|
30
|
-
from mplang.ops.ibis_cc import ibis_compile as ibis_compile
|
31
|
-
from mplang.ops.jax_cc import jax_compile as jax_compile
|
32
|
-
from mplang.ops.sql import sql_run as sql_run
|
33
24
|
|
34
25
|
__all__ = [
|
35
26
|
"FeOperation",
|
36
27
|
"builtin",
|
37
28
|
"crypto",
|
38
29
|
"ibis_cc",
|
39
|
-
"ibis_compile",
|
40
30
|
"jax_cc",
|
41
|
-
"jax_compile",
|
42
31
|
"phe",
|
43
32
|
"spu",
|
44
|
-
"
|
33
|
+
"sql_cc",
|
45
34
|
"tee",
|
46
35
|
]
|
mplang/ops/ibis_cc.py
CHANGED
@@ -95,8 +95,8 @@ def is_ibis_function(func: Callable) -> bool:
|
|
95
95
|
_IBIS_MOD = stateless_mod("ibis")
|
96
96
|
|
97
97
|
|
98
|
-
class
|
99
|
-
"""Ibis
|
98
|
+
class IbisRunner(FeOperation):
|
99
|
+
"""Ibis runner frontend operation."""
|
100
100
|
|
101
101
|
def trace(
|
102
102
|
self, func: Callable, *args: Any, **kwargs: Any
|
@@ -136,4 +136,4 @@ class IbisCompiler(FeOperation):
|
|
136
136
|
return pfunc, in_vars, treedef
|
137
137
|
|
138
138
|
|
139
|
-
|
139
|
+
run_ibis = IbisRunner(_IBIS_MOD, "run")
|
mplang/ops/jax_cc.py
CHANGED
@@ -149,11 +149,11 @@ def jax2stablehlo(
|
|
149
149
|
return pfn, in_vars, out_tree
|
150
150
|
|
151
151
|
|
152
|
-
class
|
153
|
-
"""JAX
|
152
|
+
class JaxRunner(FeOperation):
|
153
|
+
"""JAX function runner frontend operation."""
|
154
154
|
|
155
155
|
def trace(
|
156
|
-
self,
|
156
|
+
self, jax_fn: Callable, *args: Any, **kwargs: Any
|
157
157
|
) -> tuple[PFunction, list[MPObject], PyTreeDef]:
|
158
158
|
"""
|
159
159
|
JAX compilation helper function.
|
@@ -162,21 +162,21 @@ class JaxCompiler(FeOperation):
|
|
162
162
|
along with variable arguments for evaluation.
|
163
163
|
|
164
164
|
Args:
|
165
|
-
|
165
|
+
jax_fn: The JAX function to compile
|
166
166
|
*args: Positional arguments to the function
|
167
167
|
**kwargs: Keyword arguments to the function
|
168
168
|
|
169
169
|
Returns:
|
170
|
-
tuple[PFunction, list[MPObject],
|
170
|
+
tuple[PFunction, list[MPObject], PyTreeDef]: The compiled PFunction, input variables, and output tree
|
171
171
|
"""
|
172
172
|
|
173
173
|
def is_variable(arg: Any) -> bool:
|
174
174
|
return isinstance(arg, MPObject)
|
175
175
|
|
176
|
-
pfunc, in_vars, out_tree = jax2stablehlo(is_variable,
|
176
|
+
pfunc, in_vars, out_tree = jax2stablehlo(is_variable, jax_fn, *args, **kwargs)
|
177
177
|
return pfunc, in_vars, out_tree
|
178
178
|
|
179
179
|
|
180
180
|
_JAX_MOD = stateless_mod("jax")
|
181
181
|
|
182
|
-
|
182
|
+
run_jax = JaxRunner(_JAX_MOD, "run")
|
mplang/ops/{sql.py → sql_cc.py}
RENAMED
@@ -22,15 +22,16 @@ from mplang.ops.base import FeOperation, stateless_mod
|
|
22
22
|
_SQL_MOD = stateless_mod("sql")
|
23
23
|
|
24
24
|
|
25
|
-
class
|
25
|
+
class SqlRunner(FeOperation):
|
26
26
|
def __init__(self, dialect: str = "duckdb"):
|
27
27
|
# Bind to sql module with a stable op name for registry/dispatch
|
28
28
|
super().__init__(_SQL_MOD, "run")
|
29
29
|
self._dialect = dialect
|
30
30
|
|
31
|
+
# TODO(jint): we should deduce out_type according to query and in_tables' schema
|
31
32
|
def trace(
|
32
33
|
self,
|
33
|
-
|
34
|
+
query: str,
|
34
35
|
out_type: TableType,
|
35
36
|
in_tables: dict[str, MPObject] | None = None,
|
36
37
|
) -> tuple[PFunction, list[MPObject], PyTreeDef]:
|
@@ -48,7 +49,7 @@ class SqlFE(FeOperation):
|
|
48
49
|
pfn = PFunction(
|
49
50
|
fn_type="sql.run",
|
50
51
|
fn_name="",
|
51
|
-
fn_text=
|
52
|
+
fn_text=query,
|
52
53
|
ins_info=tuple(ins_info),
|
53
54
|
outs_info=(out_type,),
|
54
55
|
in_names=tuple(in_names),
|
@@ -58,4 +59,4 @@ class SqlFE(FeOperation):
|
|
58
59
|
return pfn, in_vars, treedef
|
59
60
|
|
60
61
|
|
61
|
-
|
62
|
+
run_sql = SqlRunner("duckdb")
|
mplang/simp/__init__.py
CHANGED
@@ -139,10 +139,10 @@ def run_impl(
|
|
139
139
|
pfunc, eval_args, out_tree = func(*args, **kwargs)
|
140
140
|
else:
|
141
141
|
if ibis_cc.is_ibis_function(func):
|
142
|
-
pfunc, eval_args, out_tree = ibis_cc.
|
142
|
+
pfunc, eval_args, out_tree = ibis_cc.run_ibis(func, *args, **kwargs)
|
143
143
|
else:
|
144
144
|
# unknown python callable, treat it as jax function
|
145
|
-
pfunc, eval_args, out_tree = jax_cc.
|
145
|
+
pfunc, eval_args, out_tree = jax_cc.run_jax(func, *args, **kwargs)
|
146
146
|
results = peval(pfunc, eval_args, pmask)
|
147
147
|
return out_tree.unflatten(results)
|
148
148
|
|
@@ -1,6 +1,6 @@
|
|
1
1
|
mplang/__init__.py,sha256=ofO-F-CNoVIxpMpTJtTJoQtKegJcHwcOJLzoVispiyc,1852
|
2
2
|
mplang/api.py,sha256=ssmv0_CyZPFORhOUJ84Jo6NwRJSK7_Ono3n7ZjEg4sA,3058
|
3
|
-
mplang/device.py,sha256=
|
3
|
+
mplang/device.py,sha256=7X2rPp3hnHFVAYrvBLqdB5XEbURd3VGzXWbw-6bJ1sU,12488
|
4
4
|
mplang/analysis/__init__.py,sha256=CTHFvRsi-nFngojqjn08UaR3RY9i7CJ7T2UdR95kCrk,1056
|
5
5
|
mplang/analysis/diagram.py,sha256=ffwgD12gL1_KH1uJ_EYkjmIlDrfxYJJkWj-wHl09_Xk,19520
|
6
6
|
mplang/core/__init__.py,sha256=lWxlEKfRwX7FNDzgyKZ1fiDMaCiqkyg0j5mKlZD_v7g,2244
|
@@ -20,7 +20,7 @@ mplang/core/tensor.py,sha256=86u6DogSZMoL0w5XjtTmQm2PhA_VjwybN1b6U4Zzphg,2361
|
|
20
20
|
mplang/core/tracer.py,sha256=dVMfUeCMmPz4o6tLXewGMW1Kpy5gpZORvr9w4MhwDtM,14288
|
21
21
|
mplang/core/expr/__init__.py,sha256=qwiSTUOcanFJLyK8HZ13_L1ZDrybqpPXIlTHAyeumE8,1988
|
22
22
|
mplang/core/expr/ast.py,sha256=KE46KTtlH9RA2V_EzWVKCKolsycgTmt7SotUrOc8Qxs,20923
|
23
|
-
mplang/core/expr/evaluator.py,sha256=
|
23
|
+
mplang/core/expr/evaluator.py,sha256=EFy71vYUL2xLHCtMkWlYJpyGyujDdVSAx8ByET-62qQ,23297
|
24
24
|
mplang/core/expr/printer.py,sha256=VblKGnO0OUfzH7EBkszwRNxQUB8QyyC7BlJWJEUv9so,9546
|
25
25
|
mplang/core/expr/transformer.py,sha256=TyL-8FjrVvDq_C9X7kAuKkiqt2XdZM-okjzVQj0A33s,4893
|
26
26
|
mplang/core/expr/utils.py,sha256=VDTJ_-CsdHtVy9wDaGa7XdFxQ7o5lYYaeqcgsAhkbNI,2625
|
@@ -37,15 +37,15 @@ mplang/kernels/spu.py,sha256=vh-WG-uDVPvK11CDzc-f58sUalGts_eWJPjeuxLANfY,12508
|
|
37
37
|
mplang/kernels/sql_duckdb.py,sha256=vq4UCth_PCsH8dxcpx7mMnsq54KtjD8kxwISH7tj3qg,1631
|
38
38
|
mplang/kernels/stablehlo.py,sha256=gQyg-2ANyA1TjRg90MZ79mn4cHoXhU7g5GFUCuYXyKs,3231
|
39
39
|
mplang/kernels/value.py,sha256=2NZ0UvaLyRYb3bTCqL_fQXMzbHk1qbQ3j9xiv4v5E0A,20726
|
40
|
-
mplang/ops/__init__.py,sha256=
|
40
|
+
mplang/ops/__init__.py,sha256=ryPlNziAfc4rwVazdJ6VG7kuUD0bqlRZ2smY5HNIRGI,1018
|
41
41
|
mplang/ops/base.py,sha256=h67_SHWNZGUuTCuMll-9kDgGvlPhlFov7WAQCHTmUvw,18258
|
42
42
|
mplang/ops/builtin.py,sha256=D7T8rRF9g05VIw9T72lsncF5cDQqaT37eapBieRKvRI,9363
|
43
43
|
mplang/ops/crypto.py,sha256=9CeFJrYmvjmgx-3WQl6jHXh8VafRpT4QBunbzsPF8Uc,3646
|
44
|
-
mplang/ops/ibis_cc.py,sha256=
|
45
|
-
mplang/ops/jax_cc.py,sha256=
|
44
|
+
mplang/ops/ibis_cc.py,sha256=a5OqZVRZ1NzugQPYigdlJcGKbMZHqKh1xkiJen-LtCU,4242
|
45
|
+
mplang/ops/jax_cc.py,sha256=kVhJM8i8oPd-yPqyeaZ1hfVxcZPzNhTwjhltDh50hyY,7809
|
46
46
|
mplang/ops/phe.py,sha256=SatswExjZWPed8y3qA33BCwIWbvsgHCuCAz_pv2RLLw,6790
|
47
47
|
mplang/ops/spu.py,sha256=UHr5DSoqG08xDYER_11OsMVjGGNXXxsvkFoVvXU8uik,4989
|
48
|
-
mplang/ops/
|
48
|
+
mplang/ops/sql_cc.py,sha256=-9uf75gOxLQlFiKjDm75qIx8Gbun7unvkOxezdSLGlE,2112
|
49
49
|
mplang/ops/tee.py,sha256=bOpS_BXG12D6bONikzdF2yt0oVZj9Jyd0g_3IXP8VgE,1281
|
50
50
|
mplang/protos/v1alpha1/mpir_pb2.py,sha256=Bros37t-4LMJbuUYVSM65rImUYTtZDhNTIADGbZCKp0,7522
|
51
51
|
mplang/protos/v1alpha1/mpir_pb2.pyi,sha256=dLxAtFW7mgFR-HYxC4ExI4jbtEWUGTKBvcKhI3BJ8m0,20972
|
@@ -64,7 +64,7 @@ mplang/runtime/link_comm.py,sha256=ZHNcis8QDu2rcyyF3rhpxMiJDkczoMA_c0iZ2GDW_bA,2
|
|
64
64
|
mplang/runtime/server.py,sha256=CdmBmpbylEl7XeZj26i0rUmTrPTvl2CVdRgbtR02gcg,16543
|
65
65
|
mplang/runtime/session.py,sha256=I2711V-pPRCYibNgBhjboUURdubnL6ltCoh5RvFVabs,10641
|
66
66
|
mplang/runtime/simulation.py,sha256=1I_8dIqxivxtYnQK0ofz0oXk3lXXh3-zN0lmNFnWucA,11615
|
67
|
-
mplang/simp/__init__.py,sha256=
|
67
|
+
mplang/simp/__init__.py,sha256=U-JUvaO3EEg0cvEtU1Zihh674vJKQBCDy19gI9t1f-0,11581
|
68
68
|
mplang/simp/mpi.py,sha256=Wv_Q16TQ3rdLam6OzqXiefIGSMmagGkso09ycyOkHEs,4774
|
69
69
|
mplang/simp/random.py,sha256=7PVgWNL1j7Sf3MqT5PRiWplUu-0dyhF3Ub566iqX86M,3898
|
70
70
|
mplang/simp/smpc.py,sha256=tdH54aU4T-GIDPhpmf9NCeJC0G67PdOYc04cyUkOnwE,7119
|
@@ -73,8 +73,8 @@ mplang/utils/crypto.py,sha256=rvPomBFtznRHc3RPi6Aip9lsU8zW2oxBqGv1K3vn7Rs,1052
|
|
73
73
|
mplang/utils/func_utils.py,sha256=vCJcZmu0bEbqhOQKdpttV2_MBllIcPSN0b8U4WjNGGo,5164
|
74
74
|
mplang/utils/spu_utils.py,sha256=S3L9RBkBe2AvSuMSQQ12cBY5Y1NPthubvErSX_7nj1A,4158
|
75
75
|
mplang/utils/table_utils.py,sha256=aC-IZOKkSmFkpr3NZchLM0Wt0GOn-rg_xHBHREWBwAU,2202
|
76
|
-
mplang_nightly-0.1.
|
77
|
-
mplang_nightly-0.1.
|
78
|
-
mplang_nightly-0.1.
|
79
|
-
mplang_nightly-0.1.
|
80
|
-
mplang_nightly-0.1.
|
76
|
+
mplang_nightly-0.1.dev169.dist-info/METADATA,sha256=3Ml9Mvi3n9iBnvcVp7dp7lFJABk1hUoIP5AB5BbeQFE,16547
|
77
|
+
mplang_nightly-0.1.dev169.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
78
|
+
mplang_nightly-0.1.dev169.dist-info/entry_points.txt,sha256=mG1oJT-GAjQR834a62_QIWb7litzWPPyVnwFqm-rWuY,55
|
79
|
+
mplang_nightly-0.1.dev169.dist-info/licenses/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
|
80
|
+
mplang_nightly-0.1.dev169.dist-info/RECORD,,
|
File without changes
|
{mplang_nightly-0.1.dev167.dist-info → mplang_nightly-0.1.dev169.dist-info}/entry_points.txt
RENAMED
File without changes
|
{mplang_nightly-0.1.dev167.dist-info → mplang_nightly-0.1.dev169.dist-info}/licenses/LICENSE
RENAMED
File without changes
|