mplang-nightly 0.1.dev268__py3-none-any.whl → 0.1.dev270__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (181) hide show
  1. mplang/__init__.py +391 -17
  2. mplang/{v2/backends → backends}/__init__.py +9 -7
  3. mplang/{v2/backends → backends}/bfv_impl.py +6 -6
  4. mplang/{v2/backends → backends}/crypto_impl.py +6 -6
  5. mplang/{v2/backends → backends}/field_impl.py +5 -5
  6. mplang/{v2/backends → backends}/func_impl.py +4 -4
  7. mplang/{v2/backends → backends}/phe_impl.py +3 -3
  8. mplang/{v2/backends → backends}/simp_design.md +1 -1
  9. mplang/{v2/backends → backends}/simp_driver/__init__.py +5 -5
  10. mplang/{v2/backends → backends}/simp_driver/http.py +8 -8
  11. mplang/{v2/backends → backends}/simp_driver/mem.py +9 -9
  12. mplang/{v2/backends → backends}/simp_driver/ops.py +4 -4
  13. mplang/{v2/backends → backends}/simp_driver/state.py +2 -2
  14. mplang/{v2/backends → backends}/simp_driver/values.py +2 -2
  15. mplang/{v2/backends → backends}/simp_worker/__init__.py +3 -3
  16. mplang/{v2/backends → backends}/simp_worker/http.py +10 -10
  17. mplang/{v2/backends → backends}/simp_worker/mem.py +1 -1
  18. mplang/{v2/backends → backends}/simp_worker/ops.py +5 -5
  19. mplang/{v2/backends → backends}/simp_worker/state.py +2 -4
  20. mplang/{v2/backends → backends}/spu_impl.py +8 -8
  21. mplang/{v2/backends → backends}/spu_state.py +4 -4
  22. mplang/{v2/backends → backends}/store_impl.py +3 -3
  23. mplang/{v2/backends → backends}/table_impl.py +8 -8
  24. mplang/{v2/backends → backends}/tee_impl.py +6 -6
  25. mplang/{v2/backends → backends}/tensor_impl.py +6 -6
  26. mplang/{v2/cli.py → cli.py} +9 -9
  27. mplang/{v2/cli_guide.md → cli_guide.md} +12 -12
  28. mplang/{v2/dialects → dialects}/__init__.py +5 -5
  29. mplang/{v2/dialects → dialects}/bfv.py +6 -6
  30. mplang/{v2/dialects → dialects}/crypto.py +5 -5
  31. mplang/{v2/dialects → dialects}/dtypes.py +2 -2
  32. mplang/{v2/dialects → dialects}/field.py +3 -3
  33. mplang/{v2/dialects → dialects}/func.py +2 -2
  34. mplang/{v2/dialects → dialects}/phe.py +6 -6
  35. mplang/{v2/dialects → dialects}/simp.py +6 -6
  36. mplang/{v2/dialects → dialects}/spu.py +7 -7
  37. mplang/{v2/dialects → dialects}/store.py +2 -2
  38. mplang/{v2/dialects → dialects}/table.py +3 -3
  39. mplang/{v2/dialects → dialects}/tee.py +6 -6
  40. mplang/{v2/dialects → dialects}/tensor.py +5 -5
  41. mplang/{v2/edsl → edsl}/__init__.py +3 -3
  42. mplang/{v2/edsl → edsl}/context.py +6 -6
  43. mplang/{v2/edsl → edsl}/graph.py +5 -5
  44. mplang/{v2/edsl → edsl}/jit.py +2 -2
  45. mplang/{v2/edsl → edsl}/object.py +1 -1
  46. mplang/{v2/edsl → edsl}/primitive.py +5 -5
  47. mplang/{v2/edsl → edsl}/printer.py +1 -1
  48. mplang/{v2/edsl → edsl}/serde.py +1 -1
  49. mplang/{v2/edsl → edsl}/tracer.py +7 -7
  50. mplang/{v2/edsl → edsl}/typing.py +1 -1
  51. mplang/{v2/kernels → kernels}/ldpc.cpp +13 -13
  52. mplang/{v2/kernels → kernels}/okvs.cpp +4 -4
  53. mplang/{v2/kernels → kernels}/okvs_opt.cpp +46 -31
  54. mplang/{v2/kernels → kernels}/py_kernels.py +1 -1
  55. mplang/{v2/libs → libs}/collective.py +5 -5
  56. mplang/{v2/libs → libs}/device/__init__.py +1 -1
  57. mplang/{v2/libs → libs}/device/api.py +12 -12
  58. mplang/{v2/libs → libs}/ml/__init__.py +1 -1
  59. mplang/{v2/libs → libs}/ml/sgb.py +4 -4
  60. mplang/{v2/libs → libs}/mpc/__init__.py +3 -3
  61. mplang/{v2/libs → libs}/mpc/_utils.py +2 -2
  62. mplang/{v2/libs → libs}/mpc/analytics/aggregation.py +1 -1
  63. mplang/{v2/libs → libs}/mpc/analytics/groupby.py +2 -2
  64. mplang/{v2/libs → libs}/mpc/analytics/permutation.py +3 -3
  65. mplang/{v2/libs → libs}/mpc/ot/base.py +3 -3
  66. mplang/{v2/libs → libs}/mpc/ot/extension.py +2 -2
  67. mplang/{v2/libs → libs}/mpc/ot/silent.py +4 -4
  68. mplang/{v2/libs → libs}/mpc/psi/cuckoo.py +3 -3
  69. mplang/{v2/libs → libs}/mpc/psi/okvs.py +1 -1
  70. mplang/{v2/libs → libs}/mpc/psi/okvs_gct.py +19 -13
  71. mplang/{v2/libs → libs}/mpc/psi/oprf.py +3 -3
  72. mplang/libs/mpc/psi/rr22.py +303 -0
  73. mplang/{v2/libs → libs}/mpc/psi/unbalanced.py +4 -4
  74. mplang/{v2/libs → libs}/mpc/vole/gilboa.py +3 -3
  75. mplang/{v2/libs → libs}/mpc/vole/ldpc.py +2 -2
  76. mplang/{v2/libs → libs}/mpc/vole/silver.py +6 -6
  77. mplang/{v2/runtime → runtime}/interpreter.py +11 -11
  78. mplang/{v2/runtime → runtime}/value.py +2 -2
  79. mplang/{v1/runtime → utils}/__init__.py +18 -15
  80. mplang/{v1/utils → utils}/func_utils.py +1 -1
  81. {mplang_nightly-0.1.dev268.dist-info → mplang_nightly-0.1.dev270.dist-info}/METADATA +2 -2
  82. mplang_nightly-0.1.dev270.dist-info/RECORD +102 -0
  83. mplang/v1/__init__.py +0 -157
  84. mplang/v1/_device.py +0 -602
  85. mplang/v1/analysis/__init__.py +0 -37
  86. mplang/v1/analysis/diagram.py +0 -567
  87. mplang/v1/core/__init__.py +0 -157
  88. mplang/v1/core/cluster.py +0 -343
  89. mplang/v1/core/comm.py +0 -281
  90. mplang/v1/core/context_mgr.py +0 -50
  91. mplang/v1/core/dtypes.py +0 -335
  92. mplang/v1/core/expr/__init__.py +0 -80
  93. mplang/v1/core/expr/ast.py +0 -542
  94. mplang/v1/core/expr/evaluator.py +0 -581
  95. mplang/v1/core/expr/printer.py +0 -285
  96. mplang/v1/core/expr/transformer.py +0 -141
  97. mplang/v1/core/expr/utils.py +0 -78
  98. mplang/v1/core/expr/visitor.py +0 -85
  99. mplang/v1/core/expr/walk.py +0 -387
  100. mplang/v1/core/interp.py +0 -160
  101. mplang/v1/core/mask.py +0 -325
  102. mplang/v1/core/mpir.py +0 -965
  103. mplang/v1/core/mpobject.py +0 -117
  104. mplang/v1/core/mptype.py +0 -407
  105. mplang/v1/core/pfunc.py +0 -130
  106. mplang/v1/core/primitive.py +0 -877
  107. mplang/v1/core/table.py +0 -218
  108. mplang/v1/core/tensor.py +0 -75
  109. mplang/v1/core/tracer.py +0 -383
  110. mplang/v1/host.py +0 -130
  111. mplang/v1/kernels/__init__.py +0 -41
  112. mplang/v1/kernels/base.py +0 -125
  113. mplang/v1/kernels/basic.py +0 -240
  114. mplang/v1/kernels/context.py +0 -369
  115. mplang/v1/kernels/crypto.py +0 -122
  116. mplang/v1/kernels/fhe.py +0 -858
  117. mplang/v1/kernels/mock_tee.py +0 -72
  118. mplang/v1/kernels/phe.py +0 -1864
  119. mplang/v1/kernels/spu.py +0 -341
  120. mplang/v1/kernels/sql_duckdb.py +0 -44
  121. mplang/v1/kernels/stablehlo.py +0 -90
  122. mplang/v1/kernels/value.py +0 -626
  123. mplang/v1/ops/__init__.py +0 -35
  124. mplang/v1/ops/base.py +0 -424
  125. mplang/v1/ops/basic.py +0 -294
  126. mplang/v1/ops/crypto.py +0 -262
  127. mplang/v1/ops/fhe.py +0 -272
  128. mplang/v1/ops/jax_cc.py +0 -147
  129. mplang/v1/ops/nnx_cc.py +0 -168
  130. mplang/v1/ops/phe.py +0 -216
  131. mplang/v1/ops/spu.py +0 -151
  132. mplang/v1/ops/sql_cc.py +0 -303
  133. mplang/v1/ops/tee.py +0 -36
  134. mplang/v1/protos/v1alpha1/mpir_pb2.py +0 -63
  135. mplang/v1/protos/v1alpha1/mpir_pb2.pyi +0 -557
  136. mplang/v1/protos/v1alpha1/value_pb2.py +0 -34
  137. mplang/v1/protos/v1alpha1/value_pb2.pyi +0 -169
  138. mplang/v1/runtime/channel.py +0 -230
  139. mplang/v1/runtime/cli.py +0 -451
  140. mplang/v1/runtime/client.py +0 -456
  141. mplang/v1/runtime/communicator.py +0 -131
  142. mplang/v1/runtime/data_providers.py +0 -303
  143. mplang/v1/runtime/driver.py +0 -324
  144. mplang/v1/runtime/exceptions.py +0 -27
  145. mplang/v1/runtime/http_api.md +0 -56
  146. mplang/v1/runtime/link_comm.py +0 -196
  147. mplang/v1/runtime/server.py +0 -501
  148. mplang/v1/runtime/session.py +0 -270
  149. mplang/v1/runtime/simulation.py +0 -324
  150. mplang/v1/simp/__init__.py +0 -13
  151. mplang/v1/simp/api.py +0 -353
  152. mplang/v1/simp/mpi.py +0 -131
  153. mplang/v1/simp/party.py +0 -225
  154. mplang/v1/simp/random.py +0 -120
  155. mplang/v1/simp/smpc.py +0 -238
  156. mplang/v1/utils/__init__.py +0 -13
  157. mplang/v1/utils/crypto.py +0 -32
  158. mplang/v1/utils/spu_utils.py +0 -130
  159. mplang/v1/utils/table_utils.py +0 -185
  160. mplang/v2/__init__.py +0 -424
  161. mplang/v2/libs/mpc/psi/rr22.py +0 -344
  162. mplang_nightly-0.1.dev268.dist-info/RECORD +0 -180
  163. /mplang/{v2/backends → backends}/channel.py +0 -0
  164. /mplang/{v2/edsl → edsl}/README.md +0 -0
  165. /mplang/{v2/edsl → edsl}/registry.py +0 -0
  166. /mplang/{v2/kernels → kernels}/Makefile +0 -0
  167. /mplang/{v2/kernels → kernels}/__init__.py +0 -0
  168. /mplang/{v2/kernels → kernels}/gf128.cpp +0 -0
  169. /mplang/{v2/libs → libs}/device/cluster.py +0 -0
  170. /mplang/{v2/libs → libs}/mpc/analytics/__init__.py +0 -0
  171. /mplang/{v2/libs → libs}/mpc/analytics/groupby.md +0 -0
  172. /mplang/{v2/libs → libs}/mpc/common/constants.py +0 -0
  173. /mplang/{v2/libs → libs}/mpc/ot/__init__.py +0 -0
  174. /mplang/{v2/libs → libs}/mpc/psi/__init__.py +0 -0
  175. /mplang/{v2/libs → libs}/mpc/vole/__init__.py +0 -0
  176. /mplang/{v2/runtime → runtime}/__init__.py +0 -0
  177. /mplang/{v2/runtime → runtime}/dialect_state.py +0 -0
  178. /mplang/{v2/runtime → runtime}/object_store.py +0 -0
  179. {mplang_nightly-0.1.dev268.dist-info → mplang_nightly-0.1.dev270.dist-info}/WHEEL +0 -0
  180. {mplang_nightly-0.1.dev268.dist-info → mplang_nightly-0.1.dev270.dist-info}/entry_points.txt +0 -0
  181. {mplang_nightly-0.1.dev268.dist-info → mplang_nightly-0.1.dev270.dist-info}/licenses/LICENSE +0 -0
@@ -1,387 +0,0 @@
1
- # Copyright 2025 Ant Group Co., Ltd.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
-
15
- """
16
- Pure functional traversal utilities for MPLang expression graphs.
17
-
18
- This module provides semantic-agnostic walkers over expression graphs. It exposes
19
- both dataflow-only traversal (deps edges) and structural traversal (deps +
20
- contained regions like function bodies, then/else, loop body). All traversals are
21
- implemented iteratively to avoid Python recursion limits.
22
-
23
- Notes
24
- - These walkers never evaluate expressions nor decide runtime branches. For
25
- execution order, use an evaluator/driver that consults semantic rules.
26
- - Topological order is produced w.r.t. deps edges, not region containment.
27
- """
28
-
29
- from __future__ import annotations
30
-
31
- from collections import deque
32
- from collections.abc import Callable, Iterable, Iterator, Sequence
33
- from typing import cast
34
-
35
- from mplang.v1.core.expr.ast import (
36
- AccessExpr,
37
- CallExpr,
38
- CondExpr,
39
- ConvExpr,
40
- EvalExpr,
41
- Expr,
42
- FuncDefExpr,
43
- ShflExpr,
44
- ShflSExpr,
45
- TupleExpr,
46
- VariableExpr,
47
- WhileExpr,
48
- )
49
-
50
- Node = Expr
51
- GetDeps = Callable[[Node], Iterable[Node]]
52
- YieldCond = Callable[[Node], bool]
53
-
54
-
55
- def _identity_key(n: Node) -> int:
56
- """Identity-based key for hashing nodes that may not be hashable."""
57
- return id(n)
58
-
59
-
60
- # ---------------------------- default dependency getters ----------------------------
61
-
62
-
63
- def dataflow_deps(node: Node) -> Iterable[Node]:
64
- """Default dataflow dependencies for core Expr nodes.
65
-
66
- This includes only inputs that must be computed to evaluate the node itself.
67
- Contained regions (function/branch/loop bodies) are NOT traversed here.
68
- """
69
- if isinstance(node, EvalExpr):
70
- return node.args
71
- if isinstance(node, TupleExpr):
72
- return node.args
73
- if isinstance(node, CondExpr):
74
- # Pure dataflow: pred and actual args to branch functions
75
- return [node.pred, *node.args]
76
- if isinstance(node, WhileExpr):
77
- # Initial state args required; bodies are regions, not dataflow deps
78
- return list(node.args)
79
- if isinstance(node, ConvExpr):
80
- return node.vars
81
- if isinstance(node, ShflSExpr):
82
- return [node.src_val]
83
- if isinstance(node, ShflExpr):
84
- return [node.src, node.index]
85
- if isinstance(node, AccessExpr):
86
- return [node.src]
87
- if isinstance(node, VariableExpr):
88
- return []
89
- if isinstance(node, FuncDefExpr):
90
- # Definition is not a value-producing node in dataflow; no deps
91
- return []
92
- if isinstance(node, CallExpr):
93
- # Arguments are dataflow deps; function body is a region
94
- return node.args
95
- # Fallback: try best-effort empty
96
- return []
97
-
98
-
99
- def _structural_region_roots(node: Node) -> Iterable[Node]:
100
- """Roots of contained regions for structural traversal.
101
-
102
- - Cond: then_fn.body, else_fn.body
103
- - While: cond_fn.body, body_fn.body
104
- - Call: fn.body
105
- - FuncDef: body
106
- """
107
- if isinstance(node, CondExpr):
108
- return [node.then_fn.body, node.else_fn.body]
109
- if isinstance(node, WhileExpr):
110
- return [node.cond_fn.body, node.body_fn.body]
111
- if isinstance(node, CallExpr):
112
- return [node.fn.body]
113
- if isinstance(node, FuncDefExpr):
114
- return [node.body]
115
- return []
116
-
117
-
118
- # ---------------------------------- core walkers ----------------------------------
119
-
120
-
121
- def walk(
122
- roots: Node | Sequence[Node],
123
- *,
124
- get_deps: GetDeps,
125
- traversal: str = "dfs_post_iter",
126
- yield_condition: YieldCond | None = None,
127
- detect_cycles: bool = True,
128
- ) -> Iterator[Node]:
129
- """Generic pure structural walker.
130
-
131
- Args:
132
- roots: Single root or a sequence of roots to start from.
133
- get_deps: Function mapping node -> iterable of dependency nodes.
134
- traversal: One of {'dfs_pre_iter','dfs_post_iter','bfs','topo'}.
135
- yield_condition: Optional predicate to filter yielded nodes.
136
- detect_cycles: If True, raises ValueError on cycles (for DFS/topo).
137
-
138
- Yields:
139
- Nodes in the chosen traversal order, filtered by yield_condition.
140
- """
141
- start: list[Node]
142
- if isinstance(roots, (list, tuple)):
143
- start = list(roots)
144
- else:
145
- start = [cast(Node, roots)]
146
-
147
- if traversal == "bfs":
148
- yield from _bfs(start, get_deps, yield_condition)
149
- return
150
- if traversal == "dfs_pre_iter":
151
- yield from _dfs_pre_iter(start, get_deps, yield_condition, detect_cycles)
152
- return
153
- if traversal == "dfs_post_iter":
154
- yield from _dfs_post_iter(start, get_deps, yield_condition, detect_cycles)
155
- return
156
- if traversal == "topo":
157
- yield from _topo_kahn(start, get_deps, yield_condition, detect_cycles)
158
- return
159
-
160
- raise ValueError(f"Invalid traversal type: {traversal}")
161
-
162
-
163
- def walk_dataflow(
164
- roots: Node | Sequence[Node],
165
- *,
166
- traversal: str = "dfs_post_iter",
167
- yield_condition: YieldCond | None = None,
168
- detect_cycles: bool = True,
169
- ) -> Iterator[Node]:
170
- """Walk using default dataflow dependencies for Expr nodes."""
171
- return walk(
172
- roots,
173
- get_deps=dataflow_deps,
174
- traversal=traversal,
175
- yield_condition=yield_condition,
176
- detect_cycles=detect_cycles,
177
- )
178
-
179
-
180
- def walk_structural(
181
- roots: Node | Sequence[Node],
182
- *,
183
- traversal: str = "dfs_post_iter",
184
- yield_condition: YieldCond | None = None,
185
- detect_cycles: bool = True,
186
- ) -> Iterator[Node]:
187
- """Walk including region containment (function bodies, branches, loop bodies).
188
-
189
- This augments dataflow dependencies with region roots once, so structure is
190
- fully traversed without runtime branch choices or loop iteration expansion.
191
- """
192
-
193
- def deps_plus_regions(n: Node) -> Iterable[Node]:
194
- yield from dataflow_deps(n)
195
- yield from _structural_region_roots(n)
196
-
197
- return walk(
198
- roots,
199
- get_deps=deps_plus_regions,
200
- traversal=traversal,
201
- yield_condition=yield_condition,
202
- detect_cycles=detect_cycles,
203
- )
204
-
205
-
206
- # -------------------------------- traversal engines --------------------------------
207
-
208
-
209
- def _maybe_yield(n: Node, pred: YieldCond | None) -> Iterator[Node]:
210
- if pred is None or pred(n):
211
- yield n
212
-
213
-
214
- def _bfs(
215
- roots: Sequence[Node],
216
- get_deps: GetDeps,
217
- yield_condition: YieldCond | None,
218
- ) -> Iterator[Node]:
219
- seen: set[int] = set()
220
- q: deque[Node] = deque(roots)
221
- while q:
222
- n = q.popleft()
223
- k = _identity_key(n)
224
- if k in seen:
225
- continue
226
- seen.add(k)
227
- yield from _maybe_yield(n, yield_condition)
228
- for d in get_deps(n):
229
- q.append(d)
230
-
231
-
232
- def _dfs_pre_iter(
233
- roots: Sequence[Node],
234
- get_deps: GetDeps,
235
- yield_condition: YieldCond | None,
236
- detect_cycles: bool,
237
- ) -> Iterator[Node]:
238
- seen: set[int] = set()
239
- onstack: set[int] = set()
240
- stack: list[tuple[Node, int]] = [] # (node, next_child_index)
241
-
242
- for root in roots:
243
- kroot = _identity_key(root)
244
- if kroot in seen:
245
- continue
246
- stack.append((root, 0))
247
- onstack.add(kroot)
248
-
249
- while stack:
250
- node, idx = stack[-1]
251
- k = _identity_key(node)
252
- if k not in seen:
253
- # Pre-order yield on first encounter
254
- seen.add(k)
255
- yield from _maybe_yield(node, yield_condition)
256
-
257
- deps = list(get_deps(node))
258
- if idx < len(deps):
259
- child = deps[idx]
260
- stack[-1] = (node, idx + 1)
261
- kc = _identity_key(child)
262
- if kc in onstack:
263
- if detect_cycles:
264
- raise ValueError("Cycle detected during dfs_pre_iter walk")
265
- # skip on cycles if not detecting
266
- elif kc not in seen:
267
- stack.append((child, 0))
268
- onstack.add(kc)
269
- else:
270
- # already processed child
271
- pass
272
- else:
273
- stack.pop()
274
- onstack.discard(k)
275
-
276
-
277
- def _dfs_post_iter(
278
- roots: Sequence[Node],
279
- get_deps: GetDeps,
280
- yield_condition: YieldCond | None,
281
- detect_cycles: bool,
282
- ) -> Iterator[Node]:
283
- seen: set[int] = set()
284
- onstack: set[int] = set()
285
- done: set[int] = set()
286
- stack: list[tuple[Node, int]] = [] # (node, next_child_index)
287
-
288
- for root in roots:
289
- kroot = _identity_key(root)
290
- if kroot in done:
291
- continue
292
- stack.append((root, 0))
293
- onstack.add(kroot)
294
-
295
- while stack:
296
- node, idx = stack[-1]
297
- k = _identity_key(node)
298
- deps = list(get_deps(node))
299
- if k not in seen:
300
- seen.add(k)
301
-
302
- if idx < len(deps):
303
- child = deps[idx]
304
- stack[-1] = (node, idx + 1)
305
- kc = _identity_key(child)
306
- if kc in done:
307
- continue
308
- if kc in onstack:
309
- if detect_cycles:
310
- raise ValueError("Cycle detected during dfs_post_iter walk")
311
- # else ignore to avoid infinite loop
312
- continue
313
- stack.append((child, 0))
314
- onstack.add(kc)
315
- else:
316
- # all children processed
317
- stack.pop()
318
- onstack.discard(k)
319
- if k not in done:
320
- done.add(k)
321
- yield from _maybe_yield(node, yield_condition)
322
-
323
-
324
- def _collect_closure(roots: Sequence[Node], get_deps: GetDeps) -> list[Node]:
325
- """Collect reachable nodes and adjacency (by identity) for topo sort.
326
-
327
- Returns (nodes_list, parents_map) where parents_map[v] is the set of keys of
328
- dependency nodes for v (edges dep -> v).
329
- """
330
- nodes: list[Node] = []
331
- seen: set[int] = set()
332
-
333
- q: deque[Node] = deque(roots)
334
- while q:
335
- n = q.popleft()
336
- kn = _identity_key(n)
337
- if kn in seen:
338
- continue
339
- seen.add(kn)
340
- nodes.append(n)
341
- for d in get_deps(n):
342
- q.append(d)
343
- return nodes
344
-
345
-
346
- def _topo_kahn(
347
- roots: Sequence[Node],
348
- get_deps: GetDeps,
349
- yield_condition: YieldCond | None,
350
- detect_cycles: bool,
351
- ) -> Iterator[Node]:
352
- # Build closure and in-degree from deps edges (dep -> node)
353
- nodes = _collect_closure(roots, get_deps)
354
-
355
- # reverse map: parent -> set(children)
356
- children: dict[int, set[int]] = {}
357
- indeg: dict[int, int] = {}
358
- for n in nodes:
359
- kn = _identity_key(n)
360
- indeg.setdefault(kn, 0)
361
- for d in get_deps(n):
362
- kd = _identity_key(d)
363
- children.setdefault(kd, set()).add(kn)
364
- indeg[kn] = indeg.get(kn, 0) + 1
365
- indeg.setdefault(kd, indeg.get(kd, 0))
366
-
367
- # queue of zero in-degree nodes (ready after all deps)
368
- q: deque[int] = deque(k for k, v in indeg.items() if v == 0)
369
- key2node: dict[int, Node] = {_identity_key(n): n for n in nodes}
370
- produced = 0
371
- seen_keys: set[int] = set()
372
-
373
- while q:
374
- k = q.popleft()
375
- if k in seen_keys:
376
- continue
377
- seen_keys.add(k)
378
- n = key2node[k]
379
- produced += 1
380
- yield from _maybe_yield(n, yield_condition)
381
- for c in children.get(k, set()):
382
- indeg[c] -= 1
383
- if indeg[c] == 0:
384
- q.append(c)
385
-
386
- if detect_cycles and produced < len(indeg):
387
- raise ValueError("Cycle detected during topo walk")
mplang/v1/core/interp.py DELETED
@@ -1,160 +0,0 @@
1
- # Copyright 2025 Ant Group Co., Ltd.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
-
15
- """
16
- Interpreter context and InterpVar implementation.
17
-
18
- This module provides the interpreter context for eager evaluation and InterpVar
19
- which references computed values in an interpreter.
20
- """
21
-
22
- from __future__ import annotations
23
-
24
- from abc import abstractmethod
25
- from collections.abc import Sequence
26
- from typing import Any, cast
27
-
28
- from mplang.v1.core.cluster import ClusterSpec
29
- from mplang.v1.core.expr.ast import Expr, VariableExpr
30
- from mplang.v1.core.mpobject import MPContext, MPObject
31
- from mplang.v1.core.mptype import MPType, TensorLike
32
- from mplang.v1.core.tracer import TracedFunction
33
- from mplang.v1.utils.func_utils import var_demorph, var_morph
34
-
35
-
36
- # TODO(jint): Should we use inheritance or composition here?
37
- class InterpContext(MPContext):
38
- """Context for eager evaluation using an interpreter.
39
-
40
- InterpContext executes computations immediately and stores results
41
- in an underlying interpreter.
42
- """
43
-
44
- def __init__(
45
- self,
46
- cluster_spec: ClusterSpec,
47
- ):
48
- super().__init__(cluster_spec)
49
-
50
- @abstractmethod
51
- def evaluate(self, expr: Expr, bindings: dict[str, MPObject]) -> Sequence[MPObject]:
52
- """Evaluate an expression in this context.
53
-
54
- Args:
55
- expr: The expression to evaluate.
56
- bindings: A dictionary of variable bindings.
57
-
58
- Returns:
59
- The result of the evaluation as an MPObject.
60
- """
61
- raise NotImplementedError("Should be overridden in subclasses.")
62
-
63
- @abstractmethod
64
- def fetch(self, obj: MPObject) -> list[TensorLike]:
65
- """Fetch the value of an MPObject from this InterpContext to the current Python interpreter.
66
-
67
- The MPObject must have been created by this InterpContext. If the object
68
- was not produced by this context, a ValueError will be raised.
69
-
70
- Args:
71
- obj: The MPObject to fetch. Must be produced by this InterpContext.
72
-
73
- Returns:
74
- A list of tensor-like values with length equal to psize(). For each party i,
75
- if the i-th bit of obj.pmask is 0 (indicating party i does not hold this value),
76
- the i-th element in the returned list will be None. Otherwise, it contains
77
- the actual tensor value held by party i.
78
-
79
- Raises:
80
- ValueError: If obj was not produced by this InterpContext.
81
- """
82
- raise NotImplementedError("Should be overridden in subclasses.")
83
-
84
-
85
- class InterpVar(MPObject):
86
- """A variable that references a value in an interpreter.
87
-
88
- InterpVar represents a value that has been computed and exists
89
- in the interpreter's variable store.
90
- """
91
-
92
- def __init__(self, ctx: InterpContext, mptype: MPType):
93
- self._ctx = ctx
94
- self._mptype = mptype
95
-
96
- @property
97
- def ctx(self) -> MPContext:
98
- """The context this variable belongs to."""
99
- return self._ctx
100
-
101
- @property
102
- def mptype(self) -> MPType:
103
- """The type of this variable."""
104
- # TODO: fetch type from the Interpreter and cache it.
105
- return self._mptype
106
-
107
- def __repr__(self) -> str:
108
- return f"InterpVar(mptype={self.mptype})"
109
-
110
-
111
- def apply(ctx: InterpContext, fn: TracedFunction, *args: Any, **kwargs: Any) -> Any:
112
- is_mpobj = lambda x: isinstance(x, MPObject)
113
- in_args, in_imms, in_struct = var_morph((args, kwargs), is_mpobj)
114
-
115
- # All variables must be in the same context as the function.
116
- if not all(isinstance(var, InterpVar) and var.ctx is ctx for var in in_args):
117
- raise ValueError("All input variables must be InterpVars in the same context.")
118
-
119
- # Check if the function signature matches the input types.
120
- if fn.in_struct != in_struct:
121
- raise ValueError(f"Input structure mismatch: {fn.in_struct} != {in_struct}")
122
- if fn.in_imms != in_imms:
123
- # Should trigger re-trace in JAX
124
- raise ValueError(f"Input immutables mismatch: {fn.in_imms} != {in_imms}")
125
- if len(fn.in_vars) != len(in_args):
126
- raise ValueError(f"Input types mismatch: {fn.in_vars} != {in_args}")
127
- # check parameter type match
128
- for param, arg in zip(fn.in_vars, in_args, strict=False):
129
- if param.mptype != arg.mptype:
130
- raise ValueError(
131
- f"Input variable type mismatch: {param.mptype} != {arg.mptype}"
132
- )
133
-
134
- # Prepare for the captured variables, which should also be in the same context.
135
- for captured, _traced in fn.capture_map.items():
136
- if not isinstance(captured, InterpVar) or captured.ctx is not ctx:
137
- raise ValueError(
138
- f"Capture {captured} must be in this({ctx}) context, got {captured.ctx}."
139
- )
140
-
141
- arg_binding: dict[str, MPObject] = {
142
- cast(VariableExpr, var.expr).name: obj
143
- for var, obj in zip(fn.in_vars, in_args, strict=False)
144
- }
145
- capture_binding = {
146
- cast(VariableExpr, var.expr).name: captured
147
- for captured, var in fn.capture_map.items()
148
- }
149
-
150
- if len(fn.out_vars) == 0:
151
- out_vars: list[MPObject] = []
152
- else:
153
- func_expr = fn.make_expr()
154
- assert func_expr is not None, "Function expression should not be None."
155
- out_vars = list(
156
- ctx.evaluate(func_expr.body, {**arg_binding, **capture_binding})
157
- )
158
-
159
- assert isinstance(out_vars, list), f"Expected list, got {type(out_vars)}"
160
- return var_demorph(out_vars, fn.out_imms, fn.out_struct)