mplang-nightly 0.1.dev269__py3-none-any.whl → 0.1.dev271__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mplang/__init__.py +391 -17
- mplang/{v2/backends → backends}/__init__.py +9 -7
- mplang/{v2/backends → backends}/bfv_impl.py +6 -6
- mplang/{v2/backends → backends}/crypto_impl.py +6 -6
- mplang/{v2/backends → backends}/field_impl.py +5 -5
- mplang/{v2/backends → backends}/func_impl.py +4 -4
- mplang/{v2/backends → backends}/phe_impl.py +3 -3
- mplang/{v2/backends → backends}/simp_design.md +1 -1
- mplang/{v2/backends → backends}/simp_driver/__init__.py +5 -5
- mplang/{v2/backends → backends}/simp_driver/http.py +8 -8
- mplang/{v2/backends → backends}/simp_driver/mem.py +9 -9
- mplang/{v2/backends → backends}/simp_driver/ops.py +4 -4
- mplang/{v2/backends → backends}/simp_driver/state.py +2 -2
- mplang/{v2/backends → backends}/simp_driver/values.py +2 -2
- mplang/{v2/backends → backends}/simp_worker/__init__.py +3 -3
- mplang/{v2/backends → backends}/simp_worker/http.py +10 -10
- mplang/{v2/backends → backends}/simp_worker/mem.py +1 -1
- mplang/{v2/backends → backends}/simp_worker/ops.py +5 -5
- mplang/{v2/backends → backends}/simp_worker/state.py +2 -4
- mplang/{v2/backends → backends}/spu_impl.py +8 -8
- mplang/{v2/backends → backends}/spu_state.py +4 -4
- mplang/{v2/backends → backends}/store_impl.py +3 -3
- mplang/{v2/backends → backends}/table_impl.py +8 -8
- mplang/{v2/backends → backends}/tee_impl.py +6 -6
- mplang/{v2/backends → backends}/tensor_impl.py +6 -6
- mplang/{v2/cli.py → cli.py} +9 -9
- mplang/{v2/cli_guide.md → cli_guide.md} +12 -12
- mplang/{v2/dialects → dialects}/__init__.py +5 -5
- mplang/{v2/dialects → dialects}/bfv.py +6 -6
- mplang/{v2/dialects → dialects}/crypto.py +5 -5
- mplang/{v2/dialects → dialects}/dtypes.py +2 -2
- mplang/{v2/dialects → dialects}/field.py +3 -3
- mplang/{v2/dialects → dialects}/func.py +2 -2
- mplang/{v2/dialects → dialects}/phe.py +6 -6
- mplang/{v2/dialects → dialects}/simp.py +6 -6
- mplang/{v2/dialects → dialects}/spu.py +7 -7
- mplang/{v2/dialects → dialects}/store.py +2 -2
- mplang/{v2/dialects → dialects}/table.py +3 -3
- mplang/{v2/dialects → dialects}/tee.py +6 -6
- mplang/{v2/dialects → dialects}/tensor.py +5 -5
- mplang/{v2/edsl → edsl}/__init__.py +3 -3
- mplang/{v2/edsl → edsl}/context.py +6 -6
- mplang/{v2/edsl → edsl}/graph.py +5 -5
- mplang/{v2/edsl → edsl}/jit.py +2 -2
- mplang/{v2/edsl → edsl}/object.py +1 -1
- mplang/{v2/edsl → edsl}/primitive.py +5 -5
- mplang/{v2/edsl → edsl}/printer.py +1 -1
- mplang/{v2/edsl → edsl}/serde.py +1 -1
- mplang/{v2/edsl → edsl}/tracer.py +7 -7
- mplang/{v2/edsl → edsl}/typing.py +1 -1
- mplang/{v2/kernels → kernels}/ldpc.cpp +13 -13
- mplang/{v2/kernels → kernels}/okvs.cpp +4 -4
- mplang/{v2/kernels → kernels}/okvs_opt.cpp +31 -31
- mplang/{v2/kernels → kernels}/py_kernels.py +1 -1
- mplang/{v2/libs → libs}/collective.py +5 -5
- mplang/{v2/libs → libs}/device/__init__.py +1 -1
- mplang/{v2/libs → libs}/device/api.py +12 -12
- mplang/{v2/libs → libs}/ml/__init__.py +1 -1
- mplang/{v2/libs → libs}/ml/sgb.py +4 -4
- mplang/{v2/libs → libs}/mpc/__init__.py +3 -3
- mplang/{v2/libs → libs}/mpc/_utils.py +2 -2
- mplang/{v2/libs → libs}/mpc/analytics/aggregation.py +1 -1
- mplang/{v2/libs → libs}/mpc/analytics/groupby.py +2 -2
- mplang/{v2/libs → libs}/mpc/analytics/permutation.py +3 -3
- mplang/{v2/libs → libs}/mpc/ot/base.py +3 -3
- mplang/{v2/libs → libs}/mpc/ot/extension.py +2 -2
- mplang/{v2/libs → libs}/mpc/ot/silent.py +4 -4
- mplang/{v2/libs → libs}/mpc/psi/cuckoo.py +3 -3
- mplang/{v2/libs → libs}/mpc/psi/okvs.py +1 -1
- mplang/{v2/libs → libs}/mpc/psi/okvs_gct.py +3 -3
- mplang/{v2/libs → libs}/mpc/psi/oprf.py +3 -3
- mplang/{v2/libs → libs}/mpc/psi/rr22.py +7 -7
- mplang/{v2/libs → libs}/mpc/psi/unbalanced.py +4 -4
- mplang/{v2/libs → libs}/mpc/vole/gilboa.py +3 -3
- mplang/{v2/libs → libs}/mpc/vole/ldpc.py +2 -2
- mplang/{v2/libs → libs}/mpc/vole/silver.py +6 -6
- mplang/{v2/runtime → runtime}/interpreter.py +11 -11
- mplang/{v2/runtime → runtime}/value.py +2 -2
- mplang/{v1/runtime → utils}/__init__.py +18 -15
- mplang/{v1/utils → utils}/func_utils.py +1 -1
- {mplang_nightly-0.1.dev269.dist-info → mplang_nightly-0.1.dev271.dist-info}/METADATA +2 -2
- mplang_nightly-0.1.dev271.dist-info/RECORD +102 -0
- mplang/v1/__init__.py +0 -157
- mplang/v1/_device.py +0 -602
- mplang/v1/analysis/__init__.py +0 -37
- mplang/v1/analysis/diagram.py +0 -567
- mplang/v1/core/__init__.py +0 -157
- mplang/v1/core/cluster.py +0 -343
- mplang/v1/core/comm.py +0 -281
- mplang/v1/core/context_mgr.py +0 -50
- mplang/v1/core/dtypes.py +0 -335
- mplang/v1/core/expr/__init__.py +0 -80
- mplang/v1/core/expr/ast.py +0 -542
- mplang/v1/core/expr/evaluator.py +0 -581
- mplang/v1/core/expr/printer.py +0 -285
- mplang/v1/core/expr/transformer.py +0 -141
- mplang/v1/core/expr/utils.py +0 -78
- mplang/v1/core/expr/visitor.py +0 -85
- mplang/v1/core/expr/walk.py +0 -387
- mplang/v1/core/interp.py +0 -160
- mplang/v1/core/mask.py +0 -325
- mplang/v1/core/mpir.py +0 -965
- mplang/v1/core/mpobject.py +0 -117
- mplang/v1/core/mptype.py +0 -407
- mplang/v1/core/pfunc.py +0 -130
- mplang/v1/core/primitive.py +0 -877
- mplang/v1/core/table.py +0 -218
- mplang/v1/core/tensor.py +0 -75
- mplang/v1/core/tracer.py +0 -383
- mplang/v1/host.py +0 -130
- mplang/v1/kernels/__init__.py +0 -41
- mplang/v1/kernels/base.py +0 -125
- mplang/v1/kernels/basic.py +0 -240
- mplang/v1/kernels/context.py +0 -369
- mplang/v1/kernels/crypto.py +0 -122
- mplang/v1/kernels/fhe.py +0 -858
- mplang/v1/kernels/mock_tee.py +0 -72
- mplang/v1/kernels/phe.py +0 -1864
- mplang/v1/kernels/spu.py +0 -341
- mplang/v1/kernels/sql_duckdb.py +0 -44
- mplang/v1/kernels/stablehlo.py +0 -90
- mplang/v1/kernels/value.py +0 -626
- mplang/v1/ops/__init__.py +0 -35
- mplang/v1/ops/base.py +0 -424
- mplang/v1/ops/basic.py +0 -294
- mplang/v1/ops/crypto.py +0 -262
- mplang/v1/ops/fhe.py +0 -272
- mplang/v1/ops/jax_cc.py +0 -147
- mplang/v1/ops/nnx_cc.py +0 -168
- mplang/v1/ops/phe.py +0 -216
- mplang/v1/ops/spu.py +0 -151
- mplang/v1/ops/sql_cc.py +0 -303
- mplang/v1/ops/tee.py +0 -36
- mplang/v1/protos/v1alpha1/mpir_pb2.py +0 -63
- mplang/v1/protos/v1alpha1/mpir_pb2.pyi +0 -557
- mplang/v1/protos/v1alpha1/value_pb2.py +0 -34
- mplang/v1/protos/v1alpha1/value_pb2.pyi +0 -169
- mplang/v1/runtime/channel.py +0 -230
- mplang/v1/runtime/cli.py +0 -451
- mplang/v1/runtime/client.py +0 -456
- mplang/v1/runtime/communicator.py +0 -131
- mplang/v1/runtime/data_providers.py +0 -303
- mplang/v1/runtime/driver.py +0 -324
- mplang/v1/runtime/exceptions.py +0 -27
- mplang/v1/runtime/http_api.md +0 -56
- mplang/v1/runtime/link_comm.py +0 -196
- mplang/v1/runtime/server.py +0 -501
- mplang/v1/runtime/session.py +0 -270
- mplang/v1/runtime/simulation.py +0 -324
- mplang/v1/simp/__init__.py +0 -13
- mplang/v1/simp/api.py +0 -353
- mplang/v1/simp/mpi.py +0 -131
- mplang/v1/simp/party.py +0 -225
- mplang/v1/simp/random.py +0 -120
- mplang/v1/simp/smpc.py +0 -238
- mplang/v1/utils/__init__.py +0 -13
- mplang/v1/utils/crypto.py +0 -32
- mplang/v1/utils/spu_utils.py +0 -130
- mplang/v1/utils/table_utils.py +0 -185
- mplang/v2/__init__.py +0 -424
- mplang_nightly-0.1.dev269.dist-info/RECORD +0 -180
- /mplang/{v2/backends → backends}/channel.py +0 -0
- /mplang/{v2/edsl → edsl}/README.md +0 -0
- /mplang/{v2/edsl → edsl}/registry.py +0 -0
- /mplang/{v2/kernels → kernels}/Makefile +0 -0
- /mplang/{v2/kernels → kernels}/__init__.py +0 -0
- /mplang/{v2/kernels → kernels}/gf128.cpp +0 -0
- /mplang/{v2/libs → libs}/device/cluster.py +0 -0
- /mplang/{v2/libs → libs}/mpc/analytics/__init__.py +0 -0
- /mplang/{v2/libs → libs}/mpc/analytics/groupby.md +0 -0
- /mplang/{v2/libs → libs}/mpc/common/constants.py +0 -0
- /mplang/{v2/libs → libs}/mpc/ot/__init__.py +0 -0
- /mplang/{v2/libs → libs}/mpc/psi/__init__.py +0 -0
- /mplang/{v2/libs → libs}/mpc/vole/__init__.py +0 -0
- /mplang/{v2/runtime → runtime}/__init__.py +0 -0
- /mplang/{v2/runtime → runtime}/dialect_state.py +0 -0
- /mplang/{v2/runtime → runtime}/object_store.py +0 -0
- {mplang_nightly-0.1.dev269.dist-info → mplang_nightly-0.1.dev271.dist-info}/WHEEL +0 -0
- {mplang_nightly-0.1.dev269.dist-info → mplang_nightly-0.1.dev271.dist-info}/entry_points.txt +0 -0
- {mplang_nightly-0.1.dev269.dist-info → mplang_nightly-0.1.dev271.dist-info}/licenses/LICENSE +0 -0
mplang/v1/core/expr/evaluator.py
DELETED
|
@@ -1,581 +0,0 @@
|
|
|
1
|
-
# Copyright 2025 Ant Group Co., Ltd.
|
|
2
|
-
#
|
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
-
# you may not use this file except in compliance with the License.
|
|
5
|
-
# You may obtain a copy of the License at
|
|
6
|
-
#
|
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
-
#
|
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
-
# See the License for the specific language governing permissions and
|
|
13
|
-
# limitations under the License.
|
|
14
|
-
|
|
15
|
-
"""
|
|
16
|
-
Expression evaluation engines for MPLang expressions.
|
|
17
|
-
|
|
18
|
-
- IterativeEvaluator: non-recursive dataflow executor.
|
|
19
|
-
- RecursiveEvaluator: visitor-based executor.
|
|
20
|
-
- EvalSemantic: shared helpers for both engines.
|
|
21
|
-
- IEvaluator: minimal evaluation interface.
|
|
22
|
-
- evaluator(kind, ...): factory returning an IEvaluator.
|
|
23
|
-
"""
|
|
24
|
-
|
|
25
|
-
from __future__ import annotations
|
|
26
|
-
|
|
27
|
-
from dataclasses import dataclass
|
|
28
|
-
from typing import Any, Protocol
|
|
29
|
-
|
|
30
|
-
from mplang.v1.core.comm import ICommunicator
|
|
31
|
-
from mplang.v1.core.expr.ast import (
|
|
32
|
-
AccessExpr,
|
|
33
|
-
CallExpr,
|
|
34
|
-
CondExpr,
|
|
35
|
-
ConvExpr,
|
|
36
|
-
EvalExpr,
|
|
37
|
-
Expr,
|
|
38
|
-
FuncDefExpr,
|
|
39
|
-
ShflExpr,
|
|
40
|
-
ShflSExpr,
|
|
41
|
-
TupleExpr,
|
|
42
|
-
VariableExpr,
|
|
43
|
-
WhileExpr,
|
|
44
|
-
)
|
|
45
|
-
from mplang.v1.core.expr.visitor import ExprVisitor
|
|
46
|
-
from mplang.v1.core.expr.walk import walk_dataflow
|
|
47
|
-
from mplang.v1.core.mask import Mask
|
|
48
|
-
from mplang.v1.core.pfunc import PFunction
|
|
49
|
-
from mplang.v1.kernels.context import RuntimeContext
|
|
50
|
-
from mplang.v1.kernels.value import Value
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
class IEvaluator(Protocol):
|
|
54
|
-
"""Public evaluator protocol.
|
|
55
|
-
|
|
56
|
-
Added 'runtime' attribute so callers (simulation/resource) can seed
|
|
57
|
-
backend state via evaluator.runtime.run_kernel(...).
|
|
58
|
-
"""
|
|
59
|
-
|
|
60
|
-
runtime: RuntimeContext
|
|
61
|
-
|
|
62
|
-
def evaluate(self, root: Expr, env: dict[str, Any] | None = None) -> list[Any]: ...
|
|
63
|
-
|
|
64
|
-
|
|
65
|
-
@dataclass
|
|
66
|
-
class EvalSemantic:
|
|
67
|
-
"""Shared evaluation semantics and utilities for evaluators.
|
|
68
|
-
|
|
69
|
-
Minimal dataclass carrying runtime execution context (rank/env/comm/runtime).
|
|
70
|
-
"""
|
|
71
|
-
|
|
72
|
-
rank: int
|
|
73
|
-
env: dict[str, Any]
|
|
74
|
-
comm: ICommunicator
|
|
75
|
-
runtime: RuntimeContext
|
|
76
|
-
|
|
77
|
-
# ------------------------------ Shared helpers (semantics) ------------------------------
|
|
78
|
-
def _should_run(self, rmask: Mask | None, args: list[Any]) -> bool:
|
|
79
|
-
if rmask is not None:
|
|
80
|
-
return self.comm.rank in Mask(rmask)
|
|
81
|
-
return all(arg is not None for arg in args)
|
|
82
|
-
|
|
83
|
-
def _exec_pfunc(self, pfunc: PFunction, args: list[Any]) -> list[Any]:
|
|
84
|
-
return self.runtime.run_kernel(pfunc, args)
|
|
85
|
-
|
|
86
|
-
def _eval_eval_node(self, expr: EvalExpr, arg_vals: list[Any]) -> list[Any]:
|
|
87
|
-
assert isinstance(expr.pfunc, PFunction)
|
|
88
|
-
if not self._should_run(expr.rmask, arg_vals):
|
|
89
|
-
return [None] * len(expr.mptypes)
|
|
90
|
-
return self._exec_pfunc(expr.pfunc, arg_vals)
|
|
91
|
-
|
|
92
|
-
def _eval_conv_node(self, vars_vals: list[Any]) -> list[Any]:
|
|
93
|
-
assert len(vars_vals) > 0, "pconv called with empty vars list."
|
|
94
|
-
filtered = [v for v in vars_vals if v is not None]
|
|
95
|
-
if len(filtered) == 0:
|
|
96
|
-
return [None]
|
|
97
|
-
if len(filtered) == 1:
|
|
98
|
-
return [filtered[0]]
|
|
99
|
-
raise ValueError(f"pconv called with multiple vars={filtered}.")
|
|
100
|
-
|
|
101
|
-
def _eval_shfl_s_node(self, expr: ShflSExpr, src_value: Any) -> list[Any]:
|
|
102
|
-
pmask = expr.pmask
|
|
103
|
-
src_ranks = expr.src_ranks
|
|
104
|
-
dst_ranks = list(Mask(pmask))
|
|
105
|
-
assert len(src_ranks) == len(dst_ranks)
|
|
106
|
-
cid = self.comm.new_id()
|
|
107
|
-
result = []
|
|
108
|
-
for src, dst in zip(src_ranks, dst_ranks, strict=True):
|
|
109
|
-
if self.comm.rank == src:
|
|
110
|
-
self.comm.send(dst, cid, src_value)
|
|
111
|
-
for src, dst in zip(src_ranks, dst_ranks, strict=True):
|
|
112
|
-
if self.comm.rank == dst:
|
|
113
|
-
result.append(self.comm.recv(src, cid))
|
|
114
|
-
if self.comm.rank in dst_ranks:
|
|
115
|
-
assert len(result) == 1
|
|
116
|
-
return result
|
|
117
|
-
else:
|
|
118
|
-
assert len(result) == 0
|
|
119
|
-
return [None]
|
|
120
|
-
|
|
121
|
-
def _eval_shfl_node(self, expr: ShflExpr, data: Any, index: Any) -> list[Any]:
|
|
122
|
-
# allgather index via send/recv
|
|
123
|
-
indices = [None] * self.comm.world_size
|
|
124
|
-
cid = self.comm.new_id()
|
|
125
|
-
for dst_rank in range(self.comm.world_size):
|
|
126
|
-
if dst_rank != self.comm.rank:
|
|
127
|
-
self.comm.send(dst_rank, cid, index)
|
|
128
|
-
for src_rank in range(self.comm.world_size):
|
|
129
|
-
if src_rank != self.comm.rank:
|
|
130
|
-
indices[src_rank] = self.comm.recv(src_rank, cid)
|
|
131
|
-
else:
|
|
132
|
-
indices[src_rank] = index
|
|
133
|
-
indices_int: list[int | None] = [self._as_optional_int(val) for val in indices]
|
|
134
|
-
send_pairs: list[tuple[int, int]] = []
|
|
135
|
-
for dst_idx, src_idx in enumerate(indices_int):
|
|
136
|
-
if src_idx is not None:
|
|
137
|
-
send_pairs.append((src_idx, dst_idx))
|
|
138
|
-
send_pairs.sort()
|
|
139
|
-
cid = self.comm.new_id()
|
|
140
|
-
received_data = None
|
|
141
|
-
for src_rank, dst_rank in send_pairs:
|
|
142
|
-
if self.comm.rank == src_rank:
|
|
143
|
-
self.comm.send(dst_rank, cid, data)
|
|
144
|
-
for src_rank, dst_rank in send_pairs:
|
|
145
|
-
if self.comm.rank == dst_rank:
|
|
146
|
-
received_data = self.comm.recv(src_rank, cid)
|
|
147
|
-
return [received_data]
|
|
148
|
-
|
|
149
|
-
@staticmethod
|
|
150
|
-
def _as_optional_int(val: Any) -> int | None:
|
|
151
|
-
"""Convert a value to int if possible, preserving None.
|
|
152
|
-
|
|
153
|
-
Handles Python ints, floats, numpy scalar types (e.g., np.int32, np.float64), and None.
|
|
154
|
-
Uses int(val) for conversion which works with numpy scalars via __int__().
|
|
155
|
-
"""
|
|
156
|
-
val = EvalSemantic._unwrap_value(val)
|
|
157
|
-
if val is None:
|
|
158
|
-
return None
|
|
159
|
-
return int(val)
|
|
160
|
-
|
|
161
|
-
def _simple_allgather(self, value: Any) -> list[Any]:
|
|
162
|
-
"""All-gather emulation using only ICommunicator send/recv.
|
|
163
|
-
|
|
164
|
-
This implements an O(P^2) pairwise exchange (each rank sends its value to all
|
|
165
|
-
other ranks) and collects values in rank order. Suitable for small P (typical
|
|
166
|
-
controller / simulation sizes) and control metadata like a single bool.
|
|
167
|
-
|
|
168
|
-
Returns a list of length world_size with entries ordered by rank.
|
|
169
|
-
"""
|
|
170
|
-
ws = self.comm.world_size
|
|
171
|
-
value = self._unwrap_value(value)
|
|
172
|
-
# Trivial fast-path
|
|
173
|
-
if ws == 1:
|
|
174
|
-
return [value]
|
|
175
|
-
cid = self.comm.new_id()
|
|
176
|
-
gathered: list[Any] = [None] * ws # type: ignore
|
|
177
|
-
gathered[self.comm.rank] = value
|
|
178
|
-
# Fan-out
|
|
179
|
-
for dst in range(ws):
|
|
180
|
-
if dst != self.comm.rank:
|
|
181
|
-
self.comm.send(dst, cid, value)
|
|
182
|
-
# Fan-in
|
|
183
|
-
for src in range(ws):
|
|
184
|
-
if src != self.comm.rank:
|
|
185
|
-
gathered[src] = self.comm.recv(src, cid)
|
|
186
|
-
return gathered
|
|
187
|
-
|
|
188
|
-
def _verify_uniform_predicate(self, pred: Any) -> None:
|
|
189
|
-
# Runtime uniformity check (O(P^2) send/recv emulation).
|
|
190
|
-
# Use Value.to_bool() if available, otherwise unwrap and convert
|
|
191
|
-
if isinstance(pred, Value):
|
|
192
|
-
pred_bool = pred.to_bool()
|
|
193
|
-
else:
|
|
194
|
-
pred_bool = bool(self._unwrap_value(pred))
|
|
195
|
-
vals = self._simple_allgather(pred_bool)
|
|
196
|
-
if not vals:
|
|
197
|
-
raise ValueError("uniform_cond: empty gather for predicate")
|
|
198
|
-
first = vals[0]
|
|
199
|
-
for v in vals[1:]:
|
|
200
|
-
if v != first:
|
|
201
|
-
raise ValueError(
|
|
202
|
-
"uniform_cond: predicate is not uniform across parties"
|
|
203
|
-
)
|
|
204
|
-
|
|
205
|
-
# ------------------------------ While helpers ------------------------------
|
|
206
|
-
def _check_while_predicate(self, cond_result: list[Any]) -> Any:
|
|
207
|
-
"""Validate while_loop predicate evaluation result.
|
|
208
|
-
|
|
209
|
-
Ensures the condition function returns exactly one value and that value
|
|
210
|
-
is non-None. Returns the boolean predicate value for convenience.
|
|
211
|
-
|
|
212
|
-
Raises:
|
|
213
|
-
AssertionError: If condition function returns != 1 value.
|
|
214
|
-
RuntimeError: If the single predicate value is None.
|
|
215
|
-
"""
|
|
216
|
-
assert len(cond_result) == 1, (
|
|
217
|
-
f"Condition function must return a single value, got {cond_result}"
|
|
218
|
-
)
|
|
219
|
-
cond_val = cond_result[0]
|
|
220
|
-
if cond_val is None:
|
|
221
|
-
raise RuntimeError(
|
|
222
|
-
"while_loop condition produced None on rank "
|
|
223
|
-
f"{self.rank}; ensure the predicate yields a boolean for every party."
|
|
224
|
-
)
|
|
225
|
-
# Use Value.to_bool() if available for cleaner conversion
|
|
226
|
-
if isinstance(cond_val, Value):
|
|
227
|
-
return cond_val.to_bool()
|
|
228
|
-
return bool(self._unwrap_value(cond_val))
|
|
229
|
-
|
|
230
|
-
@staticmethod
|
|
231
|
-
def _unwrap_value(value: Any) -> Any:
|
|
232
|
-
"""Convert Value payloads to numpy/python equivalents when possible."""
|
|
233
|
-
if value is None:
|
|
234
|
-
return None
|
|
235
|
-
|
|
236
|
-
if isinstance(value, Value):
|
|
237
|
-
# Try to_numpy first for broader compatibility
|
|
238
|
-
to_numpy = getattr(value, "to_numpy", None)
|
|
239
|
-
if callable(to_numpy):
|
|
240
|
-
arr = to_numpy()
|
|
241
|
-
import numpy as np
|
|
242
|
-
|
|
243
|
-
if isinstance(arr, np.ndarray):
|
|
244
|
-
if arr.size == 1:
|
|
245
|
-
return arr.item()
|
|
246
|
-
return arr
|
|
247
|
-
return arr
|
|
248
|
-
return value
|
|
249
|
-
|
|
250
|
-
|
|
251
|
-
class RecursiveEvaluator(EvalSemantic, ExprVisitor):
|
|
252
|
-
"""Recursive visitor-based evaluator."""
|
|
253
|
-
|
|
254
|
-
def __init__(
|
|
255
|
-
self,
|
|
256
|
-
rank: int,
|
|
257
|
-
env: dict[str, Any],
|
|
258
|
-
comm: ICommunicator,
|
|
259
|
-
runtime: RuntimeContext,
|
|
260
|
-
) -> None:
|
|
261
|
-
super().__init__(rank, env, comm, runtime)
|
|
262
|
-
self._cache: dict[int, Any] = {} # Cache based on expr id
|
|
263
|
-
|
|
264
|
-
def _get_var(self, name: str) -> Any:
|
|
265
|
-
"""Get variable from environment."""
|
|
266
|
-
if name not in self.env:
|
|
267
|
-
raise ValueError(f"Variable '{name}' not found in evaluator environment")
|
|
268
|
-
return self.env[name]
|
|
269
|
-
|
|
270
|
-
def _value(self, expr: Expr) -> Any:
|
|
271
|
-
"""Evaluate an expression and cache the result."""
|
|
272
|
-
values = self._values(expr)
|
|
273
|
-
if len(expr.mptypes) != 1:
|
|
274
|
-
raise ValueError(
|
|
275
|
-
f"Expected single value for expression {expr}, got {len(values)} values"
|
|
276
|
-
)
|
|
277
|
-
return values[0]
|
|
278
|
-
|
|
279
|
-
def _values(self, expr: Expr) -> list[Any]:
|
|
280
|
-
"""Evaluate an expression and return the result as a list."""
|
|
281
|
-
expr_id = id(expr)
|
|
282
|
-
if expr_id not in self._cache:
|
|
283
|
-
self._cache[expr_id] = expr.accept(self)
|
|
284
|
-
values = self._cache[expr_id]
|
|
285
|
-
if not isinstance(values, list):
|
|
286
|
-
raise ValueError(f"got {type(values)} for expression {expr}")
|
|
287
|
-
return values
|
|
288
|
-
|
|
289
|
-
# Internal helper to create a new evaluator with extended env for nested regions
|
|
290
|
-
def _fork(self, sub_bindings: dict[str, Any]) -> RecursiveEvaluator:
|
|
291
|
-
merged_env = {**self.env, **sub_bindings}
|
|
292
|
-
# Create a child evaluator sharing the same runtime (no new backend state).
|
|
293
|
-
return RecursiveEvaluator(self.rank, merged_env, self.comm, self.runtime)
|
|
294
|
-
|
|
295
|
-
def visit_eval(self, expr: EvalExpr) -> Any:
|
|
296
|
-
"""Evaluate function call expression."""
|
|
297
|
-
args = [self._value(arg) for arg in expr.args]
|
|
298
|
-
return self._eval_eval_node(expr, args)
|
|
299
|
-
|
|
300
|
-
def visit_variable(self, expr: VariableExpr) -> Any:
|
|
301
|
-
"""Evaluate variable expression - just look up in environment.
|
|
302
|
-
|
|
303
|
-
No distinction between captured variables and parameters at this level.
|
|
304
|
-
All variables are just names to be resolved in the current environment.
|
|
305
|
-
"""
|
|
306
|
-
value = self._get_var(expr.name)
|
|
307
|
-
# Ensure consistency: all visit methods should return a list
|
|
308
|
-
return [value]
|
|
309
|
-
|
|
310
|
-
def visit_tuple(self, expr: TupleExpr) -> Any:
|
|
311
|
-
"""Evaluate tuple expression."""
|
|
312
|
-
results = [self._value(arg) for arg in expr.args]
|
|
313
|
-
return results
|
|
314
|
-
|
|
315
|
-
def visit_cond(self, expr: CondExpr) -> Any:
|
|
316
|
-
"""Evaluate conditional expression (uniform/global semantics).
|
|
317
|
-
|
|
318
|
-
Current behavior:
|
|
319
|
-
* Assumes predicate is already uniform (same value on every enabled party).
|
|
320
|
-
* Only the selected branch is executed locally.
|
|
321
|
-
* If this party is masked out for outputs, returns [None] placeholders.
|
|
322
|
-
|
|
323
|
-
Future optimization notes:
|
|
324
|
-
* Current uniform verification uses an O(P^2) manual all-gather. Replace
|
|
325
|
-
with a communicator-level boolean all-reduce (AND + broadcast) when available.
|
|
326
|
-
* Add optional static uniform inference (data provenance) to elide the
|
|
327
|
-
runtime check when predicate uniformity is provable at trace time.
|
|
328
|
-
"""
|
|
329
|
-
pred_val = self._value(expr.pred)
|
|
330
|
-
if pred_val is None:
|
|
331
|
-
return [None] * len(expr.mptypes)
|
|
332
|
-
|
|
333
|
-
if expr.verify_uniform:
|
|
334
|
-
self._verify_uniform_predicate(pred_val)
|
|
335
|
-
|
|
336
|
-
# Convert to bool using Value.to_bool() if available
|
|
337
|
-
if isinstance(pred_val, Value):
|
|
338
|
-
pred = pred_val.to_bool()
|
|
339
|
-
else:
|
|
340
|
-
pred = bool(self._unwrap_value(pred_val))
|
|
341
|
-
|
|
342
|
-
# Only evaluate selected branch locally
|
|
343
|
-
if bool(pred):
|
|
344
|
-
then_call = CallExpr("then", expr.then_fn, expr.args)
|
|
345
|
-
return self._values(then_call)
|
|
346
|
-
else:
|
|
347
|
-
else_call = CallExpr("else", expr.else_fn, expr.args)
|
|
348
|
-
return self._values(else_call)
|
|
349
|
-
|
|
350
|
-
def visit_call(self, expr: CallExpr) -> Any:
|
|
351
|
-
args = [self._value(arg) for arg in expr.args]
|
|
352
|
-
assert isinstance(expr.fn, FuncDefExpr)
|
|
353
|
-
sub_env = dict(zip(expr.fn.params, args, strict=True))
|
|
354
|
-
sub_evaluator = self._fork(sub_env)
|
|
355
|
-
return expr.fn.body.accept(sub_evaluator)
|
|
356
|
-
|
|
357
|
-
def visit_while(self, expr: WhileExpr) -> Any:
|
|
358
|
-
"""Evaluate while loop expression."""
|
|
359
|
-
# Start with initial state
|
|
360
|
-
state = [self._value(arg) for arg in expr.args]
|
|
361
|
-
|
|
362
|
-
while True:
|
|
363
|
-
# Call condition function
|
|
364
|
-
cond_env = dict(zip(expr.cond_fn.params, state, strict=True))
|
|
365
|
-
cond_evaluator = self._fork(cond_env)
|
|
366
|
-
cond_result = expr.cond_fn.body.accept(cond_evaluator)
|
|
367
|
-
cond_value = self._check_while_predicate(cond_result)
|
|
368
|
-
if not cond_value:
|
|
369
|
-
break
|
|
370
|
-
|
|
371
|
-
# Call body function with same arguments
|
|
372
|
-
body_env = dict(zip(expr.body_fn.params, state, strict=True))
|
|
373
|
-
body_evaluator = self._fork(body_env)
|
|
374
|
-
new_state = expr.body_fn.body.accept(body_evaluator)
|
|
375
|
-
|
|
376
|
-
assert len(new_state) == len(expr.body_fn.mptypes)
|
|
377
|
-
assert len(new_state) <= len(state)
|
|
378
|
-
|
|
379
|
-
state = new_state + state[len(new_state) :]
|
|
380
|
-
|
|
381
|
-
# Return in the same format as original arguments
|
|
382
|
-
return state[0 : len(expr.body_fn.mptypes)]
|
|
383
|
-
|
|
384
|
-
def visit_conv(self, expr: ConvExpr) -> Any:
|
|
385
|
-
"""Evaluate converge expression."""
|
|
386
|
-
vals = [self._value(arg) for arg in expr.vars]
|
|
387
|
-
return self._eval_conv_node(vals)
|
|
388
|
-
|
|
389
|
-
def visit_shfl_s(self, expr: ShflSExpr) -> Any:
|
|
390
|
-
"""Evaluate static shuffle expression."""
|
|
391
|
-
value = self._value(expr.src_val)
|
|
392
|
-
return self._eval_shfl_s_node(expr, value)
|
|
393
|
-
|
|
394
|
-
def visit_shfl(self, expr: ShflExpr) -> Any:
|
|
395
|
-
"""Evaluate dynamic shuffle expression."""
|
|
396
|
-
data = self._value(expr.src)
|
|
397
|
-
index = self._value(expr.index)
|
|
398
|
-
return self._eval_shfl_node(expr, data, index)
|
|
399
|
-
|
|
400
|
-
def visit_access(self, expr: AccessExpr) -> Any:
|
|
401
|
-
"""Evaluate access expression."""
|
|
402
|
-
# Evaluate the expression and access the specified index
|
|
403
|
-
result = self._values(expr.src)
|
|
404
|
-
|
|
405
|
-
if expr.index < 0 or expr.index >= len(result):
|
|
406
|
-
raise IndexError(
|
|
407
|
-
f"Index {expr.index} out of range for list of length {len(result)}"
|
|
408
|
-
)
|
|
409
|
-
return [result[expr.index]] # Ensure we return a list
|
|
410
|
-
|
|
411
|
-
def visit_func_def(self, expr: FuncDefExpr) -> Any:
|
|
412
|
-
raise RuntimeError("FuncDefExpr should not be directly evaluated")
|
|
413
|
-
|
|
414
|
-
# IEvaluator API: return list of values
|
|
415
|
-
def evaluate(self, root: Expr, env: dict[str, Any] | None = None) -> list[Any]:
|
|
416
|
-
if env is None:
|
|
417
|
-
res = root.accept(self)
|
|
418
|
-
else:
|
|
419
|
-
# Spawn a sibling evaluator with override env but same runtime.
|
|
420
|
-
res = root.accept(
|
|
421
|
-
RecursiveEvaluator(self.rank, env, self.comm, self.runtime)
|
|
422
|
-
)
|
|
423
|
-
if not isinstance(res, list):
|
|
424
|
-
raise ValueError(f"got {type(res)} for expression {root}")
|
|
425
|
-
return res
|
|
426
|
-
|
|
427
|
-
|
|
428
|
-
class IterativeEvaluator(EvalSemantic):
|
|
429
|
-
"""Iterative (non-recursive) evaluator using dataflow traversal."""
|
|
430
|
-
|
|
431
|
-
def __init__(
|
|
432
|
-
self,
|
|
433
|
-
rank: int,
|
|
434
|
-
env: dict[str, Any],
|
|
435
|
-
comm: ICommunicator,
|
|
436
|
-
runtime: RuntimeContext,
|
|
437
|
-
) -> None:
|
|
438
|
-
super().__init__(rank, env, comm, runtime)
|
|
439
|
-
|
|
440
|
-
@staticmethod
|
|
441
|
-
def _first(vals: list[Any]) -> Any:
|
|
442
|
-
if not isinstance(vals, list):
|
|
443
|
-
return vals
|
|
444
|
-
if len(vals) == 0:
|
|
445
|
-
return None
|
|
446
|
-
return vals[0]
|
|
447
|
-
|
|
448
|
-
def _merge_state(self, old: list[Any], new: list[Any]) -> list[Any]:
|
|
449
|
-
assert len(new) <= len(old)
|
|
450
|
-
return new + old[len(new) :]
|
|
451
|
-
|
|
452
|
-
def _iter_eval_graph(self, root: Expr, env: dict[str, Any]) -> list[Any]:
|
|
453
|
-
symbols: dict[int, list[Any]] = {}
|
|
454
|
-
for node in walk_dataflow(root, traversal="dfs_post_iter"):
|
|
455
|
-
if isinstance(node, VariableExpr):
|
|
456
|
-
if node.name not in env:
|
|
457
|
-
raise ValueError(
|
|
458
|
-
f"Variable '{node.name}' not found in evaluator environment"
|
|
459
|
-
)
|
|
460
|
-
symbols[id(node)] = [env[node.name]]
|
|
461
|
-
elif isinstance(node, TupleExpr):
|
|
462
|
-
vals = [self._first(symbols[id(a)]) for a in node.args]
|
|
463
|
-
symbols[id(node)] = vals
|
|
464
|
-
elif isinstance(node, AccessExpr):
|
|
465
|
-
src_vals = symbols[id(node.src)]
|
|
466
|
-
symbols[id(node)] = [src_vals[node.index]]
|
|
467
|
-
elif isinstance(node, CallExpr):
|
|
468
|
-
arg_vals = [self._first(symbols[id(a)]) for a in node.args]
|
|
469
|
-
assert isinstance(node.fn, FuncDefExpr)
|
|
470
|
-
sub_env = dict(zip(node.fn.params, arg_vals, strict=True))
|
|
471
|
-
res = self._iter_eval_graph(node.fn.body, {**env, **sub_env})
|
|
472
|
-
symbols[id(node)] = res
|
|
473
|
-
elif isinstance(node, CondExpr):
|
|
474
|
-
pred_val = self._first(symbols[id(node.pred)])
|
|
475
|
-
arg_vals = [self._first(symbols[id(a)]) for a in node.args]
|
|
476
|
-
if pred_val is None:
|
|
477
|
-
symbols[id(node)] = [None] * len(node.mptypes)
|
|
478
|
-
else:
|
|
479
|
-
# Optional uniform verification identical to recursive evaluator (DRY helper).
|
|
480
|
-
if node.verify_uniform:
|
|
481
|
-
self._verify_uniform_predicate(pred_val)
|
|
482
|
-
# Convert to bool using Value.to_bool() if available
|
|
483
|
-
if isinstance(pred_val, Value):
|
|
484
|
-
pred = pred_val.to_bool()
|
|
485
|
-
else:
|
|
486
|
-
pred = bool(self._unwrap_value(pred_val))
|
|
487
|
-
if pred:
|
|
488
|
-
sub_env = dict(zip(node.then_fn.params, arg_vals, strict=True))
|
|
489
|
-
res = self._iter_eval_graph(
|
|
490
|
-
node.then_fn.body, {**env, **sub_env}
|
|
491
|
-
)
|
|
492
|
-
symbols[id(node)] = res
|
|
493
|
-
else:
|
|
494
|
-
sub_env = dict(zip(node.else_fn.params, arg_vals, strict=True))
|
|
495
|
-
res = self._iter_eval_graph(
|
|
496
|
-
node.else_fn.body, {**env, **sub_env}
|
|
497
|
-
)
|
|
498
|
-
symbols[id(node)] = res
|
|
499
|
-
elif isinstance(node, WhileExpr):
|
|
500
|
-
state = [self._first(symbols[id(a)]) for a in node.args]
|
|
501
|
-
while True:
|
|
502
|
-
cond_env = dict(zip(node.cond_fn.params, state, strict=True))
|
|
503
|
-
cond_vals = self._iter_eval_graph(
|
|
504
|
-
node.cond_fn.body, {**env, **cond_env}
|
|
505
|
-
)
|
|
506
|
-
cond_val = self._check_while_predicate(cond_vals)
|
|
507
|
-
if not bool(cond_val):
|
|
508
|
-
break
|
|
509
|
-
body_env = dict(zip(node.body_fn.params, state, strict=True))
|
|
510
|
-
new_state = self._iter_eval_graph(
|
|
511
|
-
node.body_fn.body, {**env, **body_env}
|
|
512
|
-
)
|
|
513
|
-
state = self._merge_state(state, new_state)
|
|
514
|
-
symbols[id(node)] = state[0 : len(node.body_fn.mptypes)]
|
|
515
|
-
elif isinstance(node, EvalExpr):
|
|
516
|
-
arg_vals = [self._first(symbols[id(a)]) for a in node.args]
|
|
517
|
-
symbols[id(node)] = self._eval_eval_node(node, arg_vals)
|
|
518
|
-
elif isinstance(node, ConvExpr):
|
|
519
|
-
vars_vals = [self._first(symbols[id(v)]) for v in node.vars]
|
|
520
|
-
symbols[id(node)] = self._eval_conv_node(vars_vals)
|
|
521
|
-
elif isinstance(node, ShflSExpr):
|
|
522
|
-
value = self._first(symbols[id(node.src_val)])
|
|
523
|
-
symbols[id(node)] = self._eval_shfl_s_node(node, value)
|
|
524
|
-
elif isinstance(node, ShflExpr):
|
|
525
|
-
data = self._first(symbols[id(node.src)])
|
|
526
|
-
index = self._first(symbols[id(node.index)])
|
|
527
|
-
symbols[id(node)] = self._eval_shfl_node(node, data, index)
|
|
528
|
-
elif isinstance(node, FuncDefExpr):
|
|
529
|
-
# Definition nodes are not evaluated; placeholder to satisfy walkers
|
|
530
|
-
symbols[id(node)] = node.body.mptypes
|
|
531
|
-
else:
|
|
532
|
-
raise NotImplementedError(
|
|
533
|
-
f"Unsupported expr in iterative eval: {type(node)}"
|
|
534
|
-
)
|
|
535
|
-
res = symbols[id(root)]
|
|
536
|
-
if not isinstance(res, list):
|
|
537
|
-
raise ValueError(f"got {type(res)} for expression {root}")
|
|
538
|
-
return res
|
|
539
|
-
|
|
540
|
-
def evaluate(self, root: Expr, env: dict[str, Any] | None = None) -> list[Any]:
|
|
541
|
-
"""Evaluate an expression graph iteratively (no Python recursion).
|
|
542
|
-
|
|
543
|
-
- Traverses dataflow using iterative DFS-postorder to compute ready nodes.
|
|
544
|
-
- For control flow/functional regions (Call/Cond/While), performs a
|
|
545
|
-
localized iterative evaluation of the region body with a child environment.
|
|
546
|
-
|
|
547
|
-
Args:
|
|
548
|
-
root: The root expression to evaluate.
|
|
549
|
-
env: Optional environment override for VariableExpr lookups.
|
|
550
|
-
|
|
551
|
-
Returns:
|
|
552
|
-
A list of computed output values for the root expression.
|
|
553
|
-
"""
|
|
554
|
-
cur_env = self.env if env is None else env
|
|
555
|
-
return self._iter_eval_graph(root, cur_env)
|
|
556
|
-
|
|
557
|
-
|
|
558
|
-
def create_evaluator(
|
|
559
|
-
rank: int,
|
|
560
|
-
env: dict[str, Any],
|
|
561
|
-
comm: ICommunicator,
|
|
562
|
-
runtime: RuntimeContext,
|
|
563
|
-
kind: str | None = "iterative",
|
|
564
|
-
) -> IEvaluator:
|
|
565
|
-
"""Factory to create an evaluator engine.
|
|
566
|
-
|
|
567
|
-
Args:
|
|
568
|
-
rank: Party rank.
|
|
569
|
-
env: Initial variable environment.
|
|
570
|
-
comm: Communicator for this party.
|
|
571
|
-
kind: Evaluator implementation ("iterative" or "recursive").
|
|
572
|
-
|
|
573
|
-
Returns:
|
|
574
|
-
An IEvaluator instance of the requested kind.
|
|
575
|
-
"""
|
|
576
|
-
# Backward compatibility: treat kind=None as default iterative implementation.
|
|
577
|
-
if kind is None or kind == "iterative":
|
|
578
|
-
return IterativeEvaluator(rank, env, comm, runtime)
|
|
579
|
-
if kind == "recursive":
|
|
580
|
-
return RecursiveEvaluator(rank, env, comm, runtime)
|
|
581
|
-
raise ValueError(f"Unknown evaluator kind: {kind}")
|