mplang-nightly 0.1.dev158__py3-none-any.whl → 0.1.dev268__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mplang/__init__.py +21 -45
- mplang/py.typed +13 -0
- mplang/v1/__init__.py +157 -0
- mplang/v1/_device.py +602 -0
- mplang/{analysis → v1/analysis}/__init__.py +1 -1
- mplang/{analysis → v1/analysis}/diagram.py +5 -7
- mplang/v1/core/__init__.py +157 -0
- mplang/{core → v1/core}/cluster.py +30 -14
- mplang/{core → v1/core}/comm.py +5 -1
- mplang/{core → v1/core}/context_mgr.py +1 -1
- mplang/{core/dtype.py → v1/core/dtypes.py} +44 -2
- mplang/{core → v1/core}/expr/__init__.py +7 -7
- mplang/{core → v1/core}/expr/ast.py +13 -14
- mplang/{core → v1/core}/expr/evaluator.py +65 -24
- mplang/{core → v1/core}/expr/printer.py +24 -18
- mplang/{core → v1/core}/expr/transformer.py +3 -3
- mplang/{core → v1/core}/expr/utils.py +2 -2
- mplang/{core → v1/core}/expr/visitor.py +1 -1
- mplang/{core → v1/core}/expr/walk.py +1 -1
- mplang/{core → v1/core}/interp.py +6 -6
- mplang/{core → v1/core}/mpir.py +23 -16
- mplang/{core → v1/core}/mpobject.py +6 -6
- mplang/{core → v1/core}/mptype.py +13 -10
- mplang/{core → v1/core}/pfunc.py +4 -4
- mplang/{core → v1/core}/primitive.py +106 -201
- mplang/{core → v1/core}/table.py +36 -8
- mplang/{core → v1/core}/tensor.py +1 -1
- mplang/{core → v1/core}/tracer.py +9 -9
- mplang/{api.py → v1/host.py} +38 -6
- mplang/v1/kernels/__init__.py +41 -0
- mplang/{kernels → v1/kernels}/base.py +1 -1
- mplang/v1/kernels/basic.py +240 -0
- mplang/{kernels → v1/kernels}/context.py +42 -27
- mplang/{kernels → v1/kernels}/crypto.py +44 -37
- mplang/v1/kernels/fhe.py +858 -0
- mplang/{kernels → v1/kernels}/mock_tee.py +12 -13
- mplang/{kernels → v1/kernels}/phe.py +263 -57
- mplang/{kernels → v1/kernels}/spu.py +137 -48
- mplang/{kernels → v1/kernels}/sql_duckdb.py +12 -15
- mplang/{kernels → v1/kernels}/stablehlo.py +30 -23
- mplang/v1/kernels/value.py +626 -0
- mplang/{ops → v1/ops}/__init__.py +5 -16
- mplang/{ops → v1/ops}/base.py +2 -5
- mplang/{ops/builtin.py → v1/ops/basic.py} +34 -26
- mplang/v1/ops/crypto.py +262 -0
- mplang/v1/ops/fhe.py +272 -0
- mplang/{ops → v1/ops}/jax_cc.py +33 -68
- mplang/v1/ops/nnx_cc.py +168 -0
- mplang/{ops → v1/ops}/phe.py +16 -4
- mplang/{ops → v1/ops}/spu.py +3 -5
- mplang/v1/ops/sql_cc.py +303 -0
- mplang/{ops → v1/ops}/tee.py +9 -24
- mplang/{protos → v1/protos}/v1alpha1/mpir_pb2.pyi +71 -21
- mplang/v1/protos/v1alpha1/value_pb2.py +34 -0
- mplang/v1/protos/v1alpha1/value_pb2.pyi +169 -0
- mplang/{runtime → v1/runtime}/__init__.py +2 -2
- mplang/v1/runtime/channel.py +230 -0
- mplang/{runtime → v1/runtime}/cli.py +35 -20
- mplang/{runtime → v1/runtime}/client.py +19 -8
- mplang/{runtime → v1/runtime}/communicator.py +59 -15
- mplang/{runtime → v1/runtime}/data_providers.py +80 -19
- mplang/{runtime → v1/runtime}/driver.py +30 -12
- mplang/v1/runtime/link_comm.py +196 -0
- mplang/{runtime → v1/runtime}/server.py +58 -42
- mplang/{runtime → v1/runtime}/session.py +57 -71
- mplang/{runtime → v1/runtime}/simulation.py +55 -28
- mplang/v1/simp/api.py +353 -0
- mplang/{simp → v1/simp}/mpi.py +8 -9
- mplang/{simp/__init__.py → v1/simp/party.py} +19 -145
- mplang/{simp → v1/simp}/random.py +21 -22
- mplang/v1/simp/smpc.py +238 -0
- mplang/v1/utils/table_utils.py +185 -0
- mplang/v2/__init__.py +424 -0
- mplang/v2/backends/__init__.py +57 -0
- mplang/v2/backends/bfv_impl.py +705 -0
- mplang/v2/backends/channel.py +217 -0
- mplang/v2/backends/crypto_impl.py +723 -0
- mplang/v2/backends/field_impl.py +454 -0
- mplang/v2/backends/func_impl.py +107 -0
- mplang/v2/backends/phe_impl.py +148 -0
- mplang/v2/backends/simp_design.md +136 -0
- mplang/v2/backends/simp_driver/__init__.py +41 -0
- mplang/v2/backends/simp_driver/http.py +168 -0
- mplang/v2/backends/simp_driver/mem.py +280 -0
- mplang/v2/backends/simp_driver/ops.py +135 -0
- mplang/v2/backends/simp_driver/state.py +60 -0
- mplang/v2/backends/simp_driver/values.py +52 -0
- mplang/v2/backends/simp_worker/__init__.py +29 -0
- mplang/v2/backends/simp_worker/http.py +354 -0
- mplang/v2/backends/simp_worker/mem.py +102 -0
- mplang/v2/backends/simp_worker/ops.py +167 -0
- mplang/v2/backends/simp_worker/state.py +49 -0
- mplang/v2/backends/spu_impl.py +275 -0
- mplang/v2/backends/spu_state.py +187 -0
- mplang/v2/backends/store_impl.py +62 -0
- mplang/v2/backends/table_impl.py +838 -0
- mplang/v2/backends/tee_impl.py +215 -0
- mplang/v2/backends/tensor_impl.py +519 -0
- mplang/v2/cli.py +603 -0
- mplang/v2/cli_guide.md +122 -0
- mplang/v2/dialects/__init__.py +36 -0
- mplang/v2/dialects/bfv.py +665 -0
- mplang/v2/dialects/crypto.py +689 -0
- mplang/v2/dialects/dtypes.py +378 -0
- mplang/v2/dialects/field.py +210 -0
- mplang/v2/dialects/func.py +135 -0
- mplang/v2/dialects/phe.py +723 -0
- mplang/v2/dialects/simp.py +944 -0
- mplang/v2/dialects/spu.py +349 -0
- mplang/v2/dialects/store.py +63 -0
- mplang/v2/dialects/table.py +407 -0
- mplang/v2/dialects/tee.py +346 -0
- mplang/v2/dialects/tensor.py +1175 -0
- mplang/v2/edsl/README.md +279 -0
- mplang/v2/edsl/__init__.py +99 -0
- mplang/v2/edsl/context.py +311 -0
- mplang/v2/edsl/graph.py +463 -0
- mplang/v2/edsl/jit.py +62 -0
- mplang/v2/edsl/object.py +53 -0
- mplang/v2/edsl/primitive.py +284 -0
- mplang/v2/edsl/printer.py +119 -0
- mplang/v2/edsl/registry.py +207 -0
- mplang/v2/edsl/serde.py +375 -0
- mplang/v2/edsl/tracer.py +614 -0
- mplang/v2/edsl/typing.py +816 -0
- mplang/v2/kernels/Makefile +30 -0
- mplang/v2/kernels/__init__.py +23 -0
- mplang/v2/kernels/gf128.cpp +148 -0
- mplang/v2/kernels/ldpc.cpp +82 -0
- mplang/v2/kernels/okvs.cpp +283 -0
- mplang/v2/kernels/okvs_opt.cpp +291 -0
- mplang/v2/kernels/py_kernels.py +398 -0
- mplang/v2/libs/collective.py +330 -0
- mplang/v2/libs/device/__init__.py +51 -0
- mplang/v2/libs/device/api.py +813 -0
- mplang/v2/libs/device/cluster.py +352 -0
- mplang/v2/libs/ml/__init__.py +23 -0
- mplang/v2/libs/ml/sgb.py +1861 -0
- mplang/v2/libs/mpc/__init__.py +41 -0
- mplang/v2/libs/mpc/_utils.py +99 -0
- mplang/v2/libs/mpc/analytics/__init__.py +35 -0
- mplang/v2/libs/mpc/analytics/aggregation.py +372 -0
- mplang/v2/libs/mpc/analytics/groupby.md +99 -0
- mplang/v2/libs/mpc/analytics/groupby.py +331 -0
- mplang/v2/libs/mpc/analytics/permutation.py +386 -0
- mplang/v2/libs/mpc/common/constants.py +39 -0
- mplang/v2/libs/mpc/ot/__init__.py +32 -0
- mplang/v2/libs/mpc/ot/base.py +222 -0
- mplang/v2/libs/mpc/ot/extension.py +477 -0
- mplang/v2/libs/mpc/ot/silent.py +217 -0
- mplang/v2/libs/mpc/psi/__init__.py +40 -0
- mplang/v2/libs/mpc/psi/cuckoo.py +228 -0
- mplang/v2/libs/mpc/psi/okvs.py +49 -0
- mplang/v2/libs/mpc/psi/okvs_gct.py +79 -0
- mplang/v2/libs/mpc/psi/oprf.py +310 -0
- mplang/v2/libs/mpc/psi/rr22.py +344 -0
- mplang/v2/libs/mpc/psi/unbalanced.py +200 -0
- mplang/v2/libs/mpc/vole/__init__.py +31 -0
- mplang/v2/libs/mpc/vole/gilboa.py +327 -0
- mplang/v2/libs/mpc/vole/ldpc.py +383 -0
- mplang/v2/libs/mpc/vole/silver.py +336 -0
- mplang/v2/runtime/__init__.py +15 -0
- mplang/v2/runtime/dialect_state.py +41 -0
- mplang/v2/runtime/interpreter.py +871 -0
- mplang/v2/runtime/object_store.py +194 -0
- mplang/v2/runtime/value.py +141 -0
- {mplang_nightly-0.1.dev158.dist-info → mplang_nightly-0.1.dev268.dist-info}/METADATA +24 -17
- mplang_nightly-0.1.dev268.dist-info/RECORD +180 -0
- {mplang_nightly-0.1.dev158.dist-info → mplang_nightly-0.1.dev268.dist-info}/WHEEL +1 -1
- mplang/core/__init__.py +0 -92
- mplang/device.py +0 -340
- mplang/kernels/builtin.py +0 -207
- mplang/ops/crypto.py +0 -109
- mplang/ops/ibis_cc.py +0 -139
- mplang/ops/sql.py +0 -61
- mplang/protos/v1alpha1/mpir_pb2_grpc.py +0 -3
- mplang/runtime/link_comm.py +0 -131
- mplang/simp/smpc.py +0 -201
- mplang/utils/table_utils.py +0 -73
- mplang_nightly-0.1.dev158.dist-info/RECORD +0 -77
- /mplang/{core → v1/core}/mask.py +0 -0
- /mplang/{protos → v1/protos}/v1alpha1/mpir_pb2.py +0 -0
- /mplang/{runtime → v1/runtime}/exceptions.py +0 -0
- /mplang/{runtime → v1/runtime}/http_api.md +0 -0
- /mplang/{kernels → v1/simp}/__init__.py +0 -0
- /mplang/{utils → v1/utils}/__init__.py +0 -0
- /mplang/{utils → v1/utils}/crypto.py +0 -0
- /mplang/{utils → v1/utils}/func_utils.py +0 -0
- /mplang/{utils → v1/utils}/spu_utils.py +0 -0
- {mplang_nightly-0.1.dev158.dist-info → mplang_nightly-0.1.dev268.dist-info}/entry_points.txt +0 -0
- {mplang_nightly-0.1.dev158.dist-info → mplang_nightly-0.1.dev268.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,157 @@
|
|
|
1
|
+
# Copyright 2025 Ant Group Co., Ltd.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
"""
|
|
16
|
+
Core components for multi-party computation.
|
|
17
|
+
|
|
18
|
+
This package provides the fundamental building blocks for multi-party computation,
|
|
19
|
+
including type systems, tracing mechanisms, and interpreter contexts.
|
|
20
|
+
"""
|
|
21
|
+
|
|
22
|
+
# Core type system
|
|
23
|
+
# Communication interfaces & core symbols
|
|
24
|
+
from mplang.v1.core.cluster import ClusterSpec, Device, Node, RuntimeInfo
|
|
25
|
+
from mplang.v1.core.comm import (
|
|
26
|
+
CollectiveMixin,
|
|
27
|
+
CommunicatorBase,
|
|
28
|
+
ICollective,
|
|
29
|
+
ICommunicator,
|
|
30
|
+
)
|
|
31
|
+
from mplang.v1.core.context_mgr import cur_ctx, set_ctx, with_ctx
|
|
32
|
+
from mplang.v1.core.dtypes import (
|
|
33
|
+
BINARY,
|
|
34
|
+
BOOL,
|
|
35
|
+
COMPLEX64,
|
|
36
|
+
COMPLEX128,
|
|
37
|
+
DATE,
|
|
38
|
+
DECIMAL,
|
|
39
|
+
FLOAT16,
|
|
40
|
+
FLOAT32,
|
|
41
|
+
FLOAT64,
|
|
42
|
+
INT8,
|
|
43
|
+
INT16,
|
|
44
|
+
INT32,
|
|
45
|
+
INT64,
|
|
46
|
+
INTERVAL,
|
|
47
|
+
JSON,
|
|
48
|
+
STRING,
|
|
49
|
+
TIME,
|
|
50
|
+
TIMESTAMP,
|
|
51
|
+
UINT8,
|
|
52
|
+
UINT16,
|
|
53
|
+
UINT32,
|
|
54
|
+
UINT64,
|
|
55
|
+
UUID,
|
|
56
|
+
DType,
|
|
57
|
+
)
|
|
58
|
+
from mplang.v1.core.interp import InterpContext, InterpVar
|
|
59
|
+
from mplang.v1.core.mask import Mask
|
|
60
|
+
from mplang.v1.core.mpir import IrReader, IrWriter
|
|
61
|
+
from mplang.v1.core.mpobject import MPContext, MPObject
|
|
62
|
+
from mplang.v1.core.mptype import MPType, Rank, Shape
|
|
63
|
+
from mplang.v1.core.pfunc import PFunction, get_fn_name
|
|
64
|
+
|
|
65
|
+
# Import primitive-dependent symbols at the end to avoid cycles when frontend ops
|
|
66
|
+
# import from the core facade during package initialization.
|
|
67
|
+
from mplang.v1.core.primitive import (
|
|
68
|
+
builtin_function,
|
|
69
|
+
function,
|
|
70
|
+
pconv,
|
|
71
|
+
peval,
|
|
72
|
+
pmask,
|
|
73
|
+
pshfl,
|
|
74
|
+
pshfl_s,
|
|
75
|
+
psize,
|
|
76
|
+
uniform_cond,
|
|
77
|
+
while_loop,
|
|
78
|
+
)
|
|
79
|
+
from mplang.v1.core.table import TableLike, TableType
|
|
80
|
+
from mplang.v1.core.tensor import ScalarType, TensorLike, TensorType
|
|
81
|
+
from mplang.v1.core.tracer import (
|
|
82
|
+
TraceContext,
|
|
83
|
+
TracedFunction,
|
|
84
|
+
TraceVar,
|
|
85
|
+
VarNamer,
|
|
86
|
+
trace,
|
|
87
|
+
)
|
|
88
|
+
|
|
89
|
+
__all__ = [
|
|
90
|
+
"BINARY",
|
|
91
|
+
"BOOL",
|
|
92
|
+
"COMPLEX64",
|
|
93
|
+
"COMPLEX128",
|
|
94
|
+
"DATE",
|
|
95
|
+
"DECIMAL",
|
|
96
|
+
"FLOAT16",
|
|
97
|
+
"FLOAT32",
|
|
98
|
+
"FLOAT64",
|
|
99
|
+
"INT8",
|
|
100
|
+
"INT16",
|
|
101
|
+
"INT32",
|
|
102
|
+
"INT64",
|
|
103
|
+
"INTERVAL",
|
|
104
|
+
"JSON",
|
|
105
|
+
"STRING",
|
|
106
|
+
"TIME",
|
|
107
|
+
"TIMESTAMP",
|
|
108
|
+
"UINT8",
|
|
109
|
+
"UINT16",
|
|
110
|
+
"UINT32",
|
|
111
|
+
"UINT64",
|
|
112
|
+
"UUID",
|
|
113
|
+
"ClusterSpec",
|
|
114
|
+
"CollectiveMixin",
|
|
115
|
+
"CommunicatorBase",
|
|
116
|
+
"DType",
|
|
117
|
+
"Device",
|
|
118
|
+
"ICollective",
|
|
119
|
+
"ICommunicator",
|
|
120
|
+
"InterpContext",
|
|
121
|
+
"InterpVar",
|
|
122
|
+
"IrReader",
|
|
123
|
+
"IrWriter",
|
|
124
|
+
"MPContext",
|
|
125
|
+
"MPObject",
|
|
126
|
+
"MPType",
|
|
127
|
+
"Mask",
|
|
128
|
+
"Node",
|
|
129
|
+
"PFunction",
|
|
130
|
+
"Rank",
|
|
131
|
+
"RuntimeInfo",
|
|
132
|
+
"ScalarType",
|
|
133
|
+
"Shape",
|
|
134
|
+
"TableLike",
|
|
135
|
+
"TableType",
|
|
136
|
+
"TensorLike",
|
|
137
|
+
"TensorType",
|
|
138
|
+
"TraceContext",
|
|
139
|
+
"TraceVar",
|
|
140
|
+
"TracedFunction",
|
|
141
|
+
"VarNamer",
|
|
142
|
+
"builtin_function",
|
|
143
|
+
"cur_ctx",
|
|
144
|
+
"function",
|
|
145
|
+
"get_fn_name",
|
|
146
|
+
"pconv",
|
|
147
|
+
"peval",
|
|
148
|
+
"pmask",
|
|
149
|
+
"pshfl",
|
|
150
|
+
"pshfl_s",
|
|
151
|
+
"psize",
|
|
152
|
+
"set_ctx",
|
|
153
|
+
"trace",
|
|
154
|
+
"uniform_cond",
|
|
155
|
+
"while_loop",
|
|
156
|
+
"with_ctx",
|
|
157
|
+
]
|
|
@@ -20,6 +20,7 @@ MPLang cluster configuration.
|
|
|
20
20
|
from __future__ import annotations
|
|
21
21
|
|
|
22
22
|
from dataclasses import dataclass, field
|
|
23
|
+
from functools import cached_property
|
|
23
24
|
from typing import Any
|
|
24
25
|
|
|
25
26
|
|
|
@@ -132,11 +133,11 @@ class ClusterSpec:
|
|
|
132
133
|
"which is not defined in nodes"
|
|
133
134
|
)
|
|
134
135
|
|
|
135
|
-
# ensure
|
|
136
|
+
# ensure ppu devices have exactly one member
|
|
136
137
|
for device in self.devices.values():
|
|
137
|
-
if device.kind.lower() == "
|
|
138
|
+
if device.kind.lower() == "ppu" and len(device.members) != 1:
|
|
138
139
|
raise ValueError(
|
|
139
|
-
f"
|
|
140
|
+
f"PPU device '{device.name}' must have exactly one member"
|
|
140
141
|
)
|
|
141
142
|
|
|
142
143
|
def get_node(self, name: str) -> Node:
|
|
@@ -169,6 +170,16 @@ class ClusterSpec:
|
|
|
169
170
|
},
|
|
170
171
|
}
|
|
171
172
|
|
|
173
|
+
@cached_property
|
|
174
|
+
def endpoints(self) -> list[str]:
|
|
175
|
+
eps: list[str] = []
|
|
176
|
+
for n in sorted(
|
|
177
|
+
self.nodes.values(),
|
|
178
|
+
key=lambda x: x.rank, # type: ignore[attr-defined]
|
|
179
|
+
):
|
|
180
|
+
eps.append(n.endpoint)
|
|
181
|
+
return eps
|
|
182
|
+
|
|
172
183
|
@classmethod
|
|
173
184
|
def from_dict(cls, config: dict[str, Any]) -> ClusterSpec:
|
|
174
185
|
"""Parses a raw config dictionary and returns a validated ClusterSpec."""
|
|
@@ -237,12 +248,13 @@ class ClusterSpec:
|
|
|
237
248
|
world_size: int,
|
|
238
249
|
*,
|
|
239
250
|
endpoints: list[str] | None = None,
|
|
251
|
+
spu_world_size: int | None = None,
|
|
240
252
|
spu_protocol: str = "SEMI2K",
|
|
241
253
|
spu_field: str = "FM128",
|
|
242
254
|
runtime_version: str = "simulated",
|
|
243
255
|
runtime_platform: str = "simulated",
|
|
244
256
|
op_bindings: list[dict[str, str]] | None = None,
|
|
245
|
-
|
|
257
|
+
enable_ppu_device: bool = True,
|
|
246
258
|
enable_spu_device: bool = True,
|
|
247
259
|
) -> ClusterSpec:
|
|
248
260
|
"""Convenience constructor used heavily in tests.
|
|
@@ -263,8 +275,8 @@ class ClusterSpec:
|
|
|
263
275
|
op_bindings:
|
|
264
276
|
Optional list of length ``world_size`` supplying per-node op_bindings
|
|
265
277
|
override dicts (defaults to empty dicts).
|
|
266
|
-
|
|
267
|
-
If True (default), create one ``
|
|
278
|
+
enable_ppu_device:
|
|
279
|
+
If True (default), create one ``P{rank}`` PPU device per node.
|
|
268
280
|
enable_spu_device:
|
|
269
281
|
If True (default) create a shared SPU device named ``SP0``.
|
|
270
282
|
"""
|
|
@@ -282,9 +294,9 @@ class ClusterSpec:
|
|
|
282
294
|
f"{len(op_bindings)} != {world_size}"
|
|
283
295
|
)
|
|
284
296
|
|
|
285
|
-
if not
|
|
297
|
+
if not enable_ppu_device and not enable_spu_device:
|
|
286
298
|
raise ValueError(
|
|
287
|
-
"At least one of
|
|
299
|
+
"At least one of enable_ppu_device or enable_spu_device must be True"
|
|
288
300
|
)
|
|
289
301
|
|
|
290
302
|
nodes: dict[str, Node] = {}
|
|
@@ -303,21 +315,25 @@ class ClusterSpec:
|
|
|
303
315
|
)
|
|
304
316
|
|
|
305
317
|
devices: dict[str, Device] = {}
|
|
306
|
-
# Optional per-node
|
|
307
|
-
if
|
|
318
|
+
# Optional per-node PPU devices
|
|
319
|
+
if enable_ppu_device:
|
|
308
320
|
for i in range(world_size):
|
|
309
|
-
devices[f"
|
|
310
|
-
name=f"
|
|
311
|
-
kind="
|
|
321
|
+
devices[f"P{i}"] = Device(
|
|
322
|
+
name=f"P{i}",
|
|
323
|
+
kind="ppu",
|
|
312
324
|
members=[nodes[f"node{i}"]],
|
|
313
325
|
)
|
|
314
326
|
|
|
315
327
|
# Shared SPU device
|
|
316
328
|
if enable_spu_device:
|
|
329
|
+
if spu_world_size is None:
|
|
330
|
+
spu_world_size = world_size
|
|
331
|
+
spu_members = [nodes[f"node{i}"] for i in range(spu_world_size)]
|
|
332
|
+
|
|
317
333
|
devices["SP0"] = Device(
|
|
318
334
|
name="SP0",
|
|
319
335
|
kind="SPU",
|
|
320
|
-
members=
|
|
336
|
+
members=spu_members,
|
|
321
337
|
config={
|
|
322
338
|
"protocol": spu_protocol,
|
|
323
339
|
"field": spu_field,
|
mplang/{core → v1/core}/comm.py
RENAMED
|
@@ -19,7 +19,7 @@ import threading
|
|
|
19
19
|
from abc import ABC, abstractmethod
|
|
20
20
|
from typing import Any
|
|
21
21
|
|
|
22
|
-
from mplang.core.mask import Mask
|
|
22
|
+
from mplang.v1.core.mask import Mask
|
|
23
23
|
|
|
24
24
|
|
|
25
25
|
class ICommunicator(ABC):
|
|
@@ -48,6 +48,10 @@ class ICommunicator(ABC):
|
|
|
48
48
|
def recv(self, frm: int, key: str) -> Any:
|
|
49
49
|
"""Receive data from peer with the given key"""
|
|
50
50
|
|
|
51
|
+
@abstractmethod
|
|
52
|
+
def onSent(self, frm: int, key: str, data: Any) -> None:
|
|
53
|
+
"""Called when a key is sent to self"""
|
|
54
|
+
|
|
51
55
|
|
|
52
56
|
class ICollective(ABC):
|
|
53
57
|
"""Interface for collective communication"""
|
|
@@ -20,7 +20,7 @@ from typing import TYPE_CHECKING
|
|
|
20
20
|
|
|
21
21
|
if TYPE_CHECKING:
|
|
22
22
|
# Imported only for typing to avoid import cycles at runtime.
|
|
23
|
-
from mplang.core.mpobject import MPContext
|
|
23
|
+
from mplang.v1.core.mpobject import MPContext
|
|
24
24
|
|
|
25
25
|
# The global working context.
|
|
26
26
|
_g_ctx: MPContext | None = None
|
|
@@ -21,7 +21,8 @@ import numpy as np
|
|
|
21
21
|
|
|
22
22
|
try:
|
|
23
23
|
# Check if JAX is available
|
|
24
|
-
import jax
|
|
24
|
+
import jax
|
|
25
|
+
import jax.numpy as jnp
|
|
25
26
|
|
|
26
27
|
_JAX_AVAILABLE = True
|
|
27
28
|
except ImportError:
|
|
@@ -140,6 +141,10 @@ class DType:
|
|
|
140
141
|
"""Convert from JAX dtype to custom DType."""
|
|
141
142
|
if not _JAX_AVAILABLE:
|
|
142
143
|
raise ImportError("JAX is not available")
|
|
144
|
+
# Special handling for PRNG KeyTy: <class jax._src.prng.KeyTy>
|
|
145
|
+
if jnp.issubdtype(jax_dtype, jax.dtypes.prng_key):
|
|
146
|
+
return cls.from_numpy(np.uint32)
|
|
147
|
+
|
|
143
148
|
# JAX dtypes are essentially NumPy dtypes
|
|
144
149
|
return cls.from_numpy(jax_dtype)
|
|
145
150
|
|
|
@@ -172,6 +177,13 @@ class DType:
|
|
|
172
177
|
# TypeError if it's not a pandas dtype we can handle
|
|
173
178
|
pass
|
|
174
179
|
|
|
180
|
+
try:
|
|
181
|
+
return cls._from_arrow_dtype(dtype_like)
|
|
182
|
+
except (ImportError, TypeError):
|
|
183
|
+
# ImportError if pyarrow is not installed
|
|
184
|
+
# TypeError if it's not a pyarrow dtype we can handle
|
|
185
|
+
pass
|
|
186
|
+
|
|
175
187
|
if isinstance(dtype_like, type) and dtype_like in (bool, int, float, complex):
|
|
176
188
|
return cls.from_python_type(dtype_like)
|
|
177
189
|
elif hasattr(dtype_like, "dtype") and not isinstance(dtype_like, type):
|
|
@@ -220,6 +232,37 @@ class DType:
|
|
|
220
232
|
|
|
221
233
|
raise TypeError(f"Unsupported pandas dtype: {dtype_like}")
|
|
222
234
|
|
|
235
|
+
@classmethod
|
|
236
|
+
def _from_arrow_dtype(cls, dtype_like: Any) -> DType:
|
|
237
|
+
try:
|
|
238
|
+
import pyarrow as pa
|
|
239
|
+
except ImportError:
|
|
240
|
+
raise ImportError("pyarrow not available") from None
|
|
241
|
+
|
|
242
|
+
if not isinstance(dtype_like, pa.DataType):
|
|
243
|
+
raise TypeError("Not a pyarrow dtype")
|
|
244
|
+
|
|
245
|
+
ARROW_DTYPE_MAPPING = {
|
|
246
|
+
pa.bool_(): BOOL,
|
|
247
|
+
pa.int8(): INT8,
|
|
248
|
+
pa.int16(): INT16,
|
|
249
|
+
pa.int32(): INT32,
|
|
250
|
+
pa.int64(): INT64,
|
|
251
|
+
pa.uint8(): UINT8,
|
|
252
|
+
pa.uint16(): UINT16,
|
|
253
|
+
pa.uint32(): UINT32,
|
|
254
|
+
pa.uint64(): UINT64,
|
|
255
|
+
pa.float16(): FLOAT16,
|
|
256
|
+
pa.float32(): FLOAT32,
|
|
257
|
+
pa.float64(): FLOAT64,
|
|
258
|
+
pa.string(): STRING,
|
|
259
|
+
pa.large_string(): STRING,
|
|
260
|
+
}
|
|
261
|
+
result = ARROW_DTYPE_MAPPING.get(dtype_like)
|
|
262
|
+
if result is not None:
|
|
263
|
+
return result
|
|
264
|
+
raise TypeError(f"Unsupported arrow dtype: {dtype_like}")
|
|
265
|
+
|
|
223
266
|
def to_numpy(self) -> np.dtype:
|
|
224
267
|
"""Convert custom DType to NumPy dtype."""
|
|
225
268
|
return np.dtype(self.name)
|
|
@@ -228,7 +271,6 @@ class DType:
|
|
|
228
271
|
"""Convert custom DType to JAX dtype."""
|
|
229
272
|
if not _JAX_AVAILABLE:
|
|
230
273
|
raise ImportError("JAX is not available")
|
|
231
|
-
import jax.numpy as jnp
|
|
232
274
|
|
|
233
275
|
return jnp.dtype(self.name)
|
|
234
276
|
|
|
@@ -20,7 +20,7 @@ multi-party computation graphs using the visitor pattern.
|
|
|
20
20
|
"""
|
|
21
21
|
|
|
22
22
|
# Core expression types
|
|
23
|
-
from mplang.core.expr.ast import (
|
|
23
|
+
from mplang.v1.core.expr.ast import (
|
|
24
24
|
AccessExpr,
|
|
25
25
|
CallExpr,
|
|
26
26
|
CondExpr,
|
|
@@ -36,12 +36,12 @@ from mplang.core.expr.ast import (
|
|
|
36
36
|
)
|
|
37
37
|
|
|
38
38
|
# Built-in evaluator engines
|
|
39
|
-
from mplang.core.expr.evaluator import IEvaluator, create_evaluator
|
|
40
|
-
from mplang.core.expr.printer import Printer
|
|
41
|
-
from mplang.core.expr.transformer import ExprTransformer
|
|
39
|
+
from mplang.v1.core.expr.evaluator import IEvaluator, create_evaluator
|
|
40
|
+
from mplang.v1.core.expr.printer import Printer
|
|
41
|
+
from mplang.v1.core.expr.transformer import ExprTransformer
|
|
42
42
|
|
|
43
43
|
# Utility functions
|
|
44
|
-
from mplang.core.expr.utils import (
|
|
44
|
+
from mplang.v1.core.expr.utils import (
|
|
45
45
|
deduce_mask,
|
|
46
46
|
ensure_scalar,
|
|
47
47
|
ensure_tensorlist_equal,
|
|
@@ -49,8 +49,8 @@ from mplang.core.expr.utils import (
|
|
|
49
49
|
)
|
|
50
50
|
|
|
51
51
|
# Visitor pattern interface
|
|
52
|
-
from mplang.core.expr.visitor import ExprVisitor
|
|
53
|
-
from mplang.core.expr.walk import walk, walk_dataflow, walk_structural
|
|
52
|
+
from mplang.v1.core.expr.visitor import ExprVisitor
|
|
53
|
+
from mplang.v1.core.expr.walk import walk, walk_dataflow, walk_structural
|
|
54
54
|
|
|
55
55
|
__all__ = [
|
|
56
56
|
"AccessExpr",
|
|
@@ -26,15 +26,15 @@ import logging
|
|
|
26
26
|
from abc import ABC, abstractmethod
|
|
27
27
|
from typing import TYPE_CHECKING, Any
|
|
28
28
|
|
|
29
|
-
from mplang.core.expr.utils import deduce_mask
|
|
30
|
-
from mplang.core.mask import Mask
|
|
31
|
-
from mplang.core.mptype import MPType, Rank
|
|
32
|
-
from mplang.core.pfunc import PFunction
|
|
33
|
-
from mplang.core.table import TableType
|
|
34
|
-
from mplang.core.tensor import TensorType
|
|
29
|
+
from mplang.v1.core.expr.utils import deduce_mask
|
|
30
|
+
from mplang.v1.core.mask import Mask
|
|
31
|
+
from mplang.v1.core.mptype import MPType, Rank
|
|
32
|
+
from mplang.v1.core.pfunc import PFunction
|
|
33
|
+
from mplang.v1.core.table import TableType
|
|
34
|
+
from mplang.v1.core.tensor import TensorType
|
|
35
35
|
|
|
36
36
|
if TYPE_CHECKING:
|
|
37
|
-
from mplang.core.expr.visitor import ExprVisitor
|
|
37
|
+
from mplang.v1.core.expr.visitor import ExprVisitor
|
|
38
38
|
|
|
39
39
|
|
|
40
40
|
class Expr(ABC):
|
|
@@ -286,8 +286,8 @@ class ConvExpr(Expr):
|
|
|
286
286
|
# Validate dtype / shape consistency.
|
|
287
287
|
first = types[0]
|
|
288
288
|
for c in types[1:]:
|
|
289
|
-
if
|
|
290
|
-
raise TypeError(f"Inconsistent
|
|
289
|
+
if c.raw_type() != first.raw_type():
|
|
290
|
+
raise TypeError(f"Inconsistent type in pconv: {c} vs {first}")
|
|
291
291
|
|
|
292
292
|
# Deduce the pmask by intersecting all pmasks.
|
|
293
293
|
pmasks = [t.pmask for t in types]
|
|
@@ -316,7 +316,7 @@ class ConvExpr(Expr):
|
|
|
316
316
|
else:
|
|
317
317
|
out_pmask = None
|
|
318
318
|
|
|
319
|
-
return [MPType
|
|
319
|
+
return [MPType(first.raw_type(), out_pmask, first.attrs)]
|
|
320
320
|
|
|
321
321
|
def accept(self, visitor: ExprVisitor) -> Any:
|
|
322
322
|
return visitor.visit_conv(self)
|
|
@@ -398,9 +398,7 @@ class ShflSExpr(Expr):
|
|
|
398
398
|
def _compute_mptypes(self) -> list[MPType]:
|
|
399
399
|
# The types are the same as the source value, but with a new pmask.
|
|
400
400
|
src_type = self.src_val.mptype
|
|
401
|
-
return [
|
|
402
|
-
MPType.tensor(src_type.dtype, src_type.shape, self.pmask, **src_type.attrs)
|
|
403
|
-
]
|
|
401
|
+
return [MPType(src_type._type, self.pmask, src_type.attrs)]
|
|
404
402
|
|
|
405
403
|
def accept(self, visitor: ExprVisitor) -> Any:
|
|
406
404
|
return visitor.visit_shfl_s(self)
|
|
@@ -528,8 +526,9 @@ class FuncDefExpr(Expr):
|
|
|
528
526
|
class CallExpr(Expr):
|
|
529
527
|
"""Expression for function call."""
|
|
530
528
|
|
|
531
|
-
def __init__(self, fn: FuncDefExpr, args: list[Expr]):
|
|
529
|
+
def __init__(self, name: str, fn: FuncDefExpr, args: list[Expr]):
|
|
532
530
|
super().__init__()
|
|
531
|
+
self.name = name
|
|
533
532
|
self.fn = fn
|
|
534
533
|
self.args = args
|
|
535
534
|
|
|
@@ -27,8 +27,8 @@ from __future__ import annotations
|
|
|
27
27
|
from dataclasses import dataclass
|
|
28
28
|
from typing import Any, Protocol
|
|
29
29
|
|
|
30
|
-
from mplang.core.comm import ICommunicator
|
|
31
|
-
from mplang.core.expr.ast import (
|
|
30
|
+
from mplang.v1.core.comm import ICommunicator
|
|
31
|
+
from mplang.v1.core.expr.ast import (
|
|
32
32
|
AccessExpr,
|
|
33
33
|
CallExpr,
|
|
34
34
|
CondExpr,
|
|
@@ -42,11 +42,12 @@ from mplang.core.expr.ast import (
|
|
|
42
42
|
VariableExpr,
|
|
43
43
|
WhileExpr,
|
|
44
44
|
)
|
|
45
|
-
from mplang.core.expr.visitor import ExprVisitor
|
|
46
|
-
from mplang.core.expr.walk import walk_dataflow
|
|
47
|
-
from mplang.core.mask import Mask
|
|
48
|
-
from mplang.core.pfunc import PFunction
|
|
49
|
-
from mplang.kernels.context import RuntimeContext
|
|
45
|
+
from mplang.v1.core.expr.visitor import ExprVisitor
|
|
46
|
+
from mplang.v1.core.expr.walk import walk_dataflow
|
|
47
|
+
from mplang.v1.core.mask import Mask
|
|
48
|
+
from mplang.v1.core.pfunc import PFunction
|
|
49
|
+
from mplang.v1.kernels.context import RuntimeContext
|
|
50
|
+
from mplang.v1.kernels.value import Value
|
|
50
51
|
|
|
51
52
|
|
|
52
53
|
class IEvaluator(Protocol):
|
|
@@ -149,12 +150,12 @@ class EvalSemantic:
|
|
|
149
150
|
def _as_optional_int(val: Any) -> int | None:
|
|
150
151
|
"""Convert a value to int if possible, preserving None.
|
|
151
152
|
|
|
152
|
-
Handles Python ints, numpy
|
|
153
|
+
Handles Python ints, floats, numpy scalar types (e.g., np.int32, np.float64), and None.
|
|
154
|
+
Uses int(val) for conversion which works with numpy scalars via __int__().
|
|
153
155
|
"""
|
|
156
|
+
val = EvalSemantic._unwrap_value(val)
|
|
154
157
|
if val is None:
|
|
155
158
|
return None
|
|
156
|
-
if hasattr(val, "item"):
|
|
157
|
-
return int(val.item())
|
|
158
159
|
return int(val)
|
|
159
160
|
|
|
160
161
|
def _simple_allgather(self, value: Any) -> list[Any]:
|
|
@@ -167,6 +168,7 @@ class EvalSemantic:
|
|
|
167
168
|
Returns a list of length world_size with entries ordered by rank.
|
|
168
169
|
"""
|
|
169
170
|
ws = self.comm.world_size
|
|
171
|
+
value = self._unwrap_value(value)
|
|
170
172
|
# Trivial fast-path
|
|
171
173
|
if ws == 1:
|
|
172
174
|
return [value]
|
|
@@ -185,7 +187,12 @@ class EvalSemantic:
|
|
|
185
187
|
|
|
186
188
|
def _verify_uniform_predicate(self, pred: Any) -> None:
|
|
187
189
|
# Runtime uniformity check (O(P^2) send/recv emulation).
|
|
188
|
-
|
|
190
|
+
# Use Value.to_bool() if available, otherwise unwrap and convert
|
|
191
|
+
if isinstance(pred, Value):
|
|
192
|
+
pred_bool = pred.to_bool()
|
|
193
|
+
else:
|
|
194
|
+
pred_bool = bool(self._unwrap_value(pred))
|
|
195
|
+
vals = self._simple_allgather(pred_bool)
|
|
189
196
|
if not vals:
|
|
190
197
|
raise ValueError("uniform_cond: empty gather for predicate")
|
|
191
198
|
first = vals[0]
|
|
@@ -209,13 +216,36 @@ class EvalSemantic:
|
|
|
209
216
|
assert len(cond_result) == 1, (
|
|
210
217
|
f"Condition function must return a single value, got {cond_result}"
|
|
211
218
|
)
|
|
212
|
-
|
|
213
|
-
if
|
|
219
|
+
cond_val = cond_result[0]
|
|
220
|
+
if cond_val is None:
|
|
214
221
|
raise RuntimeError(
|
|
215
222
|
"while_loop condition produced None on rank "
|
|
216
223
|
f"{self.rank}; ensure the predicate yields a boolean for every party."
|
|
217
224
|
)
|
|
218
|
-
|
|
225
|
+
# Use Value.to_bool() if available for cleaner conversion
|
|
226
|
+
if isinstance(cond_val, Value):
|
|
227
|
+
return cond_val.to_bool()
|
|
228
|
+
return bool(self._unwrap_value(cond_val))
|
|
229
|
+
|
|
230
|
+
@staticmethod
|
|
231
|
+
def _unwrap_value(value: Any) -> Any:
|
|
232
|
+
"""Convert Value payloads to numpy/python equivalents when possible."""
|
|
233
|
+
if value is None:
|
|
234
|
+
return None
|
|
235
|
+
|
|
236
|
+
if isinstance(value, Value):
|
|
237
|
+
# Try to_numpy first for broader compatibility
|
|
238
|
+
to_numpy = getattr(value, "to_numpy", None)
|
|
239
|
+
if callable(to_numpy):
|
|
240
|
+
arr = to_numpy()
|
|
241
|
+
import numpy as np
|
|
242
|
+
|
|
243
|
+
if isinstance(arr, np.ndarray):
|
|
244
|
+
if arr.size == 1:
|
|
245
|
+
return arr.item()
|
|
246
|
+
return arr
|
|
247
|
+
return arr
|
|
248
|
+
return value
|
|
219
249
|
|
|
220
250
|
|
|
221
251
|
class RecursiveEvaluator(EvalSemantic, ExprVisitor):
|
|
@@ -296,19 +326,25 @@ class RecursiveEvaluator(EvalSemantic, ExprVisitor):
|
|
|
296
326
|
* Add optional static uniform inference (data provenance) to elide the
|
|
297
327
|
runtime check when predicate uniformity is provable at trace time.
|
|
298
328
|
"""
|
|
299
|
-
|
|
300
|
-
if
|
|
329
|
+
pred_val = self._value(expr.pred)
|
|
330
|
+
if pred_val is None:
|
|
301
331
|
return [None] * len(expr.mptypes)
|
|
302
332
|
|
|
303
333
|
if expr.verify_uniform:
|
|
304
|
-
self._verify_uniform_predicate(
|
|
334
|
+
self._verify_uniform_predicate(pred_val)
|
|
335
|
+
|
|
336
|
+
# Convert to bool using Value.to_bool() if available
|
|
337
|
+
if isinstance(pred_val, Value):
|
|
338
|
+
pred = pred_val.to_bool()
|
|
339
|
+
else:
|
|
340
|
+
pred = bool(self._unwrap_value(pred_val))
|
|
305
341
|
|
|
306
342
|
# Only evaluate selected branch locally
|
|
307
|
-
if pred:
|
|
308
|
-
then_call = CallExpr(expr.then_fn, expr.args)
|
|
343
|
+
if bool(pred):
|
|
344
|
+
then_call = CallExpr("then", expr.then_fn, expr.args)
|
|
309
345
|
return self._values(then_call)
|
|
310
346
|
else:
|
|
311
|
-
else_call = CallExpr(expr.else_fn, expr.args)
|
|
347
|
+
else_call = CallExpr("else", expr.else_fn, expr.args)
|
|
312
348
|
return self._values(else_call)
|
|
313
349
|
|
|
314
350
|
def visit_call(self, expr: CallExpr) -> Any:
|
|
@@ -435,15 +471,20 @@ class IterativeEvaluator(EvalSemantic):
|
|
|
435
471
|
res = self._iter_eval_graph(node.fn.body, {**env, **sub_env})
|
|
436
472
|
symbols[id(node)] = res
|
|
437
473
|
elif isinstance(node, CondExpr):
|
|
438
|
-
|
|
474
|
+
pred_val = self._first(symbols[id(node.pred)])
|
|
439
475
|
arg_vals = [self._first(symbols[id(a)]) for a in node.args]
|
|
440
|
-
if
|
|
476
|
+
if pred_val is None:
|
|
441
477
|
symbols[id(node)] = [None] * len(node.mptypes)
|
|
442
478
|
else:
|
|
443
479
|
# Optional uniform verification identical to recursive evaluator (DRY helper).
|
|
444
480
|
if node.verify_uniform:
|
|
445
|
-
self._verify_uniform_predicate(
|
|
446
|
-
|
|
481
|
+
self._verify_uniform_predicate(pred_val)
|
|
482
|
+
# Convert to bool using Value.to_bool() if available
|
|
483
|
+
if isinstance(pred_val, Value):
|
|
484
|
+
pred = pred_val.to_bool()
|
|
485
|
+
else:
|
|
486
|
+
pred = bool(self._unwrap_value(pred_val))
|
|
487
|
+
if pred:
|
|
447
488
|
sub_env = dict(zip(node.then_fn.params, arg_vals, strict=True))
|
|
448
489
|
res = self._iter_eval_graph(
|
|
449
490
|
node.then_fn.body, {**env, **sub_env}
|