mplang-nightly 0.1.dev158__py3-none-any.whl → 0.1.dev268__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mplang/__init__.py +21 -45
- mplang/py.typed +13 -0
- mplang/v1/__init__.py +157 -0
- mplang/v1/_device.py +602 -0
- mplang/{analysis → v1/analysis}/__init__.py +1 -1
- mplang/{analysis → v1/analysis}/diagram.py +5 -7
- mplang/v1/core/__init__.py +157 -0
- mplang/{core → v1/core}/cluster.py +30 -14
- mplang/{core → v1/core}/comm.py +5 -1
- mplang/{core → v1/core}/context_mgr.py +1 -1
- mplang/{core/dtype.py → v1/core/dtypes.py} +44 -2
- mplang/{core → v1/core}/expr/__init__.py +7 -7
- mplang/{core → v1/core}/expr/ast.py +13 -14
- mplang/{core → v1/core}/expr/evaluator.py +65 -24
- mplang/{core → v1/core}/expr/printer.py +24 -18
- mplang/{core → v1/core}/expr/transformer.py +3 -3
- mplang/{core → v1/core}/expr/utils.py +2 -2
- mplang/{core → v1/core}/expr/visitor.py +1 -1
- mplang/{core → v1/core}/expr/walk.py +1 -1
- mplang/{core → v1/core}/interp.py +6 -6
- mplang/{core → v1/core}/mpir.py +23 -16
- mplang/{core → v1/core}/mpobject.py +6 -6
- mplang/{core → v1/core}/mptype.py +13 -10
- mplang/{core → v1/core}/pfunc.py +4 -4
- mplang/{core → v1/core}/primitive.py +106 -201
- mplang/{core → v1/core}/table.py +36 -8
- mplang/{core → v1/core}/tensor.py +1 -1
- mplang/{core → v1/core}/tracer.py +9 -9
- mplang/{api.py → v1/host.py} +38 -6
- mplang/v1/kernels/__init__.py +41 -0
- mplang/{kernels → v1/kernels}/base.py +1 -1
- mplang/v1/kernels/basic.py +240 -0
- mplang/{kernels → v1/kernels}/context.py +42 -27
- mplang/{kernels → v1/kernels}/crypto.py +44 -37
- mplang/v1/kernels/fhe.py +858 -0
- mplang/{kernels → v1/kernels}/mock_tee.py +12 -13
- mplang/{kernels → v1/kernels}/phe.py +263 -57
- mplang/{kernels → v1/kernels}/spu.py +137 -48
- mplang/{kernels → v1/kernels}/sql_duckdb.py +12 -15
- mplang/{kernels → v1/kernels}/stablehlo.py +30 -23
- mplang/v1/kernels/value.py +626 -0
- mplang/{ops → v1/ops}/__init__.py +5 -16
- mplang/{ops → v1/ops}/base.py +2 -5
- mplang/{ops/builtin.py → v1/ops/basic.py} +34 -26
- mplang/v1/ops/crypto.py +262 -0
- mplang/v1/ops/fhe.py +272 -0
- mplang/{ops → v1/ops}/jax_cc.py +33 -68
- mplang/v1/ops/nnx_cc.py +168 -0
- mplang/{ops → v1/ops}/phe.py +16 -4
- mplang/{ops → v1/ops}/spu.py +3 -5
- mplang/v1/ops/sql_cc.py +303 -0
- mplang/{ops → v1/ops}/tee.py +9 -24
- mplang/{protos → v1/protos}/v1alpha1/mpir_pb2.pyi +71 -21
- mplang/v1/protos/v1alpha1/value_pb2.py +34 -0
- mplang/v1/protos/v1alpha1/value_pb2.pyi +169 -0
- mplang/{runtime → v1/runtime}/__init__.py +2 -2
- mplang/v1/runtime/channel.py +230 -0
- mplang/{runtime → v1/runtime}/cli.py +35 -20
- mplang/{runtime → v1/runtime}/client.py +19 -8
- mplang/{runtime → v1/runtime}/communicator.py +59 -15
- mplang/{runtime → v1/runtime}/data_providers.py +80 -19
- mplang/{runtime → v1/runtime}/driver.py +30 -12
- mplang/v1/runtime/link_comm.py +196 -0
- mplang/{runtime → v1/runtime}/server.py +58 -42
- mplang/{runtime → v1/runtime}/session.py +57 -71
- mplang/{runtime → v1/runtime}/simulation.py +55 -28
- mplang/v1/simp/api.py +353 -0
- mplang/{simp → v1/simp}/mpi.py +8 -9
- mplang/{simp/__init__.py → v1/simp/party.py} +19 -145
- mplang/{simp → v1/simp}/random.py +21 -22
- mplang/v1/simp/smpc.py +238 -0
- mplang/v1/utils/table_utils.py +185 -0
- mplang/v2/__init__.py +424 -0
- mplang/v2/backends/__init__.py +57 -0
- mplang/v2/backends/bfv_impl.py +705 -0
- mplang/v2/backends/channel.py +217 -0
- mplang/v2/backends/crypto_impl.py +723 -0
- mplang/v2/backends/field_impl.py +454 -0
- mplang/v2/backends/func_impl.py +107 -0
- mplang/v2/backends/phe_impl.py +148 -0
- mplang/v2/backends/simp_design.md +136 -0
- mplang/v2/backends/simp_driver/__init__.py +41 -0
- mplang/v2/backends/simp_driver/http.py +168 -0
- mplang/v2/backends/simp_driver/mem.py +280 -0
- mplang/v2/backends/simp_driver/ops.py +135 -0
- mplang/v2/backends/simp_driver/state.py +60 -0
- mplang/v2/backends/simp_driver/values.py +52 -0
- mplang/v2/backends/simp_worker/__init__.py +29 -0
- mplang/v2/backends/simp_worker/http.py +354 -0
- mplang/v2/backends/simp_worker/mem.py +102 -0
- mplang/v2/backends/simp_worker/ops.py +167 -0
- mplang/v2/backends/simp_worker/state.py +49 -0
- mplang/v2/backends/spu_impl.py +275 -0
- mplang/v2/backends/spu_state.py +187 -0
- mplang/v2/backends/store_impl.py +62 -0
- mplang/v2/backends/table_impl.py +838 -0
- mplang/v2/backends/tee_impl.py +215 -0
- mplang/v2/backends/tensor_impl.py +519 -0
- mplang/v2/cli.py +603 -0
- mplang/v2/cli_guide.md +122 -0
- mplang/v2/dialects/__init__.py +36 -0
- mplang/v2/dialects/bfv.py +665 -0
- mplang/v2/dialects/crypto.py +689 -0
- mplang/v2/dialects/dtypes.py +378 -0
- mplang/v2/dialects/field.py +210 -0
- mplang/v2/dialects/func.py +135 -0
- mplang/v2/dialects/phe.py +723 -0
- mplang/v2/dialects/simp.py +944 -0
- mplang/v2/dialects/spu.py +349 -0
- mplang/v2/dialects/store.py +63 -0
- mplang/v2/dialects/table.py +407 -0
- mplang/v2/dialects/tee.py +346 -0
- mplang/v2/dialects/tensor.py +1175 -0
- mplang/v2/edsl/README.md +279 -0
- mplang/v2/edsl/__init__.py +99 -0
- mplang/v2/edsl/context.py +311 -0
- mplang/v2/edsl/graph.py +463 -0
- mplang/v2/edsl/jit.py +62 -0
- mplang/v2/edsl/object.py +53 -0
- mplang/v2/edsl/primitive.py +284 -0
- mplang/v2/edsl/printer.py +119 -0
- mplang/v2/edsl/registry.py +207 -0
- mplang/v2/edsl/serde.py +375 -0
- mplang/v2/edsl/tracer.py +614 -0
- mplang/v2/edsl/typing.py +816 -0
- mplang/v2/kernels/Makefile +30 -0
- mplang/v2/kernels/__init__.py +23 -0
- mplang/v2/kernels/gf128.cpp +148 -0
- mplang/v2/kernels/ldpc.cpp +82 -0
- mplang/v2/kernels/okvs.cpp +283 -0
- mplang/v2/kernels/okvs_opt.cpp +291 -0
- mplang/v2/kernels/py_kernels.py +398 -0
- mplang/v2/libs/collective.py +330 -0
- mplang/v2/libs/device/__init__.py +51 -0
- mplang/v2/libs/device/api.py +813 -0
- mplang/v2/libs/device/cluster.py +352 -0
- mplang/v2/libs/ml/__init__.py +23 -0
- mplang/v2/libs/ml/sgb.py +1861 -0
- mplang/v2/libs/mpc/__init__.py +41 -0
- mplang/v2/libs/mpc/_utils.py +99 -0
- mplang/v2/libs/mpc/analytics/__init__.py +35 -0
- mplang/v2/libs/mpc/analytics/aggregation.py +372 -0
- mplang/v2/libs/mpc/analytics/groupby.md +99 -0
- mplang/v2/libs/mpc/analytics/groupby.py +331 -0
- mplang/v2/libs/mpc/analytics/permutation.py +386 -0
- mplang/v2/libs/mpc/common/constants.py +39 -0
- mplang/v2/libs/mpc/ot/__init__.py +32 -0
- mplang/v2/libs/mpc/ot/base.py +222 -0
- mplang/v2/libs/mpc/ot/extension.py +477 -0
- mplang/v2/libs/mpc/ot/silent.py +217 -0
- mplang/v2/libs/mpc/psi/__init__.py +40 -0
- mplang/v2/libs/mpc/psi/cuckoo.py +228 -0
- mplang/v2/libs/mpc/psi/okvs.py +49 -0
- mplang/v2/libs/mpc/psi/okvs_gct.py +79 -0
- mplang/v2/libs/mpc/psi/oprf.py +310 -0
- mplang/v2/libs/mpc/psi/rr22.py +344 -0
- mplang/v2/libs/mpc/psi/unbalanced.py +200 -0
- mplang/v2/libs/mpc/vole/__init__.py +31 -0
- mplang/v2/libs/mpc/vole/gilboa.py +327 -0
- mplang/v2/libs/mpc/vole/ldpc.py +383 -0
- mplang/v2/libs/mpc/vole/silver.py +336 -0
- mplang/v2/runtime/__init__.py +15 -0
- mplang/v2/runtime/dialect_state.py +41 -0
- mplang/v2/runtime/interpreter.py +871 -0
- mplang/v2/runtime/object_store.py +194 -0
- mplang/v2/runtime/value.py +141 -0
- {mplang_nightly-0.1.dev158.dist-info → mplang_nightly-0.1.dev268.dist-info}/METADATA +24 -17
- mplang_nightly-0.1.dev268.dist-info/RECORD +180 -0
- {mplang_nightly-0.1.dev158.dist-info → mplang_nightly-0.1.dev268.dist-info}/WHEEL +1 -1
- mplang/core/__init__.py +0 -92
- mplang/device.py +0 -340
- mplang/kernels/builtin.py +0 -207
- mplang/ops/crypto.py +0 -109
- mplang/ops/ibis_cc.py +0 -139
- mplang/ops/sql.py +0 -61
- mplang/protos/v1alpha1/mpir_pb2_grpc.py +0 -3
- mplang/runtime/link_comm.py +0 -131
- mplang/simp/smpc.py +0 -201
- mplang/utils/table_utils.py +0 -73
- mplang_nightly-0.1.dev158.dist-info/RECORD +0 -77
- /mplang/{core → v1/core}/mask.py +0 -0
- /mplang/{protos → v1/protos}/v1alpha1/mpir_pb2.py +0 -0
- /mplang/{runtime → v1/runtime}/exceptions.py +0 -0
- /mplang/{runtime → v1/runtime}/http_api.md +0 -0
- /mplang/{kernels → v1/simp}/__init__.py +0 -0
- /mplang/{utils → v1/utils}/__init__.py +0 -0
- /mplang/{utils → v1/utils}/crypto.py +0 -0
- /mplang/{utils → v1/utils}/func_utils.py +0 -0
- /mplang/{utils → v1/utils}/spu_utils.py +0 -0
- {mplang_nightly-0.1.dev158.dist-info → mplang_nightly-0.1.dev268.dist-info}/entry_points.txt +0 -0
- {mplang_nightly-0.1.dev158.dist-info → mplang_nightly-0.1.dev268.dist-info}/licenses/LICENSE +0 -0
|
@@ -20,8 +20,8 @@ from __future__ import annotations
|
|
|
20
20
|
|
|
21
21
|
from typing import Any
|
|
22
22
|
|
|
23
|
-
from mplang.core.
|
|
24
|
-
from mplang.core.expr.ast import (
|
|
23
|
+
from mplang.v1.core.dtypes import DType
|
|
24
|
+
from mplang.v1.core.expr.ast import (
|
|
25
25
|
AccessExpr,
|
|
26
26
|
CallExpr,
|
|
27
27
|
CondExpr,
|
|
@@ -35,10 +35,10 @@ from mplang.core.expr.ast import (
|
|
|
35
35
|
VariableExpr,
|
|
36
36
|
WhileExpr,
|
|
37
37
|
)
|
|
38
|
-
from mplang.core.expr.visitor import ExprVisitor
|
|
39
|
-
from mplang.core.mptype import MPType
|
|
40
|
-
from mplang.core.pfunc import PFunction
|
|
41
|
-
from mplang.core.tensor import Shape, TensorType
|
|
38
|
+
from mplang.v1.core.expr.visitor import ExprVisitor
|
|
39
|
+
from mplang.v1.core.mptype import MPType
|
|
40
|
+
from mplang.v1.core.pfunc import PFunction
|
|
41
|
+
from mplang.v1.core.tensor import Shape, TensorType
|
|
42
42
|
|
|
43
43
|
|
|
44
44
|
class Printer(ExprVisitor):
|
|
@@ -50,11 +50,13 @@ class Printer(ExprVisitor):
|
|
|
50
50
|
compact_format: bool = True,
|
|
51
51
|
*,
|
|
52
52
|
verbose_peval: bool = False,
|
|
53
|
+
inline_pcall: bool = True,
|
|
53
54
|
):
|
|
54
55
|
super().__init__() # Initialize MemorizedVisitor
|
|
55
56
|
self.indent_size = indent_size
|
|
56
57
|
self.compact_format = compact_format
|
|
57
58
|
self.verbose_peval = verbose_peval
|
|
59
|
+
self.inline_pcall = inline_pcall
|
|
58
60
|
self._cur_indent = 0
|
|
59
61
|
self._output: list[str] = []
|
|
60
62
|
self._visited: dict[Expr, str] = {}
|
|
@@ -92,6 +94,7 @@ class Printer(ExprVisitor):
|
|
|
92
94
|
body_printer = Printer(
|
|
93
95
|
indent_size=self.indent_size,
|
|
94
96
|
compact_format=self.compact_format,
|
|
97
|
+
inline_pcall=self.inline_pcall,
|
|
95
98
|
)
|
|
96
99
|
func_def_expr.accept(body_printer)
|
|
97
100
|
regions_str += f"{indent}{r_name}: "
|
|
@@ -161,13 +164,9 @@ class Printer(ExprVisitor):
|
|
|
161
164
|
arg_names = [self._var_name(arg) for arg in expr.args]
|
|
162
165
|
fn_type = expr.pfunc.fn_type
|
|
163
166
|
|
|
164
|
-
# for well known
|
|
165
|
-
if fn_type == "
|
|
167
|
+
# for well known basic functions
|
|
168
|
+
if fn_type == "basic.constant":
|
|
166
169
|
return self._print_const(expr.pfunc, expr.mptypes)
|
|
167
|
-
elif fn_type == "builtin.rank":
|
|
168
|
-
return self._do_print("prank", [], mptypes=expr.mptypes)
|
|
169
|
-
elif fn_type == "builtin.prand":
|
|
170
|
-
return self._do_print("prand", [], mptypes=expr.mptypes)
|
|
171
170
|
|
|
172
171
|
attrs = {"fn_type": fn_type}
|
|
173
172
|
if expr.pfunc.fn_name:
|
|
@@ -209,12 +208,19 @@ class Printer(ExprVisitor):
|
|
|
209
208
|
|
|
210
209
|
def visit_call(self, expr: CallExpr) -> str:
|
|
211
210
|
arg_names = [self._var_name(arg) for arg in expr.args]
|
|
212
|
-
|
|
213
|
-
|
|
214
|
-
|
|
215
|
-
|
|
216
|
-
|
|
217
|
-
|
|
211
|
+
if self.inline_pcall:
|
|
212
|
+
return self._do_print(
|
|
213
|
+
expr.name,
|
|
214
|
+
arg_names,
|
|
215
|
+
mptypes=expr.mptypes,
|
|
216
|
+
)
|
|
217
|
+
else:
|
|
218
|
+
return self._do_print(
|
|
219
|
+
"pcall",
|
|
220
|
+
arg_names,
|
|
221
|
+
regions={"fn": expr.fn},
|
|
222
|
+
mptypes=expr.mptypes,
|
|
223
|
+
)
|
|
218
224
|
|
|
219
225
|
def visit_while(self, expr: WhileExpr) -> str:
|
|
220
226
|
arg_names = [self._var_name(arg) for arg in expr.args]
|
|
@@ -18,7 +18,7 @@ Expression transformer based on visitor pattern.
|
|
|
18
18
|
|
|
19
19
|
from collections.abc import Callable
|
|
20
20
|
|
|
21
|
-
from mplang.core.expr.ast import (
|
|
21
|
+
from mplang.v1.core.expr.ast import (
|
|
22
22
|
AccessExpr,
|
|
23
23
|
CallExpr,
|
|
24
24
|
CondExpr,
|
|
@@ -32,7 +32,7 @@ from mplang.core.expr.ast import (
|
|
|
32
32
|
VariableExpr,
|
|
33
33
|
WhileExpr,
|
|
34
34
|
)
|
|
35
|
-
from mplang.core.expr.visitor import ExprVisitor
|
|
35
|
+
from mplang.v1.core.expr.visitor import ExprVisitor
|
|
36
36
|
|
|
37
37
|
|
|
38
38
|
class ExprTransformer(ExprVisitor):
|
|
@@ -79,7 +79,7 @@ class ExprTransformer(ExprVisitor):
|
|
|
79
79
|
def visit_call(self, expr: CallExpr) -> Expr:
|
|
80
80
|
# Transform child expressions first
|
|
81
81
|
transformed_args = [arg.accept(self) for arg in expr.args]
|
|
82
|
-
new_expr = CallExpr(expr.fn, transformed_args)
|
|
82
|
+
new_expr = CallExpr(expr.name, expr.fn, transformed_args)
|
|
83
83
|
|
|
84
84
|
if "call" in self.trans_rules:
|
|
85
85
|
return self.trans_rules["call"](new_expr)
|
|
@@ -18,8 +18,8 @@ Utility functions for expression system.
|
|
|
18
18
|
|
|
19
19
|
from collections.abc import Sequence
|
|
20
20
|
|
|
21
|
-
from mplang.core.mask import Mask
|
|
22
|
-
from mplang.core.mptype import TensorLike
|
|
21
|
+
from mplang.v1.core.mask import Mask
|
|
22
|
+
from mplang.v1.core.mptype import TensorLike
|
|
23
23
|
|
|
24
24
|
|
|
25
25
|
def type_equal(*args: TensorLike) -> bool:
|
|
@@ -25,12 +25,12 @@ from abc import abstractmethod
|
|
|
25
25
|
from collections.abc import Sequence
|
|
26
26
|
from typing import Any, cast
|
|
27
27
|
|
|
28
|
-
from mplang.core.cluster import ClusterSpec
|
|
29
|
-
from mplang.core.expr.ast import Expr, VariableExpr
|
|
30
|
-
from mplang.core.mpobject import MPContext, MPObject
|
|
31
|
-
from mplang.core.mptype import MPType, TensorLike
|
|
32
|
-
from mplang.core.tracer import TracedFunction
|
|
33
|
-
from mplang.utils.func_utils import var_demorph, var_morph
|
|
28
|
+
from mplang.v1.core.cluster import ClusterSpec
|
|
29
|
+
from mplang.v1.core.expr.ast import Expr, VariableExpr
|
|
30
|
+
from mplang.v1.core.mpobject import MPContext, MPObject
|
|
31
|
+
from mplang.v1.core.mptype import MPType, TensorLike
|
|
32
|
+
from mplang.v1.core.tracer import TracedFunction
|
|
33
|
+
from mplang.v1.utils.func_utils import var_demorph, var_morph
|
|
34
34
|
|
|
35
35
|
|
|
36
36
|
# TODO(jint): Should we use inheritance or composition here?
|
mplang/{core → v1/core}/mpir.py
RENAMED
|
@@ -32,9 +32,9 @@ from typing import Any
|
|
|
32
32
|
import numpy as np
|
|
33
33
|
import spu.libspu as spu_api
|
|
34
34
|
|
|
35
|
-
from mplang.core.
|
|
36
|
-
from mplang.core.expr import Expr, FuncDefExpr
|
|
37
|
-
from mplang.core.expr.ast import (
|
|
35
|
+
from mplang.v1.core.dtypes import DATE, JSON, STRING, TIME, TIMESTAMP, DType
|
|
36
|
+
from mplang.v1.core.expr import Expr, FuncDefExpr
|
|
37
|
+
from mplang.v1.core.expr.ast import (
|
|
38
38
|
AccessExpr,
|
|
39
39
|
CallExpr,
|
|
40
40
|
CondExpr,
|
|
@@ -46,13 +46,13 @@ from mplang.core.expr.ast import (
|
|
|
46
46
|
VariableExpr,
|
|
47
47
|
WhileExpr,
|
|
48
48
|
)
|
|
49
|
-
from mplang.core.expr.walk import walk
|
|
50
|
-
from mplang.core.mask import Mask
|
|
51
|
-
from mplang.core.mptype import MPType
|
|
52
|
-
from mplang.core.pfunc import PFunction
|
|
53
|
-
from mplang.core.table import TableType
|
|
54
|
-
from mplang.core.tensor import TensorType
|
|
55
|
-
from mplang.protos.v1alpha1 import mpir_pb2
|
|
49
|
+
from mplang.v1.core.expr.walk import walk
|
|
50
|
+
from mplang.v1.core.mask import Mask
|
|
51
|
+
from mplang.v1.core.mptype import MPType
|
|
52
|
+
from mplang.v1.core.pfunc import PFunction
|
|
53
|
+
from mplang.v1.core.table import TableType
|
|
54
|
+
from mplang.v1.core.tensor import TensorType
|
|
55
|
+
from mplang.v1.protos.v1alpha1 import mpir_pb2
|
|
56
56
|
|
|
57
57
|
# Single mapping table for dtype conversion
|
|
58
58
|
DTYPE_MAPPING = {
|
|
@@ -204,7 +204,7 @@ def attr_to_proto(py_value: Any) -> mpir_pb2.AttrProto:
|
|
|
204
204
|
raise TypeError(f"Unsupported tuple/list type: {type(py_value)}")
|
|
205
205
|
elif isinstance(py_value, FuncDefExpr):
|
|
206
206
|
# Convert FuncDefExpr to GraphProto
|
|
207
|
-
graph =
|
|
207
|
+
graph = IrWriter().dumps(py_value)
|
|
208
208
|
attr_proto.type = mpir_pb2.AttrProto.GRAPH
|
|
209
209
|
attr_proto.graph.CopyFrom(graph)
|
|
210
210
|
elif isinstance(py_value, PFunction):
|
|
@@ -217,7 +217,9 @@ def attr_to_proto(py_value: Any) -> mpir_pb2.AttrProto:
|
|
|
217
217
|
# Serialize attrs dictionary
|
|
218
218
|
if py_value.attrs:
|
|
219
219
|
for attr_name, attr_value in py_value.attrs.items():
|
|
220
|
-
|
|
220
|
+
# Skip None-valued attributes to align with top-level attr handling
|
|
221
|
+
if attr_value is not None:
|
|
222
|
+
attr_proto.func.attrs[attr_name].CopyFrom(attr_to_proto(attr_value))
|
|
221
223
|
|
|
222
224
|
# Note: We don't serialize ins_info and outs_info since they can be
|
|
223
225
|
# inferred from the input expressions during deserialization
|
|
@@ -234,7 +236,7 @@ def attr_to_proto(py_value: Any) -> mpir_pb2.AttrProto:
|
|
|
234
236
|
return attr_proto
|
|
235
237
|
|
|
236
238
|
|
|
237
|
-
class
|
|
239
|
+
class IrWriter:
|
|
238
240
|
"""Writer for serializing Expr-based expressions to GraphProto.
|
|
239
241
|
|
|
240
242
|
This class traverses an expression tree and converts it into a serialized
|
|
@@ -491,6 +493,7 @@ class Writer:
|
|
|
491
493
|
op = self._create_node_proto(expr, "call")
|
|
492
494
|
self._add_single_expr_inputs(op, expr.fn)
|
|
493
495
|
self._add_expr_inputs(op, *expr.args)
|
|
496
|
+
self._add_attrs(op, name=expr.name)
|
|
494
497
|
self._finalize_node(op, expr)
|
|
495
498
|
elif isinstance(expr, WhileExpr):
|
|
496
499
|
op = self._create_node_proto(expr, "while")
|
|
@@ -524,7 +527,7 @@ class Writer:
|
|
|
524
527
|
raise TypeError(f"Unsupported expr type for serialization: {type(expr)}")
|
|
525
528
|
|
|
526
529
|
|
|
527
|
-
class
|
|
530
|
+
class IrReader:
|
|
528
531
|
"""Reader for deserializing GraphProto back to Expr-based expressions.
|
|
529
532
|
|
|
530
533
|
This class is responsible for converting serialized GraphProto representations
|
|
@@ -822,8 +825,12 @@ class Reader:
|
|
|
822
825
|
arg_exprs.append(self._value_cache[dep_name])
|
|
823
826
|
else:
|
|
824
827
|
raise ValueError(f"Input {input_name} not found for call node")
|
|
828
|
+
# Optional call-site name attribute
|
|
829
|
+
call_name = None
|
|
830
|
+
if "name" in node_proto.attrs:
|
|
831
|
+
call_name = self._proto_to_attr(node_proto.attrs["name"]) # type: ignore[assignment]
|
|
825
832
|
|
|
826
|
-
return CallExpr(fn_expr, arg_exprs)
|
|
833
|
+
return CallExpr(call_name or "", fn_expr, arg_exprs)
|
|
827
834
|
|
|
828
835
|
def _proto_to_mptype(self, type_proto: mpir_pb2.MPTypeProto) -> MPType:
|
|
829
836
|
"""Convert MPTypeProto to MPType."""
|
|
@@ -897,7 +904,7 @@ class Reader:
|
|
|
897
904
|
)
|
|
898
905
|
elif attr_proto.type == mpir_pb2.AttrProto.GRAPH:
|
|
899
906
|
# Handle nested expressions (for control flow)
|
|
900
|
-
reader =
|
|
907
|
+
reader = IrReader()
|
|
901
908
|
return reader.loads(attr_proto.graph)
|
|
902
909
|
else:
|
|
903
910
|
raise TypeError(f"Unsupported attribute type: {attr_proto.type}")
|
|
@@ -17,14 +17,14 @@ from __future__ import annotations
|
|
|
17
17
|
from abc import ABC, abstractmethod
|
|
18
18
|
from typing import TYPE_CHECKING, Any
|
|
19
19
|
|
|
20
|
-
from mplang.core.
|
|
21
|
-
from mplang.core.mask import Mask
|
|
22
|
-
from mplang.core.mptype import MPType
|
|
23
|
-
from mplang.core.table import TableType
|
|
24
|
-
from mplang.core.tensor import Shape
|
|
20
|
+
from mplang.v1.core.dtypes import DType
|
|
21
|
+
from mplang.v1.core.mask import Mask
|
|
22
|
+
from mplang.v1.core.mptype import MPType
|
|
23
|
+
from mplang.v1.core.table import TableType
|
|
24
|
+
from mplang.v1.core.tensor import Shape
|
|
25
25
|
|
|
26
26
|
if TYPE_CHECKING:
|
|
27
|
-
from mplang.core.cluster import ClusterSpec
|
|
27
|
+
from mplang.v1.core.cluster import ClusterSpec
|
|
28
28
|
|
|
29
29
|
|
|
30
30
|
class MPContext:
|
|
@@ -20,12 +20,12 @@ from typing import TYPE_CHECKING, Any
|
|
|
20
20
|
import numpy as np
|
|
21
21
|
|
|
22
22
|
if TYPE_CHECKING:
|
|
23
|
-
from mplang.core.mpobject import MPObject
|
|
23
|
+
from mplang.v1.core.mpobject import MPObject
|
|
24
24
|
|
|
25
|
-
from mplang.core.
|
|
26
|
-
from mplang.core.mask import Mask
|
|
27
|
-
from mplang.core.table import TableLike, TableType
|
|
28
|
-
from mplang.core.tensor import ScalarType, Shape, TensorLike, TensorType
|
|
25
|
+
from mplang.v1.core.dtypes import STRING, DType
|
|
26
|
+
from mplang.v1.core.mask import Mask
|
|
27
|
+
from mplang.v1.core.table import TableLike, TableType
|
|
28
|
+
from mplang.v1.core.tensor import ScalarType, Shape, TensorLike, TensorType
|
|
29
29
|
|
|
30
30
|
# basic type aliases
|
|
31
31
|
Rank = int
|
|
@@ -195,6 +195,10 @@ class MPType:
|
|
|
195
195
|
information about the object."""
|
|
196
196
|
return self._attrs
|
|
197
197
|
|
|
198
|
+
def raw_type(self) -> TensorType | TableType:
|
|
199
|
+
"""Get the raw type information (TensorType or TableType)."""
|
|
200
|
+
return self._type
|
|
201
|
+
|
|
198
202
|
def set_attr(self, key: str, value: Any) -> None:
|
|
199
203
|
"""Set an attribute for this type."""
|
|
200
204
|
self._attrs[key] = value
|
|
@@ -252,9 +256,8 @@ class MPType:
|
|
|
252
256
|
if not isinstance(other, MPType):
|
|
253
257
|
return False
|
|
254
258
|
return (
|
|
255
|
-
self._type == other._type
|
|
256
|
-
and self.
|
|
257
|
-
and self._attrs == other._attrs
|
|
259
|
+
self._type == other._type and self._pmask == other._pmask
|
|
260
|
+
# and self._attrs == other._attrs # TODO(jint): attrs should be optional
|
|
258
261
|
)
|
|
259
262
|
|
|
260
263
|
def __hash__(self) -> int:
|
|
@@ -270,7 +273,7 @@ class MPType:
|
|
|
270
273
|
def isInstance(self, obj: MPObject) -> bool:
|
|
271
274
|
"""Check if the given object is an instance of this MPType."""
|
|
272
275
|
# Import here to avoid circular import
|
|
273
|
-
from mplang.core.mpobject import MPObject
|
|
276
|
+
from mplang.v1.core.mpobject import MPObject
|
|
274
277
|
|
|
275
278
|
if not isinstance(obj, MPObject):
|
|
276
279
|
return False
|
|
@@ -373,7 +376,7 @@ class MPType:
|
|
|
373
376
|
import pandas as pd
|
|
374
377
|
|
|
375
378
|
if isinstance(obj, pd.DataFrame):
|
|
376
|
-
from mplang.core.
|
|
379
|
+
from mplang.v1.core.dtypes import DType
|
|
377
380
|
|
|
378
381
|
schema_dict = {}
|
|
379
382
|
for col_name in obj.columns:
|
mplang/{core → v1/core}/pfunc.py
RENAMED
|
@@ -19,8 +19,8 @@ from collections.abc import Sequence
|
|
|
19
19
|
from types import MappingProxyType
|
|
20
20
|
from typing import Any
|
|
21
21
|
|
|
22
|
-
from mplang.core.table import TableType
|
|
23
|
-
from mplang.core.tensor import TensorType
|
|
22
|
+
from mplang.v1.core.table import TableType
|
|
23
|
+
from mplang.v1.core.tensor import TensorType
|
|
24
24
|
|
|
25
25
|
__all__ = [
|
|
26
26
|
"PFunction",
|
|
@@ -33,7 +33,7 @@ class PFunction:
|
|
|
33
33
|
|
|
34
34
|
PFunction serves as a unified interface for describing single-party computations
|
|
35
35
|
in multi-party computing scenarios. It can represent both:
|
|
36
|
-
1. Built-in operations (e.g., "spu.makeshares", "
|
|
36
|
+
1. Built-in operations (e.g., "spu.makeshares", "basic.read")
|
|
37
37
|
2. User-defined programmable functions with custom code
|
|
38
38
|
|
|
39
39
|
The PFunction accepts a list of typed inputs (TensorType/TableType). For
|
|
@@ -47,7 +47,7 @@ class PFunction:
|
|
|
47
47
|
|
|
48
48
|
Args:
|
|
49
49
|
fn_type: The type/category identifier of this PFunction, indicating which
|
|
50
|
-
backend or handler should process it (e.g., "spu.makeshares", "
|
|
50
|
+
backend or handler should process it (e.g., "spu.makeshares", "basic.read",
|
|
51
51
|
"mlir.stablehlo"). This serves as a routing mechanism for execution.
|
|
52
52
|
ins_info: Type information for input parameters (TensorType or TableType)
|
|
53
53
|
outs_info: Type information for output values (TensorType or TableType)
|