mplang-nightly 0.1.dev269__py3-none-any.whl → 0.1.dev271__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (180) hide show
  1. mplang/__init__.py +391 -17
  2. mplang/{v2/backends → backends}/__init__.py +9 -7
  3. mplang/{v2/backends → backends}/bfv_impl.py +6 -6
  4. mplang/{v2/backends → backends}/crypto_impl.py +6 -6
  5. mplang/{v2/backends → backends}/field_impl.py +5 -5
  6. mplang/{v2/backends → backends}/func_impl.py +4 -4
  7. mplang/{v2/backends → backends}/phe_impl.py +3 -3
  8. mplang/{v2/backends → backends}/simp_design.md +1 -1
  9. mplang/{v2/backends → backends}/simp_driver/__init__.py +5 -5
  10. mplang/{v2/backends → backends}/simp_driver/http.py +8 -8
  11. mplang/{v2/backends → backends}/simp_driver/mem.py +9 -9
  12. mplang/{v2/backends → backends}/simp_driver/ops.py +4 -4
  13. mplang/{v2/backends → backends}/simp_driver/state.py +2 -2
  14. mplang/{v2/backends → backends}/simp_driver/values.py +2 -2
  15. mplang/{v2/backends → backends}/simp_worker/__init__.py +3 -3
  16. mplang/{v2/backends → backends}/simp_worker/http.py +10 -10
  17. mplang/{v2/backends → backends}/simp_worker/mem.py +1 -1
  18. mplang/{v2/backends → backends}/simp_worker/ops.py +5 -5
  19. mplang/{v2/backends → backends}/simp_worker/state.py +2 -4
  20. mplang/{v2/backends → backends}/spu_impl.py +8 -8
  21. mplang/{v2/backends → backends}/spu_state.py +4 -4
  22. mplang/{v2/backends → backends}/store_impl.py +3 -3
  23. mplang/{v2/backends → backends}/table_impl.py +8 -8
  24. mplang/{v2/backends → backends}/tee_impl.py +6 -6
  25. mplang/{v2/backends → backends}/tensor_impl.py +6 -6
  26. mplang/{v2/cli.py → cli.py} +9 -9
  27. mplang/{v2/cli_guide.md → cli_guide.md} +12 -12
  28. mplang/{v2/dialects → dialects}/__init__.py +5 -5
  29. mplang/{v2/dialects → dialects}/bfv.py +6 -6
  30. mplang/{v2/dialects → dialects}/crypto.py +5 -5
  31. mplang/{v2/dialects → dialects}/dtypes.py +2 -2
  32. mplang/{v2/dialects → dialects}/field.py +3 -3
  33. mplang/{v2/dialects → dialects}/func.py +2 -2
  34. mplang/{v2/dialects → dialects}/phe.py +6 -6
  35. mplang/{v2/dialects → dialects}/simp.py +6 -6
  36. mplang/{v2/dialects → dialects}/spu.py +7 -7
  37. mplang/{v2/dialects → dialects}/store.py +2 -2
  38. mplang/{v2/dialects → dialects}/table.py +3 -3
  39. mplang/{v2/dialects → dialects}/tee.py +6 -6
  40. mplang/{v2/dialects → dialects}/tensor.py +5 -5
  41. mplang/{v2/edsl → edsl}/__init__.py +3 -3
  42. mplang/{v2/edsl → edsl}/context.py +6 -6
  43. mplang/{v2/edsl → edsl}/graph.py +5 -5
  44. mplang/{v2/edsl → edsl}/jit.py +2 -2
  45. mplang/{v2/edsl → edsl}/object.py +1 -1
  46. mplang/{v2/edsl → edsl}/primitive.py +5 -5
  47. mplang/{v2/edsl → edsl}/printer.py +1 -1
  48. mplang/{v2/edsl → edsl}/serde.py +1 -1
  49. mplang/{v2/edsl → edsl}/tracer.py +7 -7
  50. mplang/{v2/edsl → edsl}/typing.py +1 -1
  51. mplang/{v2/kernels → kernels}/ldpc.cpp +13 -13
  52. mplang/{v2/kernels → kernels}/okvs.cpp +4 -4
  53. mplang/{v2/kernels → kernels}/okvs_opt.cpp +31 -31
  54. mplang/{v2/kernels → kernels}/py_kernels.py +1 -1
  55. mplang/{v2/libs → libs}/collective.py +5 -5
  56. mplang/{v2/libs → libs}/device/__init__.py +1 -1
  57. mplang/{v2/libs → libs}/device/api.py +12 -12
  58. mplang/{v2/libs → libs}/ml/__init__.py +1 -1
  59. mplang/{v2/libs → libs}/ml/sgb.py +4 -4
  60. mplang/{v2/libs → libs}/mpc/__init__.py +3 -3
  61. mplang/{v2/libs → libs}/mpc/_utils.py +2 -2
  62. mplang/{v2/libs → libs}/mpc/analytics/aggregation.py +1 -1
  63. mplang/{v2/libs → libs}/mpc/analytics/groupby.py +2 -2
  64. mplang/{v2/libs → libs}/mpc/analytics/permutation.py +3 -3
  65. mplang/{v2/libs → libs}/mpc/ot/base.py +3 -3
  66. mplang/{v2/libs → libs}/mpc/ot/extension.py +2 -2
  67. mplang/{v2/libs → libs}/mpc/ot/silent.py +4 -4
  68. mplang/{v2/libs → libs}/mpc/psi/cuckoo.py +3 -3
  69. mplang/{v2/libs → libs}/mpc/psi/okvs.py +1 -1
  70. mplang/{v2/libs → libs}/mpc/psi/okvs_gct.py +3 -3
  71. mplang/{v2/libs → libs}/mpc/psi/oprf.py +3 -3
  72. mplang/{v2/libs → libs}/mpc/psi/rr22.py +7 -7
  73. mplang/{v2/libs → libs}/mpc/psi/unbalanced.py +4 -4
  74. mplang/{v2/libs → libs}/mpc/vole/gilboa.py +3 -3
  75. mplang/{v2/libs → libs}/mpc/vole/ldpc.py +2 -2
  76. mplang/{v2/libs → libs}/mpc/vole/silver.py +6 -6
  77. mplang/{v2/runtime → runtime}/interpreter.py +11 -11
  78. mplang/{v2/runtime → runtime}/value.py +2 -2
  79. mplang/{v1/runtime → utils}/__init__.py +18 -15
  80. mplang/{v1/utils → utils}/func_utils.py +1 -1
  81. {mplang_nightly-0.1.dev269.dist-info → mplang_nightly-0.1.dev271.dist-info}/METADATA +2 -2
  82. mplang_nightly-0.1.dev271.dist-info/RECORD +102 -0
  83. mplang/v1/__init__.py +0 -157
  84. mplang/v1/_device.py +0 -602
  85. mplang/v1/analysis/__init__.py +0 -37
  86. mplang/v1/analysis/diagram.py +0 -567
  87. mplang/v1/core/__init__.py +0 -157
  88. mplang/v1/core/cluster.py +0 -343
  89. mplang/v1/core/comm.py +0 -281
  90. mplang/v1/core/context_mgr.py +0 -50
  91. mplang/v1/core/dtypes.py +0 -335
  92. mplang/v1/core/expr/__init__.py +0 -80
  93. mplang/v1/core/expr/ast.py +0 -542
  94. mplang/v1/core/expr/evaluator.py +0 -581
  95. mplang/v1/core/expr/printer.py +0 -285
  96. mplang/v1/core/expr/transformer.py +0 -141
  97. mplang/v1/core/expr/utils.py +0 -78
  98. mplang/v1/core/expr/visitor.py +0 -85
  99. mplang/v1/core/expr/walk.py +0 -387
  100. mplang/v1/core/interp.py +0 -160
  101. mplang/v1/core/mask.py +0 -325
  102. mplang/v1/core/mpir.py +0 -965
  103. mplang/v1/core/mpobject.py +0 -117
  104. mplang/v1/core/mptype.py +0 -407
  105. mplang/v1/core/pfunc.py +0 -130
  106. mplang/v1/core/primitive.py +0 -877
  107. mplang/v1/core/table.py +0 -218
  108. mplang/v1/core/tensor.py +0 -75
  109. mplang/v1/core/tracer.py +0 -383
  110. mplang/v1/host.py +0 -130
  111. mplang/v1/kernels/__init__.py +0 -41
  112. mplang/v1/kernels/base.py +0 -125
  113. mplang/v1/kernels/basic.py +0 -240
  114. mplang/v1/kernels/context.py +0 -369
  115. mplang/v1/kernels/crypto.py +0 -122
  116. mplang/v1/kernels/fhe.py +0 -858
  117. mplang/v1/kernels/mock_tee.py +0 -72
  118. mplang/v1/kernels/phe.py +0 -1864
  119. mplang/v1/kernels/spu.py +0 -341
  120. mplang/v1/kernels/sql_duckdb.py +0 -44
  121. mplang/v1/kernels/stablehlo.py +0 -90
  122. mplang/v1/kernels/value.py +0 -626
  123. mplang/v1/ops/__init__.py +0 -35
  124. mplang/v1/ops/base.py +0 -424
  125. mplang/v1/ops/basic.py +0 -294
  126. mplang/v1/ops/crypto.py +0 -262
  127. mplang/v1/ops/fhe.py +0 -272
  128. mplang/v1/ops/jax_cc.py +0 -147
  129. mplang/v1/ops/nnx_cc.py +0 -168
  130. mplang/v1/ops/phe.py +0 -216
  131. mplang/v1/ops/spu.py +0 -151
  132. mplang/v1/ops/sql_cc.py +0 -303
  133. mplang/v1/ops/tee.py +0 -36
  134. mplang/v1/protos/v1alpha1/mpir_pb2.py +0 -63
  135. mplang/v1/protos/v1alpha1/mpir_pb2.pyi +0 -557
  136. mplang/v1/protos/v1alpha1/value_pb2.py +0 -34
  137. mplang/v1/protos/v1alpha1/value_pb2.pyi +0 -169
  138. mplang/v1/runtime/channel.py +0 -230
  139. mplang/v1/runtime/cli.py +0 -451
  140. mplang/v1/runtime/client.py +0 -456
  141. mplang/v1/runtime/communicator.py +0 -131
  142. mplang/v1/runtime/data_providers.py +0 -303
  143. mplang/v1/runtime/driver.py +0 -324
  144. mplang/v1/runtime/exceptions.py +0 -27
  145. mplang/v1/runtime/http_api.md +0 -56
  146. mplang/v1/runtime/link_comm.py +0 -196
  147. mplang/v1/runtime/server.py +0 -501
  148. mplang/v1/runtime/session.py +0 -270
  149. mplang/v1/runtime/simulation.py +0 -324
  150. mplang/v1/simp/__init__.py +0 -13
  151. mplang/v1/simp/api.py +0 -353
  152. mplang/v1/simp/mpi.py +0 -131
  153. mplang/v1/simp/party.py +0 -225
  154. mplang/v1/simp/random.py +0 -120
  155. mplang/v1/simp/smpc.py +0 -238
  156. mplang/v1/utils/__init__.py +0 -13
  157. mplang/v1/utils/crypto.py +0 -32
  158. mplang/v1/utils/spu_utils.py +0 -130
  159. mplang/v1/utils/table_utils.py +0 -185
  160. mplang/v2/__init__.py +0 -424
  161. mplang_nightly-0.1.dev269.dist-info/RECORD +0 -180
  162. /mplang/{v2/backends → backends}/channel.py +0 -0
  163. /mplang/{v2/edsl → edsl}/README.md +0 -0
  164. /mplang/{v2/edsl → edsl}/registry.py +0 -0
  165. /mplang/{v2/kernels → kernels}/Makefile +0 -0
  166. /mplang/{v2/kernels → kernels}/__init__.py +0 -0
  167. /mplang/{v2/kernels → kernels}/gf128.cpp +0 -0
  168. /mplang/{v2/libs → libs}/device/cluster.py +0 -0
  169. /mplang/{v2/libs → libs}/mpc/analytics/__init__.py +0 -0
  170. /mplang/{v2/libs → libs}/mpc/analytics/groupby.md +0 -0
  171. /mplang/{v2/libs → libs}/mpc/common/constants.py +0 -0
  172. /mplang/{v2/libs → libs}/mpc/ot/__init__.py +0 -0
  173. /mplang/{v2/libs → libs}/mpc/psi/__init__.py +0 -0
  174. /mplang/{v2/libs → libs}/mpc/vole/__init__.py +0 -0
  175. /mplang/{v2/runtime → runtime}/__init__.py +0 -0
  176. /mplang/{v2/runtime → runtime}/dialect_state.py +0 -0
  177. /mplang/{v2/runtime → runtime}/object_store.py +0 -0
  178. {mplang_nightly-0.1.dev269.dist-info → mplang_nightly-0.1.dev271.dist-info}/WHEEL +0 -0
  179. {mplang_nightly-0.1.dev269.dist-info → mplang_nightly-0.1.dev271.dist-info}/entry_points.txt +0 -0
  180. {mplang_nightly-0.1.dev269.dist-info → mplang_nightly-0.1.dev271.dist-info}/licenses/LICENSE +0 -0
@@ -1,581 +0,0 @@
1
- # Copyright 2025 Ant Group Co., Ltd.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
-
15
- """
16
- Expression evaluation engines for MPLang expressions.
17
-
18
- - IterativeEvaluator: non-recursive dataflow executor.
19
- - RecursiveEvaluator: visitor-based executor.
20
- - EvalSemantic: shared helpers for both engines.
21
- - IEvaluator: minimal evaluation interface.
22
- - evaluator(kind, ...): factory returning an IEvaluator.
23
- """
24
-
25
- from __future__ import annotations
26
-
27
- from dataclasses import dataclass
28
- from typing import Any, Protocol
29
-
30
- from mplang.v1.core.comm import ICommunicator
31
- from mplang.v1.core.expr.ast import (
32
- AccessExpr,
33
- CallExpr,
34
- CondExpr,
35
- ConvExpr,
36
- EvalExpr,
37
- Expr,
38
- FuncDefExpr,
39
- ShflExpr,
40
- ShflSExpr,
41
- TupleExpr,
42
- VariableExpr,
43
- WhileExpr,
44
- )
45
- from mplang.v1.core.expr.visitor import ExprVisitor
46
- from mplang.v1.core.expr.walk import walk_dataflow
47
- from mplang.v1.core.mask import Mask
48
- from mplang.v1.core.pfunc import PFunction
49
- from mplang.v1.kernels.context import RuntimeContext
50
- from mplang.v1.kernels.value import Value
51
-
52
-
53
- class IEvaluator(Protocol):
54
- """Public evaluator protocol.
55
-
56
- Added 'runtime' attribute so callers (simulation/resource) can seed
57
- backend state via evaluator.runtime.run_kernel(...).
58
- """
59
-
60
- runtime: RuntimeContext
61
-
62
- def evaluate(self, root: Expr, env: dict[str, Any] | None = None) -> list[Any]: ...
63
-
64
-
65
- @dataclass
66
- class EvalSemantic:
67
- """Shared evaluation semantics and utilities for evaluators.
68
-
69
- Minimal dataclass carrying runtime execution context (rank/env/comm/runtime).
70
- """
71
-
72
- rank: int
73
- env: dict[str, Any]
74
- comm: ICommunicator
75
- runtime: RuntimeContext
76
-
77
- # ------------------------------ Shared helpers (semantics) ------------------------------
78
- def _should_run(self, rmask: Mask | None, args: list[Any]) -> bool:
79
- if rmask is not None:
80
- return self.comm.rank in Mask(rmask)
81
- return all(arg is not None for arg in args)
82
-
83
- def _exec_pfunc(self, pfunc: PFunction, args: list[Any]) -> list[Any]:
84
- return self.runtime.run_kernel(pfunc, args)
85
-
86
- def _eval_eval_node(self, expr: EvalExpr, arg_vals: list[Any]) -> list[Any]:
87
- assert isinstance(expr.pfunc, PFunction)
88
- if not self._should_run(expr.rmask, arg_vals):
89
- return [None] * len(expr.mptypes)
90
- return self._exec_pfunc(expr.pfunc, arg_vals)
91
-
92
- def _eval_conv_node(self, vars_vals: list[Any]) -> list[Any]:
93
- assert len(vars_vals) > 0, "pconv called with empty vars list."
94
- filtered = [v for v in vars_vals if v is not None]
95
- if len(filtered) == 0:
96
- return [None]
97
- if len(filtered) == 1:
98
- return [filtered[0]]
99
- raise ValueError(f"pconv called with multiple vars={filtered}.")
100
-
101
- def _eval_shfl_s_node(self, expr: ShflSExpr, src_value: Any) -> list[Any]:
102
- pmask = expr.pmask
103
- src_ranks = expr.src_ranks
104
- dst_ranks = list(Mask(pmask))
105
- assert len(src_ranks) == len(dst_ranks)
106
- cid = self.comm.new_id()
107
- result = []
108
- for src, dst in zip(src_ranks, dst_ranks, strict=True):
109
- if self.comm.rank == src:
110
- self.comm.send(dst, cid, src_value)
111
- for src, dst in zip(src_ranks, dst_ranks, strict=True):
112
- if self.comm.rank == dst:
113
- result.append(self.comm.recv(src, cid))
114
- if self.comm.rank in dst_ranks:
115
- assert len(result) == 1
116
- return result
117
- else:
118
- assert len(result) == 0
119
- return [None]
120
-
121
- def _eval_shfl_node(self, expr: ShflExpr, data: Any, index: Any) -> list[Any]:
122
- # allgather index via send/recv
123
- indices = [None] * self.comm.world_size
124
- cid = self.comm.new_id()
125
- for dst_rank in range(self.comm.world_size):
126
- if dst_rank != self.comm.rank:
127
- self.comm.send(dst_rank, cid, index)
128
- for src_rank in range(self.comm.world_size):
129
- if src_rank != self.comm.rank:
130
- indices[src_rank] = self.comm.recv(src_rank, cid)
131
- else:
132
- indices[src_rank] = index
133
- indices_int: list[int | None] = [self._as_optional_int(val) for val in indices]
134
- send_pairs: list[tuple[int, int]] = []
135
- for dst_idx, src_idx in enumerate(indices_int):
136
- if src_idx is not None:
137
- send_pairs.append((src_idx, dst_idx))
138
- send_pairs.sort()
139
- cid = self.comm.new_id()
140
- received_data = None
141
- for src_rank, dst_rank in send_pairs:
142
- if self.comm.rank == src_rank:
143
- self.comm.send(dst_rank, cid, data)
144
- for src_rank, dst_rank in send_pairs:
145
- if self.comm.rank == dst_rank:
146
- received_data = self.comm.recv(src_rank, cid)
147
- return [received_data]
148
-
149
- @staticmethod
150
- def _as_optional_int(val: Any) -> int | None:
151
- """Convert a value to int if possible, preserving None.
152
-
153
- Handles Python ints, floats, numpy scalar types (e.g., np.int32, np.float64), and None.
154
- Uses int(val) for conversion which works with numpy scalars via __int__().
155
- """
156
- val = EvalSemantic._unwrap_value(val)
157
- if val is None:
158
- return None
159
- return int(val)
160
-
161
- def _simple_allgather(self, value: Any) -> list[Any]:
162
- """All-gather emulation using only ICommunicator send/recv.
163
-
164
- This implements an O(P^2) pairwise exchange (each rank sends its value to all
165
- other ranks) and collects values in rank order. Suitable for small P (typical
166
- controller / simulation sizes) and control metadata like a single bool.
167
-
168
- Returns a list of length world_size with entries ordered by rank.
169
- """
170
- ws = self.comm.world_size
171
- value = self._unwrap_value(value)
172
- # Trivial fast-path
173
- if ws == 1:
174
- return [value]
175
- cid = self.comm.new_id()
176
- gathered: list[Any] = [None] * ws # type: ignore
177
- gathered[self.comm.rank] = value
178
- # Fan-out
179
- for dst in range(ws):
180
- if dst != self.comm.rank:
181
- self.comm.send(dst, cid, value)
182
- # Fan-in
183
- for src in range(ws):
184
- if src != self.comm.rank:
185
- gathered[src] = self.comm.recv(src, cid)
186
- return gathered
187
-
188
- def _verify_uniform_predicate(self, pred: Any) -> None:
189
- # Runtime uniformity check (O(P^2) send/recv emulation).
190
- # Use Value.to_bool() if available, otherwise unwrap and convert
191
- if isinstance(pred, Value):
192
- pred_bool = pred.to_bool()
193
- else:
194
- pred_bool = bool(self._unwrap_value(pred))
195
- vals = self._simple_allgather(pred_bool)
196
- if not vals:
197
- raise ValueError("uniform_cond: empty gather for predicate")
198
- first = vals[0]
199
- for v in vals[1:]:
200
- if v != first:
201
- raise ValueError(
202
- "uniform_cond: predicate is not uniform across parties"
203
- )
204
-
205
- # ------------------------------ While helpers ------------------------------
206
- def _check_while_predicate(self, cond_result: list[Any]) -> Any:
207
- """Validate while_loop predicate evaluation result.
208
-
209
- Ensures the condition function returns exactly one value and that value
210
- is non-None. Returns the boolean predicate value for convenience.
211
-
212
- Raises:
213
- AssertionError: If condition function returns != 1 value.
214
- RuntimeError: If the single predicate value is None.
215
- """
216
- assert len(cond_result) == 1, (
217
- f"Condition function must return a single value, got {cond_result}"
218
- )
219
- cond_val = cond_result[0]
220
- if cond_val is None:
221
- raise RuntimeError(
222
- "while_loop condition produced None on rank "
223
- f"{self.rank}; ensure the predicate yields a boolean for every party."
224
- )
225
- # Use Value.to_bool() if available for cleaner conversion
226
- if isinstance(cond_val, Value):
227
- return cond_val.to_bool()
228
- return bool(self._unwrap_value(cond_val))
229
-
230
- @staticmethod
231
- def _unwrap_value(value: Any) -> Any:
232
- """Convert Value payloads to numpy/python equivalents when possible."""
233
- if value is None:
234
- return None
235
-
236
- if isinstance(value, Value):
237
- # Try to_numpy first for broader compatibility
238
- to_numpy = getattr(value, "to_numpy", None)
239
- if callable(to_numpy):
240
- arr = to_numpy()
241
- import numpy as np
242
-
243
- if isinstance(arr, np.ndarray):
244
- if arr.size == 1:
245
- return arr.item()
246
- return arr
247
- return arr
248
- return value
249
-
250
-
251
- class RecursiveEvaluator(EvalSemantic, ExprVisitor):
252
- """Recursive visitor-based evaluator."""
253
-
254
- def __init__(
255
- self,
256
- rank: int,
257
- env: dict[str, Any],
258
- comm: ICommunicator,
259
- runtime: RuntimeContext,
260
- ) -> None:
261
- super().__init__(rank, env, comm, runtime)
262
- self._cache: dict[int, Any] = {} # Cache based on expr id
263
-
264
- def _get_var(self, name: str) -> Any:
265
- """Get variable from environment."""
266
- if name not in self.env:
267
- raise ValueError(f"Variable '{name}' not found in evaluator environment")
268
- return self.env[name]
269
-
270
- def _value(self, expr: Expr) -> Any:
271
- """Evaluate an expression and cache the result."""
272
- values = self._values(expr)
273
- if len(expr.mptypes) != 1:
274
- raise ValueError(
275
- f"Expected single value for expression {expr}, got {len(values)} values"
276
- )
277
- return values[0]
278
-
279
- def _values(self, expr: Expr) -> list[Any]:
280
- """Evaluate an expression and return the result as a list."""
281
- expr_id = id(expr)
282
- if expr_id not in self._cache:
283
- self._cache[expr_id] = expr.accept(self)
284
- values = self._cache[expr_id]
285
- if not isinstance(values, list):
286
- raise ValueError(f"got {type(values)} for expression {expr}")
287
- return values
288
-
289
- # Internal helper to create a new evaluator with extended env for nested regions
290
- def _fork(self, sub_bindings: dict[str, Any]) -> RecursiveEvaluator:
291
- merged_env = {**self.env, **sub_bindings}
292
- # Create a child evaluator sharing the same runtime (no new backend state).
293
- return RecursiveEvaluator(self.rank, merged_env, self.comm, self.runtime)
294
-
295
- def visit_eval(self, expr: EvalExpr) -> Any:
296
- """Evaluate function call expression."""
297
- args = [self._value(arg) for arg in expr.args]
298
- return self._eval_eval_node(expr, args)
299
-
300
- def visit_variable(self, expr: VariableExpr) -> Any:
301
- """Evaluate variable expression - just look up in environment.
302
-
303
- No distinction between captured variables and parameters at this level.
304
- All variables are just names to be resolved in the current environment.
305
- """
306
- value = self._get_var(expr.name)
307
- # Ensure consistency: all visit methods should return a list
308
- return [value]
309
-
310
- def visit_tuple(self, expr: TupleExpr) -> Any:
311
- """Evaluate tuple expression."""
312
- results = [self._value(arg) for arg in expr.args]
313
- return results
314
-
315
- def visit_cond(self, expr: CondExpr) -> Any:
316
- """Evaluate conditional expression (uniform/global semantics).
317
-
318
- Current behavior:
319
- * Assumes predicate is already uniform (same value on every enabled party).
320
- * Only the selected branch is executed locally.
321
- * If this party is masked out for outputs, returns [None] placeholders.
322
-
323
- Future optimization notes:
324
- * Current uniform verification uses an O(P^2) manual all-gather. Replace
325
- with a communicator-level boolean all-reduce (AND + broadcast) when available.
326
- * Add optional static uniform inference (data provenance) to elide the
327
- runtime check when predicate uniformity is provable at trace time.
328
- """
329
- pred_val = self._value(expr.pred)
330
- if pred_val is None:
331
- return [None] * len(expr.mptypes)
332
-
333
- if expr.verify_uniform:
334
- self._verify_uniform_predicate(pred_val)
335
-
336
- # Convert to bool using Value.to_bool() if available
337
- if isinstance(pred_val, Value):
338
- pred = pred_val.to_bool()
339
- else:
340
- pred = bool(self._unwrap_value(pred_val))
341
-
342
- # Only evaluate selected branch locally
343
- if bool(pred):
344
- then_call = CallExpr("then", expr.then_fn, expr.args)
345
- return self._values(then_call)
346
- else:
347
- else_call = CallExpr("else", expr.else_fn, expr.args)
348
- return self._values(else_call)
349
-
350
- def visit_call(self, expr: CallExpr) -> Any:
351
- args = [self._value(arg) for arg in expr.args]
352
- assert isinstance(expr.fn, FuncDefExpr)
353
- sub_env = dict(zip(expr.fn.params, args, strict=True))
354
- sub_evaluator = self._fork(sub_env)
355
- return expr.fn.body.accept(sub_evaluator)
356
-
357
- def visit_while(self, expr: WhileExpr) -> Any:
358
- """Evaluate while loop expression."""
359
- # Start with initial state
360
- state = [self._value(arg) for arg in expr.args]
361
-
362
- while True:
363
- # Call condition function
364
- cond_env = dict(zip(expr.cond_fn.params, state, strict=True))
365
- cond_evaluator = self._fork(cond_env)
366
- cond_result = expr.cond_fn.body.accept(cond_evaluator)
367
- cond_value = self._check_while_predicate(cond_result)
368
- if not cond_value:
369
- break
370
-
371
- # Call body function with same arguments
372
- body_env = dict(zip(expr.body_fn.params, state, strict=True))
373
- body_evaluator = self._fork(body_env)
374
- new_state = expr.body_fn.body.accept(body_evaluator)
375
-
376
- assert len(new_state) == len(expr.body_fn.mptypes)
377
- assert len(new_state) <= len(state)
378
-
379
- state = new_state + state[len(new_state) :]
380
-
381
- # Return in the same format as original arguments
382
- return state[0 : len(expr.body_fn.mptypes)]
383
-
384
- def visit_conv(self, expr: ConvExpr) -> Any:
385
- """Evaluate converge expression."""
386
- vals = [self._value(arg) for arg in expr.vars]
387
- return self._eval_conv_node(vals)
388
-
389
- def visit_shfl_s(self, expr: ShflSExpr) -> Any:
390
- """Evaluate static shuffle expression."""
391
- value = self._value(expr.src_val)
392
- return self._eval_shfl_s_node(expr, value)
393
-
394
- def visit_shfl(self, expr: ShflExpr) -> Any:
395
- """Evaluate dynamic shuffle expression."""
396
- data = self._value(expr.src)
397
- index = self._value(expr.index)
398
- return self._eval_shfl_node(expr, data, index)
399
-
400
- def visit_access(self, expr: AccessExpr) -> Any:
401
- """Evaluate access expression."""
402
- # Evaluate the expression and access the specified index
403
- result = self._values(expr.src)
404
-
405
- if expr.index < 0 or expr.index >= len(result):
406
- raise IndexError(
407
- f"Index {expr.index} out of range for list of length {len(result)}"
408
- )
409
- return [result[expr.index]] # Ensure we return a list
410
-
411
- def visit_func_def(self, expr: FuncDefExpr) -> Any:
412
- raise RuntimeError("FuncDefExpr should not be directly evaluated")
413
-
414
- # IEvaluator API: return list of values
415
- def evaluate(self, root: Expr, env: dict[str, Any] | None = None) -> list[Any]:
416
- if env is None:
417
- res = root.accept(self)
418
- else:
419
- # Spawn a sibling evaluator with override env but same runtime.
420
- res = root.accept(
421
- RecursiveEvaluator(self.rank, env, self.comm, self.runtime)
422
- )
423
- if not isinstance(res, list):
424
- raise ValueError(f"got {type(res)} for expression {root}")
425
- return res
426
-
427
-
428
- class IterativeEvaluator(EvalSemantic):
429
- """Iterative (non-recursive) evaluator using dataflow traversal."""
430
-
431
- def __init__(
432
- self,
433
- rank: int,
434
- env: dict[str, Any],
435
- comm: ICommunicator,
436
- runtime: RuntimeContext,
437
- ) -> None:
438
- super().__init__(rank, env, comm, runtime)
439
-
440
- @staticmethod
441
- def _first(vals: list[Any]) -> Any:
442
- if not isinstance(vals, list):
443
- return vals
444
- if len(vals) == 0:
445
- return None
446
- return vals[0]
447
-
448
- def _merge_state(self, old: list[Any], new: list[Any]) -> list[Any]:
449
- assert len(new) <= len(old)
450
- return new + old[len(new) :]
451
-
452
- def _iter_eval_graph(self, root: Expr, env: dict[str, Any]) -> list[Any]:
453
- symbols: dict[int, list[Any]] = {}
454
- for node in walk_dataflow(root, traversal="dfs_post_iter"):
455
- if isinstance(node, VariableExpr):
456
- if node.name not in env:
457
- raise ValueError(
458
- f"Variable '{node.name}' not found in evaluator environment"
459
- )
460
- symbols[id(node)] = [env[node.name]]
461
- elif isinstance(node, TupleExpr):
462
- vals = [self._first(symbols[id(a)]) for a in node.args]
463
- symbols[id(node)] = vals
464
- elif isinstance(node, AccessExpr):
465
- src_vals = symbols[id(node.src)]
466
- symbols[id(node)] = [src_vals[node.index]]
467
- elif isinstance(node, CallExpr):
468
- arg_vals = [self._first(symbols[id(a)]) for a in node.args]
469
- assert isinstance(node.fn, FuncDefExpr)
470
- sub_env = dict(zip(node.fn.params, arg_vals, strict=True))
471
- res = self._iter_eval_graph(node.fn.body, {**env, **sub_env})
472
- symbols[id(node)] = res
473
- elif isinstance(node, CondExpr):
474
- pred_val = self._first(symbols[id(node.pred)])
475
- arg_vals = [self._first(symbols[id(a)]) for a in node.args]
476
- if pred_val is None:
477
- symbols[id(node)] = [None] * len(node.mptypes)
478
- else:
479
- # Optional uniform verification identical to recursive evaluator (DRY helper).
480
- if node.verify_uniform:
481
- self._verify_uniform_predicate(pred_val)
482
- # Convert to bool using Value.to_bool() if available
483
- if isinstance(pred_val, Value):
484
- pred = pred_val.to_bool()
485
- else:
486
- pred = bool(self._unwrap_value(pred_val))
487
- if pred:
488
- sub_env = dict(zip(node.then_fn.params, arg_vals, strict=True))
489
- res = self._iter_eval_graph(
490
- node.then_fn.body, {**env, **sub_env}
491
- )
492
- symbols[id(node)] = res
493
- else:
494
- sub_env = dict(zip(node.else_fn.params, arg_vals, strict=True))
495
- res = self._iter_eval_graph(
496
- node.else_fn.body, {**env, **sub_env}
497
- )
498
- symbols[id(node)] = res
499
- elif isinstance(node, WhileExpr):
500
- state = [self._first(symbols[id(a)]) for a in node.args]
501
- while True:
502
- cond_env = dict(zip(node.cond_fn.params, state, strict=True))
503
- cond_vals = self._iter_eval_graph(
504
- node.cond_fn.body, {**env, **cond_env}
505
- )
506
- cond_val = self._check_while_predicate(cond_vals)
507
- if not bool(cond_val):
508
- break
509
- body_env = dict(zip(node.body_fn.params, state, strict=True))
510
- new_state = self._iter_eval_graph(
511
- node.body_fn.body, {**env, **body_env}
512
- )
513
- state = self._merge_state(state, new_state)
514
- symbols[id(node)] = state[0 : len(node.body_fn.mptypes)]
515
- elif isinstance(node, EvalExpr):
516
- arg_vals = [self._first(symbols[id(a)]) for a in node.args]
517
- symbols[id(node)] = self._eval_eval_node(node, arg_vals)
518
- elif isinstance(node, ConvExpr):
519
- vars_vals = [self._first(symbols[id(v)]) for v in node.vars]
520
- symbols[id(node)] = self._eval_conv_node(vars_vals)
521
- elif isinstance(node, ShflSExpr):
522
- value = self._first(symbols[id(node.src_val)])
523
- symbols[id(node)] = self._eval_shfl_s_node(node, value)
524
- elif isinstance(node, ShflExpr):
525
- data = self._first(symbols[id(node.src)])
526
- index = self._first(symbols[id(node.index)])
527
- symbols[id(node)] = self._eval_shfl_node(node, data, index)
528
- elif isinstance(node, FuncDefExpr):
529
- # Definition nodes are not evaluated; placeholder to satisfy walkers
530
- symbols[id(node)] = node.body.mptypes
531
- else:
532
- raise NotImplementedError(
533
- f"Unsupported expr in iterative eval: {type(node)}"
534
- )
535
- res = symbols[id(root)]
536
- if not isinstance(res, list):
537
- raise ValueError(f"got {type(res)} for expression {root}")
538
- return res
539
-
540
- def evaluate(self, root: Expr, env: dict[str, Any] | None = None) -> list[Any]:
541
- """Evaluate an expression graph iteratively (no Python recursion).
542
-
543
- - Traverses dataflow using iterative DFS-postorder to compute ready nodes.
544
- - For control flow/functional regions (Call/Cond/While), performs a
545
- localized iterative evaluation of the region body with a child environment.
546
-
547
- Args:
548
- root: The root expression to evaluate.
549
- env: Optional environment override for VariableExpr lookups.
550
-
551
- Returns:
552
- A list of computed output values for the root expression.
553
- """
554
- cur_env = self.env if env is None else env
555
- return self._iter_eval_graph(root, cur_env)
556
-
557
-
558
- def create_evaluator(
559
- rank: int,
560
- env: dict[str, Any],
561
- comm: ICommunicator,
562
- runtime: RuntimeContext,
563
- kind: str | None = "iterative",
564
- ) -> IEvaluator:
565
- """Factory to create an evaluator engine.
566
-
567
- Args:
568
- rank: Party rank.
569
- env: Initial variable environment.
570
- comm: Communicator for this party.
571
- kind: Evaluator implementation ("iterative" or "recursive").
572
-
573
- Returns:
574
- An IEvaluator instance of the requested kind.
575
- """
576
- # Backward compatibility: treat kind=None as default iterative implementation.
577
- if kind is None or kind == "iterative":
578
- return IterativeEvaluator(rank, env, comm, runtime)
579
- if kind == "recursive":
580
- return RecursiveEvaluator(rank, env, comm, runtime)
581
- raise ValueError(f"Unknown evaluator kind: {kind}")