mplang-nightly 0.1.dev192__py3-none-any.whl → 0.1.dev268__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mplang/__init__.py +21 -130
- mplang/py.typed +13 -0
- mplang/v1/__init__.py +157 -0
- mplang/v1/_device.py +602 -0
- mplang/{analysis → v1/analysis}/__init__.py +1 -1
- mplang/{analysis → v1/analysis}/diagram.py +4 -4
- mplang/{core → v1/core}/__init__.py +20 -14
- mplang/{core → v1/core}/cluster.py +6 -1
- mplang/{core → v1/core}/comm.py +1 -1
- mplang/{core → v1/core}/context_mgr.py +1 -1
- mplang/{core → v1/core}/dtypes.py +38 -0
- mplang/{core → v1/core}/expr/__init__.py +7 -7
- mplang/{core → v1/core}/expr/ast.py +11 -13
- mplang/{core → v1/core}/expr/evaluator.py +8 -8
- mplang/{core → v1/core}/expr/printer.py +6 -6
- mplang/{core → v1/core}/expr/transformer.py +2 -2
- mplang/{core → v1/core}/expr/utils.py +2 -2
- mplang/{core → v1/core}/expr/visitor.py +1 -1
- mplang/{core → v1/core}/expr/walk.py +1 -1
- mplang/{core → v1/core}/interp.py +6 -6
- mplang/{core → v1/core}/mpir.py +13 -11
- mplang/{core → v1/core}/mpobject.py +6 -6
- mplang/{core → v1/core}/mptype.py +13 -10
- mplang/{core → v1/core}/pfunc.py +2 -2
- mplang/{core → v1/core}/primitive.py +12 -12
- mplang/{core → v1/core}/table.py +36 -8
- mplang/{core → v1/core}/tensor.py +1 -1
- mplang/{core → v1/core}/tracer.py +9 -9
- mplang/{host.py → v1/host.py} +5 -5
- mplang/{kernels → v1/kernels}/__init__.py +1 -1
- mplang/{kernels → v1/kernels}/base.py +1 -1
- mplang/{kernels → v1/kernels}/basic.py +15 -15
- mplang/{kernels → v1/kernels}/context.py +19 -16
- mplang/{kernels → v1/kernels}/crypto.py +8 -10
- mplang/{kernels → v1/kernels}/fhe.py +9 -7
- mplang/{kernels → v1/kernels}/mock_tee.py +3 -3
- mplang/{kernels → v1/kernels}/phe.py +26 -18
- mplang/{kernels → v1/kernels}/spu.py +5 -5
- mplang/{kernels → v1/kernels}/sql_duckdb.py +5 -3
- mplang/{kernels → v1/kernels}/stablehlo.py +18 -17
- mplang/{kernels → v1/kernels}/value.py +2 -2
- mplang/{ops → v1/ops}/__init__.py +3 -3
- mplang/{ops → v1/ops}/base.py +1 -1
- mplang/{ops → v1/ops}/basic.py +6 -5
- mplang/v1/ops/crypto.py +262 -0
- mplang/{ops → v1/ops}/fhe.py +2 -2
- mplang/{ops → v1/ops}/jax_cc.py +26 -59
- mplang/v1/ops/nnx_cc.py +168 -0
- mplang/{ops → v1/ops}/phe.py +16 -3
- mplang/{ops → v1/ops}/spu.py +3 -3
- mplang/v1/ops/sql_cc.py +303 -0
- mplang/{ops → v1/ops}/tee.py +2 -2
- mplang/{runtime → v1/runtime}/__init__.py +2 -2
- mplang/v1/runtime/channel.py +230 -0
- mplang/{runtime → v1/runtime}/cli.py +3 -3
- mplang/{runtime → v1/runtime}/client.py +1 -1
- mplang/{runtime → v1/runtime}/communicator.py +39 -15
- mplang/{runtime → v1/runtime}/data_providers.py +80 -19
- mplang/{runtime → v1/runtime}/driver.py +4 -4
- mplang/v1/runtime/link_comm.py +196 -0
- mplang/{runtime → v1/runtime}/server.py +22 -9
- mplang/{runtime → v1/runtime}/session.py +24 -51
- mplang/{runtime → v1/runtime}/simulation.py +36 -14
- mplang/{simp → v1/simp}/api.py +72 -14
- mplang/{simp → v1/simp}/mpi.py +1 -1
- mplang/{simp → v1/simp}/party.py +5 -5
- mplang/{simp → v1/simp}/random.py +2 -2
- mplang/v1/simp/smpc.py +238 -0
- mplang/v1/utils/table_utils.py +185 -0
- mplang/v2/__init__.py +424 -0
- mplang/v2/backends/__init__.py +57 -0
- mplang/v2/backends/bfv_impl.py +705 -0
- mplang/v2/backends/channel.py +217 -0
- mplang/v2/backends/crypto_impl.py +723 -0
- mplang/v2/backends/field_impl.py +454 -0
- mplang/v2/backends/func_impl.py +107 -0
- mplang/v2/backends/phe_impl.py +148 -0
- mplang/v2/backends/simp_design.md +136 -0
- mplang/v2/backends/simp_driver/__init__.py +41 -0
- mplang/v2/backends/simp_driver/http.py +168 -0
- mplang/v2/backends/simp_driver/mem.py +280 -0
- mplang/v2/backends/simp_driver/ops.py +135 -0
- mplang/v2/backends/simp_driver/state.py +60 -0
- mplang/v2/backends/simp_driver/values.py +52 -0
- mplang/v2/backends/simp_worker/__init__.py +29 -0
- mplang/v2/backends/simp_worker/http.py +354 -0
- mplang/v2/backends/simp_worker/mem.py +102 -0
- mplang/v2/backends/simp_worker/ops.py +167 -0
- mplang/v2/backends/simp_worker/state.py +49 -0
- mplang/v2/backends/spu_impl.py +275 -0
- mplang/v2/backends/spu_state.py +187 -0
- mplang/v2/backends/store_impl.py +62 -0
- mplang/v2/backends/table_impl.py +838 -0
- mplang/v2/backends/tee_impl.py +215 -0
- mplang/v2/backends/tensor_impl.py +519 -0
- mplang/v2/cli.py +603 -0
- mplang/v2/cli_guide.md +122 -0
- mplang/v2/dialects/__init__.py +36 -0
- mplang/v2/dialects/bfv.py +665 -0
- mplang/v2/dialects/crypto.py +689 -0
- mplang/v2/dialects/dtypes.py +378 -0
- mplang/v2/dialects/field.py +210 -0
- mplang/v2/dialects/func.py +135 -0
- mplang/v2/dialects/phe.py +723 -0
- mplang/v2/dialects/simp.py +944 -0
- mplang/v2/dialects/spu.py +349 -0
- mplang/v2/dialects/store.py +63 -0
- mplang/v2/dialects/table.py +407 -0
- mplang/v2/dialects/tee.py +346 -0
- mplang/v2/dialects/tensor.py +1175 -0
- mplang/v2/edsl/README.md +279 -0
- mplang/v2/edsl/__init__.py +99 -0
- mplang/v2/edsl/context.py +311 -0
- mplang/v2/edsl/graph.py +463 -0
- mplang/v2/edsl/jit.py +62 -0
- mplang/v2/edsl/object.py +53 -0
- mplang/v2/edsl/primitive.py +284 -0
- mplang/v2/edsl/printer.py +119 -0
- mplang/v2/edsl/registry.py +207 -0
- mplang/v2/edsl/serde.py +375 -0
- mplang/v2/edsl/tracer.py +614 -0
- mplang/v2/edsl/typing.py +816 -0
- mplang/v2/kernels/Makefile +30 -0
- mplang/v2/kernels/__init__.py +23 -0
- mplang/v2/kernels/gf128.cpp +148 -0
- mplang/v2/kernels/ldpc.cpp +82 -0
- mplang/v2/kernels/okvs.cpp +283 -0
- mplang/v2/kernels/okvs_opt.cpp +291 -0
- mplang/v2/kernels/py_kernels.py +398 -0
- mplang/v2/libs/collective.py +330 -0
- mplang/v2/libs/device/__init__.py +51 -0
- mplang/v2/libs/device/api.py +813 -0
- mplang/v2/libs/device/cluster.py +352 -0
- mplang/v2/libs/ml/__init__.py +23 -0
- mplang/v2/libs/ml/sgb.py +1861 -0
- mplang/v2/libs/mpc/__init__.py +41 -0
- mplang/v2/libs/mpc/_utils.py +99 -0
- mplang/v2/libs/mpc/analytics/__init__.py +35 -0
- mplang/v2/libs/mpc/analytics/aggregation.py +372 -0
- mplang/v2/libs/mpc/analytics/groupby.md +99 -0
- mplang/v2/libs/mpc/analytics/groupby.py +331 -0
- mplang/v2/libs/mpc/analytics/permutation.py +386 -0
- mplang/v2/libs/mpc/common/constants.py +39 -0
- mplang/v2/libs/mpc/ot/__init__.py +32 -0
- mplang/v2/libs/mpc/ot/base.py +222 -0
- mplang/v2/libs/mpc/ot/extension.py +477 -0
- mplang/v2/libs/mpc/ot/silent.py +217 -0
- mplang/v2/libs/mpc/psi/__init__.py +40 -0
- mplang/v2/libs/mpc/psi/cuckoo.py +228 -0
- mplang/v2/libs/mpc/psi/okvs.py +49 -0
- mplang/v2/libs/mpc/psi/okvs_gct.py +79 -0
- mplang/v2/libs/mpc/psi/oprf.py +310 -0
- mplang/v2/libs/mpc/psi/rr22.py +344 -0
- mplang/v2/libs/mpc/psi/unbalanced.py +200 -0
- mplang/v2/libs/mpc/vole/__init__.py +31 -0
- mplang/v2/libs/mpc/vole/gilboa.py +327 -0
- mplang/v2/libs/mpc/vole/ldpc.py +383 -0
- mplang/v2/libs/mpc/vole/silver.py +336 -0
- mplang/v2/runtime/__init__.py +15 -0
- mplang/v2/runtime/dialect_state.py +41 -0
- mplang/v2/runtime/interpreter.py +871 -0
- mplang/v2/runtime/object_store.py +194 -0
- mplang/v2/runtime/value.py +141 -0
- {mplang_nightly-0.1.dev192.dist-info → mplang_nightly-0.1.dev268.dist-info}/METADATA +22 -16
- mplang_nightly-0.1.dev268.dist-info/RECORD +180 -0
- {mplang_nightly-0.1.dev192.dist-info → mplang_nightly-0.1.dev268.dist-info}/WHEEL +1 -1
- mplang/device.py +0 -327
- mplang/ops/crypto.py +0 -108
- mplang/ops/ibis_cc.py +0 -136
- mplang/ops/sql_cc.py +0 -62
- mplang/runtime/link_comm.py +0 -78
- mplang/simp/smpc.py +0 -201
- mplang/utils/table_utils.py +0 -85
- mplang_nightly-0.1.dev192.dist-info/RECORD +0 -83
- /mplang/{core → v1/core}/mask.py +0 -0
- /mplang/{protos → v1/protos}/v1alpha1/mpir_pb2.py +0 -0
- /mplang/{protos → v1/protos}/v1alpha1/mpir_pb2.pyi +0 -0
- /mplang/{protos → v1/protos}/v1alpha1/value_pb2.py +0 -0
- /mplang/{protos → v1/protos}/v1alpha1/value_pb2.pyi +0 -0
- /mplang/{runtime → v1/runtime}/exceptions.py +0 -0
- /mplang/{runtime → v1/runtime}/http_api.md +0 -0
- /mplang/{simp → v1/simp}/__init__.py +0 -0
- /mplang/{utils → v1/utils}/__init__.py +0 -0
- /mplang/{utils → v1/utils}/crypto.py +0 -0
- /mplang/{utils → v1/utils}/func_utils.py +0 -0
- /mplang/{utils → v1/utils}/spu_utils.py +0 -0
- {mplang_nightly-0.1.dev192.dist-info → mplang_nightly-0.1.dev268.dist-info}/entry_points.txt +0 -0
- {mplang_nightly-0.1.dev192.dist-info → mplang_nightly-0.1.dev268.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,944 @@
|
|
|
1
|
+
# Copyright 2025 Ant Group Co., Ltd.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
"""SIMP dialect: SPMD multi-party primitives for EDSL.
|
|
16
|
+
|
|
17
|
+
Provides control flow and communication primitives:
|
|
18
|
+
- pcall_static: Party call with explicit static parties
|
|
19
|
+
- pcall_dynamic: Party call where all parties attempt execution (output always dynamic)
|
|
20
|
+
- shuffle_dynamic, shuffle: Data redistribution
|
|
21
|
+
- converge: Merge disjoint partitions
|
|
22
|
+
- uniform_cond: Uniform conditional (eager mode)
|
|
23
|
+
- while_loop: While loop (eager mode)
|
|
24
|
+
|
|
25
|
+
Primitive definition guideline:
|
|
26
|
+
- Simple ops (add, mul) → use def_abstract_eval
|
|
27
|
+
- Complex ops (control flow, fork tracer) → use def_trace
|
|
28
|
+
|
|
29
|
+
See individual primitive docstrings for detailed documentation.
|
|
30
|
+
"""
|
|
31
|
+
|
|
32
|
+
from __future__ import annotations
|
|
33
|
+
|
|
34
|
+
from collections.abc import Callable, Sequence
|
|
35
|
+
from typing import Any, cast
|
|
36
|
+
|
|
37
|
+
from jax.tree_util import tree_flatten, tree_unflatten
|
|
38
|
+
|
|
39
|
+
import mplang.v2.edsl as el
|
|
40
|
+
import mplang.v2.edsl.typing as elt
|
|
41
|
+
|
|
42
|
+
# ---------------------------------------------------------------------------
|
|
43
|
+
# Global configuration
|
|
44
|
+
# ---------------------------------------------------------------------------
|
|
45
|
+
|
|
46
|
+
# Whether to verify predicate uniformity at runtime in uniform_cond
|
|
47
|
+
# Set to False to disable runtime checks (useful for testing or when
|
|
48
|
+
# uniformity is guaranteed)
|
|
49
|
+
VERIFY_UNIFORM_DEFAULT = True
|
|
50
|
+
|
|
51
|
+
# ---------------------------------------------------------------------------
|
|
52
|
+
# Helper utilities
|
|
53
|
+
# ---------------------------------------------------------------------------
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
def _validate_scalar_predicate(value: el.graph.Value, context: str) -> None:
|
|
57
|
+
"""Validate that a graph value represents a scalar predicate."""
|
|
58
|
+
shape = getattr(value.type, "shape", None)
|
|
59
|
+
if shape is not None and shape != ():
|
|
60
|
+
raise TypeError(
|
|
61
|
+
f"{context} must be scalar, got shape {shape} with type {value.type}"
|
|
62
|
+
)
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
def _merge_captures(*capture_lists: list[el.Object]) -> list[el.Object]:
|
|
66
|
+
"""Merge capture lists while preserving first-seen order and deduplicating by id."""
|
|
67
|
+
seen: dict[int, el.Object] = {}
|
|
68
|
+
for obj in (o for lst in capture_lists for o in lst):
|
|
69
|
+
seen.setdefault(id(obj), obj)
|
|
70
|
+
return list(seen.values())
|
|
71
|
+
|
|
72
|
+
|
|
73
|
+
def _deduce_parties(types: Sequence[elt.BaseType]) -> tuple[int, ...] | None:
|
|
74
|
+
"""Deduce common parties by intersecting all known party sets."""
|
|
75
|
+
if not types:
|
|
76
|
+
return None
|
|
77
|
+
|
|
78
|
+
# Extract parties from MPType objects
|
|
79
|
+
parties_list = []
|
|
80
|
+
for tp in types:
|
|
81
|
+
if isinstance(tp, elt.MPType):
|
|
82
|
+
parties_list.append(tp.parties)
|
|
83
|
+
|
|
84
|
+
if not parties_list:
|
|
85
|
+
return None
|
|
86
|
+
|
|
87
|
+
if any(p is None for p in parties_list):
|
|
88
|
+
return None
|
|
89
|
+
|
|
90
|
+
# Intersect all party sets (we know all parties are not None here)
|
|
91
|
+
first_parties = parties_list[0]
|
|
92
|
+
assert first_parties is not None
|
|
93
|
+
current = set(first_parties)
|
|
94
|
+
for parties in parties_list[1:]:
|
|
95
|
+
assert parties is not None
|
|
96
|
+
current &= set(parties)
|
|
97
|
+
return tuple(sorted(current))
|
|
98
|
+
|
|
99
|
+
|
|
100
|
+
class _LocalMPTracer(el.Tracer):
|
|
101
|
+
"""Tracer for single-party regions executed under MP context."""
|
|
102
|
+
|
|
103
|
+
def _lift_type(self, obj: el.Object) -> elt.BaseType:
|
|
104
|
+
"""Override to unwrap MP-typed Objects to their value types.
|
|
105
|
+
|
|
106
|
+
This enables single-party regions to work with the underlying value types.
|
|
107
|
+
MP-typed objects are unwrapped to their value types.
|
|
108
|
+
Other types (e.g. TensorType) are passed through as-is (treated as public/replicated).
|
|
109
|
+
|
|
110
|
+
Args:
|
|
111
|
+
obj: Object to lift
|
|
112
|
+
|
|
113
|
+
Returns:
|
|
114
|
+
value_type (unwrapped from MPType) or original type
|
|
115
|
+
"""
|
|
116
|
+
obj_type = obj.type
|
|
117
|
+
if isinstance(obj_type, elt.MPType):
|
|
118
|
+
return cast(elt.BaseType, obj_type.value_type)
|
|
119
|
+
return cast(elt.BaseType, obj_type)
|
|
120
|
+
|
|
121
|
+
|
|
122
|
+
# ---------------------------------------------------------------------------
|
|
123
|
+
# Control flow (scaffold)
|
|
124
|
+
# ---------------------------------------------------------------------------
|
|
125
|
+
|
|
126
|
+
|
|
127
|
+
uniform_cond_p: el.Primitive[Any] = el.Primitive("simp.uniform_cond")
|
|
128
|
+
|
|
129
|
+
|
|
130
|
+
@uniform_cond_p.def_trace
|
|
131
|
+
def _uniform_cond_trace(
|
|
132
|
+
pred: el.Object,
|
|
133
|
+
then_fn: Callable[..., Any],
|
|
134
|
+
else_fn: Callable[..., Any],
|
|
135
|
+
*args: Any,
|
|
136
|
+
**kwargs: Any,
|
|
137
|
+
) -> Any:
|
|
138
|
+
"""Implementation for uniform_cond in trace mode.
|
|
139
|
+
|
|
140
|
+
Uses def_trace (not def_abstract_eval) because uniform_cond is a complex
|
|
141
|
+
control flow primitive that requires forking tracers for both branches.
|
|
142
|
+
|
|
143
|
+
Args:
|
|
144
|
+
pred: Boolean scalar TraceObject (must be uniform across all parties)
|
|
145
|
+
then_fn: Callable accepting (*args, **kwargs) to execute when pred is True
|
|
146
|
+
else_fn: Callable accepting (*args, **kwargs) to execute when pred is False
|
|
147
|
+
*args: Positional arguments to pass to branch functions
|
|
148
|
+
**kwargs: Keyword arguments to pass to branch functions
|
|
149
|
+
|
|
150
|
+
Returns:
|
|
151
|
+
Result from tracing both branches (TraceObject or tuple of TraceObjects)
|
|
152
|
+
|
|
153
|
+
Raises:
|
|
154
|
+
TypeError: If pred is not TraceObject, branches are not callable,
|
|
155
|
+
or branch outputs have mismatched types/counts
|
|
156
|
+
|
|
157
|
+
Note:
|
|
158
|
+
The verify_uniform flag is controlled by the global
|
|
159
|
+
VERIFY_UNIFORM_DEFAULT config. To change it, set
|
|
160
|
+
mplang.dialects.simp.VERIFY_UNIFORM_DEFAULT = False
|
|
161
|
+
|
|
162
|
+
Example:
|
|
163
|
+
>>> def then_fn(x, y):
|
|
164
|
+
... return x + y
|
|
165
|
+
>>> def else_fn(x, y):
|
|
166
|
+
... return x - y
|
|
167
|
+
>>> result = uniform_cond(pred, then_fn, else_fn, x, y)
|
|
168
|
+
"""
|
|
169
|
+
cur_ctx = el.get_current_context()
|
|
170
|
+
assert isinstance(cur_ctx, el.Tracer)
|
|
171
|
+
|
|
172
|
+
if not isinstance(pred, el.TraceObject):
|
|
173
|
+
raise TypeError(f"predicate must be TraceObject, got {type(pred)}")
|
|
174
|
+
_validate_scalar_predicate(pred._graph_value, "uniform_cond predicate")
|
|
175
|
+
if not callable(then_fn) or not callable(else_fn):
|
|
176
|
+
raise TypeError("In trace mode, both branches must be callable functions")
|
|
177
|
+
|
|
178
|
+
then_traced = el.trace(then_fn, *args, **kwargs)
|
|
179
|
+
else_traced = el.trace(else_fn, *args, **kwargs)
|
|
180
|
+
if not then_traced.is_output_signature_match(else_traced):
|
|
181
|
+
then_types = [v.type for v in then_traced.graph.outputs]
|
|
182
|
+
else_types = [v.type for v in else_traced.graph.outputs]
|
|
183
|
+
raise TypeError(
|
|
184
|
+
"uniform_cond branch output signature mismatch: "
|
|
185
|
+
f"then={then_types}, else={else_types}"
|
|
186
|
+
)
|
|
187
|
+
|
|
188
|
+
num_arg_vars = len(then_traced.in_var_pos)
|
|
189
|
+
|
|
190
|
+
# Get outer graph values for parameters
|
|
191
|
+
# then_traced.params contains the original TraceObjects from args/kwargs
|
|
192
|
+
outer_arg_values = [
|
|
193
|
+
param._graph_value
|
|
194
|
+
for param in then_traced.params
|
|
195
|
+
if isinstance(param, el.TraceObject)
|
|
196
|
+
]
|
|
197
|
+
|
|
198
|
+
if len(outer_arg_values) != num_arg_vars:
|
|
199
|
+
raise RuntimeError(
|
|
200
|
+
f"uniform_cond: argument count mismatch. Expected {num_arg_vars} variables, "
|
|
201
|
+
f"got {len(outer_arg_values)} from params."
|
|
202
|
+
)
|
|
203
|
+
|
|
204
|
+
all_captures = _merge_captures(then_traced.captured, else_traced.captured)
|
|
205
|
+
|
|
206
|
+
then_traced.align_region_inputs(num_arg_vars, all_captures)
|
|
207
|
+
else_traced.align_region_inputs(num_arg_vars, all_captures)
|
|
208
|
+
|
|
209
|
+
capture_trace_objs = [cur_ctx.lift(obj) for obj in all_captures]
|
|
210
|
+
capture_values = [obj._graph_value for obj in capture_trace_objs]
|
|
211
|
+
|
|
212
|
+
output_types = [v.type for v in then_traced.graph.outputs]
|
|
213
|
+
cond_inputs = [pred._graph_value, *outer_arg_values, *capture_values]
|
|
214
|
+
|
|
215
|
+
result_values = cur_ctx.graph.add_op(
|
|
216
|
+
opcode="simp.uniform_cond",
|
|
217
|
+
inputs=cond_inputs,
|
|
218
|
+
output_types=output_types,
|
|
219
|
+
attrs={"verify_uniform": VERIFY_UNIFORM_DEFAULT},
|
|
220
|
+
regions=[then_traced.graph, else_traced.graph],
|
|
221
|
+
)
|
|
222
|
+
|
|
223
|
+
return cur_ctx.reconstruct_outputs(
|
|
224
|
+
then_traced.out_var_pos,
|
|
225
|
+
then_traced.out_imms,
|
|
226
|
+
then_traced.out_tree,
|
|
227
|
+
result_values,
|
|
228
|
+
)
|
|
229
|
+
|
|
230
|
+
|
|
231
|
+
def uniform_cond(
|
|
232
|
+
pred: el.Object,
|
|
233
|
+
then_fn: Callable[..., Any],
|
|
234
|
+
else_fn: Callable[..., Any],
|
|
235
|
+
*args: Any,
|
|
236
|
+
**kwargs: Any,
|
|
237
|
+
) -> Any:
|
|
238
|
+
"""Uniform conditional that executes only the selected branch at runtime.
|
|
239
|
+
|
|
240
|
+
Args:
|
|
241
|
+
pred: Boolean scalar TraceObject that is uniform across parties.
|
|
242
|
+
then_fn: Callable evaluated when `pred` is True.
|
|
243
|
+
else_fn: Callable evaluated when `pred` is False.
|
|
244
|
+
*args: Additional positional arguments forwarded to both branches.
|
|
245
|
+
**kwargs: Additional keyword arguments forwarded to both branches.
|
|
246
|
+
|
|
247
|
+
Returns:
|
|
248
|
+
The PyTree produced by the selected branch.
|
|
249
|
+
|
|
250
|
+
Raises:
|
|
251
|
+
TypeError: If predicate/branches are invalid or branch outputs mismatch.
|
|
252
|
+
"""
|
|
253
|
+
|
|
254
|
+
return uniform_cond_p.bind(pred, then_fn, else_fn, *args, **kwargs)
|
|
255
|
+
|
|
256
|
+
|
|
257
|
+
# ---------------------------------------------------------------------------
|
|
258
|
+
# While loop (scaffold)
|
|
259
|
+
# ---------------------------------------------------------------------------
|
|
260
|
+
|
|
261
|
+
while_loop_p: el.Primitive[Any] = el.Primitive("simp.while_loop")
|
|
262
|
+
|
|
263
|
+
|
|
264
|
+
@while_loop_p.def_trace
|
|
265
|
+
def _while_loop_trace(
|
|
266
|
+
cond_fn: Callable[[Any], Any],
|
|
267
|
+
body_fn: Callable[[Any], Any],
|
|
268
|
+
init: Any,
|
|
269
|
+
) -> Any:
|
|
270
|
+
"""Trace-mode implementation for SIMP while_loop."""
|
|
271
|
+
|
|
272
|
+
cur_ctx = el.get_current_context()
|
|
273
|
+
assert isinstance(cur_ctx, el.Tracer)
|
|
274
|
+
assert callable(cond_fn) and callable(body_fn)
|
|
275
|
+
|
|
276
|
+
state_flat, state_treedef = tree_flatten(init)
|
|
277
|
+
assert state_treedef is not None
|
|
278
|
+
if not state_flat:
|
|
279
|
+
raise TypeError("while_loop init must contain at least one Object")
|
|
280
|
+
|
|
281
|
+
# Validate all leaves are TraceObjects
|
|
282
|
+
for leaf in state_flat:
|
|
283
|
+
if not isinstance(leaf, el.TraceObject):
|
|
284
|
+
raise TypeError(
|
|
285
|
+
f"while_loop init leaves must be TraceObject, got {type(leaf)}"
|
|
286
|
+
)
|
|
287
|
+
|
|
288
|
+
cond_traced = el.trace(cond_fn, init)
|
|
289
|
+
body_traced = el.trace(body_fn, init)
|
|
290
|
+
|
|
291
|
+
# Use params from traced function (same as state_flat filtered to Objects)
|
|
292
|
+
# These are TraceObjects since we're in trace mode
|
|
293
|
+
state_trace_objs = cast(list[el.TraceObject], cond_traced.params)
|
|
294
|
+
state_values = [obj._graph_value for obj in state_trace_objs]
|
|
295
|
+
state_types = [obj.type for obj in state_trace_objs]
|
|
296
|
+
state_count = len(state_trace_objs)
|
|
297
|
+
|
|
298
|
+
cond_output_count = len(cond_traced.out_var_pos) + len(cond_traced.out_imms)
|
|
299
|
+
if cond_output_count != 1:
|
|
300
|
+
raise TypeError(
|
|
301
|
+
"while_loop cond_fn must return exactly one output, "
|
|
302
|
+
f"got {cond_output_count}"
|
|
303
|
+
)
|
|
304
|
+
if cond_traced.out_var_pos:
|
|
305
|
+
_validate_scalar_predicate(
|
|
306
|
+
cond_traced.graph.outputs[0], "while_loop cond_fn output"
|
|
307
|
+
)
|
|
308
|
+
|
|
309
|
+
body_output_count = len(body_traced.out_var_pos) + len(body_traced.out_imms)
|
|
310
|
+
if body_output_count != state_count:
|
|
311
|
+
raise TypeError(
|
|
312
|
+
"while_loop body_fn must return same number of values as init state: "
|
|
313
|
+
f"{state_count} expected, got {body_output_count}"
|
|
314
|
+
)
|
|
315
|
+
body_outputs = body_traced.graph.outputs
|
|
316
|
+
if len(body_outputs) != state_count:
|
|
317
|
+
raise TypeError(
|
|
318
|
+
"while_loop body_fn must return all Variables "
|
|
319
|
+
"(no immediates allowed in loop state), "
|
|
320
|
+
f"expected {state_count} Variables, got {len(body_outputs)}"
|
|
321
|
+
)
|
|
322
|
+
for idx, (out_val, state_obj) in enumerate(
|
|
323
|
+
zip(body_outputs, state_trace_objs, strict=True)
|
|
324
|
+
):
|
|
325
|
+
if out_val.type != state_obj.type:
|
|
326
|
+
raise TypeError(
|
|
327
|
+
"while_loop body_fn output type mismatch at index "
|
|
328
|
+
f"{idx}: {out_val.type} vs {state_obj.type}"
|
|
329
|
+
)
|
|
330
|
+
|
|
331
|
+
all_captures = _merge_captures(cond_traced.captured, body_traced.captured)
|
|
332
|
+
|
|
333
|
+
cond_traced.align_region_inputs(state_count, all_captures)
|
|
334
|
+
body_traced.align_region_inputs(state_count, all_captures)
|
|
335
|
+
|
|
336
|
+
capture_trace_objs = [cur_ctx.lift(obj) for obj in all_captures]
|
|
337
|
+
capture_values = [obj._graph_value for obj in capture_trace_objs]
|
|
338
|
+
|
|
339
|
+
loop_inputs = [*state_values, *capture_values]
|
|
340
|
+
result_values = cur_ctx.graph.add_op(
|
|
341
|
+
opcode="simp.while_loop",
|
|
342
|
+
inputs=loop_inputs,
|
|
343
|
+
output_types=state_types,
|
|
344
|
+
regions=[cond_traced.graph, body_traced.graph],
|
|
345
|
+
)
|
|
346
|
+
|
|
347
|
+
result_trace_objs = [el.TraceObject(val, cur_ctx) for val in result_values]
|
|
348
|
+
return tree_unflatten(state_treedef, result_trace_objs)
|
|
349
|
+
|
|
350
|
+
|
|
351
|
+
def while_loop(
|
|
352
|
+
cond_fn: Callable[[Any], Any],
|
|
353
|
+
body_fn: Callable[[Any], Any],
|
|
354
|
+
init: Any,
|
|
355
|
+
) -> Any:
|
|
356
|
+
"""Execute a SIMP while loop that synchronizes across parties.
|
|
357
|
+
|
|
358
|
+
Args:
|
|
359
|
+
cond_fn: Receives the current loop state and returns a boolean scalar.
|
|
360
|
+
body_fn: Receives the current loop state and returns the next state
|
|
361
|
+
with the same PyTree structure and per-leaf types as `init`.
|
|
362
|
+
init: Initial loop state (PyTree of Objects) shared by all parties.
|
|
363
|
+
|
|
364
|
+
Returns:
|
|
365
|
+
Final state after `cond_fn` evaluates to False.
|
|
366
|
+
|
|
367
|
+
Raises:
|
|
368
|
+
TypeError: If `cond_fn`/`body_fn` outputs violate the required shape or
|
|
369
|
+
type constraints.
|
|
370
|
+
"""
|
|
371
|
+
|
|
372
|
+
return while_loop_p.bind(cond_fn, body_fn, init)
|
|
373
|
+
|
|
374
|
+
|
|
375
|
+
# Core primitives with clear semantic names
|
|
376
|
+
pcall_static_p = el.Primitive[Any]("simp.pcall_static")
|
|
377
|
+
pcall_dynamic_p = el.Primitive[Any]("simp.pcall_dynamic")
|
|
378
|
+
shuffle_dynamic_p = el.Primitive[el.Object]("simp.shuffle_dynamic")
|
|
379
|
+
shuffle_static_p = el.Primitive[el.Object]("simp.shuffle")
|
|
380
|
+
converge_p = el.Primitive[el.Object]("simp.converge")
|
|
381
|
+
|
|
382
|
+
|
|
383
|
+
@pcall_static_p.def_trace
|
|
384
|
+
def _pcall_static_trace(
|
|
385
|
+
parties: tuple[int, ...],
|
|
386
|
+
local_fn: Callable[..., Any],
|
|
387
|
+
*args: Any,
|
|
388
|
+
**kwargs: Any,
|
|
389
|
+
) -> Any:
|
|
390
|
+
"""Trace a local single-party region with explicit static parties.
|
|
391
|
+
|
|
392
|
+
Args:
|
|
393
|
+
parties: Required tuple of participating party ranks.
|
|
394
|
+
local_fn: Callable representing the single-party function body.
|
|
395
|
+
*args: Positional arguments forming a PyTree of MPObjects /
|
|
396
|
+
TraceObjects / immediates passed to the region.
|
|
397
|
+
**kwargs: Keyword arguments forwarded to ``local_fn``.
|
|
398
|
+
|
|
399
|
+
Returns:
|
|
400
|
+
PyTree of TraceObjects with static parties mask.
|
|
401
|
+
|
|
402
|
+
Raises:
|
|
403
|
+
TypeError: If ``local_fn`` is not callable or arguments contain invalid types.
|
|
404
|
+
ValueError: When explicitly provided parties are not covered by input parties.
|
|
405
|
+
"""
|
|
406
|
+
cur_ctx = el.get_current_context()
|
|
407
|
+
assert isinstance(cur_ctx, el.Tracer)
|
|
408
|
+
assert callable(local_fn)
|
|
409
|
+
|
|
410
|
+
if parties is None:
|
|
411
|
+
raise ValueError("pcall_static requires explicit parties, got None")
|
|
412
|
+
|
|
413
|
+
requested_parties = tuple(sorted(set(parties)))
|
|
414
|
+
|
|
415
|
+
local_tracer = _LocalMPTracer()
|
|
416
|
+
local_traced = local_tracer.run(local_fn, *args, **kwargs)
|
|
417
|
+
|
|
418
|
+
# Get all input objects: params (function arguments) + captured (closures)
|
|
419
|
+
# TracedFunction guarantees: graph.inputs = [*params_inputs, *captured_inputs]
|
|
420
|
+
all_input_objs = local_traced.params + local_traced.captured
|
|
421
|
+
|
|
422
|
+
# All types are guaranteed to be MPType by _LocalMPTracer._lift
|
|
423
|
+
all_input_types: list[elt.MPType] = [obj.type for obj in all_input_objs] # type: ignore[misc]
|
|
424
|
+
deduced_parties = _deduce_parties(all_input_types)
|
|
425
|
+
|
|
426
|
+
if deduced_parties is not None:
|
|
427
|
+
if not set(requested_parties).issubset(set(deduced_parties)):
|
|
428
|
+
raise ValueError(
|
|
429
|
+
f"Requested parties {requested_parties} not covered by "
|
|
430
|
+
f"input argument parties {deduced_parties}"
|
|
431
|
+
)
|
|
432
|
+
|
|
433
|
+
# Re-capture all input objects in outer context
|
|
434
|
+
recaptured_objs = [cur_ctx.lift(obj) for obj in all_input_objs]
|
|
435
|
+
region_inputs = [obj._graph_value for obj in recaptured_objs]
|
|
436
|
+
result_types: list[elt.BaseType] = [
|
|
437
|
+
elt.MPType(value.type, requested_parties)
|
|
438
|
+
for value in local_traced.graph.outputs
|
|
439
|
+
]
|
|
440
|
+
|
|
441
|
+
result_values = cur_ctx.graph.add_op(
|
|
442
|
+
opcode="simp.pcall_static",
|
|
443
|
+
inputs=region_inputs,
|
|
444
|
+
output_types=result_types,
|
|
445
|
+
attrs={
|
|
446
|
+
"fn_name": local_traced.name,
|
|
447
|
+
"parties": list(requested_parties),
|
|
448
|
+
},
|
|
449
|
+
regions=[local_traced.graph],
|
|
450
|
+
)
|
|
451
|
+
|
|
452
|
+
return cur_ctx.reconstruct_outputs(
|
|
453
|
+
local_traced.out_var_pos,
|
|
454
|
+
local_traced.out_imms,
|
|
455
|
+
local_traced.out_tree,
|
|
456
|
+
result_values,
|
|
457
|
+
)
|
|
458
|
+
|
|
459
|
+
|
|
460
|
+
@pcall_dynamic_p.def_trace
|
|
461
|
+
def _pcall_dynamic_trace(
|
|
462
|
+
local_fn: Callable[..., Any],
|
|
463
|
+
*args: Any,
|
|
464
|
+
**kwargs: Any,
|
|
465
|
+
) -> Any:
|
|
466
|
+
"""Trace a party call with dynamic execution.
|
|
467
|
+
|
|
468
|
+
All parties attempt to execute. Runtime behavior: each party executes
|
|
469
|
+
if all inputs are present, otherwise outputs None. Output always has
|
|
470
|
+
dynamic parties (None).
|
|
471
|
+
|
|
472
|
+
Args:
|
|
473
|
+
local_fn: Callable representing the single-party function body.
|
|
474
|
+
*args: Positional arguments forming a PyTree of MPObjects /
|
|
475
|
+
TraceObjects / immediates passed to the region.
|
|
476
|
+
**kwargs: Keyword arguments forwarded to ``local_fn``.
|
|
477
|
+
|
|
478
|
+
Returns:
|
|
479
|
+
PyTree of TraceObjects with dynamic parties (None).
|
|
480
|
+
|
|
481
|
+
Raises:
|
|
482
|
+
TypeError: If ``local_fn`` is not callable or arguments contain invalid types.
|
|
483
|
+
"""
|
|
484
|
+
cur_ctx = el.get_current_context()
|
|
485
|
+
assert isinstance(cur_ctx, el.Tracer)
|
|
486
|
+
assert callable(local_fn)
|
|
487
|
+
|
|
488
|
+
local_tracer = _LocalMPTracer()
|
|
489
|
+
local_traced = local_tracer.run(local_fn, *args, **kwargs)
|
|
490
|
+
|
|
491
|
+
# Get all input objects: params (function arguments) + captured (closures)
|
|
492
|
+
# TracedFunction guarantees: graph.inputs = [*params_inputs, *captured_inputs]
|
|
493
|
+
all_input_objs = local_traced.params + local_traced.captured
|
|
494
|
+
|
|
495
|
+
recaptured_objs = [cur_ctx.lift(obj) for obj in all_input_objs]
|
|
496
|
+
region_inputs = [obj._graph_value for obj in recaptured_objs]
|
|
497
|
+
|
|
498
|
+
# Output always has dynamic parties (None)
|
|
499
|
+
result_types: list[elt.BaseType] = [
|
|
500
|
+
elt.MPType(value.type, None) for value in local_traced.graph.outputs
|
|
501
|
+
]
|
|
502
|
+
|
|
503
|
+
result_values = cur_ctx.graph.add_op(
|
|
504
|
+
opcode="simp.pcall_dynamic",
|
|
505
|
+
inputs=region_inputs,
|
|
506
|
+
output_types=result_types,
|
|
507
|
+
attrs={
|
|
508
|
+
"fn_name": local_traced.name,
|
|
509
|
+
},
|
|
510
|
+
regions=[local_traced.graph],
|
|
511
|
+
)
|
|
512
|
+
|
|
513
|
+
return cur_ctx.reconstruct_outputs(
|
|
514
|
+
local_traced.out_var_pos,
|
|
515
|
+
local_traced.out_imms,
|
|
516
|
+
local_traced.out_tree,
|
|
517
|
+
result_values,
|
|
518
|
+
)
|
|
519
|
+
|
|
520
|
+
|
|
521
|
+
def pcall_static(
|
|
522
|
+
parties: tuple[int, ...],
|
|
523
|
+
local_fn: Callable[..., Any],
|
|
524
|
+
*call_args: Any,
|
|
525
|
+
**call_kwargs: Any,
|
|
526
|
+
) -> Any:
|
|
527
|
+
"""Execute a function on explicitly specified parties (static).
|
|
528
|
+
|
|
529
|
+
This primitive requires explicit party specification and always produces
|
|
530
|
+
static party masks in the output. Use this when the execution parties
|
|
531
|
+
are known at compile time.
|
|
532
|
+
|
|
533
|
+
Args:
|
|
534
|
+
parties: Required tuple of party ranks (must be explicit, not None).
|
|
535
|
+
local_fn: Callable representing the single-party computation.
|
|
536
|
+
*call_args: Positional arguments forwarded to ``local_fn``.
|
|
537
|
+
**call_kwargs: Keyword arguments forwarded to ``local_fn``.
|
|
538
|
+
|
|
539
|
+
Returns:
|
|
540
|
+
Result with static parties mask matching the parties argument.
|
|
541
|
+
|
|
542
|
+
Example:
|
|
543
|
+
>>> # Compute on parties 0 and 1 (static)
|
|
544
|
+
>>> result = pcall_static(parties=(0, 1), local_fn=lambda x: x + 1, x)
|
|
545
|
+
"""
|
|
546
|
+
return pcall_static_p.bind(parties, local_fn, *call_args, **call_kwargs)
|
|
547
|
+
|
|
548
|
+
|
|
549
|
+
def pcall_dynamic(
|
|
550
|
+
local_fn: Callable[..., Any],
|
|
551
|
+
*call_args: Any,
|
|
552
|
+
**call_kwargs: Any,
|
|
553
|
+
) -> Any:
|
|
554
|
+
"""Execute a function on all parties with runtime-determined execution.
|
|
555
|
+
|
|
556
|
+
All parties attempt to execute the function. At runtime, each party executes
|
|
557
|
+
if all inputs are present, otherwise outputs None. Output always has dynamic
|
|
558
|
+
party mask (None).
|
|
559
|
+
|
|
560
|
+
Args:
|
|
561
|
+
local_fn: Callable representing the single-party computation.
|
|
562
|
+
*call_args: Positional arguments forwarded to ``local_fn``.
|
|
563
|
+
**call_kwargs: Keyword arguments forwarded to ``local_fn``.
|
|
564
|
+
|
|
565
|
+
Returns:
|
|
566
|
+
Result with dynamic parties (None). At runtime, parties with all inputs
|
|
567
|
+
execute, others output None.
|
|
568
|
+
|
|
569
|
+
Example:
|
|
570
|
+
>>> # All parties attempt execution based on input availability
|
|
571
|
+
>>> result = pcall_dynamic(local_fn=lambda x: x + 1, x)
|
|
572
|
+
"""
|
|
573
|
+
return pcall_dynamic_p.bind(local_fn, *call_args, **call_kwargs)
|
|
574
|
+
|
|
575
|
+
|
|
576
|
+
@shuffle_dynamic_p.def_abstract_eval
|
|
577
|
+
def _shuffle_dynamic_ae(src_t: elt.BaseType, index_t: elt.BaseType) -> elt.BaseType:
|
|
578
|
+
"""Type inference for dynamic shuffle (runtime-determined data redistribution).
|
|
579
|
+
|
|
580
|
+
Args:
|
|
581
|
+
src_t: Source value type (must be MPType)
|
|
582
|
+
index_t: Index value type (must be MPType with scalar shape)
|
|
583
|
+
|
|
584
|
+
Returns:
|
|
585
|
+
Output type with dynamic mask (parties=None)
|
|
586
|
+
|
|
587
|
+
Raises:
|
|
588
|
+
TypeError: If src or index are not MP-typed, or index is not scalar
|
|
589
|
+
"""
|
|
590
|
+
if not isinstance(src_t, elt.MPType):
|
|
591
|
+
raise TypeError(f"shuffle_dynamic requires MP-typed src, got {src_t}")
|
|
592
|
+
if not isinstance(index_t, elt.MPType):
|
|
593
|
+
raise TypeError(f"shuffle_dynamic requires MP-typed index, got {index_t}")
|
|
594
|
+
|
|
595
|
+
# Validate index is scalar
|
|
596
|
+
index_shape = getattr(index_t.value_type, "shape", None)
|
|
597
|
+
if index_shape is not None and index_shape != ():
|
|
598
|
+
raise TypeError(
|
|
599
|
+
f"shuffle_dynamic index must be scalar, got shape {index_shape} "
|
|
600
|
+
f"with type {index_t.value_type}"
|
|
601
|
+
)
|
|
602
|
+
|
|
603
|
+
# Output: dynamic mask (None parties)
|
|
604
|
+
return elt.MPType(src_t.value_type, None)
|
|
605
|
+
|
|
606
|
+
|
|
607
|
+
@shuffle_static_p.def_abstract_eval
|
|
608
|
+
def _shuffle_ae(src_t: elt.BaseType, routing: dict[int, int]) -> elt.BaseType:
|
|
609
|
+
"""Type inference for static shuffle (compile-time known data routing).
|
|
610
|
+
|
|
611
|
+
Args:
|
|
612
|
+
src_t: Source value type (must be MPType)
|
|
613
|
+
routing: Dict mapping target_party -> source_rank
|
|
614
|
+
|
|
615
|
+
Returns:
|
|
616
|
+
Output type with static mask (parties=tuple(sorted(routing.keys())))
|
|
617
|
+
|
|
618
|
+
Raises:
|
|
619
|
+
TypeError: If src is not MP-typed or routing is not a dict
|
|
620
|
+
ValueError: If routing references parties not in src.parties
|
|
621
|
+
"""
|
|
622
|
+
if not isinstance(src_t, elt.MPType):
|
|
623
|
+
raise TypeError(f"shuffle_static requires MP-typed src, got {src_t}")
|
|
624
|
+
|
|
625
|
+
if not isinstance(routing, dict):
|
|
626
|
+
raise TypeError(f"shuffle_static requires routing dict, got {type(routing)}")
|
|
627
|
+
|
|
628
|
+
if not routing:
|
|
629
|
+
raise ValueError("shuffle_static requires non-empty routing dict")
|
|
630
|
+
|
|
631
|
+
# Target parties are the keys of routing dict
|
|
632
|
+
target_parties = tuple(sorted(routing.keys()))
|
|
633
|
+
|
|
634
|
+
# Validate source ranks are in src.parties (if src.parties is known)
|
|
635
|
+
if src_t.parties is not None:
|
|
636
|
+
for target, source in routing.items():
|
|
637
|
+
if source not in src_t.parties:
|
|
638
|
+
raise ValueError(
|
|
639
|
+
f"shuffle_static: routing[{target}]={source} not in "
|
|
640
|
+
f"src.parties {src_t.parties}"
|
|
641
|
+
)
|
|
642
|
+
|
|
643
|
+
# Output: static mask with target parties
|
|
644
|
+
return elt.MPType(src_t.value_type, target_parties)
|
|
645
|
+
|
|
646
|
+
|
|
647
|
+
@converge_p.def_abstract_eval
|
|
648
|
+
def _converge_ae(in_types: list[elt.BaseType], *, mask: int = -1) -> elt.BaseType:
|
|
649
|
+
"""Type inference for converge operation (merge disjoint partitions).
|
|
650
|
+
|
|
651
|
+
Args:
|
|
652
|
+
in_types: List of input types (all must be MPType with same value_type)
|
|
653
|
+
attrs: Attributes dict (unused)
|
|
654
|
+
|
|
655
|
+
Returns:
|
|
656
|
+
Output type with union of input parties (or None if any input is dynamic)
|
|
657
|
+
|
|
658
|
+
Raises:
|
|
659
|
+
TypeError: If inputs are not all MP-typed or have inconsistent value_types
|
|
660
|
+
ValueError: If static parties are not disjoint
|
|
661
|
+
"""
|
|
662
|
+
if not in_types:
|
|
663
|
+
raise TypeError("converge requires at least one input")
|
|
664
|
+
|
|
665
|
+
# Validate all are MPType
|
|
666
|
+
for i, t in enumerate(in_types):
|
|
667
|
+
if not isinstance(t, elt.MPType):
|
|
668
|
+
raise TypeError(f"converge input {i} must be MP-typed, got {t}")
|
|
669
|
+
|
|
670
|
+
mp_types = [t for t in in_types if isinstance(t, elt.MPType)]
|
|
671
|
+
|
|
672
|
+
# Check value_type consistency
|
|
673
|
+
first_vtype = mp_types[0].value_type
|
|
674
|
+
for i, mt in enumerate(mp_types[1:], 1):
|
|
675
|
+
if mt.value_type != first_vtype:
|
|
676
|
+
raise TypeError(
|
|
677
|
+
f"converge value type mismatch at input {i}: "
|
|
678
|
+
f"{mt.value_type} vs {first_vtype}"
|
|
679
|
+
)
|
|
680
|
+
|
|
681
|
+
# Deduce output parties
|
|
682
|
+
parties_list = [mt.parties for mt in mp_types]
|
|
683
|
+
|
|
684
|
+
if any(p is None for p in parties_list):
|
|
685
|
+
# Dynamic case: propagate None
|
|
686
|
+
output_parties = None
|
|
687
|
+
else:
|
|
688
|
+
# Static case: check disjoint and union
|
|
689
|
+
for i, p1 in enumerate(parties_list):
|
|
690
|
+
for j, p2 in enumerate(parties_list[i + 1 :], i + 1):
|
|
691
|
+
if p1 is not None and p2 is not None:
|
|
692
|
+
if set(p1) & set(p2):
|
|
693
|
+
raise ValueError(
|
|
694
|
+
f"converge requires disjoint parties, inputs {i} and {j} "
|
|
695
|
+
f"overlap: {set(p1) & set(p2)}"
|
|
696
|
+
)
|
|
697
|
+
|
|
698
|
+
# Union all parties
|
|
699
|
+
all_parties: set[int] = set()
|
|
700
|
+
for p in parties_list:
|
|
701
|
+
if p is not None:
|
|
702
|
+
all_parties.update(p)
|
|
703
|
+
output_parties = tuple(sorted(all_parties)) if all_parties else None
|
|
704
|
+
|
|
705
|
+
return elt.MPType(first_vtype, output_parties)
|
|
706
|
+
|
|
707
|
+
|
|
708
|
+
def shuffle_dynamic(src: el.Object, index: el.Object) -> el.Object:
|
|
709
|
+
"""Dynamic shuffle: redistribute data based on runtime index values.
|
|
710
|
+
|
|
711
|
+
Each party uses its local index value to fetch data from the corresponding
|
|
712
|
+
source party. The output has dynamic mask (parties=None) since the data
|
|
713
|
+
distribution depends on runtime index values.
|
|
714
|
+
|
|
715
|
+
This is the most flexible shuffle primitive but requires runtime communication
|
|
716
|
+
pattern determination.
|
|
717
|
+
|
|
718
|
+
Args:
|
|
719
|
+
src: Source data (MP-typed)
|
|
720
|
+
index: Index indicating which source party to fetch from (MP-typed scalar)
|
|
721
|
+
|
|
722
|
+
Returns:
|
|
723
|
+
Shuffled data with dynamic mask (parties=None)
|
|
724
|
+
|
|
725
|
+
Example:
|
|
726
|
+
>>> # P0, P1, P2 each hold different index values at runtime
|
|
727
|
+
>>> result = shuffle_dynamic(src, index)
|
|
728
|
+
>>> # result.type.parties == None (dynamic)
|
|
729
|
+
"""
|
|
730
|
+
return shuffle_dynamic_p.bind(src, index)
|
|
731
|
+
|
|
732
|
+
|
|
733
|
+
def shuffle_static(src: el.Object, routing: dict[int, int]) -> el.Object:
|
|
734
|
+
"""Static shuffle: redistribute data with compile-time known routing pattern.
|
|
735
|
+
|
|
736
|
+
Unlike shuffle_dynamic, the routing pattern is known at compile time.
|
|
737
|
+
Each entry in routing specifies: target_party -> source_rank.
|
|
738
|
+
|
|
739
|
+
This enables compile-time optimization and produces a static output mask.
|
|
740
|
+
|
|
741
|
+
Design rationale:
|
|
742
|
+
Uses receiver-oriented routing {target: source} to naturally express:
|
|
743
|
+
- Permutation: {0: 1, 1: 0} (swap parties)
|
|
744
|
+
- Broadcast: {0: 1, 2: 1} (multiple targets from same source)
|
|
745
|
+
Maintains SIMP single-input-single-output semantics at MP value level.
|
|
746
|
+
|
|
747
|
+
Args:
|
|
748
|
+
src: Source data (MP-typed)
|
|
749
|
+
routing: Dict mapping target_party -> source_rank
|
|
750
|
+
e.g., {0: 1, 2: 0} means:
|
|
751
|
+
- party 0 receives from rank 1
|
|
752
|
+
- party 2 receives from rank 0
|
|
753
|
+
|
|
754
|
+
Returns:
|
|
755
|
+
Shuffled data with static mask (parties=sorted keys of routing)
|
|
756
|
+
|
|
757
|
+
Example:
|
|
758
|
+
>>> # Party 0 gets data from rank 1
|
|
759
|
+
>>> result = shuffle_static(src, routing={0: 1})
|
|
760
|
+
>>> # result.type.parties == (0,)
|
|
761
|
+
>>>
|
|
762
|
+
>>> # Multiple parties
|
|
763
|
+
>>> result = shuffle_static(src, routing={0: 1, 2: 0})
|
|
764
|
+
>>> # result.type.parties == (0, 2)
|
|
765
|
+
"""
|
|
766
|
+
return shuffle_static_p.bind(src, routing=routing)
|
|
767
|
+
|
|
768
|
+
|
|
769
|
+
def converge(*vars: el.Object) -> el.Object:
|
|
770
|
+
"""Converge multiple disjoint-partitioned variables into one.
|
|
771
|
+
|
|
772
|
+
Merges data from multiple parties into one logical variable. In static case,
|
|
773
|
+
validates that input parties are disjoint and produces their union. In dynamic
|
|
774
|
+
case, propagates the dynamic mask.
|
|
775
|
+
|
|
776
|
+
This is the fundamental operation for combining results from different parties.
|
|
777
|
+
|
|
778
|
+
Args:
|
|
779
|
+
*vars: Variable number of MP-typed inputs with disjoint parties
|
|
780
|
+
|
|
781
|
+
Returns:
|
|
782
|
+
Converged variable with union of input parties (or None if any input is dynamic)
|
|
783
|
+
|
|
784
|
+
Raises:
|
|
785
|
+
ValueError: If static parties are not disjoint
|
|
786
|
+
|
|
787
|
+
Example:
|
|
788
|
+
>>> # P0 has x, P1 has y (disjoint)
|
|
789
|
+
>>> result = converge(x, y)
|
|
790
|
+
>>> # result.type.parties == (0, 1)
|
|
791
|
+
"""
|
|
792
|
+
return converge_p.bind(*vars)
|
|
793
|
+
|
|
794
|
+
|
|
795
|
+
def constant(parties: tuple[int, ...], data: Any) -> el.Object:
|
|
796
|
+
"""Create a constant value distributed to specific parties.
|
|
797
|
+
|
|
798
|
+
This is a helper function that creates a constant value on the specified
|
|
799
|
+
parties. It is equivalent to calling `pcall_static` with a function that
|
|
800
|
+
returns the constant data.
|
|
801
|
+
|
|
802
|
+
Args:
|
|
803
|
+
parties: Tuple of party ranks where the constant should be placed.
|
|
804
|
+
data: The constant data (scalar, array, etc.).
|
|
805
|
+
|
|
806
|
+
Returns:
|
|
807
|
+
MP[Tensor, parties] object representing the distributed constant.
|
|
808
|
+
"""
|
|
809
|
+
import jax.numpy as jnp
|
|
810
|
+
import numpy as np
|
|
811
|
+
|
|
812
|
+
from mplang.v2.dialects import table, tensor
|
|
813
|
+
|
|
814
|
+
# 1. Scalars (int, float, bool, numpy scalars)
|
|
815
|
+
if isinstance(data, (int, float, bool, np.number, np.bool_)):
|
|
816
|
+
return cast(el.Object, pcall_static(parties, tensor.constant, data))
|
|
817
|
+
|
|
818
|
+
# 2. Tensor-like (numpy array or JAX array)
|
|
819
|
+
if isinstance(data, (np.ndarray, jnp.ndarray)):
|
|
820
|
+
return cast(el.Object, pcall_static(parties, tensor.constant, data))
|
|
821
|
+
|
|
822
|
+
# 3. Table-like (dict, DataFrame)
|
|
823
|
+
is_dataframe = False
|
|
824
|
+
try:
|
|
825
|
+
import pandas as pd
|
|
826
|
+
|
|
827
|
+
if isinstance(data, pd.DataFrame):
|
|
828
|
+
is_dataframe = True
|
|
829
|
+
except ImportError:
|
|
830
|
+
pass
|
|
831
|
+
|
|
832
|
+
if is_dataframe or isinstance(data, dict):
|
|
833
|
+
return cast(el.Object, pcall_static(parties, table.constant, data))
|
|
834
|
+
|
|
835
|
+
# 4. Lists/Tuples (Ambiguous, default to tensor)
|
|
836
|
+
if isinstance(data, (list, tuple)):
|
|
837
|
+
return cast(el.Object, pcall_static(parties, tensor.constant, data))
|
|
838
|
+
|
|
839
|
+
raise TypeError(f"Unsupported data type for simp.constant: {type(data)}")
|
|
840
|
+
|
|
841
|
+
|
|
842
|
+
# Backward compatibility aliases
|
|
843
|
+
def peval(
|
|
844
|
+
parties: tuple[int, ...] | None,
|
|
845
|
+
local_fn: Callable[..., Any],
|
|
846
|
+
*call_args: Any,
|
|
847
|
+
**call_kwargs: Any,
|
|
848
|
+
) -> Any:
|
|
849
|
+
"""Backward compatible peval function.
|
|
850
|
+
|
|
851
|
+
Routes to pcall_static if parties is explicit, pcall_dynamic if None.
|
|
852
|
+
"""
|
|
853
|
+
if parties is None:
|
|
854
|
+
return pcall_dynamic(local_fn, *call_args, **call_kwargs)
|
|
855
|
+
else:
|
|
856
|
+
return pcall_static(parties, local_fn, *call_args, **call_kwargs)
|
|
857
|
+
|
|
858
|
+
|
|
859
|
+
# =============================================================================
|
|
860
|
+
# Factory functions for creating configured Interpreters
|
|
861
|
+
# =============================================================================
|
|
862
|
+
|
|
863
|
+
|
|
864
|
+
def make_simulator(
|
|
865
|
+
world_size: int,
|
|
866
|
+
*,
|
|
867
|
+
cluster_spec: Any = None,
|
|
868
|
+
enable_tracing: bool = False,
|
|
869
|
+
enable_profiling: bool = False,
|
|
870
|
+
) -> Any:
|
|
871
|
+
"""Create an Interpreter configured for local SIMP simulation.
|
|
872
|
+
|
|
873
|
+
This factory creates a LocalCluster with workers and returns an
|
|
874
|
+
Interpreter with the simp dialect state attached.
|
|
875
|
+
|
|
876
|
+
Args:
|
|
877
|
+
world_size: Number of simulated parties.
|
|
878
|
+
cluster_spec: Optional ClusterSpec for metadata.
|
|
879
|
+
enable_tracing: If True, enable execution tracing.
|
|
880
|
+
enable_profiling: If True, enable primitive profiling for benchmarking.
|
|
881
|
+
|
|
882
|
+
Returns:
|
|
883
|
+
Configured Interpreter with simp state attached.
|
|
884
|
+
|
|
885
|
+
Example:
|
|
886
|
+
>>> interp = simp.make_simulator(2)
|
|
887
|
+
>>> with interp:
|
|
888
|
+
... result = my_func()
|
|
889
|
+
"""
|
|
890
|
+
if enable_profiling:
|
|
891
|
+
from mplang.v2.edsl import registry
|
|
892
|
+
|
|
893
|
+
registry.enable_profiling()
|
|
894
|
+
|
|
895
|
+
from mplang.v2.backends.simp_driver.mem import make_simulator as _make_sim
|
|
896
|
+
|
|
897
|
+
return _make_sim(
|
|
898
|
+
world_size, cluster_spec=cluster_spec, enable_tracing=enable_tracing
|
|
899
|
+
)
|
|
900
|
+
|
|
901
|
+
|
|
902
|
+
def make_driver(endpoints: list[str], *, cluster_spec: Any = None) -> Any:
|
|
903
|
+
"""Create an Interpreter configured for remote SIMP execution.
|
|
904
|
+
|
|
905
|
+
This factory creates a RemoteSimpState and returns an Interpreter
|
|
906
|
+
with the simp dialect state attached.
|
|
907
|
+
|
|
908
|
+
Args:
|
|
909
|
+
endpoints: List of HTTP endpoints for workers.
|
|
910
|
+
cluster_spec: Optional ClusterSpec for metadata.
|
|
911
|
+
|
|
912
|
+
Returns:
|
|
913
|
+
Configured Interpreter with simp state attached.
|
|
914
|
+
|
|
915
|
+
Example:
|
|
916
|
+
>>> interp = simp.make_driver(["http://worker1:8000", "http://worker2:8000"])
|
|
917
|
+
>>> with interp:
|
|
918
|
+
... result = my_func()
|
|
919
|
+
"""
|
|
920
|
+
from mplang.v2.backends.simp_driver.http import make_driver as _make_drv
|
|
921
|
+
|
|
922
|
+
return _make_drv(endpoints, cluster_spec=cluster_spec)
|
|
923
|
+
|
|
924
|
+
|
|
925
|
+
__all__ = [
|
|
926
|
+
"constant",
|
|
927
|
+
"converge",
|
|
928
|
+
"converge_p",
|
|
929
|
+
"make_driver",
|
|
930
|
+
"make_simulator",
|
|
931
|
+
"pcall_dynamic",
|
|
932
|
+
"pcall_dynamic_p",
|
|
933
|
+
"pcall_static",
|
|
934
|
+
"pcall_static_p",
|
|
935
|
+
"peval",
|
|
936
|
+
"shuffle_dynamic",
|
|
937
|
+
"shuffle_dynamic_p",
|
|
938
|
+
"shuffle_static",
|
|
939
|
+
"shuffle_static_p",
|
|
940
|
+
"uniform_cond",
|
|
941
|
+
"uniform_cond_p",
|
|
942
|
+
"while_loop",
|
|
943
|
+
"while_loop_p",
|
|
944
|
+
]
|