mplang-nightly 0.1.dev192__py3-none-any.whl → 0.1.dev268__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mplang/__init__.py +21 -130
- mplang/py.typed +13 -0
- mplang/v1/__init__.py +157 -0
- mplang/v1/_device.py +602 -0
- mplang/{analysis → v1/analysis}/__init__.py +1 -1
- mplang/{analysis → v1/analysis}/diagram.py +4 -4
- mplang/{core → v1/core}/__init__.py +20 -14
- mplang/{core → v1/core}/cluster.py +6 -1
- mplang/{core → v1/core}/comm.py +1 -1
- mplang/{core → v1/core}/context_mgr.py +1 -1
- mplang/{core → v1/core}/dtypes.py +38 -0
- mplang/{core → v1/core}/expr/__init__.py +7 -7
- mplang/{core → v1/core}/expr/ast.py +11 -13
- mplang/{core → v1/core}/expr/evaluator.py +8 -8
- mplang/{core → v1/core}/expr/printer.py +6 -6
- mplang/{core → v1/core}/expr/transformer.py +2 -2
- mplang/{core → v1/core}/expr/utils.py +2 -2
- mplang/{core → v1/core}/expr/visitor.py +1 -1
- mplang/{core → v1/core}/expr/walk.py +1 -1
- mplang/{core → v1/core}/interp.py +6 -6
- mplang/{core → v1/core}/mpir.py +13 -11
- mplang/{core → v1/core}/mpobject.py +6 -6
- mplang/{core → v1/core}/mptype.py +13 -10
- mplang/{core → v1/core}/pfunc.py +2 -2
- mplang/{core → v1/core}/primitive.py +12 -12
- mplang/{core → v1/core}/table.py +36 -8
- mplang/{core → v1/core}/tensor.py +1 -1
- mplang/{core → v1/core}/tracer.py +9 -9
- mplang/{host.py → v1/host.py} +5 -5
- mplang/{kernels → v1/kernels}/__init__.py +1 -1
- mplang/{kernels → v1/kernels}/base.py +1 -1
- mplang/{kernels → v1/kernels}/basic.py +15 -15
- mplang/{kernels → v1/kernels}/context.py +19 -16
- mplang/{kernels → v1/kernels}/crypto.py +8 -10
- mplang/{kernels → v1/kernels}/fhe.py +9 -7
- mplang/{kernels → v1/kernels}/mock_tee.py +3 -3
- mplang/{kernels → v1/kernels}/phe.py +26 -18
- mplang/{kernels → v1/kernels}/spu.py +5 -5
- mplang/{kernels → v1/kernels}/sql_duckdb.py +5 -3
- mplang/{kernels → v1/kernels}/stablehlo.py +18 -17
- mplang/{kernels → v1/kernels}/value.py +2 -2
- mplang/{ops → v1/ops}/__init__.py +3 -3
- mplang/{ops → v1/ops}/base.py +1 -1
- mplang/{ops → v1/ops}/basic.py +6 -5
- mplang/v1/ops/crypto.py +262 -0
- mplang/{ops → v1/ops}/fhe.py +2 -2
- mplang/{ops → v1/ops}/jax_cc.py +26 -59
- mplang/v1/ops/nnx_cc.py +168 -0
- mplang/{ops → v1/ops}/phe.py +16 -3
- mplang/{ops → v1/ops}/spu.py +3 -3
- mplang/v1/ops/sql_cc.py +303 -0
- mplang/{ops → v1/ops}/tee.py +2 -2
- mplang/{runtime → v1/runtime}/__init__.py +2 -2
- mplang/v1/runtime/channel.py +230 -0
- mplang/{runtime → v1/runtime}/cli.py +3 -3
- mplang/{runtime → v1/runtime}/client.py +1 -1
- mplang/{runtime → v1/runtime}/communicator.py +39 -15
- mplang/{runtime → v1/runtime}/data_providers.py +80 -19
- mplang/{runtime → v1/runtime}/driver.py +4 -4
- mplang/v1/runtime/link_comm.py +196 -0
- mplang/{runtime → v1/runtime}/server.py +22 -9
- mplang/{runtime → v1/runtime}/session.py +24 -51
- mplang/{runtime → v1/runtime}/simulation.py +36 -14
- mplang/{simp → v1/simp}/api.py +72 -14
- mplang/{simp → v1/simp}/mpi.py +1 -1
- mplang/{simp → v1/simp}/party.py +5 -5
- mplang/{simp → v1/simp}/random.py +2 -2
- mplang/v1/simp/smpc.py +238 -0
- mplang/v1/utils/table_utils.py +185 -0
- mplang/v2/__init__.py +424 -0
- mplang/v2/backends/__init__.py +57 -0
- mplang/v2/backends/bfv_impl.py +705 -0
- mplang/v2/backends/channel.py +217 -0
- mplang/v2/backends/crypto_impl.py +723 -0
- mplang/v2/backends/field_impl.py +454 -0
- mplang/v2/backends/func_impl.py +107 -0
- mplang/v2/backends/phe_impl.py +148 -0
- mplang/v2/backends/simp_design.md +136 -0
- mplang/v2/backends/simp_driver/__init__.py +41 -0
- mplang/v2/backends/simp_driver/http.py +168 -0
- mplang/v2/backends/simp_driver/mem.py +280 -0
- mplang/v2/backends/simp_driver/ops.py +135 -0
- mplang/v2/backends/simp_driver/state.py +60 -0
- mplang/v2/backends/simp_driver/values.py +52 -0
- mplang/v2/backends/simp_worker/__init__.py +29 -0
- mplang/v2/backends/simp_worker/http.py +354 -0
- mplang/v2/backends/simp_worker/mem.py +102 -0
- mplang/v2/backends/simp_worker/ops.py +167 -0
- mplang/v2/backends/simp_worker/state.py +49 -0
- mplang/v2/backends/spu_impl.py +275 -0
- mplang/v2/backends/spu_state.py +187 -0
- mplang/v2/backends/store_impl.py +62 -0
- mplang/v2/backends/table_impl.py +838 -0
- mplang/v2/backends/tee_impl.py +215 -0
- mplang/v2/backends/tensor_impl.py +519 -0
- mplang/v2/cli.py +603 -0
- mplang/v2/cli_guide.md +122 -0
- mplang/v2/dialects/__init__.py +36 -0
- mplang/v2/dialects/bfv.py +665 -0
- mplang/v2/dialects/crypto.py +689 -0
- mplang/v2/dialects/dtypes.py +378 -0
- mplang/v2/dialects/field.py +210 -0
- mplang/v2/dialects/func.py +135 -0
- mplang/v2/dialects/phe.py +723 -0
- mplang/v2/dialects/simp.py +944 -0
- mplang/v2/dialects/spu.py +349 -0
- mplang/v2/dialects/store.py +63 -0
- mplang/v2/dialects/table.py +407 -0
- mplang/v2/dialects/tee.py +346 -0
- mplang/v2/dialects/tensor.py +1175 -0
- mplang/v2/edsl/README.md +279 -0
- mplang/v2/edsl/__init__.py +99 -0
- mplang/v2/edsl/context.py +311 -0
- mplang/v2/edsl/graph.py +463 -0
- mplang/v2/edsl/jit.py +62 -0
- mplang/v2/edsl/object.py +53 -0
- mplang/v2/edsl/primitive.py +284 -0
- mplang/v2/edsl/printer.py +119 -0
- mplang/v2/edsl/registry.py +207 -0
- mplang/v2/edsl/serde.py +375 -0
- mplang/v2/edsl/tracer.py +614 -0
- mplang/v2/edsl/typing.py +816 -0
- mplang/v2/kernels/Makefile +30 -0
- mplang/v2/kernels/__init__.py +23 -0
- mplang/v2/kernels/gf128.cpp +148 -0
- mplang/v2/kernels/ldpc.cpp +82 -0
- mplang/v2/kernels/okvs.cpp +283 -0
- mplang/v2/kernels/okvs_opt.cpp +291 -0
- mplang/v2/kernels/py_kernels.py +398 -0
- mplang/v2/libs/collective.py +330 -0
- mplang/v2/libs/device/__init__.py +51 -0
- mplang/v2/libs/device/api.py +813 -0
- mplang/v2/libs/device/cluster.py +352 -0
- mplang/v2/libs/ml/__init__.py +23 -0
- mplang/v2/libs/ml/sgb.py +1861 -0
- mplang/v2/libs/mpc/__init__.py +41 -0
- mplang/v2/libs/mpc/_utils.py +99 -0
- mplang/v2/libs/mpc/analytics/__init__.py +35 -0
- mplang/v2/libs/mpc/analytics/aggregation.py +372 -0
- mplang/v2/libs/mpc/analytics/groupby.md +99 -0
- mplang/v2/libs/mpc/analytics/groupby.py +331 -0
- mplang/v2/libs/mpc/analytics/permutation.py +386 -0
- mplang/v2/libs/mpc/common/constants.py +39 -0
- mplang/v2/libs/mpc/ot/__init__.py +32 -0
- mplang/v2/libs/mpc/ot/base.py +222 -0
- mplang/v2/libs/mpc/ot/extension.py +477 -0
- mplang/v2/libs/mpc/ot/silent.py +217 -0
- mplang/v2/libs/mpc/psi/__init__.py +40 -0
- mplang/v2/libs/mpc/psi/cuckoo.py +228 -0
- mplang/v2/libs/mpc/psi/okvs.py +49 -0
- mplang/v2/libs/mpc/psi/okvs_gct.py +79 -0
- mplang/v2/libs/mpc/psi/oprf.py +310 -0
- mplang/v2/libs/mpc/psi/rr22.py +344 -0
- mplang/v2/libs/mpc/psi/unbalanced.py +200 -0
- mplang/v2/libs/mpc/vole/__init__.py +31 -0
- mplang/v2/libs/mpc/vole/gilboa.py +327 -0
- mplang/v2/libs/mpc/vole/ldpc.py +383 -0
- mplang/v2/libs/mpc/vole/silver.py +336 -0
- mplang/v2/runtime/__init__.py +15 -0
- mplang/v2/runtime/dialect_state.py +41 -0
- mplang/v2/runtime/interpreter.py +871 -0
- mplang/v2/runtime/object_store.py +194 -0
- mplang/v2/runtime/value.py +141 -0
- {mplang_nightly-0.1.dev192.dist-info → mplang_nightly-0.1.dev268.dist-info}/METADATA +22 -16
- mplang_nightly-0.1.dev268.dist-info/RECORD +180 -0
- {mplang_nightly-0.1.dev192.dist-info → mplang_nightly-0.1.dev268.dist-info}/WHEEL +1 -1
- mplang/device.py +0 -327
- mplang/ops/crypto.py +0 -108
- mplang/ops/ibis_cc.py +0 -136
- mplang/ops/sql_cc.py +0 -62
- mplang/runtime/link_comm.py +0 -78
- mplang/simp/smpc.py +0 -201
- mplang/utils/table_utils.py +0 -85
- mplang_nightly-0.1.dev192.dist-info/RECORD +0 -83
- /mplang/{core → v1/core}/mask.py +0 -0
- /mplang/{protos → v1/protos}/v1alpha1/mpir_pb2.py +0 -0
- /mplang/{protos → v1/protos}/v1alpha1/mpir_pb2.pyi +0 -0
- /mplang/{protos → v1/protos}/v1alpha1/value_pb2.py +0 -0
- /mplang/{protos → v1/protos}/v1alpha1/value_pb2.pyi +0 -0
- /mplang/{runtime → v1/runtime}/exceptions.py +0 -0
- /mplang/{runtime → v1/runtime}/http_api.md +0 -0
- /mplang/{simp → v1/simp}/__init__.py +0 -0
- /mplang/{utils → v1/utils}/__init__.py +0 -0
- /mplang/{utils → v1/utils}/crypto.py +0 -0
- /mplang/{utils → v1/utils}/func_utils.py +0 -0
- /mplang/{utils → v1/utils}/spu_utils.py +0 -0
- {mplang_nightly-0.1.dev192.dist-info → mplang_nightly-0.1.dev268.dist-info}/entry_points.txt +0 -0
- {mplang_nightly-0.1.dev192.dist-info → mplang_nightly-0.1.dev268.dist-info}/licenses/LICENSE +0 -0
|
@@ -248,6 +248,7 @@ class ClusterSpec:
|
|
|
248
248
|
world_size: int,
|
|
249
249
|
*,
|
|
250
250
|
endpoints: list[str] | None = None,
|
|
251
|
+
spu_world_size: int | None = None,
|
|
251
252
|
spu_protocol: str = "SEMI2K",
|
|
252
253
|
spu_field: str = "FM128",
|
|
253
254
|
runtime_version: str = "simulated",
|
|
@@ -325,10 +326,14 @@ class ClusterSpec:
|
|
|
325
326
|
|
|
326
327
|
# Shared SPU device
|
|
327
328
|
if enable_spu_device:
|
|
329
|
+
if spu_world_size is None:
|
|
330
|
+
spu_world_size = world_size
|
|
331
|
+
spu_members = [nodes[f"node{i}"] for i in range(spu_world_size)]
|
|
332
|
+
|
|
328
333
|
devices["SP0"] = Device(
|
|
329
334
|
name="SP0",
|
|
330
335
|
kind="SPU",
|
|
331
|
-
members=
|
|
336
|
+
members=spu_members,
|
|
332
337
|
config={
|
|
333
338
|
"protocol": spu_protocol,
|
|
334
339
|
"field": spu_field,
|
mplang/{core → v1/core}/comm.py
RENAMED
|
@@ -20,7 +20,7 @@ from typing import TYPE_CHECKING
|
|
|
20
20
|
|
|
21
21
|
if TYPE_CHECKING:
|
|
22
22
|
# Imported only for typing to avoid import cycles at runtime.
|
|
23
|
-
from mplang.core.mpobject import MPContext
|
|
23
|
+
from mplang.v1.core.mpobject import MPContext
|
|
24
24
|
|
|
25
25
|
# The global working context.
|
|
26
26
|
_g_ctx: MPContext | None = None
|
|
@@ -177,6 +177,13 @@ class DType:
|
|
|
177
177
|
# TypeError if it's not a pandas dtype we can handle
|
|
178
178
|
pass
|
|
179
179
|
|
|
180
|
+
try:
|
|
181
|
+
return cls._from_arrow_dtype(dtype_like)
|
|
182
|
+
except (ImportError, TypeError):
|
|
183
|
+
# ImportError if pyarrow is not installed
|
|
184
|
+
# TypeError if it's not a pyarrow dtype we can handle
|
|
185
|
+
pass
|
|
186
|
+
|
|
180
187
|
if isinstance(dtype_like, type) and dtype_like in (bool, int, float, complex):
|
|
181
188
|
return cls.from_python_type(dtype_like)
|
|
182
189
|
elif hasattr(dtype_like, "dtype") and not isinstance(dtype_like, type):
|
|
@@ -225,6 +232,37 @@ class DType:
|
|
|
225
232
|
|
|
226
233
|
raise TypeError(f"Unsupported pandas dtype: {dtype_like}")
|
|
227
234
|
|
|
235
|
+
@classmethod
|
|
236
|
+
def _from_arrow_dtype(cls, dtype_like: Any) -> DType:
|
|
237
|
+
try:
|
|
238
|
+
import pyarrow as pa
|
|
239
|
+
except ImportError:
|
|
240
|
+
raise ImportError("pyarrow not available") from None
|
|
241
|
+
|
|
242
|
+
if not isinstance(dtype_like, pa.DataType):
|
|
243
|
+
raise TypeError("Not a pyarrow dtype")
|
|
244
|
+
|
|
245
|
+
ARROW_DTYPE_MAPPING = {
|
|
246
|
+
pa.bool_(): BOOL,
|
|
247
|
+
pa.int8(): INT8,
|
|
248
|
+
pa.int16(): INT16,
|
|
249
|
+
pa.int32(): INT32,
|
|
250
|
+
pa.int64(): INT64,
|
|
251
|
+
pa.uint8(): UINT8,
|
|
252
|
+
pa.uint16(): UINT16,
|
|
253
|
+
pa.uint32(): UINT32,
|
|
254
|
+
pa.uint64(): UINT64,
|
|
255
|
+
pa.float16(): FLOAT16,
|
|
256
|
+
pa.float32(): FLOAT32,
|
|
257
|
+
pa.float64(): FLOAT64,
|
|
258
|
+
pa.string(): STRING,
|
|
259
|
+
pa.large_string(): STRING,
|
|
260
|
+
}
|
|
261
|
+
result = ARROW_DTYPE_MAPPING.get(dtype_like)
|
|
262
|
+
if result is not None:
|
|
263
|
+
return result
|
|
264
|
+
raise TypeError(f"Unsupported arrow dtype: {dtype_like}")
|
|
265
|
+
|
|
228
266
|
def to_numpy(self) -> np.dtype:
|
|
229
267
|
"""Convert custom DType to NumPy dtype."""
|
|
230
268
|
return np.dtype(self.name)
|
|
@@ -20,7 +20,7 @@ multi-party computation graphs using the visitor pattern.
|
|
|
20
20
|
"""
|
|
21
21
|
|
|
22
22
|
# Core expression types
|
|
23
|
-
from mplang.core.expr.ast import (
|
|
23
|
+
from mplang.v1.core.expr.ast import (
|
|
24
24
|
AccessExpr,
|
|
25
25
|
CallExpr,
|
|
26
26
|
CondExpr,
|
|
@@ -36,12 +36,12 @@ from mplang.core.expr.ast import (
|
|
|
36
36
|
)
|
|
37
37
|
|
|
38
38
|
# Built-in evaluator engines
|
|
39
|
-
from mplang.core.expr.evaluator import IEvaluator, create_evaluator
|
|
40
|
-
from mplang.core.expr.printer import Printer
|
|
41
|
-
from mplang.core.expr.transformer import ExprTransformer
|
|
39
|
+
from mplang.v1.core.expr.evaluator import IEvaluator, create_evaluator
|
|
40
|
+
from mplang.v1.core.expr.printer import Printer
|
|
41
|
+
from mplang.v1.core.expr.transformer import ExprTransformer
|
|
42
42
|
|
|
43
43
|
# Utility functions
|
|
44
|
-
from mplang.core.expr.utils import (
|
|
44
|
+
from mplang.v1.core.expr.utils import (
|
|
45
45
|
deduce_mask,
|
|
46
46
|
ensure_scalar,
|
|
47
47
|
ensure_tensorlist_equal,
|
|
@@ -49,8 +49,8 @@ from mplang.core.expr.utils import (
|
|
|
49
49
|
)
|
|
50
50
|
|
|
51
51
|
# Visitor pattern interface
|
|
52
|
-
from mplang.core.expr.visitor import ExprVisitor
|
|
53
|
-
from mplang.core.expr.walk import walk, walk_dataflow, walk_structural
|
|
52
|
+
from mplang.v1.core.expr.visitor import ExprVisitor
|
|
53
|
+
from mplang.v1.core.expr.walk import walk, walk_dataflow, walk_structural
|
|
54
54
|
|
|
55
55
|
__all__ = [
|
|
56
56
|
"AccessExpr",
|
|
@@ -26,15 +26,15 @@ import logging
|
|
|
26
26
|
from abc import ABC, abstractmethod
|
|
27
27
|
from typing import TYPE_CHECKING, Any
|
|
28
28
|
|
|
29
|
-
from mplang.core.expr.utils import deduce_mask
|
|
30
|
-
from mplang.core.mask import Mask
|
|
31
|
-
from mplang.core.mptype import MPType, Rank
|
|
32
|
-
from mplang.core.pfunc import PFunction
|
|
33
|
-
from mplang.core.table import TableType
|
|
34
|
-
from mplang.core.tensor import TensorType
|
|
29
|
+
from mplang.v1.core.expr.utils import deduce_mask
|
|
30
|
+
from mplang.v1.core.mask import Mask
|
|
31
|
+
from mplang.v1.core.mptype import MPType, Rank
|
|
32
|
+
from mplang.v1.core.pfunc import PFunction
|
|
33
|
+
from mplang.v1.core.table import TableType
|
|
34
|
+
from mplang.v1.core.tensor import TensorType
|
|
35
35
|
|
|
36
36
|
if TYPE_CHECKING:
|
|
37
|
-
from mplang.core.expr.visitor import ExprVisitor
|
|
37
|
+
from mplang.v1.core.expr.visitor import ExprVisitor
|
|
38
38
|
|
|
39
39
|
|
|
40
40
|
class Expr(ABC):
|
|
@@ -286,8 +286,8 @@ class ConvExpr(Expr):
|
|
|
286
286
|
# Validate dtype / shape consistency.
|
|
287
287
|
first = types[0]
|
|
288
288
|
for c in types[1:]:
|
|
289
|
-
if
|
|
290
|
-
raise TypeError(f"Inconsistent
|
|
289
|
+
if c.raw_type() != first.raw_type():
|
|
290
|
+
raise TypeError(f"Inconsistent type in pconv: {c} vs {first}")
|
|
291
291
|
|
|
292
292
|
# Deduce the pmask by intersecting all pmasks.
|
|
293
293
|
pmasks = [t.pmask for t in types]
|
|
@@ -316,7 +316,7 @@ class ConvExpr(Expr):
|
|
|
316
316
|
else:
|
|
317
317
|
out_pmask = None
|
|
318
318
|
|
|
319
|
-
return [MPType
|
|
319
|
+
return [MPType(first.raw_type(), out_pmask, first.attrs)]
|
|
320
320
|
|
|
321
321
|
def accept(self, visitor: ExprVisitor) -> Any:
|
|
322
322
|
return visitor.visit_conv(self)
|
|
@@ -398,9 +398,7 @@ class ShflSExpr(Expr):
|
|
|
398
398
|
def _compute_mptypes(self) -> list[MPType]:
|
|
399
399
|
# The types are the same as the source value, but with a new pmask.
|
|
400
400
|
src_type = self.src_val.mptype
|
|
401
|
-
return [
|
|
402
|
-
MPType.tensor(src_type.dtype, src_type.shape, self.pmask, **src_type.attrs)
|
|
403
|
-
]
|
|
401
|
+
return [MPType(src_type._type, self.pmask, src_type.attrs)]
|
|
404
402
|
|
|
405
403
|
def accept(self, visitor: ExprVisitor) -> Any:
|
|
406
404
|
return visitor.visit_shfl_s(self)
|
|
@@ -27,8 +27,8 @@ from __future__ import annotations
|
|
|
27
27
|
from dataclasses import dataclass
|
|
28
28
|
from typing import Any, Protocol
|
|
29
29
|
|
|
30
|
-
from mplang.core.comm import ICommunicator
|
|
31
|
-
from mplang.core.expr.ast import (
|
|
30
|
+
from mplang.v1.core.comm import ICommunicator
|
|
31
|
+
from mplang.v1.core.expr.ast import (
|
|
32
32
|
AccessExpr,
|
|
33
33
|
CallExpr,
|
|
34
34
|
CondExpr,
|
|
@@ -42,12 +42,12 @@ from mplang.core.expr.ast import (
|
|
|
42
42
|
VariableExpr,
|
|
43
43
|
WhileExpr,
|
|
44
44
|
)
|
|
45
|
-
from mplang.core.expr.visitor import ExprVisitor
|
|
46
|
-
from mplang.core.expr.walk import walk_dataflow
|
|
47
|
-
from mplang.core.mask import Mask
|
|
48
|
-
from mplang.core.pfunc import PFunction
|
|
49
|
-
from mplang.kernels.context import RuntimeContext
|
|
50
|
-
from mplang.kernels.value import Value
|
|
45
|
+
from mplang.v1.core.expr.visitor import ExprVisitor
|
|
46
|
+
from mplang.v1.core.expr.walk import walk_dataflow
|
|
47
|
+
from mplang.v1.core.mask import Mask
|
|
48
|
+
from mplang.v1.core.pfunc import PFunction
|
|
49
|
+
from mplang.v1.kernels.context import RuntimeContext
|
|
50
|
+
from mplang.v1.kernels.value import Value
|
|
51
51
|
|
|
52
52
|
|
|
53
53
|
class IEvaluator(Protocol):
|
|
@@ -20,8 +20,8 @@ from __future__ import annotations
|
|
|
20
20
|
|
|
21
21
|
from typing import Any
|
|
22
22
|
|
|
23
|
-
from mplang.core.dtypes import DType
|
|
24
|
-
from mplang.core.expr.ast import (
|
|
23
|
+
from mplang.v1.core.dtypes import DType
|
|
24
|
+
from mplang.v1.core.expr.ast import (
|
|
25
25
|
AccessExpr,
|
|
26
26
|
CallExpr,
|
|
27
27
|
CondExpr,
|
|
@@ -35,10 +35,10 @@ from mplang.core.expr.ast import (
|
|
|
35
35
|
VariableExpr,
|
|
36
36
|
WhileExpr,
|
|
37
37
|
)
|
|
38
|
-
from mplang.core.expr.visitor import ExprVisitor
|
|
39
|
-
from mplang.core.mptype import MPType
|
|
40
|
-
from mplang.core.pfunc import PFunction
|
|
41
|
-
from mplang.core.tensor import Shape, TensorType
|
|
38
|
+
from mplang.v1.core.expr.visitor import ExprVisitor
|
|
39
|
+
from mplang.v1.core.mptype import MPType
|
|
40
|
+
from mplang.v1.core.pfunc import PFunction
|
|
41
|
+
from mplang.v1.core.tensor import Shape, TensorType
|
|
42
42
|
|
|
43
43
|
|
|
44
44
|
class Printer(ExprVisitor):
|
|
@@ -18,7 +18,7 @@ Expression transformer based on visitor pattern.
|
|
|
18
18
|
|
|
19
19
|
from collections.abc import Callable
|
|
20
20
|
|
|
21
|
-
from mplang.core.expr.ast import (
|
|
21
|
+
from mplang.v1.core.expr.ast import (
|
|
22
22
|
AccessExpr,
|
|
23
23
|
CallExpr,
|
|
24
24
|
CondExpr,
|
|
@@ -32,7 +32,7 @@ from mplang.core.expr.ast import (
|
|
|
32
32
|
VariableExpr,
|
|
33
33
|
WhileExpr,
|
|
34
34
|
)
|
|
35
|
-
from mplang.core.expr.visitor import ExprVisitor
|
|
35
|
+
from mplang.v1.core.expr.visitor import ExprVisitor
|
|
36
36
|
|
|
37
37
|
|
|
38
38
|
class ExprTransformer(ExprVisitor):
|
|
@@ -18,8 +18,8 @@ Utility functions for expression system.
|
|
|
18
18
|
|
|
19
19
|
from collections.abc import Sequence
|
|
20
20
|
|
|
21
|
-
from mplang.core.mask import Mask
|
|
22
|
-
from mplang.core.mptype import TensorLike
|
|
21
|
+
from mplang.v1.core.mask import Mask
|
|
22
|
+
from mplang.v1.core.mptype import TensorLike
|
|
23
23
|
|
|
24
24
|
|
|
25
25
|
def type_equal(*args: TensorLike) -> bool:
|
|
@@ -25,12 +25,12 @@ from abc import abstractmethod
|
|
|
25
25
|
from collections.abc import Sequence
|
|
26
26
|
from typing import Any, cast
|
|
27
27
|
|
|
28
|
-
from mplang.core.cluster import ClusterSpec
|
|
29
|
-
from mplang.core.expr.ast import Expr, VariableExpr
|
|
30
|
-
from mplang.core.mpobject import MPContext, MPObject
|
|
31
|
-
from mplang.core.mptype import MPType, TensorLike
|
|
32
|
-
from mplang.core.tracer import TracedFunction
|
|
33
|
-
from mplang.utils.func_utils import var_demorph, var_morph
|
|
28
|
+
from mplang.v1.core.cluster import ClusterSpec
|
|
29
|
+
from mplang.v1.core.expr.ast import Expr, VariableExpr
|
|
30
|
+
from mplang.v1.core.mpobject import MPContext, MPObject
|
|
31
|
+
from mplang.v1.core.mptype import MPType, TensorLike
|
|
32
|
+
from mplang.v1.core.tracer import TracedFunction
|
|
33
|
+
from mplang.v1.utils.func_utils import var_demorph, var_morph
|
|
34
34
|
|
|
35
35
|
|
|
36
36
|
# TODO(jint): Should we use inheritance or composition here?
|
mplang/{core → v1/core}/mpir.py
RENAMED
|
@@ -32,9 +32,9 @@ from typing import Any
|
|
|
32
32
|
import numpy as np
|
|
33
33
|
import spu.libspu as spu_api
|
|
34
34
|
|
|
35
|
-
from mplang.core.dtypes import DATE, JSON, STRING, TIME, TIMESTAMP, DType
|
|
36
|
-
from mplang.core.expr import Expr, FuncDefExpr
|
|
37
|
-
from mplang.core.expr.ast import (
|
|
35
|
+
from mplang.v1.core.dtypes import DATE, JSON, STRING, TIME, TIMESTAMP, DType
|
|
36
|
+
from mplang.v1.core.expr import Expr, FuncDefExpr
|
|
37
|
+
from mplang.v1.core.expr.ast import (
|
|
38
38
|
AccessExpr,
|
|
39
39
|
CallExpr,
|
|
40
40
|
CondExpr,
|
|
@@ -46,13 +46,13 @@ from mplang.core.expr.ast import (
|
|
|
46
46
|
VariableExpr,
|
|
47
47
|
WhileExpr,
|
|
48
48
|
)
|
|
49
|
-
from mplang.core.expr.walk import walk
|
|
50
|
-
from mplang.core.mask import Mask
|
|
51
|
-
from mplang.core.mptype import MPType
|
|
52
|
-
from mplang.core.pfunc import PFunction
|
|
53
|
-
from mplang.core.table import TableType
|
|
54
|
-
from mplang.core.tensor import TensorType
|
|
55
|
-
from mplang.protos.v1alpha1 import mpir_pb2
|
|
49
|
+
from mplang.v1.core.expr.walk import walk
|
|
50
|
+
from mplang.v1.core.mask import Mask
|
|
51
|
+
from mplang.v1.core.mptype import MPType
|
|
52
|
+
from mplang.v1.core.pfunc import PFunction
|
|
53
|
+
from mplang.v1.core.table import TableType
|
|
54
|
+
from mplang.v1.core.tensor import TensorType
|
|
55
|
+
from mplang.v1.protos.v1alpha1 import mpir_pb2
|
|
56
56
|
|
|
57
57
|
# Single mapping table for dtype conversion
|
|
58
58
|
DTYPE_MAPPING = {
|
|
@@ -217,7 +217,9 @@ def attr_to_proto(py_value: Any) -> mpir_pb2.AttrProto:
|
|
|
217
217
|
# Serialize attrs dictionary
|
|
218
218
|
if py_value.attrs:
|
|
219
219
|
for attr_name, attr_value in py_value.attrs.items():
|
|
220
|
-
|
|
220
|
+
# Skip None-valued attributes to align with top-level attr handling
|
|
221
|
+
if attr_value is not None:
|
|
222
|
+
attr_proto.func.attrs[attr_name].CopyFrom(attr_to_proto(attr_value))
|
|
221
223
|
|
|
222
224
|
# Note: We don't serialize ins_info and outs_info since they can be
|
|
223
225
|
# inferred from the input expressions during deserialization
|
|
@@ -17,14 +17,14 @@ from __future__ import annotations
|
|
|
17
17
|
from abc import ABC, abstractmethod
|
|
18
18
|
from typing import TYPE_CHECKING, Any
|
|
19
19
|
|
|
20
|
-
from mplang.core.dtypes import DType
|
|
21
|
-
from mplang.core.mask import Mask
|
|
22
|
-
from mplang.core.mptype import MPType
|
|
23
|
-
from mplang.core.table import TableType
|
|
24
|
-
from mplang.core.tensor import Shape
|
|
20
|
+
from mplang.v1.core.dtypes import DType
|
|
21
|
+
from mplang.v1.core.mask import Mask
|
|
22
|
+
from mplang.v1.core.mptype import MPType
|
|
23
|
+
from mplang.v1.core.table import TableType
|
|
24
|
+
from mplang.v1.core.tensor import Shape
|
|
25
25
|
|
|
26
26
|
if TYPE_CHECKING:
|
|
27
|
-
from mplang.core.cluster import ClusterSpec
|
|
27
|
+
from mplang.v1.core.cluster import ClusterSpec
|
|
28
28
|
|
|
29
29
|
|
|
30
30
|
class MPContext:
|
|
@@ -20,12 +20,12 @@ from typing import TYPE_CHECKING, Any
|
|
|
20
20
|
import numpy as np
|
|
21
21
|
|
|
22
22
|
if TYPE_CHECKING:
|
|
23
|
-
from mplang.core.mpobject import MPObject
|
|
23
|
+
from mplang.v1.core.mpobject import MPObject
|
|
24
24
|
|
|
25
|
-
from mplang.core.dtypes import STRING, DType
|
|
26
|
-
from mplang.core.mask import Mask
|
|
27
|
-
from mplang.core.table import TableLike, TableType
|
|
28
|
-
from mplang.core.tensor import ScalarType, Shape, TensorLike, TensorType
|
|
25
|
+
from mplang.v1.core.dtypes import STRING, DType
|
|
26
|
+
from mplang.v1.core.mask import Mask
|
|
27
|
+
from mplang.v1.core.table import TableLike, TableType
|
|
28
|
+
from mplang.v1.core.tensor import ScalarType, Shape, TensorLike, TensorType
|
|
29
29
|
|
|
30
30
|
# basic type aliases
|
|
31
31
|
Rank = int
|
|
@@ -195,6 +195,10 @@ class MPType:
|
|
|
195
195
|
information about the object."""
|
|
196
196
|
return self._attrs
|
|
197
197
|
|
|
198
|
+
def raw_type(self) -> TensorType | TableType:
|
|
199
|
+
"""Get the raw type information (TensorType or TableType)."""
|
|
200
|
+
return self._type
|
|
201
|
+
|
|
198
202
|
def set_attr(self, key: str, value: Any) -> None:
|
|
199
203
|
"""Set an attribute for this type."""
|
|
200
204
|
self._attrs[key] = value
|
|
@@ -252,9 +256,8 @@ class MPType:
|
|
|
252
256
|
if not isinstance(other, MPType):
|
|
253
257
|
return False
|
|
254
258
|
return (
|
|
255
|
-
self._type == other._type
|
|
256
|
-
and self.
|
|
257
|
-
and self._attrs == other._attrs
|
|
259
|
+
self._type == other._type and self._pmask == other._pmask
|
|
260
|
+
# and self._attrs == other._attrs # TODO(jint): attrs should be optional
|
|
258
261
|
)
|
|
259
262
|
|
|
260
263
|
def __hash__(self) -> int:
|
|
@@ -270,7 +273,7 @@ class MPType:
|
|
|
270
273
|
def isInstance(self, obj: MPObject) -> bool:
|
|
271
274
|
"""Check if the given object is an instance of this MPType."""
|
|
272
275
|
# Import here to avoid circular import
|
|
273
|
-
from mplang.core.mpobject import MPObject
|
|
276
|
+
from mplang.v1.core.mpobject import MPObject
|
|
274
277
|
|
|
275
278
|
if not isinstance(obj, MPObject):
|
|
276
279
|
return False
|
|
@@ -373,7 +376,7 @@ class MPType:
|
|
|
373
376
|
import pandas as pd
|
|
374
377
|
|
|
375
378
|
if isinstance(obj, pd.DataFrame):
|
|
376
|
-
from mplang.core.dtypes import DType
|
|
379
|
+
from mplang.v1.core.dtypes import DType
|
|
377
380
|
|
|
378
381
|
schema_dict = {}
|
|
379
382
|
for col_name in obj.columns:
|
mplang/{core → v1/core}/pfunc.py
RENAMED
|
@@ -19,8 +19,8 @@ from collections.abc import Sequence
|
|
|
19
19
|
from types import MappingProxyType
|
|
20
20
|
from typing import Any
|
|
21
21
|
|
|
22
|
-
from mplang.core.table import TableType
|
|
23
|
-
from mplang.core.tensor import TensorType
|
|
22
|
+
from mplang.v1.core.table import TableType
|
|
23
|
+
from mplang.v1.core.tensor import TensorType
|
|
24
24
|
|
|
25
25
|
__all__ = [
|
|
26
26
|
"PFunction",
|
|
@@ -28,9 +28,9 @@ from typing import Any, ParamSpec, TypeVar, cast
|
|
|
28
28
|
|
|
29
29
|
from jax.tree_util import tree_map
|
|
30
30
|
|
|
31
|
-
from mplang.core.context_mgr import cur_ctx
|
|
32
|
-
from mplang.core.dtypes import BOOL
|
|
33
|
-
from mplang.core.expr.ast import (
|
|
31
|
+
from mplang.v1.core.context_mgr import cur_ctx
|
|
32
|
+
from mplang.v1.core.dtypes import BOOL
|
|
33
|
+
from mplang.v1.core.expr.ast import (
|
|
34
34
|
AccessExpr,
|
|
35
35
|
CallExpr,
|
|
36
36
|
CondExpr,
|
|
@@ -40,13 +40,13 @@ from mplang.core.expr.ast import (
|
|
|
40
40
|
ShflSExpr,
|
|
41
41
|
WhileExpr,
|
|
42
42
|
)
|
|
43
|
-
from mplang.core.interp import InterpContext, InterpVar, apply
|
|
44
|
-
from mplang.core.mask import Mask
|
|
45
|
-
from mplang.core.mpobject import MPContext, MPObject
|
|
46
|
-
from mplang.core.mptype import Rank
|
|
47
|
-
from mplang.core.pfunc import PFunction
|
|
48
|
-
from mplang.core.tracer import TraceContext, TraceVar, trace
|
|
49
|
-
from mplang.utils.func_utils import var_demorph, var_morph
|
|
43
|
+
from mplang.v1.core.interp import InterpContext, InterpVar, apply
|
|
44
|
+
from mplang.v1.core.mask import Mask
|
|
45
|
+
from mplang.v1.core.mpobject import MPContext, MPObject
|
|
46
|
+
from mplang.v1.core.mptype import Rank
|
|
47
|
+
from mplang.v1.core.pfunc import PFunction
|
|
48
|
+
from mplang.v1.core.tracer import TraceContext, TraceVar, trace
|
|
49
|
+
from mplang.v1.utils.func_utils import var_demorph, var_morph
|
|
50
50
|
|
|
51
51
|
|
|
52
52
|
def _switch_ctx(ctx: MPContext, obj: MPObject | Any) -> MPObject | Any:
|
|
@@ -298,7 +298,7 @@ def uniform_cond(
|
|
|
298
298
|
|
|
299
299
|
1. ``pred`` is a boolean scalar whose runtime value is identical for every enabled party.
|
|
300
300
|
2. At least one branch contains multi-party primitives (``seal`` / ``reveal`` /
|
|
301
|
-
``
|
|
301
|
+
``srun_jax`` / ``pshfl`` / mask transformations) whose cost or side-effects you
|
|
302
302
|
want to avoid if the branch is not taken.
|
|
303
303
|
3. You require the semantic guarantee that the *non-selected* branch does **not**
|
|
304
304
|
perform communication, allocate intermediate buffers, or leak timing/side-effects.
|
|
@@ -559,7 +559,7 @@ def while_loop(
|
|
|
559
559
|
secret-shared reduction).
|
|
560
560
|
|
|
561
561
|
cond_fn::
|
|
562
|
-
sealed_sum = smpc.reveal(smpc.
|
|
562
|
+
sealed_sum = smpc.reveal(smpc.srun_jax(lambda x: jnp.sum(x), smpc.seal(x)))
|
|
563
563
|
return sealed_sum < constant(10)
|
|
564
564
|
|
|
565
565
|
body_fn::
|
mplang/{core → v1/core}/table.py
RENAMED
|
@@ -18,16 +18,16 @@ from collections.abc import Iterator
|
|
|
18
18
|
from dataclasses import dataclass, field
|
|
19
19
|
from typing import Any, Protocol, runtime_checkable
|
|
20
20
|
|
|
21
|
-
from mplang.core.dtypes import DType
|
|
21
|
+
from mplang.v1.core.dtypes import DType
|
|
22
22
|
|
|
23
23
|
__all__ = ["TableLike", "TableType"]
|
|
24
24
|
|
|
25
25
|
|
|
26
26
|
@runtime_checkable
|
|
27
|
-
class
|
|
27
|
+
class PandasTableLike(Protocol):
|
|
28
28
|
"""
|
|
29
29
|
Protocol for objects structurally resembling tables from common libraries
|
|
30
|
-
(pandas DataFrame,
|
|
30
|
+
(pandas DataFrame, polars DataFrame, etc.), focusing on dtypes and columns attributes.
|
|
31
31
|
"""
|
|
32
32
|
|
|
33
33
|
@property
|
|
@@ -37,6 +37,26 @@ class TableLike(Protocol):
|
|
|
37
37
|
def columns(self) -> Any: ...
|
|
38
38
|
|
|
39
39
|
|
|
40
|
+
@runtime_checkable
|
|
41
|
+
class ArrowSchema(Protocol):
|
|
42
|
+
@property
|
|
43
|
+
def names(self) -> list[str]: ...
|
|
44
|
+
@property
|
|
45
|
+
def types(self) -> list[Any]: ...
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
@runtime_checkable
|
|
49
|
+
class ArrowTableLike(Protocol):
|
|
50
|
+
@property
|
|
51
|
+
def column_names(self) -> list[str]: ...
|
|
52
|
+
|
|
53
|
+
@property
|
|
54
|
+
def schema(self) -> ArrowSchema: ...
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
TableLike = PandasTableLike | ArrowTableLike
|
|
58
|
+
|
|
59
|
+
|
|
40
60
|
@dataclass(frozen=True)
|
|
41
61
|
class TableType:
|
|
42
62
|
"""Table schema: ordered list of column name-type pairs.
|
|
@@ -109,11 +129,19 @@ class TableType:
|
|
|
109
129
|
Returns:
|
|
110
130
|
TableType instance
|
|
111
131
|
"""
|
|
112
|
-
|
|
113
|
-
|
|
114
|
-
|
|
115
|
-
|
|
116
|
-
|
|
132
|
+
if isinstance(table, PandasTableLike):
|
|
133
|
+
columns = [
|
|
134
|
+
(name, DType.from_any(dtype))
|
|
135
|
+
for name, dtype in zip(table.columns, table.dtypes, strict=True)
|
|
136
|
+
]
|
|
137
|
+
return cls(tuple(columns))
|
|
138
|
+
elif isinstance(table, ArrowTableLike):
|
|
139
|
+
schema = table.schema
|
|
140
|
+
columns = [
|
|
141
|
+
(name, DType.from_any(dtype))
|
|
142
|
+
for name, dtype in zip(schema.names, schema.types, strict=True)
|
|
143
|
+
]
|
|
144
|
+
return cls(tuple(columns))
|
|
117
145
|
|
|
118
146
|
def column_names(self) -> tuple[str, ...]:
|
|
119
147
|
"""Get all column names."""
|
|
@@ -60,15 +60,15 @@ from collections.abc import Callable
|
|
|
60
60
|
from dataclasses import dataclass
|
|
61
61
|
from typing import Any, cast
|
|
62
62
|
|
|
63
|
-
from mplang.core.cluster import ClusterSpec
|
|
64
|
-
from mplang.core.context_mgr import with_ctx
|
|
65
|
-
from mplang.core.expr.ast import Expr, FuncDefExpr, TupleExpr, VariableExpr
|
|
66
|
-
from mplang.core.expr.printer import Printer
|
|
67
|
-
from mplang.core.mask import Mask
|
|
68
|
-
from mplang.core.mpobject import MPContext, MPObject
|
|
69
|
-
from mplang.core.mptype import MPType
|
|
70
|
-
from mplang.core.pfunc import get_fn_name
|
|
71
|
-
from mplang.utils.func_utils import MorphStruct, var_demorph, var_morph
|
|
63
|
+
from mplang.v1.core.cluster import ClusterSpec
|
|
64
|
+
from mplang.v1.core.context_mgr import with_ctx
|
|
65
|
+
from mplang.v1.core.expr.ast import Expr, FuncDefExpr, TupleExpr, VariableExpr
|
|
66
|
+
from mplang.v1.core.expr.printer import Printer
|
|
67
|
+
from mplang.v1.core.mask import Mask
|
|
68
|
+
from mplang.v1.core.mpobject import MPContext, MPObject
|
|
69
|
+
from mplang.v1.core.mptype import MPType
|
|
70
|
+
from mplang.v1.core.pfunc import get_fn_name
|
|
71
|
+
from mplang.v1.utils.func_utils import MorphStruct, var_demorph, var_morph
|
|
72
72
|
|
|
73
73
|
|
|
74
74
|
class VarNamer:
|
mplang/{host.py → v1/host.py}
RENAMED
|
@@ -19,7 +19,7 @@ from typing import Any
|
|
|
19
19
|
|
|
20
20
|
from jax.tree_util import tree_map
|
|
21
21
|
|
|
22
|
-
from mplang.core import (
|
|
22
|
+
from mplang.v1.core import (
|
|
23
23
|
ClusterSpec,
|
|
24
24
|
InterpContext,
|
|
25
25
|
MPContext,
|
|
@@ -28,7 +28,7 @@ from mplang.core import (
|
|
|
28
28
|
TracedFunction,
|
|
29
29
|
trace,
|
|
30
30
|
)
|
|
31
|
-
from mplang.core.context_mgr import cur_ctx, with_ctx
|
|
31
|
+
from mplang.v1.core.context_mgr import cur_ctx, with_ctx
|
|
32
32
|
|
|
33
33
|
|
|
34
34
|
def evaluate(
|
|
@@ -76,11 +76,11 @@ def fetch(interp: InterpContext | None, objs: Any) -> Any: # type: ignore[misc]
|
|
|
76
76
|
evaluated = evaluate(ctx, lambda x: x, objs)
|
|
77
77
|
|
|
78
78
|
def fetch_impl(arg: MPObject | Any) -> Any:
|
|
79
|
-
if isinstance(arg, MPObject):
|
|
80
|
-
return ctx.fetch(arg)
|
|
81
|
-
else:
|
|
79
|
+
if not isinstance(arg, MPObject):
|
|
82
80
|
return arg
|
|
83
81
|
|
|
82
|
+
return ctx.fetch(arg)
|
|
83
|
+
|
|
84
84
|
return tree_map(fetch_impl, evaluated)
|
|
85
85
|
|
|
86
86
|
|