mplang-nightly 0.1.dev268__py3-none-any.whl → 0.1.dev270__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mplang/__init__.py +391 -17
- mplang/{v2/backends → backends}/__init__.py +9 -7
- mplang/{v2/backends → backends}/bfv_impl.py +6 -6
- mplang/{v2/backends → backends}/crypto_impl.py +6 -6
- mplang/{v2/backends → backends}/field_impl.py +5 -5
- mplang/{v2/backends → backends}/func_impl.py +4 -4
- mplang/{v2/backends → backends}/phe_impl.py +3 -3
- mplang/{v2/backends → backends}/simp_design.md +1 -1
- mplang/{v2/backends → backends}/simp_driver/__init__.py +5 -5
- mplang/{v2/backends → backends}/simp_driver/http.py +8 -8
- mplang/{v2/backends → backends}/simp_driver/mem.py +9 -9
- mplang/{v2/backends → backends}/simp_driver/ops.py +4 -4
- mplang/{v2/backends → backends}/simp_driver/state.py +2 -2
- mplang/{v2/backends → backends}/simp_driver/values.py +2 -2
- mplang/{v2/backends → backends}/simp_worker/__init__.py +3 -3
- mplang/{v2/backends → backends}/simp_worker/http.py +10 -10
- mplang/{v2/backends → backends}/simp_worker/mem.py +1 -1
- mplang/{v2/backends → backends}/simp_worker/ops.py +5 -5
- mplang/{v2/backends → backends}/simp_worker/state.py +2 -4
- mplang/{v2/backends → backends}/spu_impl.py +8 -8
- mplang/{v2/backends → backends}/spu_state.py +4 -4
- mplang/{v2/backends → backends}/store_impl.py +3 -3
- mplang/{v2/backends → backends}/table_impl.py +8 -8
- mplang/{v2/backends → backends}/tee_impl.py +6 -6
- mplang/{v2/backends → backends}/tensor_impl.py +6 -6
- mplang/{v2/cli.py → cli.py} +9 -9
- mplang/{v2/cli_guide.md → cli_guide.md} +12 -12
- mplang/{v2/dialects → dialects}/__init__.py +5 -5
- mplang/{v2/dialects → dialects}/bfv.py +6 -6
- mplang/{v2/dialects → dialects}/crypto.py +5 -5
- mplang/{v2/dialects → dialects}/dtypes.py +2 -2
- mplang/{v2/dialects → dialects}/field.py +3 -3
- mplang/{v2/dialects → dialects}/func.py +2 -2
- mplang/{v2/dialects → dialects}/phe.py +6 -6
- mplang/{v2/dialects → dialects}/simp.py +6 -6
- mplang/{v2/dialects → dialects}/spu.py +7 -7
- mplang/{v2/dialects → dialects}/store.py +2 -2
- mplang/{v2/dialects → dialects}/table.py +3 -3
- mplang/{v2/dialects → dialects}/tee.py +6 -6
- mplang/{v2/dialects → dialects}/tensor.py +5 -5
- mplang/{v2/edsl → edsl}/__init__.py +3 -3
- mplang/{v2/edsl → edsl}/context.py +6 -6
- mplang/{v2/edsl → edsl}/graph.py +5 -5
- mplang/{v2/edsl → edsl}/jit.py +2 -2
- mplang/{v2/edsl → edsl}/object.py +1 -1
- mplang/{v2/edsl → edsl}/primitive.py +5 -5
- mplang/{v2/edsl → edsl}/printer.py +1 -1
- mplang/{v2/edsl → edsl}/serde.py +1 -1
- mplang/{v2/edsl → edsl}/tracer.py +7 -7
- mplang/{v2/edsl → edsl}/typing.py +1 -1
- mplang/{v2/kernels → kernels}/ldpc.cpp +13 -13
- mplang/{v2/kernels → kernels}/okvs.cpp +4 -4
- mplang/{v2/kernels → kernels}/okvs_opt.cpp +46 -31
- mplang/{v2/kernels → kernels}/py_kernels.py +1 -1
- mplang/{v2/libs → libs}/collective.py +5 -5
- mplang/{v2/libs → libs}/device/__init__.py +1 -1
- mplang/{v2/libs → libs}/device/api.py +12 -12
- mplang/{v2/libs → libs}/ml/__init__.py +1 -1
- mplang/{v2/libs → libs}/ml/sgb.py +4 -4
- mplang/{v2/libs → libs}/mpc/__init__.py +3 -3
- mplang/{v2/libs → libs}/mpc/_utils.py +2 -2
- mplang/{v2/libs → libs}/mpc/analytics/aggregation.py +1 -1
- mplang/{v2/libs → libs}/mpc/analytics/groupby.py +2 -2
- mplang/{v2/libs → libs}/mpc/analytics/permutation.py +3 -3
- mplang/{v2/libs → libs}/mpc/ot/base.py +3 -3
- mplang/{v2/libs → libs}/mpc/ot/extension.py +2 -2
- mplang/{v2/libs → libs}/mpc/ot/silent.py +4 -4
- mplang/{v2/libs → libs}/mpc/psi/cuckoo.py +3 -3
- mplang/{v2/libs → libs}/mpc/psi/okvs.py +1 -1
- mplang/{v2/libs → libs}/mpc/psi/okvs_gct.py +19 -13
- mplang/{v2/libs → libs}/mpc/psi/oprf.py +3 -3
- mplang/libs/mpc/psi/rr22.py +303 -0
- mplang/{v2/libs → libs}/mpc/psi/unbalanced.py +4 -4
- mplang/{v2/libs → libs}/mpc/vole/gilboa.py +3 -3
- mplang/{v2/libs → libs}/mpc/vole/ldpc.py +2 -2
- mplang/{v2/libs → libs}/mpc/vole/silver.py +6 -6
- mplang/{v2/runtime → runtime}/interpreter.py +11 -11
- mplang/{v2/runtime → runtime}/value.py +2 -2
- mplang/{v1/runtime → utils}/__init__.py +18 -15
- mplang/{v1/utils → utils}/func_utils.py +1 -1
- {mplang_nightly-0.1.dev268.dist-info → mplang_nightly-0.1.dev270.dist-info}/METADATA +2 -2
- mplang_nightly-0.1.dev270.dist-info/RECORD +102 -0
- mplang/v1/__init__.py +0 -157
- mplang/v1/_device.py +0 -602
- mplang/v1/analysis/__init__.py +0 -37
- mplang/v1/analysis/diagram.py +0 -567
- mplang/v1/core/__init__.py +0 -157
- mplang/v1/core/cluster.py +0 -343
- mplang/v1/core/comm.py +0 -281
- mplang/v1/core/context_mgr.py +0 -50
- mplang/v1/core/dtypes.py +0 -335
- mplang/v1/core/expr/__init__.py +0 -80
- mplang/v1/core/expr/ast.py +0 -542
- mplang/v1/core/expr/evaluator.py +0 -581
- mplang/v1/core/expr/printer.py +0 -285
- mplang/v1/core/expr/transformer.py +0 -141
- mplang/v1/core/expr/utils.py +0 -78
- mplang/v1/core/expr/visitor.py +0 -85
- mplang/v1/core/expr/walk.py +0 -387
- mplang/v1/core/interp.py +0 -160
- mplang/v1/core/mask.py +0 -325
- mplang/v1/core/mpir.py +0 -965
- mplang/v1/core/mpobject.py +0 -117
- mplang/v1/core/mptype.py +0 -407
- mplang/v1/core/pfunc.py +0 -130
- mplang/v1/core/primitive.py +0 -877
- mplang/v1/core/table.py +0 -218
- mplang/v1/core/tensor.py +0 -75
- mplang/v1/core/tracer.py +0 -383
- mplang/v1/host.py +0 -130
- mplang/v1/kernels/__init__.py +0 -41
- mplang/v1/kernels/base.py +0 -125
- mplang/v1/kernels/basic.py +0 -240
- mplang/v1/kernels/context.py +0 -369
- mplang/v1/kernels/crypto.py +0 -122
- mplang/v1/kernels/fhe.py +0 -858
- mplang/v1/kernels/mock_tee.py +0 -72
- mplang/v1/kernels/phe.py +0 -1864
- mplang/v1/kernels/spu.py +0 -341
- mplang/v1/kernels/sql_duckdb.py +0 -44
- mplang/v1/kernels/stablehlo.py +0 -90
- mplang/v1/kernels/value.py +0 -626
- mplang/v1/ops/__init__.py +0 -35
- mplang/v1/ops/base.py +0 -424
- mplang/v1/ops/basic.py +0 -294
- mplang/v1/ops/crypto.py +0 -262
- mplang/v1/ops/fhe.py +0 -272
- mplang/v1/ops/jax_cc.py +0 -147
- mplang/v1/ops/nnx_cc.py +0 -168
- mplang/v1/ops/phe.py +0 -216
- mplang/v1/ops/spu.py +0 -151
- mplang/v1/ops/sql_cc.py +0 -303
- mplang/v1/ops/tee.py +0 -36
- mplang/v1/protos/v1alpha1/mpir_pb2.py +0 -63
- mplang/v1/protos/v1alpha1/mpir_pb2.pyi +0 -557
- mplang/v1/protos/v1alpha1/value_pb2.py +0 -34
- mplang/v1/protos/v1alpha1/value_pb2.pyi +0 -169
- mplang/v1/runtime/channel.py +0 -230
- mplang/v1/runtime/cli.py +0 -451
- mplang/v1/runtime/client.py +0 -456
- mplang/v1/runtime/communicator.py +0 -131
- mplang/v1/runtime/data_providers.py +0 -303
- mplang/v1/runtime/driver.py +0 -324
- mplang/v1/runtime/exceptions.py +0 -27
- mplang/v1/runtime/http_api.md +0 -56
- mplang/v1/runtime/link_comm.py +0 -196
- mplang/v1/runtime/server.py +0 -501
- mplang/v1/runtime/session.py +0 -270
- mplang/v1/runtime/simulation.py +0 -324
- mplang/v1/simp/__init__.py +0 -13
- mplang/v1/simp/api.py +0 -353
- mplang/v1/simp/mpi.py +0 -131
- mplang/v1/simp/party.py +0 -225
- mplang/v1/simp/random.py +0 -120
- mplang/v1/simp/smpc.py +0 -238
- mplang/v1/utils/__init__.py +0 -13
- mplang/v1/utils/crypto.py +0 -32
- mplang/v1/utils/spu_utils.py +0 -130
- mplang/v1/utils/table_utils.py +0 -185
- mplang/v2/__init__.py +0 -424
- mplang/v2/libs/mpc/psi/rr22.py +0 -344
- mplang_nightly-0.1.dev268.dist-info/RECORD +0 -180
- /mplang/{v2/backends → backends}/channel.py +0 -0
- /mplang/{v2/edsl → edsl}/README.md +0 -0
- /mplang/{v2/edsl → edsl}/registry.py +0 -0
- /mplang/{v2/kernels → kernels}/Makefile +0 -0
- /mplang/{v2/kernels → kernels}/__init__.py +0 -0
- /mplang/{v2/kernels → kernels}/gf128.cpp +0 -0
- /mplang/{v2/libs → libs}/device/cluster.py +0 -0
- /mplang/{v2/libs → libs}/mpc/analytics/__init__.py +0 -0
- /mplang/{v2/libs → libs}/mpc/analytics/groupby.md +0 -0
- /mplang/{v2/libs → libs}/mpc/common/constants.py +0 -0
- /mplang/{v2/libs → libs}/mpc/ot/__init__.py +0 -0
- /mplang/{v2/libs → libs}/mpc/psi/__init__.py +0 -0
- /mplang/{v2/libs → libs}/mpc/vole/__init__.py +0 -0
- /mplang/{v2/runtime → runtime}/__init__.py +0 -0
- /mplang/{v2/runtime → runtime}/dialect_state.py +0 -0
- /mplang/{v2/runtime → runtime}/object_store.py +0 -0
- {mplang_nightly-0.1.dev268.dist-info → mplang_nightly-0.1.dev270.dist-info}/WHEEL +0 -0
- {mplang_nightly-0.1.dev268.dist-info → mplang_nightly-0.1.dev270.dist-info}/entry_points.txt +0 -0
- {mplang_nightly-0.1.dev268.dist-info → mplang_nightly-0.1.dev270.dist-info}/licenses/LICENSE +0 -0
mplang/v1/core/expr/walk.py
DELETED
|
@@ -1,387 +0,0 @@
|
|
|
1
|
-
# Copyright 2025 Ant Group Co., Ltd.
|
|
2
|
-
#
|
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
-
# you may not use this file except in compliance with the License.
|
|
5
|
-
# You may obtain a copy of the License at
|
|
6
|
-
#
|
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
-
#
|
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
-
# See the License for the specific language governing permissions and
|
|
13
|
-
# limitations under the License.
|
|
14
|
-
|
|
15
|
-
"""
|
|
16
|
-
Pure functional traversal utilities for MPLang expression graphs.
|
|
17
|
-
|
|
18
|
-
This module provides semantic-agnostic walkers over expression graphs. It exposes
|
|
19
|
-
both dataflow-only traversal (deps edges) and structural traversal (deps +
|
|
20
|
-
contained regions like function bodies, then/else, loop body). All traversals are
|
|
21
|
-
implemented iteratively to avoid Python recursion limits.
|
|
22
|
-
|
|
23
|
-
Notes
|
|
24
|
-
- These walkers never evaluate expressions nor decide runtime branches. For
|
|
25
|
-
execution order, use an evaluator/driver that consults semantic rules.
|
|
26
|
-
- Topological order is produced w.r.t. deps edges, not region containment.
|
|
27
|
-
"""
|
|
28
|
-
|
|
29
|
-
from __future__ import annotations
|
|
30
|
-
|
|
31
|
-
from collections import deque
|
|
32
|
-
from collections.abc import Callable, Iterable, Iterator, Sequence
|
|
33
|
-
from typing import cast
|
|
34
|
-
|
|
35
|
-
from mplang.v1.core.expr.ast import (
|
|
36
|
-
AccessExpr,
|
|
37
|
-
CallExpr,
|
|
38
|
-
CondExpr,
|
|
39
|
-
ConvExpr,
|
|
40
|
-
EvalExpr,
|
|
41
|
-
Expr,
|
|
42
|
-
FuncDefExpr,
|
|
43
|
-
ShflExpr,
|
|
44
|
-
ShflSExpr,
|
|
45
|
-
TupleExpr,
|
|
46
|
-
VariableExpr,
|
|
47
|
-
WhileExpr,
|
|
48
|
-
)
|
|
49
|
-
|
|
50
|
-
Node = Expr
|
|
51
|
-
GetDeps = Callable[[Node], Iterable[Node]]
|
|
52
|
-
YieldCond = Callable[[Node], bool]
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
def _identity_key(n: Node) -> int:
|
|
56
|
-
"""Identity-based key for hashing nodes that may not be hashable."""
|
|
57
|
-
return id(n)
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
# ---------------------------- default dependency getters ----------------------------
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
def dataflow_deps(node: Node) -> Iterable[Node]:
|
|
64
|
-
"""Default dataflow dependencies for core Expr nodes.
|
|
65
|
-
|
|
66
|
-
This includes only inputs that must be computed to evaluate the node itself.
|
|
67
|
-
Contained regions (function/branch/loop bodies) are NOT traversed here.
|
|
68
|
-
"""
|
|
69
|
-
if isinstance(node, EvalExpr):
|
|
70
|
-
return node.args
|
|
71
|
-
if isinstance(node, TupleExpr):
|
|
72
|
-
return node.args
|
|
73
|
-
if isinstance(node, CondExpr):
|
|
74
|
-
# Pure dataflow: pred and actual args to branch functions
|
|
75
|
-
return [node.pred, *node.args]
|
|
76
|
-
if isinstance(node, WhileExpr):
|
|
77
|
-
# Initial state args required; bodies are regions, not dataflow deps
|
|
78
|
-
return list(node.args)
|
|
79
|
-
if isinstance(node, ConvExpr):
|
|
80
|
-
return node.vars
|
|
81
|
-
if isinstance(node, ShflSExpr):
|
|
82
|
-
return [node.src_val]
|
|
83
|
-
if isinstance(node, ShflExpr):
|
|
84
|
-
return [node.src, node.index]
|
|
85
|
-
if isinstance(node, AccessExpr):
|
|
86
|
-
return [node.src]
|
|
87
|
-
if isinstance(node, VariableExpr):
|
|
88
|
-
return []
|
|
89
|
-
if isinstance(node, FuncDefExpr):
|
|
90
|
-
# Definition is not a value-producing node in dataflow; no deps
|
|
91
|
-
return []
|
|
92
|
-
if isinstance(node, CallExpr):
|
|
93
|
-
# Arguments are dataflow deps; function body is a region
|
|
94
|
-
return node.args
|
|
95
|
-
# Fallback: try best-effort empty
|
|
96
|
-
return []
|
|
97
|
-
|
|
98
|
-
|
|
99
|
-
def _structural_region_roots(node: Node) -> Iterable[Node]:
|
|
100
|
-
"""Roots of contained regions for structural traversal.
|
|
101
|
-
|
|
102
|
-
- Cond: then_fn.body, else_fn.body
|
|
103
|
-
- While: cond_fn.body, body_fn.body
|
|
104
|
-
- Call: fn.body
|
|
105
|
-
- FuncDef: body
|
|
106
|
-
"""
|
|
107
|
-
if isinstance(node, CondExpr):
|
|
108
|
-
return [node.then_fn.body, node.else_fn.body]
|
|
109
|
-
if isinstance(node, WhileExpr):
|
|
110
|
-
return [node.cond_fn.body, node.body_fn.body]
|
|
111
|
-
if isinstance(node, CallExpr):
|
|
112
|
-
return [node.fn.body]
|
|
113
|
-
if isinstance(node, FuncDefExpr):
|
|
114
|
-
return [node.body]
|
|
115
|
-
return []
|
|
116
|
-
|
|
117
|
-
|
|
118
|
-
# ---------------------------------- core walkers ----------------------------------
|
|
119
|
-
|
|
120
|
-
|
|
121
|
-
def walk(
|
|
122
|
-
roots: Node | Sequence[Node],
|
|
123
|
-
*,
|
|
124
|
-
get_deps: GetDeps,
|
|
125
|
-
traversal: str = "dfs_post_iter",
|
|
126
|
-
yield_condition: YieldCond | None = None,
|
|
127
|
-
detect_cycles: bool = True,
|
|
128
|
-
) -> Iterator[Node]:
|
|
129
|
-
"""Generic pure structural walker.
|
|
130
|
-
|
|
131
|
-
Args:
|
|
132
|
-
roots: Single root or a sequence of roots to start from.
|
|
133
|
-
get_deps: Function mapping node -> iterable of dependency nodes.
|
|
134
|
-
traversal: One of {'dfs_pre_iter','dfs_post_iter','bfs','topo'}.
|
|
135
|
-
yield_condition: Optional predicate to filter yielded nodes.
|
|
136
|
-
detect_cycles: If True, raises ValueError on cycles (for DFS/topo).
|
|
137
|
-
|
|
138
|
-
Yields:
|
|
139
|
-
Nodes in the chosen traversal order, filtered by yield_condition.
|
|
140
|
-
"""
|
|
141
|
-
start: list[Node]
|
|
142
|
-
if isinstance(roots, (list, tuple)):
|
|
143
|
-
start = list(roots)
|
|
144
|
-
else:
|
|
145
|
-
start = [cast(Node, roots)]
|
|
146
|
-
|
|
147
|
-
if traversal == "bfs":
|
|
148
|
-
yield from _bfs(start, get_deps, yield_condition)
|
|
149
|
-
return
|
|
150
|
-
if traversal == "dfs_pre_iter":
|
|
151
|
-
yield from _dfs_pre_iter(start, get_deps, yield_condition, detect_cycles)
|
|
152
|
-
return
|
|
153
|
-
if traversal == "dfs_post_iter":
|
|
154
|
-
yield from _dfs_post_iter(start, get_deps, yield_condition, detect_cycles)
|
|
155
|
-
return
|
|
156
|
-
if traversal == "topo":
|
|
157
|
-
yield from _topo_kahn(start, get_deps, yield_condition, detect_cycles)
|
|
158
|
-
return
|
|
159
|
-
|
|
160
|
-
raise ValueError(f"Invalid traversal type: {traversal}")
|
|
161
|
-
|
|
162
|
-
|
|
163
|
-
def walk_dataflow(
|
|
164
|
-
roots: Node | Sequence[Node],
|
|
165
|
-
*,
|
|
166
|
-
traversal: str = "dfs_post_iter",
|
|
167
|
-
yield_condition: YieldCond | None = None,
|
|
168
|
-
detect_cycles: bool = True,
|
|
169
|
-
) -> Iterator[Node]:
|
|
170
|
-
"""Walk using default dataflow dependencies for Expr nodes."""
|
|
171
|
-
return walk(
|
|
172
|
-
roots,
|
|
173
|
-
get_deps=dataflow_deps,
|
|
174
|
-
traversal=traversal,
|
|
175
|
-
yield_condition=yield_condition,
|
|
176
|
-
detect_cycles=detect_cycles,
|
|
177
|
-
)
|
|
178
|
-
|
|
179
|
-
|
|
180
|
-
def walk_structural(
|
|
181
|
-
roots: Node | Sequence[Node],
|
|
182
|
-
*,
|
|
183
|
-
traversal: str = "dfs_post_iter",
|
|
184
|
-
yield_condition: YieldCond | None = None,
|
|
185
|
-
detect_cycles: bool = True,
|
|
186
|
-
) -> Iterator[Node]:
|
|
187
|
-
"""Walk including region containment (function bodies, branches, loop bodies).
|
|
188
|
-
|
|
189
|
-
This augments dataflow dependencies with region roots once, so structure is
|
|
190
|
-
fully traversed without runtime branch choices or loop iteration expansion.
|
|
191
|
-
"""
|
|
192
|
-
|
|
193
|
-
def deps_plus_regions(n: Node) -> Iterable[Node]:
|
|
194
|
-
yield from dataflow_deps(n)
|
|
195
|
-
yield from _structural_region_roots(n)
|
|
196
|
-
|
|
197
|
-
return walk(
|
|
198
|
-
roots,
|
|
199
|
-
get_deps=deps_plus_regions,
|
|
200
|
-
traversal=traversal,
|
|
201
|
-
yield_condition=yield_condition,
|
|
202
|
-
detect_cycles=detect_cycles,
|
|
203
|
-
)
|
|
204
|
-
|
|
205
|
-
|
|
206
|
-
# -------------------------------- traversal engines --------------------------------
|
|
207
|
-
|
|
208
|
-
|
|
209
|
-
def _maybe_yield(n: Node, pred: YieldCond | None) -> Iterator[Node]:
|
|
210
|
-
if pred is None or pred(n):
|
|
211
|
-
yield n
|
|
212
|
-
|
|
213
|
-
|
|
214
|
-
def _bfs(
|
|
215
|
-
roots: Sequence[Node],
|
|
216
|
-
get_deps: GetDeps,
|
|
217
|
-
yield_condition: YieldCond | None,
|
|
218
|
-
) -> Iterator[Node]:
|
|
219
|
-
seen: set[int] = set()
|
|
220
|
-
q: deque[Node] = deque(roots)
|
|
221
|
-
while q:
|
|
222
|
-
n = q.popleft()
|
|
223
|
-
k = _identity_key(n)
|
|
224
|
-
if k in seen:
|
|
225
|
-
continue
|
|
226
|
-
seen.add(k)
|
|
227
|
-
yield from _maybe_yield(n, yield_condition)
|
|
228
|
-
for d in get_deps(n):
|
|
229
|
-
q.append(d)
|
|
230
|
-
|
|
231
|
-
|
|
232
|
-
def _dfs_pre_iter(
|
|
233
|
-
roots: Sequence[Node],
|
|
234
|
-
get_deps: GetDeps,
|
|
235
|
-
yield_condition: YieldCond | None,
|
|
236
|
-
detect_cycles: bool,
|
|
237
|
-
) -> Iterator[Node]:
|
|
238
|
-
seen: set[int] = set()
|
|
239
|
-
onstack: set[int] = set()
|
|
240
|
-
stack: list[tuple[Node, int]] = [] # (node, next_child_index)
|
|
241
|
-
|
|
242
|
-
for root in roots:
|
|
243
|
-
kroot = _identity_key(root)
|
|
244
|
-
if kroot in seen:
|
|
245
|
-
continue
|
|
246
|
-
stack.append((root, 0))
|
|
247
|
-
onstack.add(kroot)
|
|
248
|
-
|
|
249
|
-
while stack:
|
|
250
|
-
node, idx = stack[-1]
|
|
251
|
-
k = _identity_key(node)
|
|
252
|
-
if k not in seen:
|
|
253
|
-
# Pre-order yield on first encounter
|
|
254
|
-
seen.add(k)
|
|
255
|
-
yield from _maybe_yield(node, yield_condition)
|
|
256
|
-
|
|
257
|
-
deps = list(get_deps(node))
|
|
258
|
-
if idx < len(deps):
|
|
259
|
-
child = deps[idx]
|
|
260
|
-
stack[-1] = (node, idx + 1)
|
|
261
|
-
kc = _identity_key(child)
|
|
262
|
-
if kc in onstack:
|
|
263
|
-
if detect_cycles:
|
|
264
|
-
raise ValueError("Cycle detected during dfs_pre_iter walk")
|
|
265
|
-
# skip on cycles if not detecting
|
|
266
|
-
elif kc not in seen:
|
|
267
|
-
stack.append((child, 0))
|
|
268
|
-
onstack.add(kc)
|
|
269
|
-
else:
|
|
270
|
-
# already processed child
|
|
271
|
-
pass
|
|
272
|
-
else:
|
|
273
|
-
stack.pop()
|
|
274
|
-
onstack.discard(k)
|
|
275
|
-
|
|
276
|
-
|
|
277
|
-
def _dfs_post_iter(
|
|
278
|
-
roots: Sequence[Node],
|
|
279
|
-
get_deps: GetDeps,
|
|
280
|
-
yield_condition: YieldCond | None,
|
|
281
|
-
detect_cycles: bool,
|
|
282
|
-
) -> Iterator[Node]:
|
|
283
|
-
seen: set[int] = set()
|
|
284
|
-
onstack: set[int] = set()
|
|
285
|
-
done: set[int] = set()
|
|
286
|
-
stack: list[tuple[Node, int]] = [] # (node, next_child_index)
|
|
287
|
-
|
|
288
|
-
for root in roots:
|
|
289
|
-
kroot = _identity_key(root)
|
|
290
|
-
if kroot in done:
|
|
291
|
-
continue
|
|
292
|
-
stack.append((root, 0))
|
|
293
|
-
onstack.add(kroot)
|
|
294
|
-
|
|
295
|
-
while stack:
|
|
296
|
-
node, idx = stack[-1]
|
|
297
|
-
k = _identity_key(node)
|
|
298
|
-
deps = list(get_deps(node))
|
|
299
|
-
if k not in seen:
|
|
300
|
-
seen.add(k)
|
|
301
|
-
|
|
302
|
-
if idx < len(deps):
|
|
303
|
-
child = deps[idx]
|
|
304
|
-
stack[-1] = (node, idx + 1)
|
|
305
|
-
kc = _identity_key(child)
|
|
306
|
-
if kc in done:
|
|
307
|
-
continue
|
|
308
|
-
if kc in onstack:
|
|
309
|
-
if detect_cycles:
|
|
310
|
-
raise ValueError("Cycle detected during dfs_post_iter walk")
|
|
311
|
-
# else ignore to avoid infinite loop
|
|
312
|
-
continue
|
|
313
|
-
stack.append((child, 0))
|
|
314
|
-
onstack.add(kc)
|
|
315
|
-
else:
|
|
316
|
-
# all children processed
|
|
317
|
-
stack.pop()
|
|
318
|
-
onstack.discard(k)
|
|
319
|
-
if k not in done:
|
|
320
|
-
done.add(k)
|
|
321
|
-
yield from _maybe_yield(node, yield_condition)
|
|
322
|
-
|
|
323
|
-
|
|
324
|
-
def _collect_closure(roots: Sequence[Node], get_deps: GetDeps) -> list[Node]:
|
|
325
|
-
"""Collect reachable nodes and adjacency (by identity) for topo sort.
|
|
326
|
-
|
|
327
|
-
Returns (nodes_list, parents_map) where parents_map[v] is the set of keys of
|
|
328
|
-
dependency nodes for v (edges dep -> v).
|
|
329
|
-
"""
|
|
330
|
-
nodes: list[Node] = []
|
|
331
|
-
seen: set[int] = set()
|
|
332
|
-
|
|
333
|
-
q: deque[Node] = deque(roots)
|
|
334
|
-
while q:
|
|
335
|
-
n = q.popleft()
|
|
336
|
-
kn = _identity_key(n)
|
|
337
|
-
if kn in seen:
|
|
338
|
-
continue
|
|
339
|
-
seen.add(kn)
|
|
340
|
-
nodes.append(n)
|
|
341
|
-
for d in get_deps(n):
|
|
342
|
-
q.append(d)
|
|
343
|
-
return nodes
|
|
344
|
-
|
|
345
|
-
|
|
346
|
-
def _topo_kahn(
|
|
347
|
-
roots: Sequence[Node],
|
|
348
|
-
get_deps: GetDeps,
|
|
349
|
-
yield_condition: YieldCond | None,
|
|
350
|
-
detect_cycles: bool,
|
|
351
|
-
) -> Iterator[Node]:
|
|
352
|
-
# Build closure and in-degree from deps edges (dep -> node)
|
|
353
|
-
nodes = _collect_closure(roots, get_deps)
|
|
354
|
-
|
|
355
|
-
# reverse map: parent -> set(children)
|
|
356
|
-
children: dict[int, set[int]] = {}
|
|
357
|
-
indeg: dict[int, int] = {}
|
|
358
|
-
for n in nodes:
|
|
359
|
-
kn = _identity_key(n)
|
|
360
|
-
indeg.setdefault(kn, 0)
|
|
361
|
-
for d in get_deps(n):
|
|
362
|
-
kd = _identity_key(d)
|
|
363
|
-
children.setdefault(kd, set()).add(kn)
|
|
364
|
-
indeg[kn] = indeg.get(kn, 0) + 1
|
|
365
|
-
indeg.setdefault(kd, indeg.get(kd, 0))
|
|
366
|
-
|
|
367
|
-
# queue of zero in-degree nodes (ready after all deps)
|
|
368
|
-
q: deque[int] = deque(k for k, v in indeg.items() if v == 0)
|
|
369
|
-
key2node: dict[int, Node] = {_identity_key(n): n for n in nodes}
|
|
370
|
-
produced = 0
|
|
371
|
-
seen_keys: set[int] = set()
|
|
372
|
-
|
|
373
|
-
while q:
|
|
374
|
-
k = q.popleft()
|
|
375
|
-
if k in seen_keys:
|
|
376
|
-
continue
|
|
377
|
-
seen_keys.add(k)
|
|
378
|
-
n = key2node[k]
|
|
379
|
-
produced += 1
|
|
380
|
-
yield from _maybe_yield(n, yield_condition)
|
|
381
|
-
for c in children.get(k, set()):
|
|
382
|
-
indeg[c] -= 1
|
|
383
|
-
if indeg[c] == 0:
|
|
384
|
-
q.append(c)
|
|
385
|
-
|
|
386
|
-
if detect_cycles and produced < len(indeg):
|
|
387
|
-
raise ValueError("Cycle detected during topo walk")
|
mplang/v1/core/interp.py
DELETED
|
@@ -1,160 +0,0 @@
|
|
|
1
|
-
# Copyright 2025 Ant Group Co., Ltd.
|
|
2
|
-
#
|
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
-
# you may not use this file except in compliance with the License.
|
|
5
|
-
# You may obtain a copy of the License at
|
|
6
|
-
#
|
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
-
#
|
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
-
# See the License for the specific language governing permissions and
|
|
13
|
-
# limitations under the License.
|
|
14
|
-
|
|
15
|
-
"""
|
|
16
|
-
Interpreter context and InterpVar implementation.
|
|
17
|
-
|
|
18
|
-
This module provides the interpreter context for eager evaluation and InterpVar
|
|
19
|
-
which references computed values in an interpreter.
|
|
20
|
-
"""
|
|
21
|
-
|
|
22
|
-
from __future__ import annotations
|
|
23
|
-
|
|
24
|
-
from abc import abstractmethod
|
|
25
|
-
from collections.abc import Sequence
|
|
26
|
-
from typing import Any, cast
|
|
27
|
-
|
|
28
|
-
from mplang.v1.core.cluster import ClusterSpec
|
|
29
|
-
from mplang.v1.core.expr.ast import Expr, VariableExpr
|
|
30
|
-
from mplang.v1.core.mpobject import MPContext, MPObject
|
|
31
|
-
from mplang.v1.core.mptype import MPType, TensorLike
|
|
32
|
-
from mplang.v1.core.tracer import TracedFunction
|
|
33
|
-
from mplang.v1.utils.func_utils import var_demorph, var_morph
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
# TODO(jint): Should we use inheritance or composition here?
|
|
37
|
-
class InterpContext(MPContext):
|
|
38
|
-
"""Context for eager evaluation using an interpreter.
|
|
39
|
-
|
|
40
|
-
InterpContext executes computations immediately and stores results
|
|
41
|
-
in an underlying interpreter.
|
|
42
|
-
"""
|
|
43
|
-
|
|
44
|
-
def __init__(
|
|
45
|
-
self,
|
|
46
|
-
cluster_spec: ClusterSpec,
|
|
47
|
-
):
|
|
48
|
-
super().__init__(cluster_spec)
|
|
49
|
-
|
|
50
|
-
@abstractmethod
|
|
51
|
-
def evaluate(self, expr: Expr, bindings: dict[str, MPObject]) -> Sequence[MPObject]:
|
|
52
|
-
"""Evaluate an expression in this context.
|
|
53
|
-
|
|
54
|
-
Args:
|
|
55
|
-
expr: The expression to evaluate.
|
|
56
|
-
bindings: A dictionary of variable bindings.
|
|
57
|
-
|
|
58
|
-
Returns:
|
|
59
|
-
The result of the evaluation as an MPObject.
|
|
60
|
-
"""
|
|
61
|
-
raise NotImplementedError("Should be overridden in subclasses.")
|
|
62
|
-
|
|
63
|
-
@abstractmethod
|
|
64
|
-
def fetch(self, obj: MPObject) -> list[TensorLike]:
|
|
65
|
-
"""Fetch the value of an MPObject from this InterpContext to the current Python interpreter.
|
|
66
|
-
|
|
67
|
-
The MPObject must have been created by this InterpContext. If the object
|
|
68
|
-
was not produced by this context, a ValueError will be raised.
|
|
69
|
-
|
|
70
|
-
Args:
|
|
71
|
-
obj: The MPObject to fetch. Must be produced by this InterpContext.
|
|
72
|
-
|
|
73
|
-
Returns:
|
|
74
|
-
A list of tensor-like values with length equal to psize(). For each party i,
|
|
75
|
-
if the i-th bit of obj.pmask is 0 (indicating party i does not hold this value),
|
|
76
|
-
the i-th element in the returned list will be None. Otherwise, it contains
|
|
77
|
-
the actual tensor value held by party i.
|
|
78
|
-
|
|
79
|
-
Raises:
|
|
80
|
-
ValueError: If obj was not produced by this InterpContext.
|
|
81
|
-
"""
|
|
82
|
-
raise NotImplementedError("Should be overridden in subclasses.")
|
|
83
|
-
|
|
84
|
-
|
|
85
|
-
class InterpVar(MPObject):
|
|
86
|
-
"""A variable that references a value in an interpreter.
|
|
87
|
-
|
|
88
|
-
InterpVar represents a value that has been computed and exists
|
|
89
|
-
in the interpreter's variable store.
|
|
90
|
-
"""
|
|
91
|
-
|
|
92
|
-
def __init__(self, ctx: InterpContext, mptype: MPType):
|
|
93
|
-
self._ctx = ctx
|
|
94
|
-
self._mptype = mptype
|
|
95
|
-
|
|
96
|
-
@property
|
|
97
|
-
def ctx(self) -> MPContext:
|
|
98
|
-
"""The context this variable belongs to."""
|
|
99
|
-
return self._ctx
|
|
100
|
-
|
|
101
|
-
@property
|
|
102
|
-
def mptype(self) -> MPType:
|
|
103
|
-
"""The type of this variable."""
|
|
104
|
-
# TODO: fetch type from the Interpreter and cache it.
|
|
105
|
-
return self._mptype
|
|
106
|
-
|
|
107
|
-
def __repr__(self) -> str:
|
|
108
|
-
return f"InterpVar(mptype={self.mptype})"
|
|
109
|
-
|
|
110
|
-
|
|
111
|
-
def apply(ctx: InterpContext, fn: TracedFunction, *args: Any, **kwargs: Any) -> Any:
|
|
112
|
-
is_mpobj = lambda x: isinstance(x, MPObject)
|
|
113
|
-
in_args, in_imms, in_struct = var_morph((args, kwargs), is_mpobj)
|
|
114
|
-
|
|
115
|
-
# All variables must be in the same context as the function.
|
|
116
|
-
if not all(isinstance(var, InterpVar) and var.ctx is ctx for var in in_args):
|
|
117
|
-
raise ValueError("All input variables must be InterpVars in the same context.")
|
|
118
|
-
|
|
119
|
-
# Check if the function signature matches the input types.
|
|
120
|
-
if fn.in_struct != in_struct:
|
|
121
|
-
raise ValueError(f"Input structure mismatch: {fn.in_struct} != {in_struct}")
|
|
122
|
-
if fn.in_imms != in_imms:
|
|
123
|
-
# Should trigger re-trace in JAX
|
|
124
|
-
raise ValueError(f"Input immutables mismatch: {fn.in_imms} != {in_imms}")
|
|
125
|
-
if len(fn.in_vars) != len(in_args):
|
|
126
|
-
raise ValueError(f"Input types mismatch: {fn.in_vars} != {in_args}")
|
|
127
|
-
# check parameter type match
|
|
128
|
-
for param, arg in zip(fn.in_vars, in_args, strict=False):
|
|
129
|
-
if param.mptype != arg.mptype:
|
|
130
|
-
raise ValueError(
|
|
131
|
-
f"Input variable type mismatch: {param.mptype} != {arg.mptype}"
|
|
132
|
-
)
|
|
133
|
-
|
|
134
|
-
# Prepare for the captured variables, which should also be in the same context.
|
|
135
|
-
for captured, _traced in fn.capture_map.items():
|
|
136
|
-
if not isinstance(captured, InterpVar) or captured.ctx is not ctx:
|
|
137
|
-
raise ValueError(
|
|
138
|
-
f"Capture {captured} must be in this({ctx}) context, got {captured.ctx}."
|
|
139
|
-
)
|
|
140
|
-
|
|
141
|
-
arg_binding: dict[str, MPObject] = {
|
|
142
|
-
cast(VariableExpr, var.expr).name: obj
|
|
143
|
-
for var, obj in zip(fn.in_vars, in_args, strict=False)
|
|
144
|
-
}
|
|
145
|
-
capture_binding = {
|
|
146
|
-
cast(VariableExpr, var.expr).name: captured
|
|
147
|
-
for captured, var in fn.capture_map.items()
|
|
148
|
-
}
|
|
149
|
-
|
|
150
|
-
if len(fn.out_vars) == 0:
|
|
151
|
-
out_vars: list[MPObject] = []
|
|
152
|
-
else:
|
|
153
|
-
func_expr = fn.make_expr()
|
|
154
|
-
assert func_expr is not None, "Function expression should not be None."
|
|
155
|
-
out_vars = list(
|
|
156
|
-
ctx.evaluate(func_expr.body, {**arg_binding, **capture_binding})
|
|
157
|
-
)
|
|
158
|
-
|
|
159
|
-
assert isinstance(out_vars, list), f"Expected list, got {type(out_vars)}"
|
|
160
|
-
return var_demorph(out_vars, fn.out_imms, fn.out_struct)
|