mplang-nightly 0.1.dev158__py3-none-any.whl → 0.1.dev268__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mplang/__init__.py +21 -45
- mplang/py.typed +13 -0
- mplang/v1/__init__.py +157 -0
- mplang/v1/_device.py +602 -0
- mplang/{analysis → v1/analysis}/__init__.py +1 -1
- mplang/{analysis → v1/analysis}/diagram.py +5 -7
- mplang/v1/core/__init__.py +157 -0
- mplang/{core → v1/core}/cluster.py +30 -14
- mplang/{core → v1/core}/comm.py +5 -1
- mplang/{core → v1/core}/context_mgr.py +1 -1
- mplang/{core/dtype.py → v1/core/dtypes.py} +44 -2
- mplang/{core → v1/core}/expr/__init__.py +7 -7
- mplang/{core → v1/core}/expr/ast.py +13 -14
- mplang/{core → v1/core}/expr/evaluator.py +65 -24
- mplang/{core → v1/core}/expr/printer.py +24 -18
- mplang/{core → v1/core}/expr/transformer.py +3 -3
- mplang/{core → v1/core}/expr/utils.py +2 -2
- mplang/{core → v1/core}/expr/visitor.py +1 -1
- mplang/{core → v1/core}/expr/walk.py +1 -1
- mplang/{core → v1/core}/interp.py +6 -6
- mplang/{core → v1/core}/mpir.py +23 -16
- mplang/{core → v1/core}/mpobject.py +6 -6
- mplang/{core → v1/core}/mptype.py +13 -10
- mplang/{core → v1/core}/pfunc.py +4 -4
- mplang/{core → v1/core}/primitive.py +106 -201
- mplang/{core → v1/core}/table.py +36 -8
- mplang/{core → v1/core}/tensor.py +1 -1
- mplang/{core → v1/core}/tracer.py +9 -9
- mplang/{api.py → v1/host.py} +38 -6
- mplang/v1/kernels/__init__.py +41 -0
- mplang/{kernels → v1/kernels}/base.py +1 -1
- mplang/v1/kernels/basic.py +240 -0
- mplang/{kernels → v1/kernels}/context.py +42 -27
- mplang/{kernels → v1/kernels}/crypto.py +44 -37
- mplang/v1/kernels/fhe.py +858 -0
- mplang/{kernels → v1/kernels}/mock_tee.py +12 -13
- mplang/{kernels → v1/kernels}/phe.py +263 -57
- mplang/{kernels → v1/kernels}/spu.py +137 -48
- mplang/{kernels → v1/kernels}/sql_duckdb.py +12 -15
- mplang/{kernels → v1/kernels}/stablehlo.py +30 -23
- mplang/v1/kernels/value.py +626 -0
- mplang/{ops → v1/ops}/__init__.py +5 -16
- mplang/{ops → v1/ops}/base.py +2 -5
- mplang/{ops/builtin.py → v1/ops/basic.py} +34 -26
- mplang/v1/ops/crypto.py +262 -0
- mplang/v1/ops/fhe.py +272 -0
- mplang/{ops → v1/ops}/jax_cc.py +33 -68
- mplang/v1/ops/nnx_cc.py +168 -0
- mplang/{ops → v1/ops}/phe.py +16 -4
- mplang/{ops → v1/ops}/spu.py +3 -5
- mplang/v1/ops/sql_cc.py +303 -0
- mplang/{ops → v1/ops}/tee.py +9 -24
- mplang/{protos → v1/protos}/v1alpha1/mpir_pb2.pyi +71 -21
- mplang/v1/protos/v1alpha1/value_pb2.py +34 -0
- mplang/v1/protos/v1alpha1/value_pb2.pyi +169 -0
- mplang/{runtime → v1/runtime}/__init__.py +2 -2
- mplang/v1/runtime/channel.py +230 -0
- mplang/{runtime → v1/runtime}/cli.py +35 -20
- mplang/{runtime → v1/runtime}/client.py +19 -8
- mplang/{runtime → v1/runtime}/communicator.py +59 -15
- mplang/{runtime → v1/runtime}/data_providers.py +80 -19
- mplang/{runtime → v1/runtime}/driver.py +30 -12
- mplang/v1/runtime/link_comm.py +196 -0
- mplang/{runtime → v1/runtime}/server.py +58 -42
- mplang/{runtime → v1/runtime}/session.py +57 -71
- mplang/{runtime → v1/runtime}/simulation.py +55 -28
- mplang/v1/simp/api.py +353 -0
- mplang/{simp → v1/simp}/mpi.py +8 -9
- mplang/{simp/__init__.py → v1/simp/party.py} +19 -145
- mplang/{simp → v1/simp}/random.py +21 -22
- mplang/v1/simp/smpc.py +238 -0
- mplang/v1/utils/table_utils.py +185 -0
- mplang/v2/__init__.py +424 -0
- mplang/v2/backends/__init__.py +57 -0
- mplang/v2/backends/bfv_impl.py +705 -0
- mplang/v2/backends/channel.py +217 -0
- mplang/v2/backends/crypto_impl.py +723 -0
- mplang/v2/backends/field_impl.py +454 -0
- mplang/v2/backends/func_impl.py +107 -0
- mplang/v2/backends/phe_impl.py +148 -0
- mplang/v2/backends/simp_design.md +136 -0
- mplang/v2/backends/simp_driver/__init__.py +41 -0
- mplang/v2/backends/simp_driver/http.py +168 -0
- mplang/v2/backends/simp_driver/mem.py +280 -0
- mplang/v2/backends/simp_driver/ops.py +135 -0
- mplang/v2/backends/simp_driver/state.py +60 -0
- mplang/v2/backends/simp_driver/values.py +52 -0
- mplang/v2/backends/simp_worker/__init__.py +29 -0
- mplang/v2/backends/simp_worker/http.py +354 -0
- mplang/v2/backends/simp_worker/mem.py +102 -0
- mplang/v2/backends/simp_worker/ops.py +167 -0
- mplang/v2/backends/simp_worker/state.py +49 -0
- mplang/v2/backends/spu_impl.py +275 -0
- mplang/v2/backends/spu_state.py +187 -0
- mplang/v2/backends/store_impl.py +62 -0
- mplang/v2/backends/table_impl.py +838 -0
- mplang/v2/backends/tee_impl.py +215 -0
- mplang/v2/backends/tensor_impl.py +519 -0
- mplang/v2/cli.py +603 -0
- mplang/v2/cli_guide.md +122 -0
- mplang/v2/dialects/__init__.py +36 -0
- mplang/v2/dialects/bfv.py +665 -0
- mplang/v2/dialects/crypto.py +689 -0
- mplang/v2/dialects/dtypes.py +378 -0
- mplang/v2/dialects/field.py +210 -0
- mplang/v2/dialects/func.py +135 -0
- mplang/v2/dialects/phe.py +723 -0
- mplang/v2/dialects/simp.py +944 -0
- mplang/v2/dialects/spu.py +349 -0
- mplang/v2/dialects/store.py +63 -0
- mplang/v2/dialects/table.py +407 -0
- mplang/v2/dialects/tee.py +346 -0
- mplang/v2/dialects/tensor.py +1175 -0
- mplang/v2/edsl/README.md +279 -0
- mplang/v2/edsl/__init__.py +99 -0
- mplang/v2/edsl/context.py +311 -0
- mplang/v2/edsl/graph.py +463 -0
- mplang/v2/edsl/jit.py +62 -0
- mplang/v2/edsl/object.py +53 -0
- mplang/v2/edsl/primitive.py +284 -0
- mplang/v2/edsl/printer.py +119 -0
- mplang/v2/edsl/registry.py +207 -0
- mplang/v2/edsl/serde.py +375 -0
- mplang/v2/edsl/tracer.py +614 -0
- mplang/v2/edsl/typing.py +816 -0
- mplang/v2/kernels/Makefile +30 -0
- mplang/v2/kernels/__init__.py +23 -0
- mplang/v2/kernels/gf128.cpp +148 -0
- mplang/v2/kernels/ldpc.cpp +82 -0
- mplang/v2/kernels/okvs.cpp +283 -0
- mplang/v2/kernels/okvs_opt.cpp +291 -0
- mplang/v2/kernels/py_kernels.py +398 -0
- mplang/v2/libs/collective.py +330 -0
- mplang/v2/libs/device/__init__.py +51 -0
- mplang/v2/libs/device/api.py +813 -0
- mplang/v2/libs/device/cluster.py +352 -0
- mplang/v2/libs/ml/__init__.py +23 -0
- mplang/v2/libs/ml/sgb.py +1861 -0
- mplang/v2/libs/mpc/__init__.py +41 -0
- mplang/v2/libs/mpc/_utils.py +99 -0
- mplang/v2/libs/mpc/analytics/__init__.py +35 -0
- mplang/v2/libs/mpc/analytics/aggregation.py +372 -0
- mplang/v2/libs/mpc/analytics/groupby.md +99 -0
- mplang/v2/libs/mpc/analytics/groupby.py +331 -0
- mplang/v2/libs/mpc/analytics/permutation.py +386 -0
- mplang/v2/libs/mpc/common/constants.py +39 -0
- mplang/v2/libs/mpc/ot/__init__.py +32 -0
- mplang/v2/libs/mpc/ot/base.py +222 -0
- mplang/v2/libs/mpc/ot/extension.py +477 -0
- mplang/v2/libs/mpc/ot/silent.py +217 -0
- mplang/v2/libs/mpc/psi/__init__.py +40 -0
- mplang/v2/libs/mpc/psi/cuckoo.py +228 -0
- mplang/v2/libs/mpc/psi/okvs.py +49 -0
- mplang/v2/libs/mpc/psi/okvs_gct.py +79 -0
- mplang/v2/libs/mpc/psi/oprf.py +310 -0
- mplang/v2/libs/mpc/psi/rr22.py +344 -0
- mplang/v2/libs/mpc/psi/unbalanced.py +200 -0
- mplang/v2/libs/mpc/vole/__init__.py +31 -0
- mplang/v2/libs/mpc/vole/gilboa.py +327 -0
- mplang/v2/libs/mpc/vole/ldpc.py +383 -0
- mplang/v2/libs/mpc/vole/silver.py +336 -0
- mplang/v2/runtime/__init__.py +15 -0
- mplang/v2/runtime/dialect_state.py +41 -0
- mplang/v2/runtime/interpreter.py +871 -0
- mplang/v2/runtime/object_store.py +194 -0
- mplang/v2/runtime/value.py +141 -0
- {mplang_nightly-0.1.dev158.dist-info → mplang_nightly-0.1.dev268.dist-info}/METADATA +24 -17
- mplang_nightly-0.1.dev268.dist-info/RECORD +180 -0
- {mplang_nightly-0.1.dev158.dist-info → mplang_nightly-0.1.dev268.dist-info}/WHEEL +1 -1
- mplang/core/__init__.py +0 -92
- mplang/device.py +0 -340
- mplang/kernels/builtin.py +0 -207
- mplang/ops/crypto.py +0 -109
- mplang/ops/ibis_cc.py +0 -139
- mplang/ops/sql.py +0 -61
- mplang/protos/v1alpha1/mpir_pb2_grpc.py +0 -3
- mplang/runtime/link_comm.py +0 -131
- mplang/simp/smpc.py +0 -201
- mplang/utils/table_utils.py +0 -73
- mplang_nightly-0.1.dev158.dist-info/RECORD +0 -77
- /mplang/{core → v1/core}/mask.py +0 -0
- /mplang/{protos → v1/protos}/v1alpha1/mpir_pb2.py +0 -0
- /mplang/{runtime → v1/runtime}/exceptions.py +0 -0
- /mplang/{runtime → v1/runtime}/http_api.md +0 -0
- /mplang/{kernels → v1/simp}/__init__.py +0 -0
- /mplang/{utils → v1/utils}/__init__.py +0 -0
- /mplang/{utils → v1/utils}/crypto.py +0 -0
- /mplang/{utils → v1/utils}/func_utils.py +0 -0
- /mplang/{utils → v1/utils}/spu_utils.py +0 -0
- {mplang_nightly-0.1.dev158.dist-info → mplang_nightly-0.1.dev268.dist-info}/entry_points.txt +0 -0
- {mplang_nightly-0.1.dev158.dist-info → mplang_nightly-0.1.dev268.dist-info}/licenses/LICENSE +0 -0
|
@@ -28,10 +28,11 @@ from typing import Any, ParamSpec, TypeVar, cast
|
|
|
28
28
|
|
|
29
29
|
from jax.tree_util import tree_map
|
|
30
30
|
|
|
31
|
-
from mplang.core.context_mgr import cur_ctx
|
|
32
|
-
from mplang.core.
|
|
33
|
-
from mplang.core.expr.ast import (
|
|
31
|
+
from mplang.v1.core.context_mgr import cur_ctx
|
|
32
|
+
from mplang.v1.core.dtypes import BOOL
|
|
33
|
+
from mplang.v1.core.expr.ast import (
|
|
34
34
|
AccessExpr,
|
|
35
|
+
CallExpr,
|
|
35
36
|
CondExpr,
|
|
36
37
|
ConvExpr,
|
|
37
38
|
EvalExpr,
|
|
@@ -39,16 +40,13 @@ from mplang.core.expr.ast import (
|
|
|
39
40
|
ShflSExpr,
|
|
40
41
|
WhileExpr,
|
|
41
42
|
)
|
|
42
|
-
from mplang.core.interp import InterpContext, InterpVar, apply
|
|
43
|
-
from mplang.core.mask import Mask
|
|
44
|
-
from mplang.core.mpobject import MPContext, MPObject
|
|
45
|
-
from mplang.core.mptype import Rank
|
|
46
|
-
from mplang.core.pfunc import PFunction
|
|
47
|
-
from mplang.core.
|
|
48
|
-
from mplang.
|
|
49
|
-
from mplang.core.tracer import TraceContext, TraceVar, trace
|
|
50
|
-
from mplang.ops import builtin
|
|
51
|
-
from mplang.utils.func_utils import var_demorph, var_morph
|
|
43
|
+
from mplang.v1.core.interp import InterpContext, InterpVar, apply
|
|
44
|
+
from mplang.v1.core.mask import Mask
|
|
45
|
+
from mplang.v1.core.mpobject import MPContext, MPObject
|
|
46
|
+
from mplang.v1.core.mptype import Rank
|
|
47
|
+
from mplang.v1.core.pfunc import PFunction
|
|
48
|
+
from mplang.v1.core.tracer import TraceContext, TraceVar, trace
|
|
49
|
+
from mplang.v1.utils.func_utils import var_demorph, var_morph
|
|
52
50
|
|
|
53
51
|
|
|
54
52
|
def _switch_ctx(ctx: MPContext, obj: MPObject | Any) -> MPObject | Any:
|
|
@@ -87,30 +85,106 @@ P = ParamSpec("P")
|
|
|
87
85
|
R = TypeVar("R")
|
|
88
86
|
|
|
89
87
|
|
|
90
|
-
def
|
|
88
|
+
def trace_before_apply(fn: Callable[P, R], make_call: bool) -> Callable[P, R]:
|
|
91
89
|
"""A decorator to make all primitive call in trace context."""
|
|
92
90
|
|
|
93
91
|
@wraps(fn)
|
|
94
92
|
def wrapped(*args: P.args, **kwargs: P.kwargs) -> R:
|
|
95
93
|
current_ctx = cur_ctx()
|
|
96
94
|
if isinstance(current_ctx, TraceContext):
|
|
97
|
-
# If we are in a tracer context
|
|
98
|
-
|
|
99
|
-
|
|
100
|
-
|
|
95
|
+
# If we are already in a tracer context
|
|
96
|
+
if make_call:
|
|
97
|
+
# make a primitive call
|
|
98
|
+
tracer = current_ctx
|
|
99
|
+
tfn = trace(tracer.fork(), fn, *args, **kwargs)
|
|
100
|
+
is_mpobj = lambda x: isinstance(x, MPObject)
|
|
101
|
+
in_vars, in_imms, in_struct = var_morph((args, kwargs), is_mpobj)
|
|
102
|
+
assert in_struct == tfn.in_struct and in_imms == tfn.in_imms
|
|
103
|
+
arg_exprs = [arg.expr for arg in in_vars]
|
|
104
|
+
# re-capture all captured variables into current context if needed.
|
|
105
|
+
cap_exprs = [tracer.capture(var).expr for var in tfn.capture_map.keys()]
|
|
106
|
+
caller_expr = CallExpr(
|
|
107
|
+
name=fn.__name__, fn=tfn.make_expr(), args=arg_exprs + cap_exprs
|
|
108
|
+
)
|
|
109
|
+
out_vars = [
|
|
110
|
+
TraceVar(tracer, AccessExpr(caller_expr, idx))
|
|
111
|
+
for idx in range(caller_expr.num_outputs)
|
|
112
|
+
]
|
|
113
|
+
return cast(R, var_demorph(out_vars, tfn.out_imms, tfn.out_struct))
|
|
114
|
+
else:
|
|
115
|
+
# embed the function call in the current tracer context
|
|
116
|
+
# Note: switch_ctx will do the capture if needed.
|
|
117
|
+
args, kwargs = tree_map(
|
|
118
|
+
partial(_switch_ctx, current_ctx), (args, kwargs)
|
|
119
|
+
)
|
|
120
|
+
return fn(*args, **kwargs)
|
|
101
121
|
elif isinstance(current_ctx, InterpContext):
|
|
102
122
|
trace_ctx = TraceContext(current_ctx.cluster_spec, parent=current_ctx)
|
|
103
123
|
# TODO(jint): should we add trace_and_apply to improve the performance?
|
|
104
|
-
|
|
124
|
+
tfn = trace(trace_ctx, fn, *args, **kwargs)
|
|
105
125
|
# Return back to the original context.
|
|
106
|
-
return cast(R, apply(current_ctx,
|
|
126
|
+
return cast(R, apply(current_ctx, tfn, *args, **kwargs))
|
|
107
127
|
else:
|
|
108
128
|
raise ValueError(f"Unsupported context type: {type(current_ctx)}")
|
|
109
129
|
|
|
110
130
|
return wrapped
|
|
111
131
|
|
|
112
132
|
|
|
113
|
-
|
|
133
|
+
def builtin_function(fn: Callable[P, R]) -> Callable[P, R]:
|
|
134
|
+
"""Decorator to trace a Python function as an opaque primitive call (`CallExpr`).
|
|
135
|
+
|
|
136
|
+
When a function decorated with `@builtin_function` is called within a `TraceContext`, it is
|
|
137
|
+
not inlined. Instead, it is traced separately in a forked context, and a `CallExpr`
|
|
138
|
+
node is inserted into the main graph. This is useful for encapsulating complex
|
|
139
|
+
operations or third-party library calls as single, opaque nodes.
|
|
140
|
+
|
|
141
|
+
**Implementation Note:**
|
|
142
|
+
A `CallExpr` represents a call to a single inline lambda (non-recursive, as we don't
|
|
143
|
+
have Y-combinator support). This single lambda call can be treated as a "primitive call"
|
|
144
|
+
by the printer/visualizer - hence the name "primitive". The function body is captured
|
|
145
|
+
once during tracing and represented as an opaque callable unit in the expression graph,
|
|
146
|
+
maintaining a clear boundary between the caller and callee contexts.
|
|
147
|
+
|
|
148
|
+
Args:
|
|
149
|
+
fn: The function to be traced as a primitive operation.
|
|
150
|
+
|
|
151
|
+
Returns:
|
|
152
|
+
A wrapped function that creates a `CallExpr` node when called in a trace context.
|
|
153
|
+
|
|
154
|
+
Example:
|
|
155
|
+
```python
|
|
156
|
+
@builtin_function
|
|
157
|
+
def my_op(x: MPObject) -> MPObject:
|
|
158
|
+
# Complex logic traced as a single CallExpr node
|
|
159
|
+
return x + 1
|
|
160
|
+
```
|
|
161
|
+
"""
|
|
162
|
+
return trace_before_apply(fn, make_call=True)
|
|
163
|
+
|
|
164
|
+
|
|
165
|
+
def function(fn: Callable[P, R]) -> Callable[P, R]:
|
|
166
|
+
"""Decorator to trace a Python function by inlining its body.
|
|
167
|
+
|
|
168
|
+
When a function decorated with `@function` is called within a `TraceContext`, its
|
|
169
|
+
underlying primitive operations are expanded and inserted directly into the caller's
|
|
170
|
+
graph. This is the default tracing behavior and is suitable for most pure-Python
|
|
171
|
+
multi-party functions.
|
|
172
|
+
|
|
173
|
+
Args:
|
|
174
|
+
fn: The function to be traced and inlined.
|
|
175
|
+
|
|
176
|
+
Returns:
|
|
177
|
+
A wrapped function that inlines its operations into the caller's trace context.
|
|
178
|
+
|
|
179
|
+
Example:
|
|
180
|
+
```python
|
|
181
|
+
@function
|
|
182
|
+
def my_func(x: MPObject, y: MPObject) -> MPObject:
|
|
183
|
+
# Operations are inlined into the caller's trace
|
|
184
|
+
return x + y * constant(2)
|
|
185
|
+
```
|
|
186
|
+
"""
|
|
187
|
+
return trace_before_apply(fn, make_call=False)
|
|
114
188
|
|
|
115
189
|
|
|
116
190
|
# ============================================================================
|
|
@@ -126,18 +200,15 @@ def _tracer() -> TraceContext:
|
|
|
126
200
|
return ctx
|
|
127
201
|
|
|
128
202
|
|
|
129
|
-
@primitive
|
|
130
203
|
def psize() -> int:
|
|
131
204
|
"""Get the size of the current party world.
|
|
132
205
|
|
|
133
206
|
Returns:
|
|
134
207
|
int: The total number of parties in the current multi-party computation context.
|
|
135
208
|
"""
|
|
136
|
-
|
|
137
|
-
return ctx.world_size()
|
|
209
|
+
return cur_ctx().world_size()
|
|
138
210
|
|
|
139
211
|
|
|
140
|
-
@primitive
|
|
141
212
|
def pmask() -> Mask:
|
|
142
213
|
"""Get the current party mask in this computation context.
|
|
143
214
|
|
|
@@ -145,112 +216,10 @@ def pmask() -> Mask:
|
|
|
145
216
|
Mask: The current party mask indicating which parties are active
|
|
146
217
|
in the current computation context.
|
|
147
218
|
"""
|
|
148
|
-
|
|
149
|
-
return ctx.mask
|
|
150
|
-
|
|
151
|
-
|
|
152
|
-
@primitive
|
|
153
|
-
def prank() -> MPObject:
|
|
154
|
-
"""Multi-party get the rank (party identifier) of each party.
|
|
155
|
-
|
|
156
|
-
This function returns a scalar tensor containing the rank (party identifier)
|
|
157
|
-
for each party in the current party mask. Each party independently produces
|
|
158
|
-
its own rank value, which serves as a unique identifier within the multi-party
|
|
159
|
-
computation context.
|
|
160
|
-
|
|
161
|
-
The rank values range from 0 to world_size-1, where world_size is the total
|
|
162
|
-
number of parties in the computation. Each party's rank is private to that
|
|
163
|
-
party and represents its position in the multi-party protocol.
|
|
164
|
-
|
|
165
|
-
Returns:
|
|
166
|
-
MPObject: A variable representing a scalar tensor with:
|
|
167
|
-
- dtype: UINT64
|
|
168
|
-
- shape: () (scalar)
|
|
169
|
-
|
|
170
|
-
Note:
|
|
171
|
-
Each party in the current party mask independently produces its own rank value.
|
|
172
|
-
"""
|
|
173
|
-
pfunc, eval_args, out_tree = builtin.rank()
|
|
174
|
-
results = peval(pfunc, eval_args)
|
|
175
|
-
return out_tree.unflatten(results) # type: ignore[no-any-return]
|
|
176
|
-
|
|
177
|
-
|
|
178
|
-
@primitive
|
|
179
|
-
def prand(shape: Shape = ()) -> MPObject:
|
|
180
|
-
"""Multi-party generate a private random (uint64) tensor with the given shape.
|
|
181
|
-
|
|
182
|
-
This function creates a private random tensor where each party independently
|
|
183
|
-
generates its own local random values. Each party's random values are private
|
|
184
|
-
and unknown to other parties. The output tensor contains 64-bit unsigned
|
|
185
|
-
integers, with each party holding its own privately generated values.
|
|
186
|
-
|
|
187
|
-
Args:
|
|
188
|
-
shape: The shape of the random tensor to generate.
|
|
189
|
-
Must be a tuple of positive integers. Defaults to () for scalar.
|
|
190
|
-
|
|
191
|
-
Returns:
|
|
192
|
-
MPObject: A variable representing the generated private random tensor with:
|
|
193
|
-
- dtype: UINT64
|
|
194
|
-
- shape: As specified by the shape parameter
|
|
195
|
-
|
|
196
|
-
Note:
|
|
197
|
-
Each party in the current party mask independently generates its own
|
|
198
|
-
private random values. The randomness is local to each party and is
|
|
199
|
-
not shared or revealed to other parties.
|
|
200
|
-
"""
|
|
201
|
-
pfunc, eval_args, out_tree = builtin.prand(shape)
|
|
202
|
-
results = peval(pfunc, eval_args)
|
|
203
|
-
return out_tree.unflatten(results) # type: ignore[no-any-return]
|
|
204
|
-
|
|
205
|
-
|
|
206
|
-
@primitive
|
|
207
|
-
def constant(data: TensorLike | ScalarType | TableLike) -> MPObject:
|
|
208
|
-
"""Create a constant tensor or table from data.
|
|
209
|
-
|
|
210
|
-
This function creates a constant that can be used in multi-party
|
|
211
|
-
computations. The constant value is embedded directly into the computation
|
|
212
|
-
graph and is available to all parties in the current party mask.
|
|
213
|
-
|
|
214
|
-
Args:
|
|
215
|
-
data: The constant data to embed. Can be:
|
|
216
|
-
- A scalar value (int, float, bool)
|
|
217
|
-
- A numpy array or other tensor-like object
|
|
218
|
-
- A pandas DataFrame or other table-like object
|
|
219
|
-
- Any object that can be converted to tensor
|
|
220
|
-
|
|
221
|
-
Returns:
|
|
222
|
-
MPObject: A variable representing the constant tensor or table with:
|
|
223
|
-
- dtype: Inferred from the input data
|
|
224
|
-
- shape: Inferred from the input data (for tensors)
|
|
225
|
-
- schema: Inferred from the input data (for tables)
|
|
226
|
-
- data: The embedded constant values
|
|
227
|
-
|
|
228
|
-
Note:
|
|
229
|
-
The constant data is embedded at graph construction time and is available
|
|
230
|
-
to all parties during execution. Large constants may impact graph size.
|
|
231
|
-
|
|
232
|
-
For table-like objects (e.g., pandas DataFrame), JSON serialization is used.
|
|
233
|
-
Note that the constant primitive is not designed to carry large tables efficiently -
|
|
234
|
-
consider using dedicated table loading mechanisms for substantial datasets.
|
|
235
|
-
"""
|
|
236
|
-
pfunc, eval_args, out_tree = builtin.constant(data)
|
|
237
|
-
results = peval(pfunc, eval_args)
|
|
238
|
-
return out_tree.unflatten(results) # type: ignore[no-any-return]
|
|
239
|
-
|
|
240
|
-
|
|
241
|
-
@primitive
|
|
242
|
-
def debug_print(obj: MPObject, prefix: str = "") -> MPObject:
|
|
243
|
-
"""Print local value of obj on owning parties and pass it through.
|
|
244
|
-
|
|
245
|
-
Returns the same MPObject value to keep it alive against DCE and to
|
|
246
|
-
support usage like: x = debug_print(x, prefix="x=").
|
|
247
|
-
"""
|
|
248
|
-
pfunc, eval_args, out_tree = builtin.debug_print(obj, prefix=prefix)
|
|
249
|
-
results = peval(pfunc, eval_args)
|
|
250
|
-
return out_tree.unflatten(results) # type: ignore[no-any-return]
|
|
219
|
+
return _tracer().mask
|
|
251
220
|
|
|
252
221
|
|
|
253
|
-
@
|
|
222
|
+
@function
|
|
254
223
|
def peval(
|
|
255
224
|
pfunc: PFunction,
|
|
256
225
|
args: list[MPObject],
|
|
@@ -314,71 +283,7 @@ def peval(
|
|
|
314
283
|
return [TraceVar(ctx, res) for res in ret_exprs]
|
|
315
284
|
|
|
316
285
|
|
|
317
|
-
|
|
318
|
-
"""Set the mask of an MPObject to a new value.
|
|
319
|
-
|
|
320
|
-
This function allows changing the party mask of an existing MPObject variable.
|
|
321
|
-
The behavior depends on whether the input MPObject has a dynamic or static pmask:
|
|
322
|
-
|
|
323
|
-
**Case 1: Dynamic pmask (arg.pmask is None)**
|
|
324
|
-
- The input MPObject has a runtime-determined pmask
|
|
325
|
-
- The return value's pmask will be exactly the specified mask
|
|
326
|
-
- No validation is performed at compile time
|
|
327
|
-
|
|
328
|
-
**Case 2: Static pmask (arg.pmask is not None)**
|
|
329
|
-
- If mask is a subset of arg.pmask: return_var.pmask == arg.pmask (unchanged)
|
|
330
|
-
- If mask is NOT a subset of arg.pmask: raises ValueError at compile time
|
|
331
|
-
|
|
332
|
-
Args:
|
|
333
|
-
arg: The MPObject whose mask needs to be changed.
|
|
334
|
-
mask: The target mask to apply. Must be a valid party mask.
|
|
335
|
-
|
|
336
|
-
Returns:
|
|
337
|
-
MPObject: A new variable with the specified mask behavior:
|
|
338
|
-
- For dynamic inputs: pmask = mask
|
|
339
|
-
- For static inputs (valid subset): pmask = arg.pmask
|
|
340
|
-
|
|
341
|
-
Raises:
|
|
342
|
-
ValueError: When arg has a static pmask and mask is not a subset of arg.pmask.
|
|
343
|
-
This validation occurs at compile time during graph construction.
|
|
344
|
-
|
|
345
|
-
Examples:
|
|
346
|
-
**Example 1: Dynamic pmask - mask assignment**
|
|
347
|
-
P0 P1 P2
|
|
348
|
-
-- -- --
|
|
349
|
-
Input: ? ? ? (pmask=None, runtime-determined)
|
|
350
|
-
mask: [0,2] (target mask)
|
|
351
|
-
-----------------------------------------------------------
|
|
352
|
-
Output: x0 - x2 (pmask=[0,2])
|
|
353
|
-
|
|
354
|
-
**Example 2: Static pmask - valid subset**
|
|
355
|
-
P0 P1 P2
|
|
356
|
-
-- -- --
|
|
357
|
-
Input: x0 x1 x2 (pmask=[0,1,2])
|
|
358
|
-
mask: [0,2] (subset of input pmask)
|
|
359
|
-
-----------------------------------------------------------
|
|
360
|
-
Output: x0 - x2 (pmask=[0,2])
|
|
361
|
-
|
|
362
|
-
**Example 3: Static pmask - invalid subset (compile error)**
|
|
363
|
-
P0 P1 P2
|
|
364
|
-
-- -- --
|
|
365
|
-
Input: x0 - x2 (pmask=[0,2])
|
|
366
|
-
mask: [1,2] (NOT subset of [0,2])
|
|
367
|
-
-----------------------------------------------------------
|
|
368
|
-
Result: ValueError at compile time
|
|
369
|
-
|
|
370
|
-
Note:
|
|
371
|
-
This function is typically used for constraining the execution scope
|
|
372
|
-
of variables or for type casting between different pmask contexts.
|
|
373
|
-
The underlying implementation uses JAX identity function with the
|
|
374
|
-
specified execution mask.
|
|
375
|
-
"""
|
|
376
|
-
pfunc, eval_args, out_tree = builtin.identity(arg)
|
|
377
|
-
results = peval(pfunc, eval_args, mask)
|
|
378
|
-
return out_tree.unflatten(results) # type: ignore[no-any-return]
|
|
379
|
-
|
|
380
|
-
|
|
381
|
-
@primitive
|
|
286
|
+
@function
|
|
382
287
|
def uniform_cond(
|
|
383
288
|
pred: MPObject,
|
|
384
289
|
then_fn: Callable[..., Any],
|
|
@@ -393,7 +298,7 @@ def uniform_cond(
|
|
|
393
298
|
|
|
394
299
|
1. ``pred`` is a boolean scalar whose runtime value is identical for every enabled party.
|
|
395
300
|
2. At least one branch contains multi-party primitives (``seal`` / ``reveal`` /
|
|
396
|
-
``
|
|
301
|
+
``srun_jax`` / ``pshfl`` / mask transformations) whose cost or side-effects you
|
|
397
302
|
want to avoid if the branch is not taken.
|
|
398
303
|
3. You require the semantic guarantee that the *non-selected* branch does **not**
|
|
399
304
|
perform communication, allocate intermediate buffers, or leak timing/side-effects.
|
|
@@ -588,7 +493,7 @@ def uniform_cond(
|
|
|
588
493
|
return var_demorph(out_vars, then_tfn.out_imms, then_tfn.out_struct) # type: ignore[no-any-return]
|
|
589
494
|
|
|
590
495
|
|
|
591
|
-
@
|
|
496
|
+
@function
|
|
592
497
|
def while_loop(
|
|
593
498
|
cond_fn: Callable[[Any], MPObject],
|
|
594
499
|
body_fn: Callable[[Any], Any],
|
|
@@ -654,7 +559,7 @@ def while_loop(
|
|
|
654
559
|
secret-shared reduction).
|
|
655
560
|
|
|
656
561
|
cond_fn::
|
|
657
|
-
sealed_sum = smpc.reveal(smpc.
|
|
562
|
+
sealed_sum = smpc.reveal(smpc.srun_jax(lambda x: jnp.sum(x), smpc.seal(x)))
|
|
658
563
|
return sealed_sum < constant(10)
|
|
659
564
|
|
|
660
565
|
body_fn::
|
|
@@ -781,7 +686,7 @@ def while_loop(
|
|
|
781
686
|
return var_demorph(out_vars, body_tfn.out_imms, body_tfn.out_struct)
|
|
782
687
|
|
|
783
688
|
|
|
784
|
-
@
|
|
689
|
+
@function
|
|
785
690
|
def pshfl(src: MPObject, index: MPObject) -> MPObject:
|
|
786
691
|
"""Shuffle the input tensor to the specified index (dynamic version).
|
|
787
692
|
|
|
@@ -813,7 +718,7 @@ def pshfl(src: MPObject, index: MPObject) -> MPObject:
|
|
|
813
718
|
|
|
814
719
|
Raises:
|
|
815
720
|
ValueError: If the index tensor is not a scalar.
|
|
816
|
-
RuntimeError: If src[index[i]] is None for any valid index[i] (i.e
|
|
721
|
+
RuntimeError: If src[index[i]] is None for any valid index[i] (i.e.,
|
|
817
722
|
trying to fetch from a party that doesn't hold the data).
|
|
818
723
|
|
|
819
724
|
Examples:
|
|
@@ -851,7 +756,7 @@ def pshfl(src: MPObject, index: MPObject) -> MPObject:
|
|
|
851
756
|
return TraceVar(_tracer(), shfl_expr)
|
|
852
757
|
|
|
853
758
|
|
|
854
|
-
@
|
|
759
|
+
@function
|
|
855
760
|
def pshfl_s(src_val: MPObject, pmask: Mask, src_ranks: list[Rank]) -> MPObject:
|
|
856
761
|
"""Shuffle the input tensor to the specified rank, static version.
|
|
857
762
|
|
|
@@ -910,7 +815,7 @@ def pshfl_s(src_val: MPObject, pmask: Mask, src_ranks: list[Rank]) -> MPObject:
|
|
|
910
815
|
return TraceVar(_tracer(), shfl_s_expr)
|
|
911
816
|
|
|
912
817
|
|
|
913
|
-
@
|
|
818
|
+
@function
|
|
914
819
|
def pconv(vars: list[MPObject]) -> MPObject:
|
|
915
820
|
"""Combine multiple variables that share the same dtype and shape into one.
|
|
916
821
|
|
mplang/{core → v1/core}/table.py
RENAMED
|
@@ -18,16 +18,16 @@ from collections.abc import Iterator
|
|
|
18
18
|
from dataclasses import dataclass, field
|
|
19
19
|
from typing import Any, Protocol, runtime_checkable
|
|
20
20
|
|
|
21
|
-
from mplang.core.
|
|
21
|
+
from mplang.v1.core.dtypes import DType
|
|
22
22
|
|
|
23
23
|
__all__ = ["TableLike", "TableType"]
|
|
24
24
|
|
|
25
25
|
|
|
26
26
|
@runtime_checkable
|
|
27
|
-
class
|
|
27
|
+
class PandasTableLike(Protocol):
|
|
28
28
|
"""
|
|
29
29
|
Protocol for objects structurally resembling tables from common libraries
|
|
30
|
-
(pandas DataFrame,
|
|
30
|
+
(pandas DataFrame, polars DataFrame, etc.), focusing on dtypes and columns attributes.
|
|
31
31
|
"""
|
|
32
32
|
|
|
33
33
|
@property
|
|
@@ -37,6 +37,26 @@ class TableLike(Protocol):
|
|
|
37
37
|
def columns(self) -> Any: ...
|
|
38
38
|
|
|
39
39
|
|
|
40
|
+
@runtime_checkable
|
|
41
|
+
class ArrowSchema(Protocol):
|
|
42
|
+
@property
|
|
43
|
+
def names(self) -> list[str]: ...
|
|
44
|
+
@property
|
|
45
|
+
def types(self) -> list[Any]: ...
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
@runtime_checkable
|
|
49
|
+
class ArrowTableLike(Protocol):
|
|
50
|
+
@property
|
|
51
|
+
def column_names(self) -> list[str]: ...
|
|
52
|
+
|
|
53
|
+
@property
|
|
54
|
+
def schema(self) -> ArrowSchema: ...
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
TableLike = PandasTableLike | ArrowTableLike
|
|
58
|
+
|
|
59
|
+
|
|
40
60
|
@dataclass(frozen=True)
|
|
41
61
|
class TableType:
|
|
42
62
|
"""Table schema: ordered list of column name-type pairs.
|
|
@@ -109,11 +129,19 @@ class TableType:
|
|
|
109
129
|
Returns:
|
|
110
130
|
TableType instance
|
|
111
131
|
"""
|
|
112
|
-
|
|
113
|
-
|
|
114
|
-
|
|
115
|
-
|
|
116
|
-
|
|
132
|
+
if isinstance(table, PandasTableLike):
|
|
133
|
+
columns = [
|
|
134
|
+
(name, DType.from_any(dtype))
|
|
135
|
+
for name, dtype in zip(table.columns, table.dtypes, strict=True)
|
|
136
|
+
]
|
|
137
|
+
return cls(tuple(columns))
|
|
138
|
+
elif isinstance(table, ArrowTableLike):
|
|
139
|
+
schema = table.schema
|
|
140
|
+
columns = [
|
|
141
|
+
(name, DType.from_any(dtype))
|
|
142
|
+
for name, dtype in zip(schema.names, schema.types, strict=True)
|
|
143
|
+
]
|
|
144
|
+
return cls(tuple(columns))
|
|
117
145
|
|
|
118
146
|
def column_names(self) -> tuple[str, ...]:
|
|
119
147
|
"""Get all column names."""
|
|
@@ -60,15 +60,15 @@ from collections.abc import Callable
|
|
|
60
60
|
from dataclasses import dataclass
|
|
61
61
|
from typing import Any, cast
|
|
62
62
|
|
|
63
|
-
from mplang.core.cluster import ClusterSpec
|
|
64
|
-
from mplang.core.context_mgr import with_ctx
|
|
65
|
-
from mplang.core.expr.ast import Expr, FuncDefExpr, TupleExpr, VariableExpr
|
|
66
|
-
from mplang.core.expr.printer import Printer
|
|
67
|
-
from mplang.core.mask import Mask
|
|
68
|
-
from mplang.core.mpobject import MPContext, MPObject
|
|
69
|
-
from mplang.core.mptype import MPType
|
|
70
|
-
from mplang.core.pfunc import get_fn_name
|
|
71
|
-
from mplang.utils.func_utils import MorphStruct, var_demorph, var_morph
|
|
63
|
+
from mplang.v1.core.cluster import ClusterSpec
|
|
64
|
+
from mplang.v1.core.context_mgr import with_ctx
|
|
65
|
+
from mplang.v1.core.expr.ast import Expr, FuncDefExpr, TupleExpr, VariableExpr
|
|
66
|
+
from mplang.v1.core.expr.printer import Printer
|
|
67
|
+
from mplang.v1.core.mask import Mask
|
|
68
|
+
from mplang.v1.core.mpobject import MPContext, MPObject
|
|
69
|
+
from mplang.v1.core.mptype import MPType
|
|
70
|
+
from mplang.v1.core.pfunc import get_fn_name
|
|
71
|
+
from mplang.v1.utils.func_utils import MorphStruct, var_demorph, var_morph
|
|
72
72
|
|
|
73
73
|
|
|
74
74
|
class VarNamer:
|
mplang/{api.py → v1/host.py}
RENAMED
|
@@ -19,7 +19,8 @@ from typing import Any
|
|
|
19
19
|
|
|
20
20
|
from jax.tree_util import tree_map
|
|
21
21
|
|
|
22
|
-
from mplang.core import (
|
|
22
|
+
from mplang.v1.core import (
|
|
23
|
+
ClusterSpec,
|
|
23
24
|
InterpContext,
|
|
24
25
|
MPContext,
|
|
25
26
|
MPObject,
|
|
@@ -27,8 +28,7 @@ from mplang.core import (
|
|
|
27
28
|
TracedFunction,
|
|
28
29
|
trace,
|
|
29
30
|
)
|
|
30
|
-
from mplang.core.
|
|
31
|
-
from mplang.core.context_mgr import cur_ctx, with_ctx
|
|
31
|
+
from mplang.v1.core.context_mgr import cur_ctx, with_ctx
|
|
32
32
|
|
|
33
33
|
|
|
34
34
|
def evaluate(
|
|
@@ -38,6 +38,16 @@ def evaluate(
|
|
|
38
38
|
|
|
39
39
|
This function accepts arbitrary types as it's designed to handle
|
|
40
40
|
any multi-party computation function and arguments.
|
|
41
|
+
|
|
42
|
+
Args:
|
|
43
|
+
interp: The interpreter context for evaluating the multi-party function.
|
|
44
|
+
mpfn: The multi-party function to evaluate.
|
|
45
|
+
*args: Positional arguments to pass to the function.
|
|
46
|
+
**kwargs: Keyword arguments to pass to the function.
|
|
47
|
+
|
|
48
|
+
Returns:
|
|
49
|
+
Any: The result of evaluating the multi-party function, which can be
|
|
50
|
+
any type depending on the function's return type.
|
|
41
51
|
"""
|
|
42
52
|
assert isinstance(interp, InterpContext), f"Expect InterpContext, got {interp}"
|
|
43
53
|
with with_ctx(interp):
|
|
@@ -49,6 +59,16 @@ def fetch(interp: InterpContext | None, objs: Any) -> Any: # type: ignore[misc]
|
|
|
49
59
|
|
|
50
60
|
This function uses tree_map to handle arbitrary nested structures,
|
|
51
61
|
so it needs to accept and return Any type.
|
|
62
|
+
|
|
63
|
+
Args:
|
|
64
|
+
interp: The interpreter context for fetching results. If None, uses the
|
|
65
|
+
current context from cur_ctx().
|
|
66
|
+
objs: The objects containing MPObject instances to fetch. Can be any
|
|
67
|
+
nested structure.
|
|
68
|
+
|
|
69
|
+
Returns:
|
|
70
|
+
Any: The fetched results with the same structure as the input objects,
|
|
71
|
+
but with MPObject instances replaced by their computed values.
|
|
52
72
|
"""
|
|
53
73
|
ctx = interp or cur_ctx()
|
|
54
74
|
assert isinstance(ctx, InterpContext), f"Expect MPExecutor, got {ctx}"
|
|
@@ -56,11 +76,11 @@ def fetch(interp: InterpContext | None, objs: Any) -> Any: # type: ignore[misc]
|
|
|
56
76
|
evaluated = evaluate(ctx, lambda x: x, objs)
|
|
57
77
|
|
|
58
78
|
def fetch_impl(arg: MPObject | Any) -> Any:
|
|
59
|
-
if isinstance(arg, MPObject):
|
|
60
|
-
return ctx.fetch(arg)
|
|
61
|
-
else:
|
|
79
|
+
if not isinstance(arg, MPObject):
|
|
62
80
|
return arg
|
|
63
81
|
|
|
82
|
+
return ctx.fetch(arg)
|
|
83
|
+
|
|
64
84
|
return tree_map(fetch_impl, evaluated)
|
|
65
85
|
|
|
66
86
|
|
|
@@ -94,5 +114,17 @@ class CompileOptions(MPContext):
|
|
|
94
114
|
def compile(
|
|
95
115
|
mctx: MPContext, fn: Callable[..., Any], *args: Any, **kwargs: Any
|
|
96
116
|
) -> TracedFunction:
|
|
117
|
+
"""Compile a multi-party function into a TracedFunction.
|
|
118
|
+
|
|
119
|
+
Args:
|
|
120
|
+
mctx: The multi-party context for compilation.
|
|
121
|
+
fn: The function to compile.
|
|
122
|
+
*args: Positional arguments to pass during compilation.
|
|
123
|
+
**kwargs: Keyword arguments to pass during compilation.
|
|
124
|
+
|
|
125
|
+
Returns:
|
|
126
|
+
TracedFunction: The compiled function representation that can be
|
|
127
|
+
evaluated in multi-party contexts.
|
|
128
|
+
"""
|
|
97
129
|
trace_ctx = TraceContext(mctx.cluster_spec)
|
|
98
130
|
return trace(trace_ctx, fn, *args, **kwargs)
|
|
@@ -0,0 +1,41 @@
|
|
|
1
|
+
# Copyright 2025 Ant Group Co., Ltd.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
from mplang.v1.kernels.value import (
|
|
16
|
+
BytesBlob,
|
|
17
|
+
TableValue,
|
|
18
|
+
TensorValue,
|
|
19
|
+
Value,
|
|
20
|
+
ValueDecodeError,
|
|
21
|
+
ValueError,
|
|
22
|
+
decode_value,
|
|
23
|
+
encode_value,
|
|
24
|
+
is_value_envelope,
|
|
25
|
+
list_value_kinds,
|
|
26
|
+
register_value,
|
|
27
|
+
)
|
|
28
|
+
|
|
29
|
+
__all__ = [
|
|
30
|
+
"BytesBlob",
|
|
31
|
+
"TableValue",
|
|
32
|
+
"TensorValue",
|
|
33
|
+
"Value",
|
|
34
|
+
"ValueDecodeError",
|
|
35
|
+
"ValueError",
|
|
36
|
+
"decode_value",
|
|
37
|
+
"encode_value",
|
|
38
|
+
"is_value_envelope",
|
|
39
|
+
"list_value_kinds",
|
|
40
|
+
"register_value",
|
|
41
|
+
]
|