mplang-nightly 0.1.dev170__py3-none-any.whl → 0.1.dev172__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mplang/__init__.py +1 -1
- mplang/core/expr/printer.py +2 -6
- mplang/core/pfunc.py +2 -2
- mplang/core/primitive.py +12 -12
- mplang/device.py +6 -6
- mplang/kernels/{builtin.py → basic.py} +24 -24
- mplang/kernels/context.py +13 -13
- mplang/ops/__init__.py +2 -2
- mplang/ops/base.py +1 -1
- mplang/ops/{builtin.py → basic.py} +14 -14
- mplang/simp/__init__.py +3 -3
- {mplang_nightly-0.1.dev170.dist-info → mplang_nightly-0.1.dev172.dist-info}/METADATA +1 -1
- {mplang_nightly-0.1.dev170.dist-info → mplang_nightly-0.1.dev172.dist-info}/RECORD +17 -17
- /mplang/{api.py → host.py} +0 -0
- {mplang_nightly-0.1.dev170.dist-info → mplang_nightly-0.1.dev172.dist-info}/WHEEL +0 -0
- {mplang_nightly-0.1.dev170.dist-info → mplang_nightly-0.1.dev172.dist-info}/entry_points.txt +0 -0
- {mplang_nightly-0.1.dev170.dist-info → mplang_nightly-0.1.dev172.dist-info}/licenses/LICENSE +0 -0
mplang/__init__.py
CHANGED
@@ -24,7 +24,6 @@ except PackageNotFoundError:
|
|
24
24
|
__version__ = "0.0.0-dev"
|
25
25
|
|
26
26
|
from mplang import analysis
|
27
|
-
from mplang.api import CompileOptions, compile, evaluate, fetch
|
28
27
|
from mplang.core import (
|
29
28
|
DType,
|
30
29
|
InterpContext,
|
@@ -38,6 +37,7 @@ from mplang.core import (
|
|
38
37
|
)
|
39
38
|
from mplang.core.cluster import ClusterSpec, Device, Node, RuntimeInfo
|
40
39
|
from mplang.core.context_mgr import cur_ctx, set_ctx, with_ctx
|
40
|
+
from mplang.host import CompileOptions, compile, evaluate, fetch
|
41
41
|
from mplang.runtime.driver import Driver
|
42
42
|
from mplang.runtime.simulation import Simulator
|
43
43
|
|
mplang/core/expr/printer.py
CHANGED
@@ -164,13 +164,9 @@ class Printer(ExprVisitor):
|
|
164
164
|
arg_names = [self._var_name(arg) for arg in expr.args]
|
165
165
|
fn_type = expr.pfunc.fn_type
|
166
166
|
|
167
|
-
# for well known
|
168
|
-
if fn_type == "
|
167
|
+
# for well known basic functions
|
168
|
+
if fn_type == "basic.constant":
|
169
169
|
return self._print_const(expr.pfunc, expr.mptypes)
|
170
|
-
elif fn_type == "builtin.rank":
|
171
|
-
return self._do_print("prank", [], mptypes=expr.mptypes)
|
172
|
-
elif fn_type == "builtin.prand":
|
173
|
-
return self._do_print("prand", [], mptypes=expr.mptypes)
|
174
170
|
|
175
171
|
attrs = {"fn_type": fn_type}
|
176
172
|
if expr.pfunc.fn_name:
|
mplang/core/pfunc.py
CHANGED
@@ -33,7 +33,7 @@ class PFunction:
|
|
33
33
|
|
34
34
|
PFunction serves as a unified interface for describing single-party computations
|
35
35
|
in multi-party computing scenarios. It can represent both:
|
36
|
-
1. Built-in operations (e.g., "spu.makeshares", "
|
36
|
+
1. Built-in operations (e.g., "spu.makeshares", "basic.read")
|
37
37
|
2. User-defined programmable functions with custom code
|
38
38
|
|
39
39
|
The PFunction accepts a list of typed inputs (TensorType/TableType). For
|
@@ -47,7 +47,7 @@ class PFunction:
|
|
47
47
|
|
48
48
|
Args:
|
49
49
|
fn_type: The type/category identifier of this PFunction, indicating which
|
50
|
-
backend or handler should process it (e.g., "spu.makeshares", "
|
50
|
+
backend or handler should process it (e.g., "spu.makeshares", "basic.read",
|
51
51
|
"mlir.stablehlo"). This serves as a routing mechanism for execution.
|
52
52
|
ins_info: Type information for input parameters (TensorType or TableType)
|
53
53
|
outs_info: Type information for output values (TensorType or TableType)
|
mplang/core/primitive.py
CHANGED
@@ -48,7 +48,7 @@ from mplang.core.pfunc import PFunction
|
|
48
48
|
from mplang.core.table import TableLike
|
49
49
|
from mplang.core.tensor import ScalarType, Shape, TensorLike
|
50
50
|
from mplang.core.tracer import TraceContext, TraceVar, trace
|
51
|
-
from mplang.ops import
|
51
|
+
from mplang.ops import basic
|
52
52
|
from mplang.utils.func_utils import var_demorph, var_morph
|
53
53
|
|
54
54
|
|
@@ -133,10 +133,10 @@ def trace_before_apply(fn: Callable[P, R], make_call: bool) -> Callable[P, R]:
|
|
133
133
|
return wrapped
|
134
134
|
|
135
135
|
|
136
|
-
def
|
136
|
+
def bltin_function(fn: Callable[P, R]) -> Callable[P, R]:
|
137
137
|
"""Decorator to trace a Python function as an opaque primitive call (`CallExpr`).
|
138
138
|
|
139
|
-
When a function decorated with `@
|
139
|
+
When a function decorated with `@bltin_function` is called within a `TraceContext`, it is
|
140
140
|
not inlined. Instead, it is traced separately in a forked context, and a `CallExpr`
|
141
141
|
node is inserted into the main graph. This is useful for encapsulating complex
|
142
142
|
operations or third-party library calls as single, opaque nodes.
|
@@ -156,7 +156,7 @@ def primitive(fn: Callable[P, R]) -> Callable[P, R]:
|
|
156
156
|
|
157
157
|
Example:
|
158
158
|
```python
|
159
|
-
@
|
159
|
+
@bltin_function
|
160
160
|
def my_op(x: MPObject) -> MPObject:
|
161
161
|
# Complex logic traced as a single CallExpr node
|
162
162
|
return x + 1
|
@@ -222,7 +222,7 @@ def pmask() -> Mask:
|
|
222
222
|
return _tracer().mask
|
223
223
|
|
224
224
|
|
225
|
-
@
|
225
|
+
@bltin_function
|
226
226
|
def prank() -> MPObject:
|
227
227
|
"""Multi-party get the rank (party identifier) of each party.
|
228
228
|
|
@@ -243,12 +243,12 @@ def prank() -> MPObject:
|
|
243
243
|
Note:
|
244
244
|
Each party in the current party mask independently produces its own rank value.
|
245
245
|
"""
|
246
|
-
pfunc, eval_args, out_tree =
|
246
|
+
pfunc, eval_args, out_tree = basic.rank()
|
247
247
|
results = peval(pfunc, eval_args)
|
248
248
|
return out_tree.unflatten(results) # type: ignore[no-any-return]
|
249
249
|
|
250
250
|
|
251
|
-
@
|
251
|
+
@bltin_function
|
252
252
|
def prand(shape: Shape = ()) -> MPObject:
|
253
253
|
"""Multi-party generate a private random (uint64) tensor with the given shape.
|
254
254
|
|
@@ -271,7 +271,7 @@ def prand(shape: Shape = ()) -> MPObject:
|
|
271
271
|
private random values. The randomness is local to each party and is
|
272
272
|
not shared or revealed to other parties.
|
273
273
|
"""
|
274
|
-
pfunc, eval_args, out_tree =
|
274
|
+
pfunc, eval_args, out_tree = basic.prand(shape)
|
275
275
|
results = peval(pfunc, eval_args)
|
276
276
|
return out_tree.unflatten(results) # type: ignore[no-any-return]
|
277
277
|
|
@@ -305,19 +305,19 @@ def constant(data: TensorLike | ScalarType | TableLike) -> MPObject:
|
|
305
305
|
Note that the constant primitive is not designed to carry large tables efficiently -
|
306
306
|
consider using dedicated table loading mechanisms for substantial datasets.
|
307
307
|
"""
|
308
|
-
pfunc, eval_args, out_tree =
|
308
|
+
pfunc, eval_args, out_tree = basic.constant(data)
|
309
309
|
results = peval(pfunc, eval_args)
|
310
310
|
return out_tree.unflatten(results) # type: ignore[no-any-return]
|
311
311
|
|
312
312
|
|
313
|
-
@
|
313
|
+
@bltin_function
|
314
314
|
def debug_print(obj: MPObject, prefix: str = "") -> MPObject:
|
315
315
|
"""Print local value of obj on owning parties and pass it through.
|
316
316
|
|
317
317
|
Returns the same MPObject value to keep it alive against DCE and to
|
318
318
|
support usage like: x = debug_print(x, prefix="x=").
|
319
319
|
"""
|
320
|
-
pfunc, eval_args, out_tree =
|
320
|
+
pfunc, eval_args, out_tree = basic.debug_print(obj, prefix=prefix)
|
321
321
|
results = peval(pfunc, eval_args)
|
322
322
|
return out_tree.unflatten(results) # type: ignore[no-any-return]
|
323
323
|
|
@@ -445,7 +445,7 @@ def set_mask(arg: MPObject, mask: Mask) -> MPObject:
|
|
445
445
|
The underlying implementation uses JAX identity function with the
|
446
446
|
specified execution mask.
|
447
447
|
"""
|
448
|
-
pfunc, eval_args, out_tree =
|
448
|
+
pfunc, eval_args, out_tree = basic.identity(arg)
|
449
449
|
results = peval(pfunc, eval_args, mask)
|
450
450
|
return out_tree.unflatten(results) # type: ignore[no-any-return]
|
451
451
|
|
mplang/device.py
CHANGED
@@ -29,13 +29,13 @@ from typing import Any
|
|
29
29
|
|
30
30
|
from jax.tree_util import tree_map
|
31
31
|
|
32
|
-
import mplang.
|
32
|
+
import mplang.host as mapi
|
33
33
|
from mplang import simp
|
34
34
|
from mplang.core import InterpContext, MPObject, primitive
|
35
35
|
from mplang.core.cluster import ClusterSpec, Device
|
36
36
|
from mplang.core.context_mgr import cur_ctx
|
37
37
|
from mplang.core.tensor import TensorType
|
38
|
-
from mplang.ops import
|
38
|
+
from mplang.ops import basic, crypto, ibis_cc, jax_cc, tee
|
39
39
|
from mplang.ops.base import FeOperation
|
40
40
|
from mplang.ops.ibis_cc import IbisRunner
|
41
41
|
from mplang.ops.jax_cc import JaxRunner
|
@@ -209,11 +209,11 @@ def _d2d(to_dev_id: str, obj: MPObject) -> MPObject:
|
|
209
209
|
sess_p, sess_t = _ensure_tee_session(frm_dev_id, to_dev_id, frm_rank, tee_rank)
|
210
210
|
# Bytes-only path: pack -> enc -> p2p -> dec -> unpack (with static out type)
|
211
211
|
obj_ty = TensorType.from_obj(obj)
|
212
|
-
b = simp.runAt(frm_rank,
|
212
|
+
b = simp.runAt(frm_rank, basic.pack)(obj)
|
213
213
|
ct = simp.runAt(frm_rank, crypto.enc)(b, sess_p)
|
214
214
|
ct_at_tee = mpi.p2p(frm_rank, tee_rank, ct)
|
215
215
|
b_at_tee = simp.runAt(tee_rank, crypto.dec)(ct_at_tee, sess_t)
|
216
|
-
pt_at_tee = simp.runAt(tee_rank,
|
216
|
+
pt_at_tee = simp.runAt(tee_rank, basic.unpack)(b_at_tee, out_ty=obj_ty)
|
217
217
|
return tree_map(partial(_set_devid, dev_id=to_dev_id), pt_at_tee) # type: ignore[no-any-return]
|
218
218
|
elif frm_to_pair == ("TEE", "PPU"):
|
219
219
|
# Transparent encryption from TEE to a specific PPU using the reverse-direction session key
|
@@ -223,11 +223,11 @@ def _d2d(to_dev_id: str, obj: MPObject) -> MPObject:
|
|
223
223
|
# Ensure bidirectional session established for this pair
|
224
224
|
sess_p, sess_t = _ensure_tee_session(to_dev_id, frm_dev_id, ppu_rank, tee_rank)
|
225
225
|
obj_ty = TensorType.from_obj(obj)
|
226
|
-
b = simp.runAt(tee_rank,
|
226
|
+
b = simp.runAt(tee_rank, basic.pack)(obj)
|
227
227
|
ct = simp.runAt(tee_rank, crypto.enc)(b, sess_t)
|
228
228
|
ct_at_ppu = mpi.p2p(tee_rank, ppu_rank, ct)
|
229
229
|
b_at_ppu = simp.runAt(ppu_rank, crypto.dec)(ct_at_ppu, sess_p)
|
230
|
-
pt_at_ppu = simp.runAt(ppu_rank,
|
230
|
+
pt_at_ppu = simp.runAt(ppu_rank, basic.unpack)(b_at_ppu, out_ty=obj_ty)
|
231
231
|
return tree_map(partial(_set_devid, dev_id=to_dev_id), pt_at_ppu) # type: ignore[no-any-return]
|
232
232
|
else:
|
233
233
|
supported = [
|
@@ -25,17 +25,17 @@ from mplang.runtime.data_providers import get_provider, resolve_uri
|
|
25
25
|
from mplang.utils import table_utils
|
26
26
|
|
27
27
|
|
28
|
-
@kernel_def("
|
28
|
+
@kernel_def("basic.identity")
|
29
29
|
def _identity(pfunc: PFunction, value: Value) -> Value:
|
30
30
|
# Runtime guarantees exactly one argument; no extra arity checks here.
|
31
31
|
return value
|
32
32
|
|
33
33
|
|
34
|
-
@kernel_def("
|
34
|
+
@kernel_def("basic.read")
|
35
35
|
def _read(pfunc: PFunction) -> Value:
|
36
36
|
path = pfunc.attrs.get("path")
|
37
37
|
if path is None:
|
38
|
-
raise ValueError("missing path attr for
|
38
|
+
raise ValueError("missing path attr for basic.read")
|
39
39
|
out_t = pfunc.outs_info[0]
|
40
40
|
uri = resolve_uri(str(path))
|
41
41
|
prov = get_provider(uri.scheme)
|
@@ -45,7 +45,7 @@ def _read(pfunc: PFunction) -> Value:
|
|
45
45
|
try:
|
46
46
|
data = prov.read(uri, out_t, ctx=ctx)
|
47
47
|
except Exception as e: # pragma: no cover - provider errors
|
48
|
-
raise RuntimeError(f"
|
48
|
+
raise RuntimeError(f"basic.read failed: {e}") from e
|
49
49
|
|
50
50
|
if isinstance(out_t, TableType):
|
51
51
|
if isinstance(data, TableValue):
|
@@ -56,15 +56,15 @@ def _read(pfunc: PFunction) -> Value:
|
|
56
56
|
return data
|
57
57
|
return TensorValue(np.asarray(data))
|
58
58
|
raise TypeError(
|
59
|
-
f"
|
59
|
+
f"basic.read only supports TableType/TensorType outputs, got {type(out_t).__name__}"
|
60
60
|
)
|
61
61
|
|
62
62
|
|
63
|
-
@kernel_def("
|
63
|
+
@kernel_def("basic.write")
|
64
64
|
def _write(pfunc: PFunction, obj: Value) -> Value:
|
65
65
|
path = pfunc.attrs.get("path")
|
66
66
|
if path is None:
|
67
|
-
raise ValueError("missing path attr for
|
67
|
+
raise ValueError("missing path attr for basic.write")
|
68
68
|
uri = resolve_uri(str(path))
|
69
69
|
prov = get_provider(uri.scheme)
|
70
70
|
if prov is None:
|
@@ -74,16 +74,16 @@ def _write(pfunc: PFunction, obj: Value) -> Value:
|
|
74
74
|
try:
|
75
75
|
prov.write(uri, obj, ctx=ctx)
|
76
76
|
except Exception as e: # pragma: no cover
|
77
|
-
raise RuntimeError(f"
|
77
|
+
raise RuntimeError(f"basic.write failed: {e}") from e
|
78
78
|
return obj
|
79
79
|
|
80
80
|
|
81
|
-
@kernel_def("
|
81
|
+
@kernel_def("basic.constant")
|
82
82
|
def _constant(pfunc: PFunction) -> Value:
|
83
83
|
"""Return constants as Value types (TensorValue or TableValue)."""
|
84
84
|
data_bytes = pfunc.attrs.get("data_bytes")
|
85
85
|
if data_bytes is None:
|
86
|
-
raise ValueError("missing data_bytes attr for
|
86
|
+
raise ValueError("missing data_bytes attr for basic.constant")
|
87
87
|
out_t = pfunc.outs_info[0]
|
88
88
|
fmt = pfunc.attrs.get("data_format")
|
89
89
|
if isinstance(out_t, TableType):
|
@@ -98,7 +98,7 @@ def _constant(pfunc: PFunction) -> Value:
|
|
98
98
|
return TensorValue(arr)
|
99
99
|
|
100
100
|
|
101
|
-
@kernel_def("
|
101
|
+
@kernel_def("basic.rank")
|
102
102
|
def _rank(pfunc: PFunction) -> TensorValue:
|
103
103
|
"""Return rank as TensorValue."""
|
104
104
|
ctx = cur_kctx()
|
@@ -106,7 +106,7 @@ def _rank(pfunc: PFunction) -> TensorValue:
|
|
106
106
|
return TensorValue(arr)
|
107
107
|
|
108
108
|
|
109
|
-
@kernel_def("
|
109
|
+
@kernel_def("basic.prand")
|
110
110
|
def _prand(pfunc: PFunction) -> TensorValue:
|
111
111
|
"""Return random data as TensorValue."""
|
112
112
|
shape = pfunc.attrs.get("shape", ())
|
@@ -118,7 +118,7 @@ def _prand(pfunc: PFunction) -> TensorValue:
|
|
118
118
|
return TensorValue(data)
|
119
119
|
|
120
120
|
|
121
|
-
@kernel_def("
|
121
|
+
@kernel_def("basic.table_to_tensor")
|
122
122
|
def _table_to_tensor(pfunc: PFunction, table: TableValue) -> TensorValue:
|
123
123
|
"""Convert table to tensor, return as TensorValue."""
|
124
124
|
arrow_table = table.to_arrow()
|
@@ -131,7 +131,7 @@ def _table_to_tensor(pfunc: PFunction, table: TableValue) -> TensorValue:
|
|
131
131
|
return TensorValue(mat)
|
132
132
|
|
133
133
|
|
134
|
-
@kernel_def("
|
134
|
+
@kernel_def("basic.tensor_to_table")
|
135
135
|
def _tensor_to_table(pfunc: PFunction, tensor: TensorValue) -> TableValue:
|
136
136
|
"""Convert tensor to table, return as TableValue."""
|
137
137
|
import pyarrow as pa # type: ignore
|
@@ -168,7 +168,7 @@ def _summ(v: Value) -> str:
|
|
168
168
|
return f"<unprintable {type(v).__name__}: {e}>"
|
169
169
|
|
170
170
|
|
171
|
-
@kernel_def("
|
171
|
+
@kernel_def("basic.debug_print")
|
172
172
|
def _debug_print(pfunc: PFunction, val: Value) -> Value:
|
173
173
|
prefix = pfunc.attrs.get("prefix", "")
|
174
174
|
ctx = cur_kctx()
|
@@ -176,16 +176,16 @@ def _debug_print(pfunc: PFunction, val: Value) -> Value:
|
|
176
176
|
return val
|
177
177
|
|
178
178
|
|
179
|
-
@kernel_def("
|
179
|
+
@kernel_def("basic.pack")
|
180
180
|
def _pack(pfunc: PFunction, value: Value) -> TensorValue:
|
181
181
|
outs_info = pfunc.outs_info
|
182
182
|
if len(outs_info) != 1:
|
183
|
-
raise ValueError("
|
183
|
+
raise ValueError("basic.pack expects single output type")
|
184
184
|
out_ty = outs_info[0]
|
185
185
|
if not isinstance(out_ty, TensorType):
|
186
|
-
raise TypeError("
|
186
|
+
raise TypeError("basic.pack must return TensorType")
|
187
187
|
if out_ty.dtype.numpy_dtype() != np.uint8:
|
188
|
-
raise TypeError("
|
188
|
+
raise TypeError("basic.pack output dtype must be uint8")
|
189
189
|
|
190
190
|
if isinstance(value, TableValue):
|
191
191
|
# Serialize Arrow table using IPC stream for consistency with Value serde
|
@@ -203,14 +203,14 @@ def _pack(pfunc: PFunction, value: Value) -> TensorValue:
|
|
203
203
|
arr = value.to_numpy()
|
204
204
|
return TensorValue(np.frombuffer(arr.tobytes(order="C"), dtype=np.uint8))
|
205
205
|
|
206
|
-
raise TypeError(f"
|
206
|
+
raise TypeError(f"basic.pack does not support Value type {type(value).__name__}")
|
207
207
|
|
208
208
|
|
209
|
-
@kernel_def("
|
209
|
+
@kernel_def("basic.unpack")
|
210
210
|
def _unpack(pfunc: PFunction, packed: TensorValue) -> Value:
|
211
211
|
outs_info = pfunc.outs_info
|
212
212
|
if len(outs_info) != 1:
|
213
|
-
raise ValueError("
|
213
|
+
raise ValueError("basic.unpack expects single output type")
|
214
214
|
out_ty = outs_info[0]
|
215
215
|
|
216
216
|
b = packed.to_numpy().astype(np.uint8, copy=False).reshape(-1)
|
@@ -219,7 +219,7 @@ def _unpack(pfunc: PFunction, packed: TensorValue) -> Value:
|
|
219
219
|
np_dtype = out_ty.dtype.numpy_dtype()
|
220
220
|
shape = tuple(out_ty.shape)
|
221
221
|
if any(dim < 0 for dim in shape):
|
222
|
-
raise ValueError("
|
222
|
+
raise ValueError("basic.unpack does not support dynamic tensor shapes")
|
223
223
|
elem_count = int(np.prod(shape))
|
224
224
|
expected = elem_count * np.dtype(np_dtype).itemsize
|
225
225
|
if b.size != expected:
|
@@ -239,4 +239,4 @@ def _unpack(pfunc: PFunction, packed: TensorValue) -> Value:
|
|
239
239
|
table = reader.read_all()
|
240
240
|
return TableValue(table)
|
241
241
|
|
242
|
-
raise TypeError("
|
242
|
+
raise TypeError("basic.unpack output type must be TensorType or TableType")
|
mplang/kernels/context.py
CHANGED
@@ -35,7 +35,7 @@ def _ensure_impl_imported() -> None:
|
|
35
35
|
global _IMPL_IMPORTED
|
36
36
|
if _IMPL_IMPORTED:
|
37
37
|
return
|
38
|
-
from mplang.kernels import
|
38
|
+
from mplang.kernels import basic as _impl_basic # noqa: F401
|
39
39
|
from mplang.kernels import crypto as _impl_crypto # noqa: F401
|
40
40
|
from mplang.kernels import mock_tee as _impl_tee # noqa: F401
|
41
41
|
from mplang.kernels import phe as _impl_phe # noqa: F401
|
@@ -49,18 +49,18 @@ def _ensure_impl_imported() -> None:
|
|
49
49
|
# imports consolidated above
|
50
50
|
|
51
51
|
_DEFAULT_BINDINGS: dict[str, str] = {
|
52
|
-
#
|
53
|
-
"
|
54
|
-
"
|
55
|
-
"
|
56
|
-
"
|
57
|
-
"
|
58
|
-
"
|
59
|
-
"
|
60
|
-
"
|
61
|
-
"
|
62
|
-
"
|
63
|
-
"
|
52
|
+
# basic
|
53
|
+
"basic.identity": "basic.identity",
|
54
|
+
"basic.read": "basic.read",
|
55
|
+
"basic.write": "basic.write",
|
56
|
+
"basic.constant": "basic.constant",
|
57
|
+
"basic.rank": "basic.rank",
|
58
|
+
"basic.prand": "basic.prand",
|
59
|
+
"basic.table_to_tensor": "basic.table_to_tensor",
|
60
|
+
"basic.tensor_to_table": "basic.tensor_to_table",
|
61
|
+
"basic.debug_print": "basic.debug_print",
|
62
|
+
"basic.pack": "basic.pack",
|
63
|
+
"basic.unpack": "basic.unpack",
|
64
64
|
# crypto
|
65
65
|
"crypto.keygen": "crypto.keygen",
|
66
66
|
"crypto.enc": "crypto.enc",
|
mplang/ops/__init__.py
CHANGED
@@ -19,12 +19,12 @@ This module contains compilers that transform high-level functions into
|
|
19
19
|
portable, serializable intermediate representations.
|
20
20
|
"""
|
21
21
|
|
22
|
-
from mplang.ops import
|
22
|
+
from mplang.ops import basic, crypto, ibis_cc, jax_cc, phe, spu, sql_cc, tee
|
23
23
|
from mplang.ops.base import FeOperation as FeOperation
|
24
24
|
|
25
25
|
__all__ = [
|
26
26
|
"FeOperation",
|
27
|
-
"
|
27
|
+
"basic",
|
28
28
|
"crypto",
|
29
29
|
"ibis_cc",
|
30
30
|
"jax_cc",
|
mplang/ops/base.py
CHANGED
@@ -129,7 +129,7 @@ class FeModule(ABC):
|
|
129
129
|
- You need compilation/stateful behavior/dynamic routing, multiple PFunctions, or complex capture flows.
|
130
130
|
|
131
131
|
Tips:
|
132
|
-
- Keep routing information in PFunction.fn_type (e.g., "
|
132
|
+
- Keep routing information in PFunction.fn_type (e.g., "basic.read", "sql.run", "mlir.stablehlo").
|
133
133
|
- Avoid backend-specific logic in kernels; only validate and shape types.
|
134
134
|
- Prefer keyword-only attributes in typed_op kernels for clarity (def op(x: MPObject, *, attr: int)).
|
135
135
|
"""
|
@@ -23,10 +23,10 @@ from mplang.core.tensor import ScalarType, Shape, TensorLike, TensorType
|
|
23
23
|
from mplang.ops.base import stateless_mod
|
24
24
|
from mplang.utils import table_utils
|
25
25
|
|
26
|
-
|
26
|
+
_BASIC_MOD = stateless_mod("basic")
|
27
27
|
|
28
28
|
|
29
|
-
@
|
29
|
+
@_BASIC_MOD.simple_op()
|
30
30
|
def identity(x: TensorType) -> TensorType:
|
31
31
|
"""Return the input type unchanged.
|
32
32
|
|
@@ -40,7 +40,7 @@ def identity(x: TensorType) -> TensorType:
|
|
40
40
|
return x
|
41
41
|
|
42
42
|
|
43
|
-
@
|
43
|
+
@_BASIC_MOD.simple_op()
|
44
44
|
def read(*, path: str, ty: TensorType) -> TensorType:
|
45
45
|
"""Declare reading a value of type ``ty`` from ``path`` (type-only).
|
46
46
|
|
@@ -63,7 +63,7 @@ def read(*, path: str, ty: TensorType) -> TensorType:
|
|
63
63
|
return ty
|
64
64
|
|
65
65
|
|
66
|
-
@
|
66
|
+
@_BASIC_MOD.simple_op()
|
67
67
|
def write(x: TensorType, *, path: str) -> TensorType:
|
68
68
|
"""Declare writing the input value to ``path`` and return the same type.
|
69
69
|
|
@@ -77,7 +77,7 @@ def write(x: TensorType, *, path: str) -> TensorType:
|
|
77
77
|
return x
|
78
78
|
|
79
79
|
|
80
|
-
@
|
80
|
+
@_BASIC_MOD.op_def()
|
81
81
|
def constant(
|
82
82
|
data: TensorLike | ScalarType | TableLike,
|
83
83
|
) -> tuple[PFunction, list[MPObject], PyTreeDef]:
|
@@ -89,7 +89,7 @@ def constant(
|
|
89
89
|
|
90
90
|
Returns:
|
91
91
|
Tuple[PFunction, list[MPObject], PyTreeDef]:
|
92
|
-
- PFunction: ``fn_type='
|
92
|
+
- PFunction: ``fn_type='basic.constant'`` with one output whose type
|
93
93
|
matches ``data``; payload serialized via ``data_bytes`` with
|
94
94
|
``data_format`` ('bytes[numpy]' or 'bytes[csv]').
|
95
95
|
- list[MPObject]: Empty (no inputs captured).
|
@@ -120,7 +120,7 @@ def constant(
|
|
120
120
|
data_format = "bytes[numpy]"
|
121
121
|
|
122
122
|
pfunc = PFunction(
|
123
|
-
fn_type="
|
123
|
+
fn_type="basic.constant",
|
124
124
|
ins_info=(),
|
125
125
|
outs_info=(out_type,),
|
126
126
|
data_bytes=data_bytes,
|
@@ -130,7 +130,7 @@ def constant(
|
|
130
130
|
return pfunc, [], treedef
|
131
131
|
|
132
132
|
|
133
|
-
@
|
133
|
+
@_BASIC_MOD.simple_op()
|
134
134
|
def rank() -> TensorType:
|
135
135
|
"""Return the scalar UINT64 tensor type for the current party rank.
|
136
136
|
|
@@ -140,7 +140,7 @@ def rank() -> TensorType:
|
|
140
140
|
return TensorType(UINT64, ())
|
141
141
|
|
142
142
|
|
143
|
-
@
|
143
|
+
@_BASIC_MOD.simple_op()
|
144
144
|
def prand(*, shape: Shape = ()) -> TensorType:
|
145
145
|
"""Declare a private random UINT64 tensor with the given shape.
|
146
146
|
|
@@ -153,7 +153,7 @@ def prand(*, shape: Shape = ()) -> TensorType:
|
|
153
153
|
return TensorType(UINT64, shape)
|
154
154
|
|
155
155
|
|
156
|
-
@
|
156
|
+
@_BASIC_MOD.simple_op()
|
157
157
|
def debug_print(
|
158
158
|
x: TensorType | TableType, *, prefix: str = ""
|
159
159
|
) -> TableType | TensorType:
|
@@ -169,7 +169,7 @@ def debug_print(
|
|
169
169
|
return x
|
170
170
|
|
171
171
|
|
172
|
-
@
|
172
|
+
@_BASIC_MOD.simple_op()
|
173
173
|
def pack(x: TensorType | TableType) -> TensorType:
|
174
174
|
"""Serialize a tensor/table into a byte vector (type-only).
|
175
175
|
|
@@ -189,7 +189,7 @@ def pack(x: TensorType | TableType) -> TensorType:
|
|
189
189
|
return TensorType(UINT8, (-1,))
|
190
190
|
|
191
191
|
|
192
|
-
@
|
192
|
+
@_BASIC_MOD.simple_op()
|
193
193
|
def unpack(b: TensorType, *, out_ty: TensorType | TableType) -> TensorType | TableType:
|
194
194
|
"""Deserialize a byte vector into the explicit output type.
|
195
195
|
|
@@ -215,7 +215,7 @@ def unpack(b: TensorType, *, out_ty: TensorType | TableType) -> TensorType | Tab
|
|
215
215
|
return out_ty
|
216
216
|
|
217
217
|
|
218
|
-
@
|
218
|
+
@_BASIC_MOD.simple_op()
|
219
219
|
def table_to_tensor(table: TableType, *, number_rows: int) -> TensorType:
|
220
220
|
"""Convert a homogeneous-typed table to a dense 2D tensor.
|
221
221
|
|
@@ -248,7 +248,7 @@ def table_to_tensor(table: TableType, *, number_rows: int) -> TensorType:
|
|
248
248
|
return TensorType(first, shape) # type: ignore[arg-type]
|
249
249
|
|
250
250
|
|
251
|
-
@
|
251
|
+
@_BASIC_MOD.simple_op()
|
252
252
|
def tensor_to_table(tensor: TensorType, *, column_names: list[str]) -> TableType:
|
253
253
|
"""Convert a rank-2 tensor into a table with named columns.
|
254
254
|
|
mplang/simp/__init__.py
CHANGED
@@ -105,7 +105,7 @@ def run_impl(
|
|
105
105
|
The result of evaluating the function through the appropriate handler
|
106
106
|
|
107
107
|
Raises:
|
108
|
-
ValueError: If
|
108
|
+
ValueError: If basic.write is called without required arguments
|
109
109
|
TypeError: If the function compilation or evaluation fails
|
110
110
|
RuntimeError: If the underlying peval execution encounters errors
|
111
111
|
|
@@ -114,11 +114,11 @@ def run_impl(
|
|
114
114
|
|
115
115
|
>>> tensor_info = TensorType(shape=(10, 10), dtype=np.float32)
|
116
116
|
>>> attrs = {"format": "binary"}
|
117
|
-
>>> result = run_impl(
|
117
|
+
>>> result = run_impl(basic.read, "data/input.bin", tensor_info, attrs)
|
118
118
|
|
119
119
|
Writing data to a file:
|
120
120
|
|
121
|
-
>>> run_impl(
|
121
|
+
>>> run_impl(basic.write, data, "data/output.bin")
|
122
122
|
|
123
123
|
Running a JAX function:
|
124
124
|
|
@@ -1,6 +1,6 @@
|
|
1
|
-
mplang/__init__.py,sha256=
|
2
|
-
mplang/
|
3
|
-
mplang/
|
1
|
+
mplang/__init__.py,sha256=tZmMm5LXY5QXKDvyYw1ysSStRky8CHelPQ9T0qgmfGs,1853
|
2
|
+
mplang/device.py,sha256=b9H1I-MFtL7hvJ38xoq65QdH1vGgPl9fNLQ-IwZWRvE,12479
|
3
|
+
mplang/host.py,sha256=ssmv0_CyZPFORhOUJ84Jo6NwRJSK7_Ono3n7ZjEg4sA,3058
|
4
4
|
mplang/analysis/__init__.py,sha256=CTHFvRsi-nFngojqjn08UaR3RY9i7CJ7T2UdR95kCrk,1056
|
5
5
|
mplang/analysis/diagram.py,sha256=ffwgD12gL1_KH1uJ_EYkjmIlDrfxYJJkWj-wHl09_Xk,19520
|
6
6
|
mplang/core/__init__.py,sha256=lWxlEKfRwX7FNDzgyKZ1fiDMaCiqkyg0j5mKlZD_v7g,2244
|
@@ -13,23 +13,23 @@ mplang/core/mask.py,sha256=14DFxaA446lGjN4dzTuQgm9Shcn34rYI87YJHg0YGNQ,10693
|
|
13
13
|
mplang/core/mpir.py,sha256=3NyHa1cDnUaw3XWIUgyOMXfZ9JS-30COb29AoXYcRtM,38251
|
14
14
|
mplang/core/mpobject.py,sha256=0pHSd7SrAFTScCFcB9ziDztElYQn-oIZOKBx47B3QX0,3732
|
15
15
|
mplang/core/mptype.py,sha256=7Cp2e58uUX-uqTp6QxuioOMJ8BzLBPXlWG5rRakv2uo,13773
|
16
|
-
mplang/core/pfunc.py,sha256=
|
17
|
-
mplang/core/primitive.py,sha256=
|
16
|
+
mplang/core/pfunc.py,sha256=WOGmMr4HCUELED5QxYbhhyQJRDXrA5Bk3tPbZWpwmw8,5102
|
17
|
+
mplang/core/primitive.py,sha256=vu60-k0fSAWWidcWDC0_FGvrRZww12oGXjB8CR9F6Yo,43889
|
18
18
|
mplang/core/table.py,sha256=BqTBZn7Tfwce4vzl3XYhaX5hVmKagVq9-YoERDta6d8,5892
|
19
19
|
mplang/core/tensor.py,sha256=86u6DogSZMoL0w5XjtTmQm2PhA_VjwybN1b6U4Zzphg,2361
|
20
20
|
mplang/core/tracer.py,sha256=dVMfUeCMmPz4o6tLXewGMW1Kpy5gpZORvr9w4MhwDtM,14288
|
21
21
|
mplang/core/expr/__init__.py,sha256=qwiSTUOcanFJLyK8HZ13_L1ZDrybqpPXIlTHAyeumE8,1988
|
22
22
|
mplang/core/expr/ast.py,sha256=K-rNqlpgkdjVzwSrLgunYnL4zWl1USJGLOgfz0qJNO4,20959
|
23
23
|
mplang/core/expr/evaluator.py,sha256=rpzZQPPVtxBvUuCx-9_bFmzr_7tfAQjPlP_rqpWjgIo,23313
|
24
|
-
mplang/core/expr/printer.py,sha256=
|
24
|
+
mplang/core/expr/printer.py,sha256=Ec6tCLtOUYqu0i1ZmtRvuSLjGpMqB30SM5CYZ_l2CqA,9660
|
25
25
|
mplang/core/expr/transformer.py,sha256=gez9eedVsWoLasSgWvPmGR8WfQnGXPlldWeVFEjqyYo,4904
|
26
26
|
mplang/core/expr/utils.py,sha256=VDTJ_-CsdHtVy9wDaGa7XdFxQ7o5lYYaeqcgsAhkbNI,2625
|
27
27
|
mplang/core/expr/visitor.py,sha256=2Ge-I5N-wH8VVXy8d2WyNaEv8x6seiRx9peyH9S2BYU,2044
|
28
28
|
mplang/core/expr/walk.py,sha256=lXkGJEEuvKGDqQihbxXPxfz2RfR1Q1zYUlt11iooQW0,11889
|
29
29
|
mplang/kernels/__init__.py,sha256=eooIUklLSg-cvyGk6uDSwZ3bUAjM6AXtHw_YdbUamYo,1052
|
30
30
|
mplang/kernels/base.py,sha256=-YV4Aj5fs6GT4ehS6Tyi8WQ-amxn5edHTFJRQzyjHXY,3826
|
31
|
-
mplang/kernels/
|
32
|
-
mplang/kernels/context.py,sha256=
|
31
|
+
mplang/kernels/basic.py,sha256=thE4jAoozsk_3_t7ahyQj8J9QQlC79CyNRPCqHnyQhk,8967
|
32
|
+
mplang/kernels/context.py,sha256=OTXuqZ8ziu1fVXem6lIn4DMygOAeSgySm_-dfzYakEA,13552
|
33
33
|
mplang/kernels/crypto.py,sha256=y5epCht71QenQnSbn5xRB0DCnb55Wm83iBJ1KRUedUU,4323
|
34
34
|
mplang/kernels/mock_tee.py,sha256=IMTIy5-tEMqB8bD1FG0Ki5UcH8dLAQUFIlFnktXUDX0,2492
|
35
35
|
mplang/kernels/phe.py,sha256=xglWuxPiclKVL1_YHgeN4KSUudqKwP5c4IskB-QjC1Y,72711
|
@@ -37,9 +37,9 @@ mplang/kernels/spu.py,sha256=vh-WG-uDVPvK11CDzc-f58sUalGts_eWJPjeuxLANfY,12508
|
|
37
37
|
mplang/kernels/sql_duckdb.py,sha256=vq4UCth_PCsH8dxcpx7mMnsq54KtjD8kxwISH7tj3qg,1631
|
38
38
|
mplang/kernels/stablehlo.py,sha256=gQyg-2ANyA1TjRg90MZ79mn4cHoXhU7g5GFUCuYXyKs,3231
|
39
39
|
mplang/kernels/value.py,sha256=2NZ0UvaLyRYb3bTCqL_fQXMzbHk1qbQ3j9xiv4v5E0A,20726
|
40
|
-
mplang/ops/__init__.py,sha256=
|
41
|
-
mplang/ops/base.py,sha256=
|
42
|
-
mplang/ops/
|
40
|
+
mplang/ops/__init__.py,sha256=QIRs49KTG-vUavhUiRHuYFZWrlKh97kwb2nVGHmxII0,1014
|
41
|
+
mplang/ops/base.py,sha256=HohQ0I39dda576mMhlzQp53oOWdRdUTsvn42t8QP_68,18256
|
42
|
+
mplang/ops/basic.py,sha256=Mj_4x9SUs0m0F3lzZYU-LIzDdBPyrsDvPJSrUUkIzPE,9333
|
43
43
|
mplang/ops/crypto.py,sha256=9CeFJrYmvjmgx-3WQl6jHXh8VafRpT4QBunbzsPF8Uc,3646
|
44
44
|
mplang/ops/ibis_cc.py,sha256=a5OqZVRZ1NzugQPYigdlJcGKbMZHqKh1xkiJen-LtCU,4242
|
45
45
|
mplang/ops/jax_cc.py,sha256=kVhJM8i8oPd-yPqyeaZ1hfVxcZPzNhTwjhltDh50hyY,7809
|
@@ -64,7 +64,7 @@ mplang/runtime/link_comm.py,sha256=ZHNcis8QDu2rcyyF3rhpxMiJDkczoMA_c0iZ2GDW_bA,2
|
|
64
64
|
mplang/runtime/server.py,sha256=CdmBmpbylEl7XeZj26i0rUmTrPTvl2CVdRgbtR02gcg,16543
|
65
65
|
mplang/runtime/session.py,sha256=I2711V-pPRCYibNgBhjboUURdubnL6ltCoh5RvFVabs,10641
|
66
66
|
mplang/runtime/simulation.py,sha256=1I_8dIqxivxtYnQK0ofz0oXk3lXXh3-zN0lmNFnWucA,11615
|
67
|
-
mplang/simp/__init__.py,sha256=
|
67
|
+
mplang/simp/__init__.py,sha256=_X1kpq9qhoPUL2gQRVNkSjS5jNSxYRu7-_1GSwJ9PK8,11575
|
68
68
|
mplang/simp/mpi.py,sha256=Wv_Q16TQ3rdLam6OzqXiefIGSMmagGkso09ycyOkHEs,4774
|
69
69
|
mplang/simp/random.py,sha256=7PVgWNL1j7Sf3MqT5PRiWplUu-0dyhF3Ub566iqX86M,3898
|
70
70
|
mplang/simp/smpc.py,sha256=tdH54aU4T-GIDPhpmf9NCeJC0G67PdOYc04cyUkOnwE,7119
|
@@ -73,8 +73,8 @@ mplang/utils/crypto.py,sha256=rvPomBFtznRHc3RPi6Aip9lsU8zW2oxBqGv1K3vn7Rs,1052
|
|
73
73
|
mplang/utils/func_utils.py,sha256=vCJcZmu0bEbqhOQKdpttV2_MBllIcPSN0b8U4WjNGGo,5164
|
74
74
|
mplang/utils/spu_utils.py,sha256=S3L9RBkBe2AvSuMSQQ12cBY5Y1NPthubvErSX_7nj1A,4158
|
75
75
|
mplang/utils/table_utils.py,sha256=aC-IZOKkSmFkpr3NZchLM0Wt0GOn-rg_xHBHREWBwAU,2202
|
76
|
-
mplang_nightly-0.1.
|
77
|
-
mplang_nightly-0.1.
|
78
|
-
mplang_nightly-0.1.
|
79
|
-
mplang_nightly-0.1.
|
80
|
-
mplang_nightly-0.1.
|
76
|
+
mplang_nightly-0.1.dev172.dist-info/METADATA,sha256=hjAMloWrHmrUgIZhcHEO268e1Nn2r4Kg52MXkhl927s,16547
|
77
|
+
mplang_nightly-0.1.dev172.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
78
|
+
mplang_nightly-0.1.dev172.dist-info/entry_points.txt,sha256=mG1oJT-GAjQR834a62_QIWb7litzWPPyVnwFqm-rWuY,55
|
79
|
+
mplang_nightly-0.1.dev172.dist-info/licenses/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
|
80
|
+
mplang_nightly-0.1.dev172.dist-info/RECORD,,
|
/mplang/{api.py → host.py}
RENAMED
File without changes
|
File without changes
|
{mplang_nightly-0.1.dev170.dist-info → mplang_nightly-0.1.dev172.dist-info}/entry_points.txt
RENAMED
File without changes
|
{mplang_nightly-0.1.dev170.dist-info → mplang_nightly-0.1.dev172.dist-info}/licenses/LICENSE
RENAMED
File without changes
|