mplang-nightly 0.1.dev192__py3-none-any.whl → 0.1.dev268__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mplang/__init__.py +21 -130
- mplang/py.typed +13 -0
- mplang/v1/__init__.py +157 -0
- mplang/v1/_device.py +602 -0
- mplang/{analysis → v1/analysis}/__init__.py +1 -1
- mplang/{analysis → v1/analysis}/diagram.py +4 -4
- mplang/{core → v1/core}/__init__.py +20 -14
- mplang/{core → v1/core}/cluster.py +6 -1
- mplang/{core → v1/core}/comm.py +1 -1
- mplang/{core → v1/core}/context_mgr.py +1 -1
- mplang/{core → v1/core}/dtypes.py +38 -0
- mplang/{core → v1/core}/expr/__init__.py +7 -7
- mplang/{core → v1/core}/expr/ast.py +11 -13
- mplang/{core → v1/core}/expr/evaluator.py +8 -8
- mplang/{core → v1/core}/expr/printer.py +6 -6
- mplang/{core → v1/core}/expr/transformer.py +2 -2
- mplang/{core → v1/core}/expr/utils.py +2 -2
- mplang/{core → v1/core}/expr/visitor.py +1 -1
- mplang/{core → v1/core}/expr/walk.py +1 -1
- mplang/{core → v1/core}/interp.py +6 -6
- mplang/{core → v1/core}/mpir.py +13 -11
- mplang/{core → v1/core}/mpobject.py +6 -6
- mplang/{core → v1/core}/mptype.py +13 -10
- mplang/{core → v1/core}/pfunc.py +2 -2
- mplang/{core → v1/core}/primitive.py +12 -12
- mplang/{core → v1/core}/table.py +36 -8
- mplang/{core → v1/core}/tensor.py +1 -1
- mplang/{core → v1/core}/tracer.py +9 -9
- mplang/{host.py → v1/host.py} +5 -5
- mplang/{kernels → v1/kernels}/__init__.py +1 -1
- mplang/{kernels → v1/kernels}/base.py +1 -1
- mplang/{kernels → v1/kernels}/basic.py +15 -15
- mplang/{kernels → v1/kernels}/context.py +19 -16
- mplang/{kernels → v1/kernels}/crypto.py +8 -10
- mplang/{kernels → v1/kernels}/fhe.py +9 -7
- mplang/{kernels → v1/kernels}/mock_tee.py +3 -3
- mplang/{kernels → v1/kernels}/phe.py +26 -18
- mplang/{kernels → v1/kernels}/spu.py +5 -5
- mplang/{kernels → v1/kernels}/sql_duckdb.py +5 -3
- mplang/{kernels → v1/kernels}/stablehlo.py +18 -17
- mplang/{kernels → v1/kernels}/value.py +2 -2
- mplang/{ops → v1/ops}/__init__.py +3 -3
- mplang/{ops → v1/ops}/base.py +1 -1
- mplang/{ops → v1/ops}/basic.py +6 -5
- mplang/v1/ops/crypto.py +262 -0
- mplang/{ops → v1/ops}/fhe.py +2 -2
- mplang/{ops → v1/ops}/jax_cc.py +26 -59
- mplang/v1/ops/nnx_cc.py +168 -0
- mplang/{ops → v1/ops}/phe.py +16 -3
- mplang/{ops → v1/ops}/spu.py +3 -3
- mplang/v1/ops/sql_cc.py +303 -0
- mplang/{ops → v1/ops}/tee.py +2 -2
- mplang/{runtime → v1/runtime}/__init__.py +2 -2
- mplang/v1/runtime/channel.py +230 -0
- mplang/{runtime → v1/runtime}/cli.py +3 -3
- mplang/{runtime → v1/runtime}/client.py +1 -1
- mplang/{runtime → v1/runtime}/communicator.py +39 -15
- mplang/{runtime → v1/runtime}/data_providers.py +80 -19
- mplang/{runtime → v1/runtime}/driver.py +4 -4
- mplang/v1/runtime/link_comm.py +196 -0
- mplang/{runtime → v1/runtime}/server.py +22 -9
- mplang/{runtime → v1/runtime}/session.py +24 -51
- mplang/{runtime → v1/runtime}/simulation.py +36 -14
- mplang/{simp → v1/simp}/api.py +72 -14
- mplang/{simp → v1/simp}/mpi.py +1 -1
- mplang/{simp → v1/simp}/party.py +5 -5
- mplang/{simp → v1/simp}/random.py +2 -2
- mplang/v1/simp/smpc.py +238 -0
- mplang/v1/utils/table_utils.py +185 -0
- mplang/v2/__init__.py +424 -0
- mplang/v2/backends/__init__.py +57 -0
- mplang/v2/backends/bfv_impl.py +705 -0
- mplang/v2/backends/channel.py +217 -0
- mplang/v2/backends/crypto_impl.py +723 -0
- mplang/v2/backends/field_impl.py +454 -0
- mplang/v2/backends/func_impl.py +107 -0
- mplang/v2/backends/phe_impl.py +148 -0
- mplang/v2/backends/simp_design.md +136 -0
- mplang/v2/backends/simp_driver/__init__.py +41 -0
- mplang/v2/backends/simp_driver/http.py +168 -0
- mplang/v2/backends/simp_driver/mem.py +280 -0
- mplang/v2/backends/simp_driver/ops.py +135 -0
- mplang/v2/backends/simp_driver/state.py +60 -0
- mplang/v2/backends/simp_driver/values.py +52 -0
- mplang/v2/backends/simp_worker/__init__.py +29 -0
- mplang/v2/backends/simp_worker/http.py +354 -0
- mplang/v2/backends/simp_worker/mem.py +102 -0
- mplang/v2/backends/simp_worker/ops.py +167 -0
- mplang/v2/backends/simp_worker/state.py +49 -0
- mplang/v2/backends/spu_impl.py +275 -0
- mplang/v2/backends/spu_state.py +187 -0
- mplang/v2/backends/store_impl.py +62 -0
- mplang/v2/backends/table_impl.py +838 -0
- mplang/v2/backends/tee_impl.py +215 -0
- mplang/v2/backends/tensor_impl.py +519 -0
- mplang/v2/cli.py +603 -0
- mplang/v2/cli_guide.md +122 -0
- mplang/v2/dialects/__init__.py +36 -0
- mplang/v2/dialects/bfv.py +665 -0
- mplang/v2/dialects/crypto.py +689 -0
- mplang/v2/dialects/dtypes.py +378 -0
- mplang/v2/dialects/field.py +210 -0
- mplang/v2/dialects/func.py +135 -0
- mplang/v2/dialects/phe.py +723 -0
- mplang/v2/dialects/simp.py +944 -0
- mplang/v2/dialects/spu.py +349 -0
- mplang/v2/dialects/store.py +63 -0
- mplang/v2/dialects/table.py +407 -0
- mplang/v2/dialects/tee.py +346 -0
- mplang/v2/dialects/tensor.py +1175 -0
- mplang/v2/edsl/README.md +279 -0
- mplang/v2/edsl/__init__.py +99 -0
- mplang/v2/edsl/context.py +311 -0
- mplang/v2/edsl/graph.py +463 -0
- mplang/v2/edsl/jit.py +62 -0
- mplang/v2/edsl/object.py +53 -0
- mplang/v2/edsl/primitive.py +284 -0
- mplang/v2/edsl/printer.py +119 -0
- mplang/v2/edsl/registry.py +207 -0
- mplang/v2/edsl/serde.py +375 -0
- mplang/v2/edsl/tracer.py +614 -0
- mplang/v2/edsl/typing.py +816 -0
- mplang/v2/kernels/Makefile +30 -0
- mplang/v2/kernels/__init__.py +23 -0
- mplang/v2/kernels/gf128.cpp +148 -0
- mplang/v2/kernels/ldpc.cpp +82 -0
- mplang/v2/kernels/okvs.cpp +283 -0
- mplang/v2/kernels/okvs_opt.cpp +291 -0
- mplang/v2/kernels/py_kernels.py +398 -0
- mplang/v2/libs/collective.py +330 -0
- mplang/v2/libs/device/__init__.py +51 -0
- mplang/v2/libs/device/api.py +813 -0
- mplang/v2/libs/device/cluster.py +352 -0
- mplang/v2/libs/ml/__init__.py +23 -0
- mplang/v2/libs/ml/sgb.py +1861 -0
- mplang/v2/libs/mpc/__init__.py +41 -0
- mplang/v2/libs/mpc/_utils.py +99 -0
- mplang/v2/libs/mpc/analytics/__init__.py +35 -0
- mplang/v2/libs/mpc/analytics/aggregation.py +372 -0
- mplang/v2/libs/mpc/analytics/groupby.md +99 -0
- mplang/v2/libs/mpc/analytics/groupby.py +331 -0
- mplang/v2/libs/mpc/analytics/permutation.py +386 -0
- mplang/v2/libs/mpc/common/constants.py +39 -0
- mplang/v2/libs/mpc/ot/__init__.py +32 -0
- mplang/v2/libs/mpc/ot/base.py +222 -0
- mplang/v2/libs/mpc/ot/extension.py +477 -0
- mplang/v2/libs/mpc/ot/silent.py +217 -0
- mplang/v2/libs/mpc/psi/__init__.py +40 -0
- mplang/v2/libs/mpc/psi/cuckoo.py +228 -0
- mplang/v2/libs/mpc/psi/okvs.py +49 -0
- mplang/v2/libs/mpc/psi/okvs_gct.py +79 -0
- mplang/v2/libs/mpc/psi/oprf.py +310 -0
- mplang/v2/libs/mpc/psi/rr22.py +344 -0
- mplang/v2/libs/mpc/psi/unbalanced.py +200 -0
- mplang/v2/libs/mpc/vole/__init__.py +31 -0
- mplang/v2/libs/mpc/vole/gilboa.py +327 -0
- mplang/v2/libs/mpc/vole/ldpc.py +383 -0
- mplang/v2/libs/mpc/vole/silver.py +336 -0
- mplang/v2/runtime/__init__.py +15 -0
- mplang/v2/runtime/dialect_state.py +41 -0
- mplang/v2/runtime/interpreter.py +871 -0
- mplang/v2/runtime/object_store.py +194 -0
- mplang/v2/runtime/value.py +141 -0
- {mplang_nightly-0.1.dev192.dist-info → mplang_nightly-0.1.dev268.dist-info}/METADATA +22 -16
- mplang_nightly-0.1.dev268.dist-info/RECORD +180 -0
- {mplang_nightly-0.1.dev192.dist-info → mplang_nightly-0.1.dev268.dist-info}/WHEEL +1 -1
- mplang/device.py +0 -327
- mplang/ops/crypto.py +0 -108
- mplang/ops/ibis_cc.py +0 -136
- mplang/ops/sql_cc.py +0 -62
- mplang/runtime/link_comm.py +0 -78
- mplang/simp/smpc.py +0 -201
- mplang/utils/table_utils.py +0 -85
- mplang_nightly-0.1.dev192.dist-info/RECORD +0 -83
- /mplang/{core → v1/core}/mask.py +0 -0
- /mplang/{protos → v1/protos}/v1alpha1/mpir_pb2.py +0 -0
- /mplang/{protos → v1/protos}/v1alpha1/mpir_pb2.pyi +0 -0
- /mplang/{protos → v1/protos}/v1alpha1/value_pb2.py +0 -0
- /mplang/{protos → v1/protos}/v1alpha1/value_pb2.pyi +0 -0
- /mplang/{runtime → v1/runtime}/exceptions.py +0 -0
- /mplang/{runtime → v1/runtime}/http_api.md +0 -0
- /mplang/{simp → v1/simp}/__init__.py +0 -0
- /mplang/{utils → v1/utils}/__init__.py +0 -0
- /mplang/{utils → v1/utils}/crypto.py +0 -0
- /mplang/{utils → v1/utils}/func_utils.py +0 -0
- /mplang/{utils → v1/utils}/spu_utils.py +0 -0
- {mplang_nightly-0.1.dev192.dist-info → mplang_nightly-0.1.dev268.dist-info}/entry_points.txt +0 -0
- {mplang_nightly-0.1.dev192.dist-info → mplang_nightly-0.1.dev268.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,871 @@
|
|
|
1
|
+
# Copyright 2025 Ant Group Co., Ltd.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
"""Interpreter: Execute Graph IR and Eager Operations.
|
|
16
|
+
|
|
17
|
+
Interpreter is a Context that executes operations immediately.
|
|
18
|
+
It can execute both:
|
|
19
|
+
1. Graph IR (via GraphInterpreter)
|
|
20
|
+
2. Eager operations on InterpObject (via backend executors)
|
|
21
|
+
"""
|
|
22
|
+
|
|
23
|
+
from __future__ import annotations
|
|
24
|
+
|
|
25
|
+
import collections
|
|
26
|
+
import concurrent.futures
|
|
27
|
+
import json
|
|
28
|
+
import os
|
|
29
|
+
import pathlib
|
|
30
|
+
import queue
|
|
31
|
+
import threading
|
|
32
|
+
import time
|
|
33
|
+
from collections.abc import Callable
|
|
34
|
+
from typing import TYPE_CHECKING, Any, cast
|
|
35
|
+
|
|
36
|
+
from mplang.v2.edsl.context import AbstractInterpreter
|
|
37
|
+
from mplang.v2.edsl.graph import Graph
|
|
38
|
+
from mplang.v2.edsl.object import Object
|
|
39
|
+
from mplang.v2.edsl.registry import get_impl
|
|
40
|
+
from mplang.v2.edsl.typing import BaseType
|
|
41
|
+
from mplang.v2.runtime.dialect_state import DialectState
|
|
42
|
+
from mplang.v2.runtime.object_store import ObjectStore
|
|
43
|
+
|
|
44
|
+
if TYPE_CHECKING:
|
|
45
|
+
from mplang.v2.edsl.primitive import Primitive
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
class ExecutionTracer:
|
|
49
|
+
"""Tracer for DAG execution events (Chrome Tracing format)."""
|
|
50
|
+
|
|
51
|
+
def __init__(self, enabled: bool = False, *, trace_dir: str | pathlib.Path):
|
|
52
|
+
self.enabled = enabled
|
|
53
|
+
self.start_time = 0.0
|
|
54
|
+
self.end_time = 0.0
|
|
55
|
+
self.active_tasks_samples: list[tuple[float, int]] = []
|
|
56
|
+
self.queue_size_samples: list[tuple[float, int]] = []
|
|
57
|
+
self.completed_ops = 0
|
|
58
|
+
self.total_ops = 0
|
|
59
|
+
self.trace_dir = pathlib.Path(trace_dir)
|
|
60
|
+
|
|
61
|
+
# Tracing
|
|
62
|
+
self.trace_events: list[dict[str, Any]] = []
|
|
63
|
+
self.op_schedule_times: dict[
|
|
64
|
+
tuple[int, Any], float
|
|
65
|
+
] = {} # (id(op), namespace) -> ts (us)
|
|
66
|
+
self.pid = os.getpid()
|
|
67
|
+
|
|
68
|
+
def start(self) -> None:
|
|
69
|
+
self.start_time = time.time()
|
|
70
|
+
|
|
71
|
+
def stop(self, filename_prefix: str = "dag_trace") -> None:
|
|
72
|
+
self.end_time = time.time()
|
|
73
|
+
self.save_trace(filename_prefix)
|
|
74
|
+
|
|
75
|
+
def sample(self, active_tasks: int, queue_size: int) -> None:
|
|
76
|
+
now = time.time() - self.start_time
|
|
77
|
+
self.active_tasks_samples.append((now, active_tasks))
|
|
78
|
+
self.queue_size_samples.append((now, queue_size))
|
|
79
|
+
|
|
80
|
+
def log_schedule(self, op: Any, namespace: Any = None) -> None:
|
|
81
|
+
if not self.enabled:
|
|
82
|
+
return
|
|
83
|
+
key = (id(op), namespace)
|
|
84
|
+
self.op_schedule_times[key] = time.time() * 1e6
|
|
85
|
+
|
|
86
|
+
def log_start(
|
|
87
|
+
self, op: Any, pid: int | None = None, namespace: Any = None
|
|
88
|
+
) -> float:
|
|
89
|
+
if not self.enabled:
|
|
90
|
+
return 0.0
|
|
91
|
+
start_ts = time.time() * 1e6
|
|
92
|
+
if pid is None:
|
|
93
|
+
pid = self.pid
|
|
94
|
+
|
|
95
|
+
# Record scheduling latency (Queue Time)
|
|
96
|
+
key = (id(op), namespace)
|
|
97
|
+
if key in self.op_schedule_times:
|
|
98
|
+
sched_ts = self.op_schedule_times.pop(key)
|
|
99
|
+
self.trace_events.append({
|
|
100
|
+
"name": f"Queue: {op.opcode}",
|
|
101
|
+
"cat": "scheduler",
|
|
102
|
+
"ph": "X",
|
|
103
|
+
"ts": sched_ts,
|
|
104
|
+
"dur": start_ts - sched_ts,
|
|
105
|
+
"pid": pid,
|
|
106
|
+
"tid": "SchedulerQueue",
|
|
107
|
+
})
|
|
108
|
+
return start_ts
|
|
109
|
+
|
|
110
|
+
def log_end(self, op: Any, start_ts: float, pid: int | None = None) -> None:
|
|
111
|
+
if not self.enabled:
|
|
112
|
+
return
|
|
113
|
+
end_ts = time.time() * 1e6
|
|
114
|
+
tid = threading.get_ident()
|
|
115
|
+
if pid is None:
|
|
116
|
+
pid = self.pid
|
|
117
|
+
|
|
118
|
+
self.trace_events.append({
|
|
119
|
+
"name": op.opcode,
|
|
120
|
+
"cat": "op",
|
|
121
|
+
"ph": "X",
|
|
122
|
+
"ts": start_ts,
|
|
123
|
+
"dur": end_ts - start_ts,
|
|
124
|
+
"pid": pid,
|
|
125
|
+
"tid": tid,
|
|
126
|
+
"args": {
|
|
127
|
+
"opcode": op.opcode,
|
|
128
|
+
},
|
|
129
|
+
})
|
|
130
|
+
|
|
131
|
+
def log_custom_event(
|
|
132
|
+
self,
|
|
133
|
+
name: str,
|
|
134
|
+
start_ts: float,
|
|
135
|
+
end_ts: float,
|
|
136
|
+
cat: str = "custom",
|
|
137
|
+
args: dict[str, Any] | None = None,
|
|
138
|
+
) -> None:
|
|
139
|
+
"""Log a custom event with explicit start/end timestamps (in seconds)."""
|
|
140
|
+
if not self.enabled:
|
|
141
|
+
return
|
|
142
|
+
tid = threading.get_ident()
|
|
143
|
+
|
|
144
|
+
# Convert to microseconds
|
|
145
|
+
ts_us = start_ts * 1e6
|
|
146
|
+
dur_us = (end_ts - start_ts) * 1e6
|
|
147
|
+
|
|
148
|
+
self.trace_events.append({
|
|
149
|
+
"name": name,
|
|
150
|
+
"cat": cat,
|
|
151
|
+
"ph": "X",
|
|
152
|
+
"ts": ts_us,
|
|
153
|
+
"dur": dur_us,
|
|
154
|
+
"pid": self.pid,
|
|
155
|
+
"tid": tid,
|
|
156
|
+
"args": args or {},
|
|
157
|
+
})
|
|
158
|
+
|
|
159
|
+
def save_trace(
|
|
160
|
+
self,
|
|
161
|
+
filename_prefix: str = "dag_trace",
|
|
162
|
+
job_id: str | None = None,
|
|
163
|
+
rank: int | None = None,
|
|
164
|
+
) -> None:
|
|
165
|
+
if not self.enabled or not self.trace_events:
|
|
166
|
+
return
|
|
167
|
+
try:
|
|
168
|
+
if len(self.trace_events) < 100:
|
|
169
|
+
return # Skip small graphs
|
|
170
|
+
|
|
171
|
+
# Use unique filename to avoid overwriting
|
|
172
|
+
if job_id:
|
|
173
|
+
# Format: trace_<job_id>_rank_<rank>.json
|
|
174
|
+
rank_str = f"_rank_{rank}" if rank is not None else ""
|
|
175
|
+
filename = f"trace_{job_id}{rank_str}.json"
|
|
176
|
+
else:
|
|
177
|
+
timestamp = int(time.time() * 1000)
|
|
178
|
+
tid = threading.get_ident()
|
|
179
|
+
filename = f"{filename_prefix}_{timestamp}_{tid}.json"
|
|
180
|
+
|
|
181
|
+
# Save trace to trace_dir
|
|
182
|
+
self.trace_dir.mkdir(parents=True, exist_ok=True)
|
|
183
|
+
filepath = self.trace_dir / filename
|
|
184
|
+
|
|
185
|
+
with open(filepath, "w") as f:
|
|
186
|
+
json.dump({"traceEvents": self.trace_events}, f)
|
|
187
|
+
print(f"\n[Tracer] Trace saved to {filepath.absolute()}")
|
|
188
|
+
except Exception as e:
|
|
189
|
+
print(f"[Tracer] Failed to save trace: {e}")
|
|
190
|
+
|
|
191
|
+
def print_summary(self) -> None:
|
|
192
|
+
duration = self.end_time - self.start_time
|
|
193
|
+
if duration <= 0:
|
|
194
|
+
return
|
|
195
|
+
|
|
196
|
+
avg_active = (
|
|
197
|
+
sum(c for _, c in self.active_tasks_samples)
|
|
198
|
+
/ len(self.active_tasks_samples)
|
|
199
|
+
if self.active_tasks_samples
|
|
200
|
+
else 0
|
|
201
|
+
)
|
|
202
|
+
max_active = (
|
|
203
|
+
max(c for _, c in self.active_tasks_samples)
|
|
204
|
+
if self.active_tasks_samples
|
|
205
|
+
else 0
|
|
206
|
+
)
|
|
207
|
+
avg_queue = (
|
|
208
|
+
sum(c for _, c in self.queue_size_samples) / len(self.queue_size_samples)
|
|
209
|
+
if self.queue_size_samples
|
|
210
|
+
else 0
|
|
211
|
+
)
|
|
212
|
+
|
|
213
|
+
print("\n" + "=" * 80)
|
|
214
|
+
print("DAG EXECUTION PROFILER")
|
|
215
|
+
print("=" * 80)
|
|
216
|
+
print(f"Total Duration: {duration:.3f}s")
|
|
217
|
+
print(f"Total Ops: {self.total_ops}")
|
|
218
|
+
print(f"Throughput: {self.total_ops / duration:.1f} ops/s")
|
|
219
|
+
print("-" * 80)
|
|
220
|
+
print(f"Active Tasks: Avg={avg_active:.1f}, Max={max_active}")
|
|
221
|
+
print(f"Ready Queue: Avg={avg_queue:.1f}")
|
|
222
|
+
print("=" * 80 + "\n")
|
|
223
|
+
|
|
224
|
+
|
|
225
|
+
class _NullTracer:
|
|
226
|
+
"""No-op tracer stub for when tracing is disabled."""
|
|
227
|
+
|
|
228
|
+
enabled = False
|
|
229
|
+
total_ops = 0
|
|
230
|
+
|
|
231
|
+
def log_schedule(self, op: Any, namespace: Any = None) -> None:
|
|
232
|
+
pass
|
|
233
|
+
|
|
234
|
+
def log_start(
|
|
235
|
+
self, op: Any, pid: int | None = None, namespace: Any = None
|
|
236
|
+
) -> float:
|
|
237
|
+
return 0.0
|
|
238
|
+
|
|
239
|
+
def log_end(self, op: Any, start_ts: float, pid: int | None = None) -> None:
|
|
240
|
+
pass
|
|
241
|
+
|
|
242
|
+
def stop(self) -> None:
|
|
243
|
+
pass
|
|
244
|
+
|
|
245
|
+
def save_trace(self, **kwargs: Any) -> None:
|
|
246
|
+
pass
|
|
247
|
+
|
|
248
|
+
|
|
249
|
+
class InterpObject(Object):
|
|
250
|
+
"""Interp-time object (during eager execution).
|
|
251
|
+
|
|
252
|
+
Holds a runtime object (the actual data/handle owned by the backend executor)
|
|
253
|
+
and a reference to the Interpreter (Context).
|
|
254
|
+
Operations delegate to primitives which execute immediately.
|
|
255
|
+
|
|
256
|
+
The runtime object can be:
|
|
257
|
+
- FHE backend: Local TenSEAL/SEAL ciphertext
|
|
258
|
+
- JAX backend: Local jax.Array
|
|
259
|
+
- MP backend: Backend handle (pointer to party-side data)
|
|
260
|
+
- SQL backend: DatabaseHandle
|
|
261
|
+
- etc.
|
|
262
|
+
|
|
263
|
+
Example:
|
|
264
|
+
>>> # FHE backend (local execution)
|
|
265
|
+
>>> x = fhe.encrypt([1, 2, 3]) # InterpObject with local ciphertext
|
|
266
|
+
>>> y = fhe.encrypt([4, 5, 6])
|
|
267
|
+
>>> z = x + y # InterpObject.__add__ → add_p.bind(x, y)
|
|
268
|
+
|
|
269
|
+
>>> # MP backend (distributed execution)
|
|
270
|
+
>>> x = mp.random.uniform(shape=(10,)) # InterpObject with backend handle
|
|
271
|
+
>>> y = mp.random.uniform(shape=(10,))
|
|
272
|
+
>>> z = x + y # InterpObject.__add__ → add_p.bind(x, y)
|
|
273
|
+
"""
|
|
274
|
+
|
|
275
|
+
def __init__(
|
|
276
|
+
self,
|
|
277
|
+
runtime_obj: Any,
|
|
278
|
+
obj_type: BaseType,
|
|
279
|
+
interpreter: Interpreter | None = None,
|
|
280
|
+
):
|
|
281
|
+
"""Initialize InterpObject.
|
|
282
|
+
|
|
283
|
+
Args:
|
|
284
|
+
runtime_obj: Backend-specific runtime object (ciphertext, array, handle, etc.)
|
|
285
|
+
obj_type: Type of the object (BaseType from edsl.typing)
|
|
286
|
+
interpreter: Interpreter context (if None, uses default interpreter)
|
|
287
|
+
"""
|
|
288
|
+
self._runtime_obj = runtime_obj
|
|
289
|
+
self._type = obj_type
|
|
290
|
+
self._context = interpreter # InterpObject holds its Interpreter (Context)
|
|
291
|
+
|
|
292
|
+
@property
|
|
293
|
+
def type(self) -> BaseType:
|
|
294
|
+
return self._type
|
|
295
|
+
|
|
296
|
+
@property
|
|
297
|
+
def runtime_obj(self) -> Any:
|
|
298
|
+
"""Get the underlying runtime object (backend-specific)."""
|
|
299
|
+
return self._runtime_obj
|
|
300
|
+
|
|
301
|
+
def __repr__(self) -> str:
|
|
302
|
+
runtime_repr = repr(self._runtime_obj)
|
|
303
|
+
# Truncate long representations
|
|
304
|
+
if len(runtime_repr) > 50:
|
|
305
|
+
runtime_repr = runtime_repr[:47] + "..."
|
|
306
|
+
return f"InterpObject({runtime_repr}, type={self.type})"
|
|
307
|
+
|
|
308
|
+
|
|
309
|
+
class Interpreter(AbstractInterpreter):
|
|
310
|
+
"""Execution context for eager execution.
|
|
311
|
+
|
|
312
|
+
Inherits from Context and implements bind_primitive() by executing immediately.
|
|
313
|
+
|
|
314
|
+
Responsibilities:
|
|
315
|
+
1. Execute primitives on InterpObject immediately
|
|
316
|
+
2. Delegate to backend-specific executors
|
|
317
|
+
3. Execute Graph IR (via GraphInterpreter)
|
|
318
|
+
|
|
319
|
+
Example:
|
|
320
|
+
>>> interp = Interpreter()
|
|
321
|
+
>>> x = InterpObject(np.array([1, 2, 3]), Tensor[f32, (3,)])
|
|
322
|
+
>>> y = InterpObject(np.array([4, 5, 6]), Tensor[f32, (3,)])
|
|
323
|
+
>>> z = x + y # InterpObject.__add__ → add_p.bind(x, y)
|
|
324
|
+
"""
|
|
325
|
+
|
|
326
|
+
def __init__(
|
|
327
|
+
self,
|
|
328
|
+
executor: concurrent.futures.Executor | None = None,
|
|
329
|
+
name: str = "Interpreter",
|
|
330
|
+
tracer: ExecutionTracer | None = None,
|
|
331
|
+
trace_pid: int | None = None,
|
|
332
|
+
store: ObjectStore | None = None,
|
|
333
|
+
root_dir: str | pathlib.Path | None = None,
|
|
334
|
+
handlers: dict[str, Callable[..., Any]] | None = None,
|
|
335
|
+
) -> None:
|
|
336
|
+
# Persistence Root
|
|
337
|
+
self.root_dir = (
|
|
338
|
+
pathlib.Path(root_dir)
|
|
339
|
+
if root_dir
|
|
340
|
+
else pathlib.Path(os.environ.get("MPLANG_DATA_ROOT", ".mpl"))
|
|
341
|
+
)
|
|
342
|
+
|
|
343
|
+
# Initialize Context base class (for state management)
|
|
344
|
+
super().__init__()
|
|
345
|
+
|
|
346
|
+
# Instance-level handler registry (overrides global registry)
|
|
347
|
+
self.handlers: dict[str, Callable] = handlers or {}
|
|
348
|
+
self.tracer = tracer
|
|
349
|
+
|
|
350
|
+
# GraphValue -> InterpObject cache
|
|
351
|
+
# Maps a GraphValue (IR node) to its computed InterpObject (Runtime result).
|
|
352
|
+
# This serves two purposes:
|
|
353
|
+
# 1. Caching: Avoid re-evaluating the same graph node multiple times.
|
|
354
|
+
# 2. MIMO Optimization: When one output of a multi-output op is computed,
|
|
355
|
+
# all sibling outputs are cached here to avoid re-execution.
|
|
356
|
+
self._execution_cache: dict[Any, InterpObject] = {}
|
|
357
|
+
self.executor = executor
|
|
358
|
+
self.async_ops: set[str] = set()
|
|
359
|
+
self.name = name
|
|
360
|
+
self.trace_pid = trace_pid
|
|
361
|
+
self.store: ObjectStore | None = store
|
|
362
|
+
|
|
363
|
+
def shutdown(self) -> None:
|
|
364
|
+
"""Shutdown the interpreter and release resources.
|
|
365
|
+
|
|
366
|
+
This method is idempotent and safe to call multiple times.
|
|
367
|
+
It performs the following cleanup:
|
|
368
|
+
1. Shuts down the internal executor (if any).
|
|
369
|
+
2. Stops the execution tracer (if any).
|
|
370
|
+
3. Shuts down any attached dialect states (e.g., stopping drivers).
|
|
371
|
+
"""
|
|
372
|
+
# 1. Shutdown Executor
|
|
373
|
+
if self.executor:
|
|
374
|
+
self.executor.shutdown(wait=True)
|
|
375
|
+
self.executor = None
|
|
376
|
+
|
|
377
|
+
# 2. Stop Tracer
|
|
378
|
+
if self.tracer:
|
|
379
|
+
self.tracer.stop()
|
|
380
|
+
# Don't clear self.tracer, as we might want to read stats later
|
|
381
|
+
|
|
382
|
+
# 3. Shutdown Dialect States
|
|
383
|
+
# Iterate over all attached states (e.g., drivers, cluster managers)
|
|
384
|
+
# and shut them down if they support it.
|
|
385
|
+
for state in self._states.values():
|
|
386
|
+
if hasattr(state, "shutdown") and callable(state.shutdown):
|
|
387
|
+
state.shutdown()
|
|
388
|
+
|
|
389
|
+
# =========================================================================
|
|
390
|
+
# Dialect State Management
|
|
391
|
+
# =========================================================================
|
|
392
|
+
def get_dialect_state(self, dialect: str) -> DialectState | None:
|
|
393
|
+
"""Get the state object for a specific dialect.
|
|
394
|
+
|
|
395
|
+
This is a convenience wrapper around get_state("dialect.{dialect}").
|
|
396
|
+
|
|
397
|
+
Args:
|
|
398
|
+
dialect: Name of the dialect (e.g., "simp", "bfv", "spu")
|
|
399
|
+
|
|
400
|
+
Returns:
|
|
401
|
+
The dialect state object, or None if not set.
|
|
402
|
+
|
|
403
|
+
Example:
|
|
404
|
+
simp_state = interpreter.get_dialect_state("simp")
|
|
405
|
+
if simp_state is not None:
|
|
406
|
+
simp_state.submit(rank, graph, inputs)
|
|
407
|
+
"""
|
|
408
|
+
state = self.get_state(f"dialect.{dialect}")
|
|
409
|
+
# Type assertion: dialect states are always DialectState or None
|
|
410
|
+
return cast(DialectState | None, state)
|
|
411
|
+
|
|
412
|
+
def set_dialect_state(self, dialect: str, state: DialectState) -> None:
|
|
413
|
+
"""Set the state object for a specific dialect.
|
|
414
|
+
|
|
415
|
+
This is a convenience wrapper around set_state("dialect.{dialect}", state).
|
|
416
|
+
|
|
417
|
+
Args:
|
|
418
|
+
dialect: Name of the dialect (e.g., "simp", "bfv", "spu")
|
|
419
|
+
state: The dialect state object (should implement DialectState protocol)
|
|
420
|
+
|
|
421
|
+
Example:
|
|
422
|
+
interpreter.set_dialect_state("simp", cluster.connect())
|
|
423
|
+
"""
|
|
424
|
+
self.set_state(f"dialect.{dialect}", state)
|
|
425
|
+
|
|
426
|
+
def bind_primitive(
|
|
427
|
+
self, primitive: Primitive, args: tuple[Any, ...], kwargs: dict[str, Any]
|
|
428
|
+
) -> InterpObject | list[InterpObject] | Any:
|
|
429
|
+
"""Execute primitive by tracing and interpreting.
|
|
430
|
+
|
|
431
|
+
Implements the unified trace → interpret flow:
|
|
432
|
+
1. All InterpObject arguments already registered via lift()
|
|
433
|
+
2. Create a Tracer and push it as context
|
|
434
|
+
3. Call primitive.bind() to build Graph IR (uses obj id in value names)
|
|
435
|
+
4. Execute the graph via evaluate_graph() (resolves inputs via registry)
|
|
436
|
+
|
|
437
|
+
Args:
|
|
438
|
+
primitive: The primitive to execute
|
|
439
|
+
args: Positional arguments (already lifted by Primitive.bind)
|
|
440
|
+
kwargs: Keyword arguments (already lifted by Primitive.bind)
|
|
441
|
+
|
|
442
|
+
Returns:
|
|
443
|
+
Execution result (InterpObject or list of InterpObject or mixed with immediates)
|
|
444
|
+
"""
|
|
445
|
+
from mplang.v1.utils.func_utils import var_demorph, var_morph
|
|
446
|
+
from mplang.v2.edsl.tracer import Tracer
|
|
447
|
+
|
|
448
|
+
# Create tracer and build graph
|
|
449
|
+
# Note: primitive.bind() internally calls Tracer.lift() with is_param=False,
|
|
450
|
+
# so all args become captures (not params). This is correct because we're
|
|
451
|
+
# tracing a primitive execution, not a user function with explicit parameters.
|
|
452
|
+
with Tracer() as ctx:
|
|
453
|
+
# Finalize graph by setting outputs
|
|
454
|
+
result_traced = primitive.bind(*args, **kwargs)
|
|
455
|
+
|
|
456
|
+
# Separate outputs into variables (Objects) and immediates (constants)
|
|
457
|
+
out_vars, out_imms, morph_struct = var_morph(
|
|
458
|
+
result_traced, lambda x: isinstance(x, Object)
|
|
459
|
+
)
|
|
460
|
+
|
|
461
|
+
if out_vars:
|
|
462
|
+
graph = ctx.finalize(out_vars)
|
|
463
|
+
else:
|
|
464
|
+
# All outputs are immediates, no graph outputs
|
|
465
|
+
graph = ctx.graph
|
|
466
|
+
graph.outputs = []
|
|
467
|
+
|
|
468
|
+
# Build inputs list for interpret
|
|
469
|
+
# _captured_vars contains all inputs (no params in this context)
|
|
470
|
+
inputs_list = [
|
|
471
|
+
obj.runtime_obj if isinstance(obj, InterpObject) else obj
|
|
472
|
+
for obj, _ in ctx._captured_vars.values()
|
|
473
|
+
]
|
|
474
|
+
|
|
475
|
+
# Execute graph (may have 0 outputs if all were immediates)
|
|
476
|
+
if graph.outputs:
|
|
477
|
+
result_runtime_list = self.evaluate_graph(graph, inputs_list)
|
|
478
|
+
else:
|
|
479
|
+
result_runtime_list = []
|
|
480
|
+
|
|
481
|
+
# Wrap runtime results as InterpObjects
|
|
482
|
+
interp_results = [
|
|
483
|
+
InterpObject(rt_val, tr_obj.type, self)
|
|
484
|
+
for rt_val, tr_obj in zip(result_runtime_list, out_vars, strict=True)
|
|
485
|
+
]
|
|
486
|
+
|
|
487
|
+
# Reconstruct the output tree: merge InterpObjects and immediates
|
|
488
|
+
return var_demorph(interp_results, out_imms, morph_struct)
|
|
489
|
+
|
|
490
|
+
def lift(self, obj: Any) -> InterpObject | Any:
|
|
491
|
+
"""Lift an object to the Interpreter's native representation.
|
|
492
|
+
|
|
493
|
+
This is THE central method that manages the boundary between
|
|
494
|
+
InterpObject and TraceObject:
|
|
495
|
+
|
|
496
|
+
1. **InterpObject → TraceObject** (during nested tracing):
|
|
497
|
+
- Register the InterpObject in self._objects for later resolution
|
|
498
|
+
- The InterpObject must belong to this Interpreter
|
|
499
|
+
- When the object flows into Tracer.lift() during bind_primitive,
|
|
500
|
+
it will be captured as input with a clean SSA name like "%arg0"
|
|
501
|
+
|
|
502
|
+
2. **TraceObject → InterpObject** (evaluate traced computation):
|
|
503
|
+
- Extract the graph from the TraceObject's context (Tracer)
|
|
504
|
+
- Execute the graph via evaluate_graph() to get runtime result
|
|
505
|
+
- Wrap result as InterpObject and register it
|
|
506
|
+
|
|
507
|
+
3. **Constants**: Pass through unchanged
|
|
508
|
+
|
|
509
|
+
Args:
|
|
510
|
+
obj: Object to lift (InterpObject, TraceObject, or constant)
|
|
511
|
+
|
|
512
|
+
Returns:
|
|
513
|
+
InterpObject (if Object input) or constant (pass-through)
|
|
514
|
+
|
|
515
|
+
Example:
|
|
516
|
+
>>> # InterpObject case
|
|
517
|
+
>>> x = InterpObject(np.array([1, 2]), Tensor[f32, (2,)])
|
|
518
|
+
>>> x_lifted = interp.lift(x) # registers in _objects, returns x
|
|
519
|
+
>>>
|
|
520
|
+
>>> # TraceObject case
|
|
521
|
+
>>> tracer = Tracer()
|
|
522
|
+
>>> push_context(tracer)
|
|
523
|
+
>>> z_trace = some_primitive.bind(x, y) # TraceObject
|
|
524
|
+
>>> pop_context()
|
|
525
|
+
>>> interp = Interpreter()
|
|
526
|
+
>>> z_interp = interp.lift(z_trace) # evaluate graph → InterpObject
|
|
527
|
+
"""
|
|
528
|
+
from mplang.v2.edsl.tracer import TraceObject
|
|
529
|
+
|
|
530
|
+
if isinstance(obj, InterpObject):
|
|
531
|
+
# InterpObject must belong to this interpreter
|
|
532
|
+
if obj._context is not None and obj._context is not self:
|
|
533
|
+
raise ValueError(
|
|
534
|
+
f"InterpObject belongs to a different Interpreter. "
|
|
535
|
+
f"Object context: {obj._context}, Current interpreter: {self}"
|
|
536
|
+
)
|
|
537
|
+
return obj
|
|
538
|
+
|
|
539
|
+
elif isinstance(obj, TraceObject):
|
|
540
|
+
# Check execution cache
|
|
541
|
+
# If this value was computed as part of a previous execution (e.g. sibling output)
|
|
542
|
+
# we can return it immediately without re-execution.
|
|
543
|
+
graph_value = obj._graph_value
|
|
544
|
+
if graph_value in self._execution_cache:
|
|
545
|
+
return self._execution_cache[graph_value]
|
|
546
|
+
|
|
547
|
+
# First time seeing this Value.
|
|
548
|
+
# We need to execute the graph to compute it.
|
|
549
|
+
# MIMO Optimization:
|
|
550
|
+
# Instead of just asking for this single value, we ask for ALL outputs
|
|
551
|
+
# of the operation that produced this value. This ensures that if we
|
|
552
|
+
# later ask for a sibling output, it will be in the cache.
|
|
553
|
+
|
|
554
|
+
tracer = obj._context
|
|
555
|
+
graph = tracer.graph
|
|
556
|
+
defining_op = graph_value.defining_op
|
|
557
|
+
|
|
558
|
+
if defining_op is None:
|
|
559
|
+
# Value is likely a constant or input (no defining op in graph)
|
|
560
|
+
# Just execute graph for this single value
|
|
561
|
+
target_outputs = [graph_value]
|
|
562
|
+
else:
|
|
563
|
+
# Fetch all outputs of the defining op
|
|
564
|
+
target_outputs = defining_op.outputs
|
|
565
|
+
|
|
566
|
+
# Temporarily set graph outputs to the target outputs
|
|
567
|
+
# We must save/restore original outputs to avoid side effects
|
|
568
|
+
original_outputs = graph.outputs
|
|
569
|
+
graph.outputs = target_outputs
|
|
570
|
+
|
|
571
|
+
try:
|
|
572
|
+
# Resolve inputs from Tracer's captured vars
|
|
573
|
+
# _captured_vars preserves insertion order which matches graph.inputs order
|
|
574
|
+
inputs_list = []
|
|
575
|
+
for captured_obj, _ in tracer._captured_vars.values():
|
|
576
|
+
# Recursively lift captured objects to ensure they are ready
|
|
577
|
+
lifted = self.lift(captured_obj)
|
|
578
|
+
if isinstance(lifted, InterpObject):
|
|
579
|
+
inputs_list.append(lifted.runtime_obj)
|
|
580
|
+
else:
|
|
581
|
+
inputs_list.append(lifted)
|
|
582
|
+
|
|
583
|
+
# Execute graph
|
|
584
|
+
results_runtime = self.evaluate_graph(graph, inputs_list)
|
|
585
|
+
|
|
586
|
+
# Cache all results
|
|
587
|
+
for val, res in zip(target_outputs, results_runtime, strict=True):
|
|
588
|
+
# Wrap as InterpObject and cache
|
|
589
|
+
# Note: We use obj.type for the requested value, but for siblings
|
|
590
|
+
# we should ideally use their types. However, we don't have TraceObjects
|
|
591
|
+
# for siblings here, only GraphValues.
|
|
592
|
+
# InterpObject needs a type. GraphValue has a type.
|
|
593
|
+
self._execution_cache[val] = InterpObject(res, val.type, self)
|
|
594
|
+
|
|
595
|
+
finally:
|
|
596
|
+
# Restore original outputs
|
|
597
|
+
graph.outputs = original_outputs
|
|
598
|
+
|
|
599
|
+
# Now the result for our requested object should be in the cache
|
|
600
|
+
if graph_value not in self._execution_cache:
|
|
601
|
+
raise RuntimeError(
|
|
602
|
+
f"Failed to compute value for {obj} even after graph execution"
|
|
603
|
+
)
|
|
604
|
+
|
|
605
|
+
return self._execution_cache[graph_value]
|
|
606
|
+
|
|
607
|
+
else:
|
|
608
|
+
# Constants: pass through unchanged
|
|
609
|
+
return obj
|
|
610
|
+
|
|
611
|
+
def evaluate_graph(
|
|
612
|
+
self, graph: Graph, inputs: list[Any], job_id: str | None = None
|
|
613
|
+
) -> list[Any]:
|
|
614
|
+
"""Execute a Graph IR with runtime data.
|
|
615
|
+
|
|
616
|
+
Can be overridden by subclasses to implement remote execution or compilation.
|
|
617
|
+
|
|
618
|
+
Args:
|
|
619
|
+
graph: Finalized Graph IR to execute
|
|
620
|
+
inputs: Runtime objects corresponding to graph.inputs (positional)
|
|
621
|
+
job_id: Optional unique ID for this execution job (for profiling/tracing).
|
|
622
|
+
|
|
623
|
+
Returns:
|
|
624
|
+
List of runtime execution results corresponding to graph.outputs.
|
|
625
|
+
"""
|
|
626
|
+
if self.executor:
|
|
627
|
+
return self._evaluate_graph_async(graph, inputs, job_id)
|
|
628
|
+
else:
|
|
629
|
+
return self._evaluate_graph_sync(graph, inputs, job_id)
|
|
630
|
+
|
|
631
|
+
def _evaluate_graph_sync(
|
|
632
|
+
self, graph: Graph, inputs: list[Any], job_id: str | None = None
|
|
633
|
+
) -> list[Any]:
|
|
634
|
+
"""Synchronous execution (Baseline)."""
|
|
635
|
+
# Local environment: Value -> Runtime Object
|
|
636
|
+
env = dict(zip(graph.inputs, inputs, strict=True))
|
|
637
|
+
|
|
638
|
+
for op in graph.operations:
|
|
639
|
+
# Resolve inputs
|
|
640
|
+
try:
|
|
641
|
+
args = [env[val] for val in op.inputs]
|
|
642
|
+
except KeyError as e:
|
|
643
|
+
missing_keys = [str(k) for k in op.inputs if k not in env]
|
|
644
|
+
# Limit available keys output to avoid flooding logs if env is huge
|
|
645
|
+
available_keys = [str(k) for k in list(env.keys())[:20]]
|
|
646
|
+
if len(env) > 20:
|
|
647
|
+
available_keys.append("...")
|
|
648
|
+
|
|
649
|
+
raise RuntimeError(
|
|
650
|
+
f"Failed to resolve inputs for op '{op.opcode}'.\n"
|
|
651
|
+
f"Missing values: {missing_keys}\n"
|
|
652
|
+
f"Available values (partial): {available_keys}"
|
|
653
|
+
) from e
|
|
654
|
+
|
|
655
|
+
# Dispatch
|
|
656
|
+
# 1. Check instance-level handlers
|
|
657
|
+
handler = self.handlers.get(op.opcode)
|
|
658
|
+
# 2. Check global registry
|
|
659
|
+
if not handler:
|
|
660
|
+
handler = get_impl(op.opcode)
|
|
661
|
+
|
|
662
|
+
if handler:
|
|
663
|
+
# Pass interpreter to support recursive execution (HOFs)
|
|
664
|
+
# Pass op to access attributes and regions
|
|
665
|
+
# Pass args as runtime values
|
|
666
|
+
results = handler(self, op, *args)
|
|
667
|
+
else:
|
|
668
|
+
raise NotImplementedError(
|
|
669
|
+
f"No implementation registered for opcode: {op.opcode}"
|
|
670
|
+
)
|
|
671
|
+
|
|
672
|
+
# Update environment with outputs
|
|
673
|
+
# Handler should return a single value or a tuple/list of values
|
|
674
|
+
if len(op.outputs) == 0:
|
|
675
|
+
pass # Void operation
|
|
676
|
+
elif len(op.outputs) == 1:
|
|
677
|
+
env[op.outputs[0]] = results
|
|
678
|
+
else:
|
|
679
|
+
if len(results) != len(op.outputs):
|
|
680
|
+
raise RuntimeError(
|
|
681
|
+
f"Op {op.opcode} returned {len(results)} values, expected {len(op.outputs)}"
|
|
682
|
+
)
|
|
683
|
+
for out_val, res in zip(op.outputs, results, strict=True):
|
|
684
|
+
env[out_val] = res
|
|
685
|
+
|
|
686
|
+
# Return outputs
|
|
687
|
+
if self.tracer and job_id:
|
|
688
|
+
self.tracer.save_trace(job_id=job_id, rank=self.trace_pid)
|
|
689
|
+
|
|
690
|
+
return [env[out] for out in graph.outputs]
|
|
691
|
+
|
|
692
|
+
def _evaluate_graph_async(
|
|
693
|
+
self, graph: Graph, inputs: list[Any], job_id: str | None = None
|
|
694
|
+
) -> list[Any]:
|
|
695
|
+
"""Asynchronous execution with non-blocking DAG scheduling."""
|
|
696
|
+
# Tracer setup (if not provided, use a disabled stub)
|
|
697
|
+
tracer: ExecutionTracer | _NullTracer
|
|
698
|
+
if self.tracer:
|
|
699
|
+
tracer = self.tracer
|
|
700
|
+
tracer.total_ops += len(graph.operations)
|
|
701
|
+
else:
|
|
702
|
+
# No tracer provided - use minimal stub (no trace_dir needed)
|
|
703
|
+
tracer = _NullTracer()
|
|
704
|
+
|
|
705
|
+
active_tasks = 0
|
|
706
|
+
|
|
707
|
+
# 1. Setup State
|
|
708
|
+
# Value -> Runtime Object (initially inputs)
|
|
709
|
+
env = dict(zip(graph.inputs, inputs, strict=True))
|
|
710
|
+
|
|
711
|
+
# Op -> Pending Input Count
|
|
712
|
+
pending_counts = {}
|
|
713
|
+
# Value -> list[Op] (Consumers)
|
|
714
|
+
value_to_consumers: dict[Any, list[Any]] = collections.defaultdict(list)
|
|
715
|
+
# Value -> Remaining Consumers Count (for GC)
|
|
716
|
+
remaining_consumers: dict[Any, int] = collections.defaultdict(int)
|
|
717
|
+
|
|
718
|
+
# 2. Build Dependency Graph
|
|
719
|
+
for op in graph.operations:
|
|
720
|
+
count = 0
|
|
721
|
+
for val in op.inputs:
|
|
722
|
+
if val not in env: # If not already resolved (input or constant)
|
|
723
|
+
value_to_consumers[val].append(op)
|
|
724
|
+
remaining_consumers[val] += 1
|
|
725
|
+
count += 1
|
|
726
|
+
pending_counts[op] = count
|
|
727
|
+
|
|
728
|
+
# Mark graph outputs as having an extra consumer (the user)
|
|
729
|
+
# so they are not GC'd before return
|
|
730
|
+
for out in graph.outputs:
|
|
731
|
+
remaining_consumers[out] += 1
|
|
732
|
+
|
|
733
|
+
# 3. Synchronization
|
|
734
|
+
lock = threading.Lock()
|
|
735
|
+
ready_queue: queue.Queue[Any] = queue.Queue()
|
|
736
|
+
remaining_ops = len(graph.operations)
|
|
737
|
+
|
|
738
|
+
# Error propagation
|
|
739
|
+
error_occurred = False
|
|
740
|
+
|
|
741
|
+
# 4. Execution Helper
|
|
742
|
+
def on_op_done(op: Any, result: Any, error: Exception | None = None) -> None:
|
|
743
|
+
nonlocal remaining_ops, error_occurred, active_tasks
|
|
744
|
+
|
|
745
|
+
if error:
|
|
746
|
+
with lock:
|
|
747
|
+
if not error_occurred:
|
|
748
|
+
error_occurred = True
|
|
749
|
+
ready_queue.put(error)
|
|
750
|
+
return
|
|
751
|
+
|
|
752
|
+
with lock:
|
|
753
|
+
if op.opcode in self.async_ops and self.executor:
|
|
754
|
+
active_tasks -= 1
|
|
755
|
+
# profiler.sample(active_tasks, ready_queue.qsize())
|
|
756
|
+
|
|
757
|
+
if error_occurred:
|
|
758
|
+
return
|
|
759
|
+
|
|
760
|
+
# Store results
|
|
761
|
+
if len(op.outputs) == 1:
|
|
762
|
+
env[op.outputs[0]] = result
|
|
763
|
+
else:
|
|
764
|
+
for out_val, res in zip(op.outputs, result, strict=True):
|
|
765
|
+
env[out_val] = res
|
|
766
|
+
|
|
767
|
+
# Trigger consumers
|
|
768
|
+
for out_val in op.outputs:
|
|
769
|
+
if out_val in value_to_consumers:
|
|
770
|
+
for consumer_op in value_to_consumers[out_val]:
|
|
771
|
+
pending_counts[consumer_op] -= 1
|
|
772
|
+
if pending_counts[consumer_op] == 0:
|
|
773
|
+
tracer.log_schedule(
|
|
774
|
+
consumer_op, namespace=self.trace_pid
|
|
775
|
+
)
|
|
776
|
+
ready_queue.put(consumer_op)
|
|
777
|
+
|
|
778
|
+
# GC Inputs
|
|
779
|
+
for val in op.inputs:
|
|
780
|
+
if val in remaining_consumers:
|
|
781
|
+
remaining_consumers[val] -= 1
|
|
782
|
+
if remaining_consumers[val] == 0:
|
|
783
|
+
env.pop(val, None)
|
|
784
|
+
|
|
785
|
+
remaining_ops -= 1
|
|
786
|
+
if remaining_ops == 0:
|
|
787
|
+
ready_queue.put(None) # Sentinel
|
|
788
|
+
|
|
789
|
+
def execute_op(op: Any) -> None:
|
|
790
|
+
nonlocal active_tasks
|
|
791
|
+
# Extract args from env (must be ready)
|
|
792
|
+
args = [env[val] for val in op.inputs]
|
|
793
|
+
|
|
794
|
+
handler = self.handlers.get(op.opcode)
|
|
795
|
+
if not handler:
|
|
796
|
+
handler = get_impl(op.opcode)
|
|
797
|
+
|
|
798
|
+
if not handler:
|
|
799
|
+
raise NotImplementedError(
|
|
800
|
+
f"No implementation registered for opcode: {op.opcode}"
|
|
801
|
+
)
|
|
802
|
+
|
|
803
|
+
if op.opcode in self.async_ops and self.executor:
|
|
804
|
+
with lock:
|
|
805
|
+
active_tasks += 1
|
|
806
|
+
# profiler.sample(active_tasks, ready_queue.qsize())
|
|
807
|
+
|
|
808
|
+
# Submit to executor
|
|
809
|
+
def task() -> Any:
|
|
810
|
+
start_ts = tracer.log_start(
|
|
811
|
+
op, pid=self.trace_pid, namespace=self.trace_pid
|
|
812
|
+
)
|
|
813
|
+
res = handler(self, op, *args)
|
|
814
|
+
tracer.log_end(op, start_ts, pid=self.trace_pid)
|
|
815
|
+
return res
|
|
816
|
+
|
|
817
|
+
def callback(fut: Any) -> None:
|
|
818
|
+
try:
|
|
819
|
+
res = fut.result()
|
|
820
|
+
on_op_done(op, res)
|
|
821
|
+
except Exception as e:
|
|
822
|
+
on_op_done(op, None, error=e)
|
|
823
|
+
|
|
824
|
+
fut = self.executor.submit(task)
|
|
825
|
+
fut.add_done_callback(callback)
|
|
826
|
+
else:
|
|
827
|
+
# Sync execution (run immediately)
|
|
828
|
+
try:
|
|
829
|
+
start_ts = tracer.log_start(
|
|
830
|
+
op, pid=self.trace_pid, namespace=self.trace_pid
|
|
831
|
+
)
|
|
832
|
+
res = handler(self, op, *args)
|
|
833
|
+
tracer.log_end(op, start_ts, pid=self.trace_pid)
|
|
834
|
+
on_op_done(op, res)
|
|
835
|
+
except Exception as e:
|
|
836
|
+
on_op_done(op, None, error=e)
|
|
837
|
+
|
|
838
|
+
# 5. Initial Submission
|
|
839
|
+
# Submit all ops with 0 pending inputs
|
|
840
|
+
initial_ops = [op for op, count in pending_counts.items() if count == 0]
|
|
841
|
+
if not initial_ops and remaining_ops > 0:
|
|
842
|
+
# Cycle detected or empty graph?
|
|
843
|
+
pass
|
|
844
|
+
|
|
845
|
+
for op in initial_ops:
|
|
846
|
+
tracer.log_schedule(op, namespace=self.trace_pid)
|
|
847
|
+
ready_queue.put(op)
|
|
848
|
+
|
|
849
|
+
# Handle empty graph case
|
|
850
|
+
if remaining_ops == 0:
|
|
851
|
+
ready_queue.put(None)
|
|
852
|
+
|
|
853
|
+
# 6. Main Loop
|
|
854
|
+
while True:
|
|
855
|
+
item = ready_queue.get()
|
|
856
|
+
if item is None:
|
|
857
|
+
break
|
|
858
|
+
if isinstance(item, Exception):
|
|
859
|
+
raise item
|
|
860
|
+
|
|
861
|
+
# It's an op
|
|
862
|
+
execute_op(item)
|
|
863
|
+
|
|
864
|
+
# 7. Return outputs
|
|
865
|
+
if not self.tracer:
|
|
866
|
+
tracer.stop()
|
|
867
|
+
|
|
868
|
+
if self.tracer and job_id:
|
|
869
|
+
self.tracer.save_trace(job_id=job_id, rank=self.trace_pid)
|
|
870
|
+
|
|
871
|
+
return [env[out] for out in graph.outputs]
|