mplang-nightly 0.1.dev269__py3-none-any.whl → 0.1.dev270__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mplang/__init__.py +391 -17
- mplang/{v2/backends → backends}/__init__.py +9 -7
- mplang/{v2/backends → backends}/bfv_impl.py +6 -6
- mplang/{v2/backends → backends}/crypto_impl.py +6 -6
- mplang/{v2/backends → backends}/field_impl.py +5 -5
- mplang/{v2/backends → backends}/func_impl.py +4 -4
- mplang/{v2/backends → backends}/phe_impl.py +3 -3
- mplang/{v2/backends → backends}/simp_design.md +1 -1
- mplang/{v2/backends → backends}/simp_driver/__init__.py +5 -5
- mplang/{v2/backends → backends}/simp_driver/http.py +8 -8
- mplang/{v2/backends → backends}/simp_driver/mem.py +9 -9
- mplang/{v2/backends → backends}/simp_driver/ops.py +4 -4
- mplang/{v2/backends → backends}/simp_driver/state.py +2 -2
- mplang/{v2/backends → backends}/simp_driver/values.py +2 -2
- mplang/{v2/backends → backends}/simp_worker/__init__.py +3 -3
- mplang/{v2/backends → backends}/simp_worker/http.py +10 -10
- mplang/{v2/backends → backends}/simp_worker/mem.py +1 -1
- mplang/{v2/backends → backends}/simp_worker/ops.py +5 -5
- mplang/{v2/backends → backends}/simp_worker/state.py +2 -4
- mplang/{v2/backends → backends}/spu_impl.py +8 -8
- mplang/{v2/backends → backends}/spu_state.py +4 -4
- mplang/{v2/backends → backends}/store_impl.py +3 -3
- mplang/{v2/backends → backends}/table_impl.py +8 -8
- mplang/{v2/backends → backends}/tee_impl.py +6 -6
- mplang/{v2/backends → backends}/tensor_impl.py +6 -6
- mplang/{v2/cli.py → cli.py} +9 -9
- mplang/{v2/cli_guide.md → cli_guide.md} +12 -12
- mplang/{v2/dialects → dialects}/__init__.py +5 -5
- mplang/{v2/dialects → dialects}/bfv.py +6 -6
- mplang/{v2/dialects → dialects}/crypto.py +5 -5
- mplang/{v2/dialects → dialects}/dtypes.py +2 -2
- mplang/{v2/dialects → dialects}/field.py +3 -3
- mplang/{v2/dialects → dialects}/func.py +2 -2
- mplang/{v2/dialects → dialects}/phe.py +6 -6
- mplang/{v2/dialects → dialects}/simp.py +6 -6
- mplang/{v2/dialects → dialects}/spu.py +7 -7
- mplang/{v2/dialects → dialects}/store.py +2 -2
- mplang/{v2/dialects → dialects}/table.py +3 -3
- mplang/{v2/dialects → dialects}/tee.py +6 -6
- mplang/{v2/dialects → dialects}/tensor.py +5 -5
- mplang/{v2/edsl → edsl}/__init__.py +3 -3
- mplang/{v2/edsl → edsl}/context.py +6 -6
- mplang/{v2/edsl → edsl}/graph.py +5 -5
- mplang/{v2/edsl → edsl}/jit.py +2 -2
- mplang/{v2/edsl → edsl}/object.py +1 -1
- mplang/{v2/edsl → edsl}/primitive.py +5 -5
- mplang/{v2/edsl → edsl}/printer.py +1 -1
- mplang/{v2/edsl → edsl}/serde.py +1 -1
- mplang/{v2/edsl → edsl}/tracer.py +7 -7
- mplang/{v2/edsl → edsl}/typing.py +1 -1
- mplang/{v2/kernels → kernels}/ldpc.cpp +13 -13
- mplang/{v2/kernels → kernels}/okvs.cpp +4 -4
- mplang/{v2/kernels → kernels}/okvs_opt.cpp +31 -31
- mplang/{v2/kernels → kernels}/py_kernels.py +1 -1
- mplang/{v2/libs → libs}/collective.py +5 -5
- mplang/{v2/libs → libs}/device/__init__.py +1 -1
- mplang/{v2/libs → libs}/device/api.py +12 -12
- mplang/{v2/libs → libs}/ml/__init__.py +1 -1
- mplang/{v2/libs → libs}/ml/sgb.py +4 -4
- mplang/{v2/libs → libs}/mpc/__init__.py +3 -3
- mplang/{v2/libs → libs}/mpc/_utils.py +2 -2
- mplang/{v2/libs → libs}/mpc/analytics/aggregation.py +1 -1
- mplang/{v2/libs → libs}/mpc/analytics/groupby.py +2 -2
- mplang/{v2/libs → libs}/mpc/analytics/permutation.py +3 -3
- mplang/{v2/libs → libs}/mpc/ot/base.py +3 -3
- mplang/{v2/libs → libs}/mpc/ot/extension.py +2 -2
- mplang/{v2/libs → libs}/mpc/ot/silent.py +4 -4
- mplang/{v2/libs → libs}/mpc/psi/cuckoo.py +3 -3
- mplang/{v2/libs → libs}/mpc/psi/okvs.py +1 -1
- mplang/{v2/libs → libs}/mpc/psi/okvs_gct.py +3 -3
- mplang/{v2/libs → libs}/mpc/psi/oprf.py +3 -3
- mplang/{v2/libs → libs}/mpc/psi/rr22.py +7 -7
- mplang/{v2/libs → libs}/mpc/psi/unbalanced.py +4 -4
- mplang/{v2/libs → libs}/mpc/vole/gilboa.py +3 -3
- mplang/{v2/libs → libs}/mpc/vole/ldpc.py +2 -2
- mplang/{v2/libs → libs}/mpc/vole/silver.py +6 -6
- mplang/{v2/runtime → runtime}/interpreter.py +11 -11
- mplang/{v2/runtime → runtime}/value.py +2 -2
- mplang/{v1/runtime → utils}/__init__.py +18 -15
- mplang/{v1/utils → utils}/func_utils.py +1 -1
- {mplang_nightly-0.1.dev269.dist-info → mplang_nightly-0.1.dev270.dist-info}/METADATA +2 -2
- mplang_nightly-0.1.dev270.dist-info/RECORD +102 -0
- mplang/v1/__init__.py +0 -157
- mplang/v1/_device.py +0 -602
- mplang/v1/analysis/__init__.py +0 -37
- mplang/v1/analysis/diagram.py +0 -567
- mplang/v1/core/__init__.py +0 -157
- mplang/v1/core/cluster.py +0 -343
- mplang/v1/core/comm.py +0 -281
- mplang/v1/core/context_mgr.py +0 -50
- mplang/v1/core/dtypes.py +0 -335
- mplang/v1/core/expr/__init__.py +0 -80
- mplang/v1/core/expr/ast.py +0 -542
- mplang/v1/core/expr/evaluator.py +0 -581
- mplang/v1/core/expr/printer.py +0 -285
- mplang/v1/core/expr/transformer.py +0 -141
- mplang/v1/core/expr/utils.py +0 -78
- mplang/v1/core/expr/visitor.py +0 -85
- mplang/v1/core/expr/walk.py +0 -387
- mplang/v1/core/interp.py +0 -160
- mplang/v1/core/mask.py +0 -325
- mplang/v1/core/mpir.py +0 -965
- mplang/v1/core/mpobject.py +0 -117
- mplang/v1/core/mptype.py +0 -407
- mplang/v1/core/pfunc.py +0 -130
- mplang/v1/core/primitive.py +0 -877
- mplang/v1/core/table.py +0 -218
- mplang/v1/core/tensor.py +0 -75
- mplang/v1/core/tracer.py +0 -383
- mplang/v1/host.py +0 -130
- mplang/v1/kernels/__init__.py +0 -41
- mplang/v1/kernels/base.py +0 -125
- mplang/v1/kernels/basic.py +0 -240
- mplang/v1/kernels/context.py +0 -369
- mplang/v1/kernels/crypto.py +0 -122
- mplang/v1/kernels/fhe.py +0 -858
- mplang/v1/kernels/mock_tee.py +0 -72
- mplang/v1/kernels/phe.py +0 -1864
- mplang/v1/kernels/spu.py +0 -341
- mplang/v1/kernels/sql_duckdb.py +0 -44
- mplang/v1/kernels/stablehlo.py +0 -90
- mplang/v1/kernels/value.py +0 -626
- mplang/v1/ops/__init__.py +0 -35
- mplang/v1/ops/base.py +0 -424
- mplang/v1/ops/basic.py +0 -294
- mplang/v1/ops/crypto.py +0 -262
- mplang/v1/ops/fhe.py +0 -272
- mplang/v1/ops/jax_cc.py +0 -147
- mplang/v1/ops/nnx_cc.py +0 -168
- mplang/v1/ops/phe.py +0 -216
- mplang/v1/ops/spu.py +0 -151
- mplang/v1/ops/sql_cc.py +0 -303
- mplang/v1/ops/tee.py +0 -36
- mplang/v1/protos/v1alpha1/mpir_pb2.py +0 -63
- mplang/v1/protos/v1alpha1/mpir_pb2.pyi +0 -557
- mplang/v1/protos/v1alpha1/value_pb2.py +0 -34
- mplang/v1/protos/v1alpha1/value_pb2.pyi +0 -169
- mplang/v1/runtime/channel.py +0 -230
- mplang/v1/runtime/cli.py +0 -451
- mplang/v1/runtime/client.py +0 -456
- mplang/v1/runtime/communicator.py +0 -131
- mplang/v1/runtime/data_providers.py +0 -303
- mplang/v1/runtime/driver.py +0 -324
- mplang/v1/runtime/exceptions.py +0 -27
- mplang/v1/runtime/http_api.md +0 -56
- mplang/v1/runtime/link_comm.py +0 -196
- mplang/v1/runtime/server.py +0 -501
- mplang/v1/runtime/session.py +0 -270
- mplang/v1/runtime/simulation.py +0 -324
- mplang/v1/simp/__init__.py +0 -13
- mplang/v1/simp/api.py +0 -353
- mplang/v1/simp/mpi.py +0 -131
- mplang/v1/simp/party.py +0 -225
- mplang/v1/simp/random.py +0 -120
- mplang/v1/simp/smpc.py +0 -238
- mplang/v1/utils/__init__.py +0 -13
- mplang/v1/utils/crypto.py +0 -32
- mplang/v1/utils/spu_utils.py +0 -130
- mplang/v1/utils/table_utils.py +0 -185
- mplang/v2/__init__.py +0 -424
- mplang_nightly-0.1.dev269.dist-info/RECORD +0 -180
- /mplang/{v2/backends → backends}/channel.py +0 -0
- /mplang/{v2/edsl → edsl}/README.md +0 -0
- /mplang/{v2/edsl → edsl}/registry.py +0 -0
- /mplang/{v2/kernels → kernels}/Makefile +0 -0
- /mplang/{v2/kernels → kernels}/__init__.py +0 -0
- /mplang/{v2/kernels → kernels}/gf128.cpp +0 -0
- /mplang/{v2/libs → libs}/device/cluster.py +0 -0
- /mplang/{v2/libs → libs}/mpc/analytics/__init__.py +0 -0
- /mplang/{v2/libs → libs}/mpc/analytics/groupby.md +0 -0
- /mplang/{v2/libs → libs}/mpc/common/constants.py +0 -0
- /mplang/{v2/libs → libs}/mpc/ot/__init__.py +0 -0
- /mplang/{v2/libs → libs}/mpc/psi/__init__.py +0 -0
- /mplang/{v2/libs → libs}/mpc/vole/__init__.py +0 -0
- /mplang/{v2/runtime → runtime}/__init__.py +0 -0
- /mplang/{v2/runtime → runtime}/dialect_state.py +0 -0
- /mplang/{v2/runtime → runtime}/object_store.py +0 -0
- {mplang_nightly-0.1.dev269.dist-info → mplang_nightly-0.1.dev270.dist-info}/WHEEL +0 -0
- {mplang_nightly-0.1.dev269.dist-info → mplang_nightly-0.1.dev270.dist-info}/entry_points.txt +0 -0
- {mplang_nightly-0.1.dev269.dist-info → mplang_nightly-0.1.dev270.dist-info}/licenses/LICENSE +0 -0
mplang/v1/core/primitive.py
DELETED
|
@@ -1,877 +0,0 @@
|
|
|
1
|
-
# Copyright 2025 Ant Group Co., Ltd.
|
|
2
|
-
#
|
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
-
# you may not use this file except in compliance with the License.
|
|
5
|
-
# You may obtain a copy of the License at
|
|
6
|
-
#
|
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
-
#
|
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
-
# See the License for the specific language governing permissions and
|
|
13
|
-
# limitations under the License.
|
|
14
|
-
|
|
15
|
-
"""
|
|
16
|
-
Primitive operations for the new expr-based implementation.
|
|
17
|
-
|
|
18
|
-
This module defines the fundamental primitive operations that form the building
|
|
19
|
-
blocks of multi-party computations. All primitives are designed to work in
|
|
20
|
-
TraceContext by default, automatically switching contexts as needed.
|
|
21
|
-
"""
|
|
22
|
-
|
|
23
|
-
from __future__ import annotations
|
|
24
|
-
|
|
25
|
-
from collections.abc import Callable
|
|
26
|
-
from functools import partial, wraps
|
|
27
|
-
from typing import Any, ParamSpec, TypeVar, cast
|
|
28
|
-
|
|
29
|
-
from jax.tree_util import tree_map
|
|
30
|
-
|
|
31
|
-
from mplang.v1.core.context_mgr import cur_ctx
|
|
32
|
-
from mplang.v1.core.dtypes import BOOL
|
|
33
|
-
from mplang.v1.core.expr.ast import (
|
|
34
|
-
AccessExpr,
|
|
35
|
-
CallExpr,
|
|
36
|
-
CondExpr,
|
|
37
|
-
ConvExpr,
|
|
38
|
-
EvalExpr,
|
|
39
|
-
ShflExpr,
|
|
40
|
-
ShflSExpr,
|
|
41
|
-
WhileExpr,
|
|
42
|
-
)
|
|
43
|
-
from mplang.v1.core.interp import InterpContext, InterpVar, apply
|
|
44
|
-
from mplang.v1.core.mask import Mask
|
|
45
|
-
from mplang.v1.core.mpobject import MPContext, MPObject
|
|
46
|
-
from mplang.v1.core.mptype import Rank
|
|
47
|
-
from mplang.v1.core.pfunc import PFunction
|
|
48
|
-
from mplang.v1.core.tracer import TraceContext, TraceVar, trace
|
|
49
|
-
from mplang.v1.utils.func_utils import var_demorph, var_morph
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
def _switch_ctx(ctx: MPContext, obj: MPObject | Any) -> MPObject | Any:
|
|
53
|
-
assert isinstance(ctx, MPContext), f"Expect MPContext, got {ctx}"
|
|
54
|
-
|
|
55
|
-
if not isinstance(obj, MPObject):
|
|
56
|
-
# If obj is not an MPObject, return it as is
|
|
57
|
-
return obj
|
|
58
|
-
|
|
59
|
-
if ctx is obj.ctx:
|
|
60
|
-
# If the object is already in the correct context, return it directly
|
|
61
|
-
return obj
|
|
62
|
-
|
|
63
|
-
if obj.ctx.world_size() != ctx.world_size():
|
|
64
|
-
# TODO(jint): strict check if source and target context are compatible.
|
|
65
|
-
raise ValueError(f"{obj} world_size mismatch, expect {ctx.world_size()}.")
|
|
66
|
-
|
|
67
|
-
if isinstance(ctx, TraceContext):
|
|
68
|
-
# Capture the object (as a variable) into current TraceContext
|
|
69
|
-
return ctx.capture(obj)
|
|
70
|
-
elif isinstance(ctx, InterpContext):
|
|
71
|
-
if isinstance(obj, InterpVar):
|
|
72
|
-
raise ValueError(f"Cannot import InterpVar {obj} from {obj.ctx} to {ctx}")
|
|
73
|
-
elif isinstance(obj, TraceVar):
|
|
74
|
-
assert isinstance(obj.ctx, TraceContext), obj
|
|
75
|
-
# TODO: implement eval method in InterpContext
|
|
76
|
-
raise NotImplementedError("InterpContext.eval not implemented yet")
|
|
77
|
-
else:
|
|
78
|
-
raise ValueError(f"Import from {obj.ctx} to {ctx} not supported")
|
|
79
|
-
else:
|
|
80
|
-
raise ValueError(f"Unsupported context type: {type(ctx)}")
|
|
81
|
-
|
|
82
|
-
|
|
83
|
-
# Define type variables for preserving function signatures
|
|
84
|
-
P = ParamSpec("P")
|
|
85
|
-
R = TypeVar("R")
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
def trace_before_apply(fn: Callable[P, R], make_call: bool) -> Callable[P, R]:
|
|
89
|
-
"""A decorator to make all primitive call in trace context."""
|
|
90
|
-
|
|
91
|
-
@wraps(fn)
|
|
92
|
-
def wrapped(*args: P.args, **kwargs: P.kwargs) -> R:
|
|
93
|
-
current_ctx = cur_ctx()
|
|
94
|
-
if isinstance(current_ctx, TraceContext):
|
|
95
|
-
# If we are already in a tracer context
|
|
96
|
-
if make_call:
|
|
97
|
-
# make a primitive call
|
|
98
|
-
tracer = current_ctx
|
|
99
|
-
tfn = trace(tracer.fork(), fn, *args, **kwargs)
|
|
100
|
-
is_mpobj = lambda x: isinstance(x, MPObject)
|
|
101
|
-
in_vars, in_imms, in_struct = var_morph((args, kwargs), is_mpobj)
|
|
102
|
-
assert in_struct == tfn.in_struct and in_imms == tfn.in_imms
|
|
103
|
-
arg_exprs = [arg.expr for arg in in_vars]
|
|
104
|
-
# re-capture all captured variables into current context if needed.
|
|
105
|
-
cap_exprs = [tracer.capture(var).expr for var in tfn.capture_map.keys()]
|
|
106
|
-
caller_expr = CallExpr(
|
|
107
|
-
name=fn.__name__, fn=tfn.make_expr(), args=arg_exprs + cap_exprs
|
|
108
|
-
)
|
|
109
|
-
out_vars = [
|
|
110
|
-
TraceVar(tracer, AccessExpr(caller_expr, idx))
|
|
111
|
-
for idx in range(caller_expr.num_outputs)
|
|
112
|
-
]
|
|
113
|
-
return cast(R, var_demorph(out_vars, tfn.out_imms, tfn.out_struct))
|
|
114
|
-
else:
|
|
115
|
-
# embed the function call in the current tracer context
|
|
116
|
-
# Note: switch_ctx will do the capture if needed.
|
|
117
|
-
args, kwargs = tree_map(
|
|
118
|
-
partial(_switch_ctx, current_ctx), (args, kwargs)
|
|
119
|
-
)
|
|
120
|
-
return fn(*args, **kwargs)
|
|
121
|
-
elif isinstance(current_ctx, InterpContext):
|
|
122
|
-
trace_ctx = TraceContext(current_ctx.cluster_spec, parent=current_ctx)
|
|
123
|
-
# TODO(jint): should we add trace_and_apply to improve the performance?
|
|
124
|
-
tfn = trace(trace_ctx, fn, *args, **kwargs)
|
|
125
|
-
# Return back to the original context.
|
|
126
|
-
return cast(R, apply(current_ctx, tfn, *args, **kwargs))
|
|
127
|
-
else:
|
|
128
|
-
raise ValueError(f"Unsupported context type: {type(current_ctx)}")
|
|
129
|
-
|
|
130
|
-
return wrapped
|
|
131
|
-
|
|
132
|
-
|
|
133
|
-
def builtin_function(fn: Callable[P, R]) -> Callable[P, R]:
|
|
134
|
-
"""Decorator to trace a Python function as an opaque primitive call (`CallExpr`).
|
|
135
|
-
|
|
136
|
-
When a function decorated with `@builtin_function` is called within a `TraceContext`, it is
|
|
137
|
-
not inlined. Instead, it is traced separately in a forked context, and a `CallExpr`
|
|
138
|
-
node is inserted into the main graph. This is useful for encapsulating complex
|
|
139
|
-
operations or third-party library calls as single, opaque nodes.
|
|
140
|
-
|
|
141
|
-
**Implementation Note:**
|
|
142
|
-
A `CallExpr` represents a call to a single inline lambda (non-recursive, as we don't
|
|
143
|
-
have Y-combinator support). This single lambda call can be treated as a "primitive call"
|
|
144
|
-
by the printer/visualizer - hence the name "primitive". The function body is captured
|
|
145
|
-
once during tracing and represented as an opaque callable unit in the expression graph,
|
|
146
|
-
maintaining a clear boundary between the caller and callee contexts.
|
|
147
|
-
|
|
148
|
-
Args:
|
|
149
|
-
fn: The function to be traced as a primitive operation.
|
|
150
|
-
|
|
151
|
-
Returns:
|
|
152
|
-
A wrapped function that creates a `CallExpr` node when called in a trace context.
|
|
153
|
-
|
|
154
|
-
Example:
|
|
155
|
-
```python
|
|
156
|
-
@builtin_function
|
|
157
|
-
def my_op(x: MPObject) -> MPObject:
|
|
158
|
-
# Complex logic traced as a single CallExpr node
|
|
159
|
-
return x + 1
|
|
160
|
-
```
|
|
161
|
-
"""
|
|
162
|
-
return trace_before_apply(fn, make_call=True)
|
|
163
|
-
|
|
164
|
-
|
|
165
|
-
def function(fn: Callable[P, R]) -> Callable[P, R]:
|
|
166
|
-
"""Decorator to trace a Python function by inlining its body.
|
|
167
|
-
|
|
168
|
-
When a function decorated with `@function` is called within a `TraceContext`, its
|
|
169
|
-
underlying primitive operations are expanded and inserted directly into the caller's
|
|
170
|
-
graph. This is the default tracing behavior and is suitable for most pure-Python
|
|
171
|
-
multi-party functions.
|
|
172
|
-
|
|
173
|
-
Args:
|
|
174
|
-
fn: The function to be traced and inlined.
|
|
175
|
-
|
|
176
|
-
Returns:
|
|
177
|
-
A wrapped function that inlines its operations into the caller's trace context.
|
|
178
|
-
|
|
179
|
-
Example:
|
|
180
|
-
```python
|
|
181
|
-
@function
|
|
182
|
-
def my_func(x: MPObject, y: MPObject) -> MPObject:
|
|
183
|
-
# Operations are inlined into the caller's trace
|
|
184
|
-
return x + y * constant(2)
|
|
185
|
-
```
|
|
186
|
-
"""
|
|
187
|
-
return trace_before_apply(fn, make_call=False)
|
|
188
|
-
|
|
189
|
-
|
|
190
|
-
# ============================================================================
|
|
191
|
-
# Basic Primitive Operations
|
|
192
|
-
# ============================================================================
|
|
193
|
-
|
|
194
|
-
|
|
195
|
-
def _tracer() -> TraceContext:
|
|
196
|
-
"""Get the current context and ensure it's a Tracer."""
|
|
197
|
-
ctx = cur_ctx()
|
|
198
|
-
if not isinstance(ctx, TraceContext):
|
|
199
|
-
raise ValueError(f"Expect tracer, got {ctx}")
|
|
200
|
-
return ctx
|
|
201
|
-
|
|
202
|
-
|
|
203
|
-
def psize() -> int:
|
|
204
|
-
"""Get the size of the current party world.
|
|
205
|
-
|
|
206
|
-
Returns:
|
|
207
|
-
int: The total number of parties in the current multi-party computation context.
|
|
208
|
-
"""
|
|
209
|
-
return cur_ctx().world_size()
|
|
210
|
-
|
|
211
|
-
|
|
212
|
-
def pmask() -> Mask:
|
|
213
|
-
"""Get the current party mask in this computation context.
|
|
214
|
-
|
|
215
|
-
Returns:
|
|
216
|
-
Mask: The current party mask indicating which parties are active
|
|
217
|
-
in the current computation context.
|
|
218
|
-
"""
|
|
219
|
-
return _tracer().mask
|
|
220
|
-
|
|
221
|
-
|
|
222
|
-
@function
|
|
223
|
-
def peval(
|
|
224
|
-
pfunc: PFunction,
|
|
225
|
-
args: list[MPObject],
|
|
226
|
-
rmask: Mask | None = None,
|
|
227
|
-
) -> list[MPObject]:
|
|
228
|
-
"""Multi-party evaluate a function in a SPMD (Single Program, Multiple Data) way.
|
|
229
|
-
|
|
230
|
-
This function evaluates a PFunction (primitive function) across multiple parties
|
|
231
|
-
in a coordinated manner. All parties execute the same function logic but operate
|
|
232
|
-
on their own local data portions according to their party masks.
|
|
233
|
-
|
|
234
|
-
Args:
|
|
235
|
-
pfunc: The function to be evaluated in multi-party computation.
|
|
236
|
-
This should be a compiled primitive function that supports
|
|
237
|
-
multi-party execution.
|
|
238
|
-
args: Input arguments as a list of MPObject variables.
|
|
239
|
-
Each argument represents data distributed across parties
|
|
240
|
-
according to their respective party masks.
|
|
241
|
-
rmask: Execution enforcement mask that forces the
|
|
242
|
-
runtime to evaluate the function with the specified party mask.
|
|
243
|
-
|
|
244
|
-
**Important**: This rmask is different from MPObject.pmask:
|
|
245
|
-
- MPObject.pmask: Compile-time type information indicating data distribution
|
|
246
|
-
- This rmask: Runtime execution constraint specifying which parties execute
|
|
247
|
-
|
|
248
|
-
If None, the runtime automatically determines the execution mask based
|
|
249
|
-
on the current context. If provided, the runtime will attempt to execute
|
|
250
|
-
with this exact mask. Defaults to None.
|
|
251
|
-
|
|
252
|
-
Returns:
|
|
253
|
-
list[MPObject]: A list of output variables from the evaluation.
|
|
254
|
-
|
|
255
|
-
Raises:
|
|
256
|
-
ValueError: Raised at compile-time when all input arguments have known
|
|
257
|
-
pmasks but they are incompatible with the required rmask constraint.
|
|
258
|
-
This is a static validation error detected during graph construction.
|
|
259
|
-
RuntimeError: Raised at runtime when the rmask constraint cannot be
|
|
260
|
-
satisfied. This occurs when some input arguments have unknown pmasks
|
|
261
|
-
(determined at runtime) and the actual runtime pmasks don't meet
|
|
262
|
-
the rmask requirement.
|
|
263
|
-
|
|
264
|
-
Note:
|
|
265
|
-
The function body operates in SPMD fashion where all parties execute the
|
|
266
|
-
same program logic but on their respective data partitions.
|
|
267
|
-
"""
|
|
268
|
-
ctx = _tracer()
|
|
269
|
-
|
|
270
|
-
if rmask is None and len(args) == 0:
|
|
271
|
-
# Zero-arg call: default to current context mask (do not implicitly widen)
|
|
272
|
-
rmask = ctx.mask
|
|
273
|
-
if rmask is not None and not Mask(rmask).is_subset(ctx.mask):
|
|
274
|
-
# Keep error wording for backward-compatibility with existing tests/docs
|
|
275
|
-
raise ValueError(
|
|
276
|
-
f"Specified rmask {rmask} is not a subset of deduced pmask {ctx.mask}"
|
|
277
|
-
)
|
|
278
|
-
|
|
279
|
-
arg_exprs = [arg.expr for arg in cast(list[TraceVar], args)]
|
|
280
|
-
fn_expr = EvalExpr(pfunc, arg_exprs, rmask)
|
|
281
|
-
ret_exprs = [AccessExpr(fn_expr, idx) for idx in range(fn_expr.num_outputs)]
|
|
282
|
-
|
|
283
|
-
return [TraceVar(ctx, res) for res in ret_exprs]
|
|
284
|
-
|
|
285
|
-
|
|
286
|
-
@function
|
|
287
|
-
def uniform_cond(
|
|
288
|
-
pred: MPObject,
|
|
289
|
-
then_fn: Callable[..., Any],
|
|
290
|
-
else_fn: Callable[..., Any],
|
|
291
|
-
*args: Any,
|
|
292
|
-
verify_uniform: bool = True,
|
|
293
|
-
) -> Any:
|
|
294
|
-
"""Global (uniform) multi-party conditional.
|
|
295
|
-
|
|
296
|
-
Exactly one branch (``then_fn`` or ``else_fn``) is executed *globally* across
|
|
297
|
-
all active parties. Use this primitive when:
|
|
298
|
-
|
|
299
|
-
1. ``pred`` is a boolean scalar whose runtime value is identical for every enabled party.
|
|
300
|
-
2. At least one branch contains multi-party primitives (``seal`` / ``reveal`` /
|
|
301
|
-
``srun_jax`` / ``pshfl`` / mask transformations) whose cost or side-effects you
|
|
302
|
-
want to avoid if the branch is not taken.
|
|
303
|
-
3. You require the semantic guarantee that the *non-selected* branch does **not**
|
|
304
|
-
perform communication, allocate intermediate buffers, or leak timing/side-effects.
|
|
305
|
-
|
|
306
|
-
DO NOT use this when:
|
|
307
|
-
* Predicate differs per party (use party-local selection or ``jax.where``).
|
|
308
|
-
* You only need elementwise / per-entry selection (use ``jax.where`` / ``peval(jax.where)``).
|
|
309
|
-
* Predicate is still secret-shared and you cannot reveal it (future: oblivious branch).
|
|
310
|
-
|
|
311
|
-
Choosing between primitives (decision guide):
|
|
312
|
-
|
|
313
|
-
1. Use ``jax.where`` (elementwise select) WHEN:
|
|
314
|
-
- You already have both candidate tensors computed (cheap or unavoidable), AND
|
|
315
|
-
- You want per-element blending, OR
|
|
316
|
-
- Predicate may differ per party / per element.
|
|
317
|
-
|
|
318
|
-
Example::
|
|
319
|
-
y = peval(jax.where, [mask, a, b]) # both a and b computed
|
|
320
|
-
|
|
321
|
-
2. Use ``uniform_cond`` (this primitive) WHEN:
|
|
322
|
-
- Exactly one expensive or MPC-effectful branch should run, AND
|
|
323
|
-
- Predicate is (or must be) globally uniform, AND
|
|
324
|
-
- You want to avoid executing the non-selected branch entirely.
|
|
325
|
-
|
|
326
|
-
Example::
|
|
327
|
-
def heavy_then(x):
|
|
328
|
-
sealed = smpc.seal(x)
|
|
329
|
-
return smpc.reveal(sealed) + constant(1)
|
|
330
|
-
|
|
331
|
-
|
|
332
|
-
def light_else(x):
|
|
333
|
-
return x - constant(1)
|
|
334
|
-
|
|
335
|
-
|
|
336
|
-
pred = reveal(global_flag) # uniform bool
|
|
337
|
-
y = uniform_cond(pred, heavy_then, light_else, x)
|
|
338
|
-
|
|
339
|
-
3. Use ``jax.lax.cond`` (inside peval) WHEN:
|
|
340
|
-
- Both branches are purely local numeric compute (no MPC comms), AND
|
|
341
|
-
- You accept both branches being traced & (possibly) device-compiled, OR
|
|
342
|
-
- You operate fully in JAX world without multi-party side-effects.
|
|
343
|
-
|
|
344
|
-
Example::
|
|
345
|
-
# Branches are pure JAX functions
|
|
346
|
-
y = peval(jax.lax.cond, [pred, fn_a, fn_b, x])
|
|
347
|
-
|
|
348
|
-
Args:
|
|
349
|
-
pred: Boolean scalar ``MPObject``; must have shape ``()`` and dtype bool. Intended to be
|
|
350
|
-
*uniform* (same logical value) across parties. If ``verify_uniform`` is True,
|
|
351
|
-
runtime will assert uniformity.
|
|
352
|
-
then_fn: Multi-party function executed when ``pred`` is True.
|
|
353
|
-
else_fn: Multi-party function executed when ``pred`` is False.
|
|
354
|
-
*args: MPObject arguments passed to the selected branch.
|
|
355
|
-
verify_uniform: Whether to perform a runtime uniformity assertion. Disable only if
|
|
356
|
-
the caller can guarantee (by construction) uniformity; disabling removes a
|
|
357
|
-
safety check and may lead to undefined behavior if predicate diverges.
|
|
358
|
-
|
|
359
|
-
Returns:
|
|
360
|
-
A PyTree of MPObjects whose structure and per-leaf MPType matches the outputs of both
|
|
361
|
-
branches (branches must agree exactly on MPType including pmask).
|
|
362
|
-
|
|
363
|
-
Raises:
|
|
364
|
-
TypeError: If ``pred`` is not a bool scalar; or branch output types mismatch.
|
|
365
|
-
ValueError: If ``verify_uniform=True`` and runtime detects non-uniform predicate.
|
|
366
|
-
|
|
367
|
-
Security:
|
|
368
|
-
``pred`` must be public (revealed) – using a secret, non-revealed boolean would create
|
|
369
|
-
a data-dependent control path (timing / communication pattern leak). Reveal first, or
|
|
370
|
-
use an oblivious selection (``jax.where``) if you cannot reveal.
|
|
371
|
-
|
|
372
|
-
Example (common):
|
|
373
|
-
>>> pred = simp.reveal(secret_flag) # bool scalar, now public + uniform
|
|
374
|
-
>>> out = uniform_cond(pred, branch_a, branch_b, x, y)
|
|
375
|
-
|
|
376
|
-
"""
|
|
377
|
-
assert all(isinstance(x, MPObject) for x in args), args
|
|
378
|
-
|
|
379
|
-
cur_tracer = _tracer()
|
|
380
|
-
|
|
381
|
-
# Predicate static shape/dtype check
|
|
382
|
-
pred_ty = pred.mptype
|
|
383
|
-
if len(pred_ty.shape) != 0:
|
|
384
|
-
raise TypeError(
|
|
385
|
-
f"uniform_cond predicate must be scalar, got shape {pred_ty.shape}"
|
|
386
|
-
)
|
|
387
|
-
# dtype naming depends on dtype system; assume name property or eq compare
|
|
388
|
-
if pred_ty.dtype != BOOL:
|
|
389
|
-
raise TypeError(f"uniform_cond predicate must be boolean, got {pred_ty.dtype}")
|
|
390
|
-
|
|
391
|
-
# Static pmask rule:
|
|
392
|
-
# If predicate has a static pmask (not None), it must equal the current trace
|
|
393
|
-
# context mask. Otherwise some parties would execute a branch without a
|
|
394
|
-
# defined predicate value (unsafe). To run on a subset either:
|
|
395
|
-
# 1. Trace the entire uniform_cond under a subset TraceContext (ctx.fork(mask=...))
|
|
396
|
-
# 2. Broadcast / lift predicate to full mask (e.g. pshfl_s)
|
|
397
|
-
# Pred pmask None => dynamic: defer to runtime uniformity (if verify_uniform=True).
|
|
398
|
-
pred_pmask = pred_ty.pmask
|
|
399
|
-
if pred_pmask is not None and pred_pmask != cur_tracer.mask:
|
|
400
|
-
raise ValueError(
|
|
401
|
-
"uniform_cond predicate static pmask mismatch: predicate pmask="
|
|
402
|
-
f"{pred_pmask} trace mask={cur_tracer.mask}. Trace under a subset "
|
|
403
|
-
"context (ctx.fork(mask=...)) or broadcast predicate (pshfl_s) to all parties."
|
|
404
|
-
)
|
|
405
|
-
# Step 1: Trace both branches in separate contexts
|
|
406
|
-
then_tracer = cur_tracer.fork()
|
|
407
|
-
then_tfn = trace(then_tracer, then_fn, *args)
|
|
408
|
-
|
|
409
|
-
else_tracer = cur_tracer.fork()
|
|
410
|
-
else_tfn = trace(else_tracer, else_fn, *args)
|
|
411
|
-
|
|
412
|
-
if not then_tfn.is_signature_match(else_tfn, check_captures=False):
|
|
413
|
-
# Branch outputs (structure, MPType, shape) must match exactly; treat mismatch as a
|
|
414
|
-
# type error per uniform_cond contract (docstring promises TypeError for output mismatch).
|
|
415
|
-
raise TypeError(
|
|
416
|
-
f"uniform_cond branch output/signature mismatch: {then_tfn} vs {else_tfn}"
|
|
417
|
-
)
|
|
418
|
-
|
|
419
|
-
# Enforce identical output MPTypes (including pmask). Then/else already have out_vars.
|
|
420
|
-
if len(then_tfn.out_vars) != len(else_tfn.out_vars):
|
|
421
|
-
raise TypeError(
|
|
422
|
-
"uniform_cond branches must return same number of outputs: "
|
|
423
|
-
f"{len(then_tfn.out_vars)} vs {len(else_tfn.out_vars)}"
|
|
424
|
-
)
|
|
425
|
-
for i, (tv, ev) in enumerate(
|
|
426
|
-
zip(then_tfn.out_vars, else_tfn.out_vars, strict=True)
|
|
427
|
-
):
|
|
428
|
-
if tv.mptype != ev.mptype:
|
|
429
|
-
raise TypeError(
|
|
430
|
-
"uniform_cond branch output MPType mismatch at index "
|
|
431
|
-
f"{i}: {tv.mptype} vs {ev.mptype}"
|
|
432
|
-
)
|
|
433
|
-
|
|
434
|
-
# Step 2: Handle variable captures from outer scopes
|
|
435
|
-
|
|
436
|
-
# Collect all variables captured by either branch function
|
|
437
|
-
# Example: then_fn captures (a, b), else_fn captures (a, c)
|
|
438
|
-
# Result: all_captures = [a, b, c] (union, order preserved)
|
|
439
|
-
all_captures = list((then_tfn.capture_map | else_tfn.capture_map).keys())
|
|
440
|
-
|
|
441
|
-
# Problem: Branch functions may capture variables from outer scopes, but
|
|
442
|
-
# expr only permits parameter passing from current scope.
|
|
443
|
-
#
|
|
444
|
-
# Scope diagram:
|
|
445
|
-
# outer_scope [var_a, var_b]
|
|
446
|
-
# |
|
|
447
|
-
# cur_tracer [pred, x] ← we are here
|
|
448
|
-
# |
|
|
449
|
-
# ┌────┴────┐
|
|
450
|
-
# then_fn else_fn ← both may capture var_a, var_b
|
|
451
|
-
# but expr needs them in cur_tracer!
|
|
452
|
-
#
|
|
453
|
-
# Solution: Re-capture all outer variables into current scope
|
|
454
|
-
# Before: var_a lives in outer_scope, branches reference it
|
|
455
|
-
# After: var_a is re-captured into cur_tracer, expr can use it
|
|
456
|
-
capture_vars = [
|
|
457
|
-
var if var.ctx is cur_tracer else cur_tracer.capture(var)
|
|
458
|
-
for var in all_captures
|
|
459
|
-
]
|
|
460
|
-
|
|
461
|
-
assert all(isinstance(var, TraceVar) for var in capture_vars), capture_vars
|
|
462
|
-
capture_exprs = [cast(TraceVar, var).expr for var in capture_vars]
|
|
463
|
-
|
|
464
|
-
# Step 3: Build the conditional expression
|
|
465
|
-
pred_expr = cast(TraceVar, pred).expr
|
|
466
|
-
arg_exprs = [arg.expr for arg in cast(list[TraceVar], args)]
|
|
467
|
-
|
|
468
|
-
# Input order: [regular_args, captured_vars]
|
|
469
|
-
in_exprs = arg_exprs + capture_exprs
|
|
470
|
-
|
|
471
|
-
# Generate branch functions with correct parameter mapping:
|
|
472
|
-
# Parameter list = [args_params, capture_params]
|
|
473
|
-
then_fn_expr = then_tfn.make_expr(
|
|
474
|
-
then_tfn.in_names() + then_tfn.capture_names(all_captures)
|
|
475
|
-
)
|
|
476
|
-
else_fn_expr = else_tfn.make_expr(
|
|
477
|
-
else_tfn.in_names() + else_tfn.capture_names(all_captures)
|
|
478
|
-
)
|
|
479
|
-
|
|
480
|
-
# Step 4: Create final conditional and return values
|
|
481
|
-
assert then_fn_expr is not None and else_fn_expr is not None
|
|
482
|
-
fn_expr = CondExpr(
|
|
483
|
-
pred_expr,
|
|
484
|
-
then_fn_expr,
|
|
485
|
-
else_fn_expr,
|
|
486
|
-
in_exprs,
|
|
487
|
-
verify_uniform=verify_uniform,
|
|
488
|
-
)
|
|
489
|
-
|
|
490
|
-
rets_expr = [AccessExpr(fn_expr, idx) for idx in range(fn_expr.num_outputs)]
|
|
491
|
-
out_vars = [TraceVar(cur_tracer, res) for res in rets_expr]
|
|
492
|
-
|
|
493
|
-
return var_demorph(out_vars, then_tfn.out_imms, then_tfn.out_struct) # type: ignore[no-any-return]
|
|
494
|
-
|
|
495
|
-
|
|
496
|
-
@function
|
|
497
|
-
def while_loop(
|
|
498
|
-
cond_fn: Callable[[Any], MPObject],
|
|
499
|
-
body_fn: Callable[[Any], Any],
|
|
500
|
-
init: Any,
|
|
501
|
-
) -> Any:
|
|
502
|
-
"""Multi-party while loop with condition and body functions.
|
|
503
|
-
|
|
504
|
-
This function implements iterative computation in multi-party settings using
|
|
505
|
-
a while loop construct. The loop continues executing as long as the condition
|
|
506
|
-
function returns true, with all parties maintaining synchronization throughout
|
|
507
|
-
the iteration process.
|
|
508
|
-
|
|
509
|
-
The condition function must return a scalar boolean value, and the body function
|
|
510
|
-
must have the same input and output signature to enable proper iteration. Both
|
|
511
|
-
functions operate on the loop variable, which is updated in each iteration.
|
|
512
|
-
|
|
513
|
-
Args:
|
|
514
|
-
cond_fn: A multi-party function that evaluates
|
|
515
|
-
the loop condition. Must take the same
|
|
516
|
-
input type as body_fn and return a single
|
|
517
|
-
scalar boolean output.
|
|
518
|
-
body_fn: A multi-party function that represents
|
|
519
|
-
the loop body. Must take the same input
|
|
520
|
-
type as cond_fn and return a single output
|
|
521
|
-
with the same type as its input (for state update).
|
|
522
|
-
init: The initial value for the loop variable. This value is passed
|
|
523
|
-
to both cond_fn and body_fn in the first iteration.
|
|
524
|
-
|
|
525
|
-
Returns:
|
|
526
|
-
MPObject: The final value of the loop variable after the while loop terminates.
|
|
527
|
-
The output type is inferred from the body function and initial value,
|
|
528
|
-
with conservative pmask if they change during iteration.
|
|
529
|
-
|
|
530
|
-
Raises:
|
|
531
|
-
ValueError: If cond_fn or body_fn don't have exactly one output,
|
|
532
|
-
if cond_fn output is not scalar, or if input signatures
|
|
533
|
-
are incompatible, or if body function output type doesn't
|
|
534
|
-
match initial state type.
|
|
535
|
-
|
|
536
|
-
Examples:
|
|
537
|
-
**Scenario 1 – Local (non-synchronized) predicate**
|
|
538
|
-
|
|
539
|
-
Each party decides *independently* when to leave the loop.
|
|
540
|
-
|
|
541
|
-
cond_fn: ``lambda x: x < 10``
|
|
542
|
-
body_fn: ``lambda x: x + constant(1)``
|
|
543
|
-
init: party-local values ``[0, 5]``
|
|
544
|
-
|
|
545
|
-
Iterations P0 P1
|
|
546
|
-
--------------------------------
|
|
547
|
-
start 0 5
|
|
548
|
-
after 1st iter 1 6
|
|
549
|
-
after 5th iter 5 10 ← P1 is done
|
|
550
|
-
after 10th iter 10 10 ← P0 is done
|
|
551
|
-
|
|
552
|
-
The parties stop at different iterations yet converge to the same final
|
|
553
|
-
value ``[10, 10]``. Such patterns are usually implemented more
|
|
554
|
-
efficiently via ``peval(jax.while_loop, …)``.
|
|
555
|
-
|
|
556
|
-
**Scenario 2 – Globally synchronized predicate**
|
|
557
|
-
|
|
558
|
-
All parties evaluate *exactly* the same boolean each round (e.g. via a
|
|
559
|
-
secret-shared reduction).
|
|
560
|
-
|
|
561
|
-
cond_fn::
|
|
562
|
-
sealed_sum = smpc.reveal(smpc.srun_jax(lambda x: jnp.sum(x), smpc.seal(x)))
|
|
563
|
-
return sealed_sum < constant(10)
|
|
564
|
-
|
|
565
|
-
body_fn::
|
|
566
|
-
return x + prank() # every party adds its own rank
|
|
567
|
-
|
|
568
|
-
Iterations (rank 0 & rank 1 example):
|
|
569
|
-
|
|
570
|
-
Iteration P0 (rank 0) P1 (rank 1) sealed_sum predicate
|
|
571
|
-
-------------------------------------------------------------------
|
|
572
|
-
start 0 5 5 True
|
|
573
|
-
after 1st iter 0 6 6 True
|
|
574
|
-
after 2nd iter 0 7 7 True
|
|
575
|
-
after 3rd iter 0 8 8 True
|
|
576
|
-
after 4th iter 0 9 9 True
|
|
577
|
-
after 5th iter 0 10 10 False ← loop exits *simultaneously*
|
|
578
|
-
|
|
579
|
-
Because the predicate is identical for every party at every step, they
|
|
580
|
-
enter and exit the loop together. Supporting such globally
|
|
581
|
-
synchronized control flow is the primary reason this primitive exists
|
|
582
|
-
(plain ``jax.while_loop`` cannot express it).
|
|
583
|
-
|
|
584
|
-
Note:
|
|
585
|
-
Control-flow execution domain (who runs cond/body) follows the outer context's
|
|
586
|
-
mask; we do not shrink the tracer at trace time based on state pmasks. Value
|
|
587
|
-
visibility and real participation are enforced per-op by argument pmask
|
|
588
|
-
intersection (and optional rmask). The loop state MPType (including pmask)
|
|
589
|
-
must remain identical across iterations. Both functions can capture variables
|
|
590
|
-
from outer scopes. This implementation is similar to JAX while_loop but
|
|
591
|
-
adapted for multi-party computation.
|
|
592
|
-
"""
|
|
593
|
-
cur_tracer = _tracer()
|
|
594
|
-
|
|
595
|
-
# Flatten init into loop-carried MPObject leaves, disallow non-MPObject leaves for now
|
|
596
|
-
is_mpobj = lambda x: isinstance(x, MPObject)
|
|
597
|
-
init_vars, init_imms, _init_struct = var_morph(init, is_mpobj)
|
|
598
|
-
|
|
599
|
-
if len(init_vars) == 0:
|
|
600
|
-
raise ValueError("while_loop requires at least one MPObject in init state")
|
|
601
|
-
if len(init_imms) != 0:
|
|
602
|
-
raise TypeError(
|
|
603
|
-
"while_loop init must be a PyTree of MPObjects (no Python/immediate leaves)"
|
|
604
|
-
)
|
|
605
|
-
|
|
606
|
-
cond_tracer = cur_tracer.fork()
|
|
607
|
-
cond_tfn = trace(cond_tracer, cond_fn, init)
|
|
608
|
-
|
|
609
|
-
body_tracer = cur_tracer.fork()
|
|
610
|
-
body_tfn = trace(body_tracer, body_fn, init)
|
|
611
|
-
|
|
612
|
-
# Validate cond returns single value
|
|
613
|
-
if len(cond_tfn.out_vars) != 1:
|
|
614
|
-
raise ValueError(
|
|
615
|
-
f"Condition function must return a single boolean variable: got {len(cond_tfn.out_vars)} outputs"
|
|
616
|
-
)
|
|
617
|
-
cond_out_var = cond_tfn.out_vars[0]
|
|
618
|
-
if len(cond_out_var.mptype.shape) != 0:
|
|
619
|
-
raise TypeError(
|
|
620
|
-
f"Condition function must return a scalar, but got shape {cond_out_var.mptype.shape}"
|
|
621
|
-
)
|
|
622
|
-
# Enforce boolean dtype for condition
|
|
623
|
-
if cond_out_var.mptype.dtype != BOOL:
|
|
624
|
-
raise TypeError(
|
|
625
|
-
f"Condition function must return a boolean scalar, got dtype {cond_out_var.mptype.dtype}"
|
|
626
|
-
)
|
|
627
|
-
|
|
628
|
-
# Static pmask rule:
|
|
629
|
-
# If the predicate's pmask is statically known it must match the trace context
|
|
630
|
-
# mask. Otherwise some parties in this context would lack a boolean to drive
|
|
631
|
-
# control flow (previously could lead to hang via None). To restrict to a subset:
|
|
632
|
-
# 1. Trace the entire while_loop under a subset context (ctx.fork(mask=submask)), or
|
|
633
|
-
# 2. Broadcast predicate to full mask (e.g. pshfl_s) before while_loop.
|
|
634
|
-
# Dynamic predicates (pmask=None) are allowed; runtime guard (evaluator) raises
|
|
635
|
-
# if any participating party observes None.
|
|
636
|
-
pred_pmask = cond_out_var.mptype.pmask
|
|
637
|
-
if pred_pmask is not None and pred_pmask != cur_tracer.mask:
|
|
638
|
-
raise ValueError(
|
|
639
|
-
"while_loop predicate static pmask mismatch: predicate pmask="
|
|
640
|
-
f"{pred_pmask} trace mask={cur_tracer.mask}. Trace under subset context "
|
|
641
|
-
"or broadcast predicate to all parties."
|
|
642
|
-
)
|
|
643
|
-
|
|
644
|
-
# Validate body returns same number of leaves and same dtype/shape per leaf
|
|
645
|
-
if len(body_tfn.out_vars) != len(cond_tfn.in_vars):
|
|
646
|
-
raise ValueError(
|
|
647
|
-
"Body function must return the same number of MPObject leaves as the init state"
|
|
648
|
-
)
|
|
649
|
-
for i, (out_v, in_v) in enumerate(
|
|
650
|
-
zip(body_tfn.out_vars, cond_tfn.in_vars, strict=True)
|
|
651
|
-
):
|
|
652
|
-
if out_v.mptype != in_v.mptype:
|
|
653
|
-
raise TypeError(
|
|
654
|
-
f"Body output leaf {i} type mismatch: {out_v.mptype} vs {in_v.mptype}"
|
|
655
|
-
)
|
|
656
|
-
|
|
657
|
-
# Handle variable captures from outer scopes (union of both functions)
|
|
658
|
-
all_captures = list((cond_tfn.capture_map | body_tfn.capture_map).keys())
|
|
659
|
-
capture_vars = [
|
|
660
|
-
var if var.ctx is cur_tracer else cur_tracer.capture(var)
|
|
661
|
-
for var in all_captures
|
|
662
|
-
]
|
|
663
|
-
assert all(isinstance(var, TraceVar) for var in capture_vars), capture_vars
|
|
664
|
-
|
|
665
|
-
# Build WhileExpr with all state leaves followed by captures
|
|
666
|
-
state_exprs = [cast(TraceVar, v).expr for v in init_vars]
|
|
667
|
-
capture_exprs = [cast(TraceVar, var).expr for var in capture_vars]
|
|
668
|
-
|
|
669
|
-
cond_fn_expr = cond_tfn.make_expr(
|
|
670
|
-
cond_tfn.in_names() + cond_tfn.capture_names(all_captures)
|
|
671
|
-
)
|
|
672
|
-
body_fn_expr = body_tfn.make_expr(
|
|
673
|
-
body_tfn.in_names() + body_tfn.capture_names(all_captures)
|
|
674
|
-
)
|
|
675
|
-
|
|
676
|
-
assert cond_fn_expr is not None and body_fn_expr is not None
|
|
677
|
-
all_args = [*state_exprs, *capture_exprs]
|
|
678
|
-
out_expr = WhileExpr(cond_fn_expr, body_fn_expr, all_args)
|
|
679
|
-
|
|
680
|
-
# Materialize outputs and reconstruct the original PyTree of init (args part)
|
|
681
|
-
rets_expr = [AccessExpr(out_expr, idx) for idx in range(out_expr.num_outputs)]
|
|
682
|
-
out_vars = [TraceVar(cur_tracer, res) for res in rets_expr]
|
|
683
|
-
|
|
684
|
-
# Reconstruct the Python return using the body function's output structure
|
|
685
|
-
# This preserves the exact PyTree the body returns (matching JAX semantics).
|
|
686
|
-
return var_demorph(out_vars, body_tfn.out_imms, body_tfn.out_struct)
|
|
687
|
-
|
|
688
|
-
|
|
689
|
-
@function
|
|
690
|
-
def pshfl(src: MPObject, index: MPObject) -> MPObject:
|
|
691
|
-
"""Shuffle the input tensor to the specified index (dynamic version).
|
|
692
|
-
|
|
693
|
-
This operation redistributes data from the source tensor to the target index
|
|
694
|
-
based on the provided index tensor. Each output party receives data from the
|
|
695
|
-
corresponding index in the source tensor.
|
|
696
|
-
|
|
697
|
-
Semantics:
|
|
698
|
-
- If index[i] is None (runtime pmask of the i'th party is None), then the
|
|
699
|
-
i'th party will receive None as the result.
|
|
700
|
-
- If src[index[i]] is None (cannot source the variable from the index[i]'th
|
|
701
|
-
party because that party doesn't hold the data), the runtime will raise
|
|
702
|
-
an exception.
|
|
703
|
-
- The operation requires that for each valid index[i], the corresponding
|
|
704
|
-
party index[i] must actually hold the source data in src.
|
|
705
|
-
|
|
706
|
-
Args:
|
|
707
|
-
src: The input tensor to be shuffled. Must be held by the parties
|
|
708
|
-
that will be referenced by the index values.
|
|
709
|
-
index: The index tensor indicating which source parties to fetch
|
|
710
|
-
data from. Must be a scalar tensor. Each party uses its
|
|
711
|
-
local index value to determine which source party to
|
|
712
|
-
fetch data from.
|
|
713
|
-
|
|
714
|
-
Returns:
|
|
715
|
-
MPObject: The shuffled tensor with data redistributed according to the index.
|
|
716
|
-
Parties with None index will receive None. The output pmask is
|
|
717
|
-
inherited from index.pmask.
|
|
718
|
-
|
|
719
|
-
Raises:
|
|
720
|
-
ValueError: If the index tensor is not a scalar.
|
|
721
|
-
RuntimeError: If src[index[i]] is None for any valid index[i] (i.e.,
|
|
722
|
-
trying to fetch from a party that doesn't hold the data).
|
|
723
|
-
|
|
724
|
-
Examples:
|
|
725
|
-
`index` is a distributed tensor where each party holds the rank of the
|
|
726
|
-
party it wants to pull data from.
|
|
727
|
-
|
|
728
|
-
**Example 1: Basic dynamic shuffle**
|
|
729
|
-
P0 P1 P2
|
|
730
|
-
-- -- --
|
|
731
|
-
Input: x0 - x2
|
|
732
|
-
Index: - 0 - (P1's index is 0, fetches from P0)
|
|
733
|
-
-----------------------------------------------------------
|
|
734
|
-
Output: - x0 -
|
|
735
|
-
|
|
736
|
-
**Example 2: Cross shuffle**
|
|
737
|
-
P0 P1 P2
|
|
738
|
-
-- -- --
|
|
739
|
-
Input: x0 x1 x2
|
|
740
|
-
Index: 2 0 1 (P0←P2, P1←P0, P2←P1)
|
|
741
|
-
-----------------------------------------------------------
|
|
742
|
-
Output: x2 x0 x1
|
|
743
|
-
|
|
744
|
-
**Example 3: Error case - invalid source**
|
|
745
|
-
P0 P1 P2
|
|
746
|
-
-- -- --
|
|
747
|
-
Input: x0 - - (only P0 has data)
|
|
748
|
-
Index: - 1 - (P1 tries to fetch from P1, which has no data)
|
|
749
|
-
-----------------------------------------------------------
|
|
750
|
-
Result: RuntimeError
|
|
751
|
-
"""
|
|
752
|
-
src_expr = cast(TraceVar, src).expr
|
|
753
|
-
index_expr = cast(TraceVar, index).expr
|
|
754
|
-
|
|
755
|
-
shfl_expr = ShflExpr(src_expr, index_expr)
|
|
756
|
-
return TraceVar(_tracer(), shfl_expr)
|
|
757
|
-
|
|
758
|
-
|
|
759
|
-
@function
|
|
760
|
-
def pshfl_s(src_val: MPObject, pmask: Mask, src_ranks: list[Rank]) -> MPObject:
|
|
761
|
-
"""Shuffle the input tensor to the specified rank, static version.
|
|
762
|
-
|
|
763
|
-
This operation redistributes data from source ranks to target ranks based on
|
|
764
|
-
the specified mapping. Each output party receives data from its corresponding
|
|
765
|
-
source rank.
|
|
766
|
-
|
|
767
|
-
Args:
|
|
768
|
-
src_val: The input tensor to be shuffled.
|
|
769
|
-
pmask: The mask indicating which parties will hold the output.
|
|
770
|
-
Only parties with non-zero bits in pmask will receive output.
|
|
771
|
-
src_ranks: List of source ranks. The i-th output party
|
|
772
|
-
(i-th non-zero bit in pmask) receives data from
|
|
773
|
-
src_ranks[i].
|
|
774
|
-
|
|
775
|
-
Returns:
|
|
776
|
-
MPObject: The shuffled tensor with data redistributed according to the
|
|
777
|
-
src_ranks mapping.
|
|
778
|
-
|
|
779
|
-
Raises:
|
|
780
|
-
ValueError: If any rank in src_ranks is not present in src_val.pmask,
|
|
781
|
-
or if src_ranks length doesn't match the number of bits in pmask.
|
|
782
|
-
|
|
783
|
-
Examples:
|
|
784
|
-
`pmask` and `src_ranks` define the shuffle. `pmask` selects the parties
|
|
785
|
-
that will produce an output. `src_ranks` provides the source rank for
|
|
786
|
-
each of these active parties. The "Logical Index" below illustrates
|
|
787
|
-
the source for each party.
|
|
788
|
-
|
|
789
|
-
**Example 1: Basic shuffle from P1 to P0**
|
|
790
|
-
P0 P1 P2
|
|
791
|
-
-- -- --
|
|
792
|
-
Input: - x1 -
|
|
793
|
-
Logical Index: 1 - - ; pmask=[0], src_ranks=[1]
|
|
794
|
-
-----------------------------------------------------------
|
|
795
|
-
Output: x1 - -
|
|
796
|
-
|
|
797
|
-
**Example 2: Multiple party shuffle**
|
|
798
|
-
P0 P1 P2 P3
|
|
799
|
-
-- -- -- --
|
|
800
|
-
Input: x0 x1 - x3
|
|
801
|
-
Logical Index: - 0 - 3 ; pmask=[1,3], src_ranks=[0,3]
|
|
802
|
-
-----------------------------------------------------------
|
|
803
|
-
Output: - x0 - x3
|
|
804
|
-
|
|
805
|
-
**Example 3: Cross shuffle**
|
|
806
|
-
P0 P1 P2
|
|
807
|
-
-- -- --
|
|
808
|
-
Input: x0 x1 x2
|
|
809
|
-
Logical Index: 2 0 1 ; pmask=[0,1,2], src_ranks=[2,0,1]
|
|
810
|
-
-----------------------------------------------------------
|
|
811
|
-
Output: x2 x0 x1
|
|
812
|
-
"""
|
|
813
|
-
src_expr = cast(TraceVar, src_val).expr
|
|
814
|
-
shfl_s_expr = ShflSExpr(src_expr, pmask, src_ranks)
|
|
815
|
-
return TraceVar(_tracer(), shfl_s_expr)
|
|
816
|
-
|
|
817
|
-
|
|
818
|
-
@function
|
|
819
|
-
def pconv(vars: list[MPObject]) -> MPObject:
|
|
820
|
-
"""Combine multiple variables that share the same dtype and shape into one.
|
|
821
|
-
|
|
822
|
-
This function combines multiple variables that share the same dtype and shape
|
|
823
|
-
into one. The input variables are assumed to have non-intersecting pmasks (actual holders),
|
|
824
|
-
meaning each variable is held by different parties.
|
|
825
|
-
|
|
826
|
-
If the pmasks intersect, the compiler or runtime will raise an error.
|
|
827
|
-
|
|
828
|
-
Args:
|
|
829
|
-
vars: A list of MPObject variables with identical dtype
|
|
830
|
-
and shape but disjoint pmasks.
|
|
831
|
-
|
|
832
|
-
Returns:
|
|
833
|
-
MPObject: A single variable that represents the convergence of all input
|
|
834
|
-
variables.
|
|
835
|
-
|
|
836
|
-
Raises:
|
|
837
|
-
ValueError: If vars is empty or if the pmasks of input variables intersect,
|
|
838
|
-
indicating conflicting ownership of the same data partitions.
|
|
839
|
-
TypeError: If the input variables don't have identical types (dtype and shape).
|
|
840
|
-
|
|
841
|
-
Examples:
|
|
842
|
-
**Example 1 – merge two disjoint variables**
|
|
843
|
-
|
|
844
|
-
P0 P1 P2
|
|
845
|
-
-- -- --
|
|
846
|
-
x0: x0 - -
|
|
847
|
-
x1: - x1 -
|
|
848
|
-
---------------------------------
|
|
849
|
-
Output: x0 x1 -
|
|
850
|
-
|
|
851
|
-
**Example 2 – merge three parties**
|
|
852
|
-
|
|
853
|
-
P0 P1 P2
|
|
854
|
-
-- -- --
|
|
855
|
-
a: a0 - -
|
|
856
|
-
b: - b1 -
|
|
857
|
-
c: - - c2
|
|
858
|
-
---------------------------------
|
|
859
|
-
Output: a0 b1 c2
|
|
860
|
-
|
|
861
|
-
**Example 3 – error (overlapping pmask)**
|
|
862
|
-
|
|
863
|
-
P0 P1
|
|
864
|
-
-- --
|
|
865
|
-
u: u0 -
|
|
866
|
-
v: v0 - ← overlap on P0
|
|
867
|
-
---------------------------------
|
|
868
|
-
pconv([u, v]) # raises ValueError
|
|
869
|
-
|
|
870
|
-
Note:
|
|
871
|
-
This operation is used to combine multiple variables into a single object,
|
|
872
|
-
typically for unifying data held by different parties. The resulting variable
|
|
873
|
-
has a pmask that is the union of all input pmasks.
|
|
874
|
-
"""
|
|
875
|
-
var_exprs = [cast(TraceVar, var).expr for var in vars]
|
|
876
|
-
conv_expr = ConvExpr(var_exprs)
|
|
877
|
-
return TraceVar(_tracer(), conv_expr)
|