mplang-nightly 0.1.dev158__py3-none-any.whl → 0.1.dev268__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (191) hide show
  1. mplang/__init__.py +21 -45
  2. mplang/py.typed +13 -0
  3. mplang/v1/__init__.py +157 -0
  4. mplang/v1/_device.py +602 -0
  5. mplang/{analysis → v1/analysis}/__init__.py +1 -1
  6. mplang/{analysis → v1/analysis}/diagram.py +5 -7
  7. mplang/v1/core/__init__.py +157 -0
  8. mplang/{core → v1/core}/cluster.py +30 -14
  9. mplang/{core → v1/core}/comm.py +5 -1
  10. mplang/{core → v1/core}/context_mgr.py +1 -1
  11. mplang/{core/dtype.py → v1/core/dtypes.py} +44 -2
  12. mplang/{core → v1/core}/expr/__init__.py +7 -7
  13. mplang/{core → v1/core}/expr/ast.py +13 -14
  14. mplang/{core → v1/core}/expr/evaluator.py +65 -24
  15. mplang/{core → v1/core}/expr/printer.py +24 -18
  16. mplang/{core → v1/core}/expr/transformer.py +3 -3
  17. mplang/{core → v1/core}/expr/utils.py +2 -2
  18. mplang/{core → v1/core}/expr/visitor.py +1 -1
  19. mplang/{core → v1/core}/expr/walk.py +1 -1
  20. mplang/{core → v1/core}/interp.py +6 -6
  21. mplang/{core → v1/core}/mpir.py +23 -16
  22. mplang/{core → v1/core}/mpobject.py +6 -6
  23. mplang/{core → v1/core}/mptype.py +13 -10
  24. mplang/{core → v1/core}/pfunc.py +4 -4
  25. mplang/{core → v1/core}/primitive.py +106 -201
  26. mplang/{core → v1/core}/table.py +36 -8
  27. mplang/{core → v1/core}/tensor.py +1 -1
  28. mplang/{core → v1/core}/tracer.py +9 -9
  29. mplang/{api.py → v1/host.py} +38 -6
  30. mplang/v1/kernels/__init__.py +41 -0
  31. mplang/{kernels → v1/kernels}/base.py +1 -1
  32. mplang/v1/kernels/basic.py +240 -0
  33. mplang/{kernels → v1/kernels}/context.py +42 -27
  34. mplang/{kernels → v1/kernels}/crypto.py +44 -37
  35. mplang/v1/kernels/fhe.py +858 -0
  36. mplang/{kernels → v1/kernels}/mock_tee.py +12 -13
  37. mplang/{kernels → v1/kernels}/phe.py +263 -57
  38. mplang/{kernels → v1/kernels}/spu.py +137 -48
  39. mplang/{kernels → v1/kernels}/sql_duckdb.py +12 -15
  40. mplang/{kernels → v1/kernels}/stablehlo.py +30 -23
  41. mplang/v1/kernels/value.py +626 -0
  42. mplang/{ops → v1/ops}/__init__.py +5 -16
  43. mplang/{ops → v1/ops}/base.py +2 -5
  44. mplang/{ops/builtin.py → v1/ops/basic.py} +34 -26
  45. mplang/v1/ops/crypto.py +262 -0
  46. mplang/v1/ops/fhe.py +272 -0
  47. mplang/{ops → v1/ops}/jax_cc.py +33 -68
  48. mplang/v1/ops/nnx_cc.py +168 -0
  49. mplang/{ops → v1/ops}/phe.py +16 -4
  50. mplang/{ops → v1/ops}/spu.py +3 -5
  51. mplang/v1/ops/sql_cc.py +303 -0
  52. mplang/{ops → v1/ops}/tee.py +9 -24
  53. mplang/{protos → v1/protos}/v1alpha1/mpir_pb2.pyi +71 -21
  54. mplang/v1/protos/v1alpha1/value_pb2.py +34 -0
  55. mplang/v1/protos/v1alpha1/value_pb2.pyi +169 -0
  56. mplang/{runtime → v1/runtime}/__init__.py +2 -2
  57. mplang/v1/runtime/channel.py +230 -0
  58. mplang/{runtime → v1/runtime}/cli.py +35 -20
  59. mplang/{runtime → v1/runtime}/client.py +19 -8
  60. mplang/{runtime → v1/runtime}/communicator.py +59 -15
  61. mplang/{runtime → v1/runtime}/data_providers.py +80 -19
  62. mplang/{runtime → v1/runtime}/driver.py +30 -12
  63. mplang/v1/runtime/link_comm.py +196 -0
  64. mplang/{runtime → v1/runtime}/server.py +58 -42
  65. mplang/{runtime → v1/runtime}/session.py +57 -71
  66. mplang/{runtime → v1/runtime}/simulation.py +55 -28
  67. mplang/v1/simp/api.py +353 -0
  68. mplang/{simp → v1/simp}/mpi.py +8 -9
  69. mplang/{simp/__init__.py → v1/simp/party.py} +19 -145
  70. mplang/{simp → v1/simp}/random.py +21 -22
  71. mplang/v1/simp/smpc.py +238 -0
  72. mplang/v1/utils/table_utils.py +185 -0
  73. mplang/v2/__init__.py +424 -0
  74. mplang/v2/backends/__init__.py +57 -0
  75. mplang/v2/backends/bfv_impl.py +705 -0
  76. mplang/v2/backends/channel.py +217 -0
  77. mplang/v2/backends/crypto_impl.py +723 -0
  78. mplang/v2/backends/field_impl.py +454 -0
  79. mplang/v2/backends/func_impl.py +107 -0
  80. mplang/v2/backends/phe_impl.py +148 -0
  81. mplang/v2/backends/simp_design.md +136 -0
  82. mplang/v2/backends/simp_driver/__init__.py +41 -0
  83. mplang/v2/backends/simp_driver/http.py +168 -0
  84. mplang/v2/backends/simp_driver/mem.py +280 -0
  85. mplang/v2/backends/simp_driver/ops.py +135 -0
  86. mplang/v2/backends/simp_driver/state.py +60 -0
  87. mplang/v2/backends/simp_driver/values.py +52 -0
  88. mplang/v2/backends/simp_worker/__init__.py +29 -0
  89. mplang/v2/backends/simp_worker/http.py +354 -0
  90. mplang/v2/backends/simp_worker/mem.py +102 -0
  91. mplang/v2/backends/simp_worker/ops.py +167 -0
  92. mplang/v2/backends/simp_worker/state.py +49 -0
  93. mplang/v2/backends/spu_impl.py +275 -0
  94. mplang/v2/backends/spu_state.py +187 -0
  95. mplang/v2/backends/store_impl.py +62 -0
  96. mplang/v2/backends/table_impl.py +838 -0
  97. mplang/v2/backends/tee_impl.py +215 -0
  98. mplang/v2/backends/tensor_impl.py +519 -0
  99. mplang/v2/cli.py +603 -0
  100. mplang/v2/cli_guide.md +122 -0
  101. mplang/v2/dialects/__init__.py +36 -0
  102. mplang/v2/dialects/bfv.py +665 -0
  103. mplang/v2/dialects/crypto.py +689 -0
  104. mplang/v2/dialects/dtypes.py +378 -0
  105. mplang/v2/dialects/field.py +210 -0
  106. mplang/v2/dialects/func.py +135 -0
  107. mplang/v2/dialects/phe.py +723 -0
  108. mplang/v2/dialects/simp.py +944 -0
  109. mplang/v2/dialects/spu.py +349 -0
  110. mplang/v2/dialects/store.py +63 -0
  111. mplang/v2/dialects/table.py +407 -0
  112. mplang/v2/dialects/tee.py +346 -0
  113. mplang/v2/dialects/tensor.py +1175 -0
  114. mplang/v2/edsl/README.md +279 -0
  115. mplang/v2/edsl/__init__.py +99 -0
  116. mplang/v2/edsl/context.py +311 -0
  117. mplang/v2/edsl/graph.py +463 -0
  118. mplang/v2/edsl/jit.py +62 -0
  119. mplang/v2/edsl/object.py +53 -0
  120. mplang/v2/edsl/primitive.py +284 -0
  121. mplang/v2/edsl/printer.py +119 -0
  122. mplang/v2/edsl/registry.py +207 -0
  123. mplang/v2/edsl/serde.py +375 -0
  124. mplang/v2/edsl/tracer.py +614 -0
  125. mplang/v2/edsl/typing.py +816 -0
  126. mplang/v2/kernels/Makefile +30 -0
  127. mplang/v2/kernels/__init__.py +23 -0
  128. mplang/v2/kernels/gf128.cpp +148 -0
  129. mplang/v2/kernels/ldpc.cpp +82 -0
  130. mplang/v2/kernels/okvs.cpp +283 -0
  131. mplang/v2/kernels/okvs_opt.cpp +291 -0
  132. mplang/v2/kernels/py_kernels.py +398 -0
  133. mplang/v2/libs/collective.py +330 -0
  134. mplang/v2/libs/device/__init__.py +51 -0
  135. mplang/v2/libs/device/api.py +813 -0
  136. mplang/v2/libs/device/cluster.py +352 -0
  137. mplang/v2/libs/ml/__init__.py +23 -0
  138. mplang/v2/libs/ml/sgb.py +1861 -0
  139. mplang/v2/libs/mpc/__init__.py +41 -0
  140. mplang/v2/libs/mpc/_utils.py +99 -0
  141. mplang/v2/libs/mpc/analytics/__init__.py +35 -0
  142. mplang/v2/libs/mpc/analytics/aggregation.py +372 -0
  143. mplang/v2/libs/mpc/analytics/groupby.md +99 -0
  144. mplang/v2/libs/mpc/analytics/groupby.py +331 -0
  145. mplang/v2/libs/mpc/analytics/permutation.py +386 -0
  146. mplang/v2/libs/mpc/common/constants.py +39 -0
  147. mplang/v2/libs/mpc/ot/__init__.py +32 -0
  148. mplang/v2/libs/mpc/ot/base.py +222 -0
  149. mplang/v2/libs/mpc/ot/extension.py +477 -0
  150. mplang/v2/libs/mpc/ot/silent.py +217 -0
  151. mplang/v2/libs/mpc/psi/__init__.py +40 -0
  152. mplang/v2/libs/mpc/psi/cuckoo.py +228 -0
  153. mplang/v2/libs/mpc/psi/okvs.py +49 -0
  154. mplang/v2/libs/mpc/psi/okvs_gct.py +79 -0
  155. mplang/v2/libs/mpc/psi/oprf.py +310 -0
  156. mplang/v2/libs/mpc/psi/rr22.py +344 -0
  157. mplang/v2/libs/mpc/psi/unbalanced.py +200 -0
  158. mplang/v2/libs/mpc/vole/__init__.py +31 -0
  159. mplang/v2/libs/mpc/vole/gilboa.py +327 -0
  160. mplang/v2/libs/mpc/vole/ldpc.py +383 -0
  161. mplang/v2/libs/mpc/vole/silver.py +336 -0
  162. mplang/v2/runtime/__init__.py +15 -0
  163. mplang/v2/runtime/dialect_state.py +41 -0
  164. mplang/v2/runtime/interpreter.py +871 -0
  165. mplang/v2/runtime/object_store.py +194 -0
  166. mplang/v2/runtime/value.py +141 -0
  167. {mplang_nightly-0.1.dev158.dist-info → mplang_nightly-0.1.dev268.dist-info}/METADATA +24 -17
  168. mplang_nightly-0.1.dev268.dist-info/RECORD +180 -0
  169. {mplang_nightly-0.1.dev158.dist-info → mplang_nightly-0.1.dev268.dist-info}/WHEEL +1 -1
  170. mplang/core/__init__.py +0 -92
  171. mplang/device.py +0 -340
  172. mplang/kernels/builtin.py +0 -207
  173. mplang/ops/crypto.py +0 -109
  174. mplang/ops/ibis_cc.py +0 -139
  175. mplang/ops/sql.py +0 -61
  176. mplang/protos/v1alpha1/mpir_pb2_grpc.py +0 -3
  177. mplang/runtime/link_comm.py +0 -131
  178. mplang/simp/smpc.py +0 -201
  179. mplang/utils/table_utils.py +0 -73
  180. mplang_nightly-0.1.dev158.dist-info/RECORD +0 -77
  181. /mplang/{core → v1/core}/mask.py +0 -0
  182. /mplang/{protos → v1/protos}/v1alpha1/mpir_pb2.py +0 -0
  183. /mplang/{runtime → v1/runtime}/exceptions.py +0 -0
  184. /mplang/{runtime → v1/runtime}/http_api.md +0 -0
  185. /mplang/{kernels → v1/simp}/__init__.py +0 -0
  186. /mplang/{utils → v1/utils}/__init__.py +0 -0
  187. /mplang/{utils → v1/utils}/crypto.py +0 -0
  188. /mplang/{utils → v1/utils}/func_utils.py +0 -0
  189. /mplang/{utils → v1/utils}/spu_utils.py +0 -0
  190. {mplang_nightly-0.1.dev158.dist-info → mplang_nightly-0.1.dev268.dist-info}/entry_points.txt +0 -0
  191. {mplang_nightly-0.1.dev158.dist-info → mplang_nightly-0.1.dev268.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,871 @@
1
+ # Copyright 2025 Ant Group Co., Ltd.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """Interpreter: Execute Graph IR and Eager Operations.
16
+
17
+ Interpreter is a Context that executes operations immediately.
18
+ It can execute both:
19
+ 1. Graph IR (via GraphInterpreter)
20
+ 2. Eager operations on InterpObject (via backend executors)
21
+ """
22
+
23
+ from __future__ import annotations
24
+
25
+ import collections
26
+ import concurrent.futures
27
+ import json
28
+ import os
29
+ import pathlib
30
+ import queue
31
+ import threading
32
+ import time
33
+ from collections.abc import Callable
34
+ from typing import TYPE_CHECKING, Any, cast
35
+
36
+ from mplang.v2.edsl.context import AbstractInterpreter
37
+ from mplang.v2.edsl.graph import Graph
38
+ from mplang.v2.edsl.object import Object
39
+ from mplang.v2.edsl.registry import get_impl
40
+ from mplang.v2.edsl.typing import BaseType
41
+ from mplang.v2.runtime.dialect_state import DialectState
42
+ from mplang.v2.runtime.object_store import ObjectStore
43
+
44
+ if TYPE_CHECKING:
45
+ from mplang.v2.edsl.primitive import Primitive
46
+
47
+
48
+ class ExecutionTracer:
49
+ """Tracer for DAG execution events (Chrome Tracing format)."""
50
+
51
+ def __init__(self, enabled: bool = False, *, trace_dir: str | pathlib.Path):
52
+ self.enabled = enabled
53
+ self.start_time = 0.0
54
+ self.end_time = 0.0
55
+ self.active_tasks_samples: list[tuple[float, int]] = []
56
+ self.queue_size_samples: list[tuple[float, int]] = []
57
+ self.completed_ops = 0
58
+ self.total_ops = 0
59
+ self.trace_dir = pathlib.Path(trace_dir)
60
+
61
+ # Tracing
62
+ self.trace_events: list[dict[str, Any]] = []
63
+ self.op_schedule_times: dict[
64
+ tuple[int, Any], float
65
+ ] = {} # (id(op), namespace) -> ts (us)
66
+ self.pid = os.getpid()
67
+
68
+ def start(self) -> None:
69
+ self.start_time = time.time()
70
+
71
+ def stop(self, filename_prefix: str = "dag_trace") -> None:
72
+ self.end_time = time.time()
73
+ self.save_trace(filename_prefix)
74
+
75
+ def sample(self, active_tasks: int, queue_size: int) -> None:
76
+ now = time.time() - self.start_time
77
+ self.active_tasks_samples.append((now, active_tasks))
78
+ self.queue_size_samples.append((now, queue_size))
79
+
80
+ def log_schedule(self, op: Any, namespace: Any = None) -> None:
81
+ if not self.enabled:
82
+ return
83
+ key = (id(op), namespace)
84
+ self.op_schedule_times[key] = time.time() * 1e6
85
+
86
+ def log_start(
87
+ self, op: Any, pid: int | None = None, namespace: Any = None
88
+ ) -> float:
89
+ if not self.enabled:
90
+ return 0.0
91
+ start_ts = time.time() * 1e6
92
+ if pid is None:
93
+ pid = self.pid
94
+
95
+ # Record scheduling latency (Queue Time)
96
+ key = (id(op), namespace)
97
+ if key in self.op_schedule_times:
98
+ sched_ts = self.op_schedule_times.pop(key)
99
+ self.trace_events.append({
100
+ "name": f"Queue: {op.opcode}",
101
+ "cat": "scheduler",
102
+ "ph": "X",
103
+ "ts": sched_ts,
104
+ "dur": start_ts - sched_ts,
105
+ "pid": pid,
106
+ "tid": "SchedulerQueue",
107
+ })
108
+ return start_ts
109
+
110
+ def log_end(self, op: Any, start_ts: float, pid: int | None = None) -> None:
111
+ if not self.enabled:
112
+ return
113
+ end_ts = time.time() * 1e6
114
+ tid = threading.get_ident()
115
+ if pid is None:
116
+ pid = self.pid
117
+
118
+ self.trace_events.append({
119
+ "name": op.opcode,
120
+ "cat": "op",
121
+ "ph": "X",
122
+ "ts": start_ts,
123
+ "dur": end_ts - start_ts,
124
+ "pid": pid,
125
+ "tid": tid,
126
+ "args": {
127
+ "opcode": op.opcode,
128
+ },
129
+ })
130
+
131
+ def log_custom_event(
132
+ self,
133
+ name: str,
134
+ start_ts: float,
135
+ end_ts: float,
136
+ cat: str = "custom",
137
+ args: dict[str, Any] | None = None,
138
+ ) -> None:
139
+ """Log a custom event with explicit start/end timestamps (in seconds)."""
140
+ if not self.enabled:
141
+ return
142
+ tid = threading.get_ident()
143
+
144
+ # Convert to microseconds
145
+ ts_us = start_ts * 1e6
146
+ dur_us = (end_ts - start_ts) * 1e6
147
+
148
+ self.trace_events.append({
149
+ "name": name,
150
+ "cat": cat,
151
+ "ph": "X",
152
+ "ts": ts_us,
153
+ "dur": dur_us,
154
+ "pid": self.pid,
155
+ "tid": tid,
156
+ "args": args or {},
157
+ })
158
+
159
+ def save_trace(
160
+ self,
161
+ filename_prefix: str = "dag_trace",
162
+ job_id: str | None = None,
163
+ rank: int | None = None,
164
+ ) -> None:
165
+ if not self.enabled or not self.trace_events:
166
+ return
167
+ try:
168
+ if len(self.trace_events) < 100:
169
+ return # Skip small graphs
170
+
171
+ # Use unique filename to avoid overwriting
172
+ if job_id:
173
+ # Format: trace_<job_id>_rank_<rank>.json
174
+ rank_str = f"_rank_{rank}" if rank is not None else ""
175
+ filename = f"trace_{job_id}{rank_str}.json"
176
+ else:
177
+ timestamp = int(time.time() * 1000)
178
+ tid = threading.get_ident()
179
+ filename = f"{filename_prefix}_{timestamp}_{tid}.json"
180
+
181
+ # Save trace to trace_dir
182
+ self.trace_dir.mkdir(parents=True, exist_ok=True)
183
+ filepath = self.trace_dir / filename
184
+
185
+ with open(filepath, "w") as f:
186
+ json.dump({"traceEvents": self.trace_events}, f)
187
+ print(f"\n[Tracer] Trace saved to {filepath.absolute()}")
188
+ except Exception as e:
189
+ print(f"[Tracer] Failed to save trace: {e}")
190
+
191
+ def print_summary(self) -> None:
192
+ duration = self.end_time - self.start_time
193
+ if duration <= 0:
194
+ return
195
+
196
+ avg_active = (
197
+ sum(c for _, c in self.active_tasks_samples)
198
+ / len(self.active_tasks_samples)
199
+ if self.active_tasks_samples
200
+ else 0
201
+ )
202
+ max_active = (
203
+ max(c for _, c in self.active_tasks_samples)
204
+ if self.active_tasks_samples
205
+ else 0
206
+ )
207
+ avg_queue = (
208
+ sum(c for _, c in self.queue_size_samples) / len(self.queue_size_samples)
209
+ if self.queue_size_samples
210
+ else 0
211
+ )
212
+
213
+ print("\n" + "=" * 80)
214
+ print("DAG EXECUTION PROFILER")
215
+ print("=" * 80)
216
+ print(f"Total Duration: {duration:.3f}s")
217
+ print(f"Total Ops: {self.total_ops}")
218
+ print(f"Throughput: {self.total_ops / duration:.1f} ops/s")
219
+ print("-" * 80)
220
+ print(f"Active Tasks: Avg={avg_active:.1f}, Max={max_active}")
221
+ print(f"Ready Queue: Avg={avg_queue:.1f}")
222
+ print("=" * 80 + "\n")
223
+
224
+
225
+ class _NullTracer:
226
+ """No-op tracer stub for when tracing is disabled."""
227
+
228
+ enabled = False
229
+ total_ops = 0
230
+
231
+ def log_schedule(self, op: Any, namespace: Any = None) -> None:
232
+ pass
233
+
234
+ def log_start(
235
+ self, op: Any, pid: int | None = None, namespace: Any = None
236
+ ) -> float:
237
+ return 0.0
238
+
239
+ def log_end(self, op: Any, start_ts: float, pid: int | None = None) -> None:
240
+ pass
241
+
242
+ def stop(self) -> None:
243
+ pass
244
+
245
+ def save_trace(self, **kwargs: Any) -> None:
246
+ pass
247
+
248
+
249
+ class InterpObject(Object):
250
+ """Interp-time object (during eager execution).
251
+
252
+ Holds a runtime object (the actual data/handle owned by the backend executor)
253
+ and a reference to the Interpreter (Context).
254
+ Operations delegate to primitives which execute immediately.
255
+
256
+ The runtime object can be:
257
+ - FHE backend: Local TenSEAL/SEAL ciphertext
258
+ - JAX backend: Local jax.Array
259
+ - MP backend: Backend handle (pointer to party-side data)
260
+ - SQL backend: DatabaseHandle
261
+ - etc.
262
+
263
+ Example:
264
+ >>> # FHE backend (local execution)
265
+ >>> x = fhe.encrypt([1, 2, 3]) # InterpObject with local ciphertext
266
+ >>> y = fhe.encrypt([4, 5, 6])
267
+ >>> z = x + y # InterpObject.__add__ → add_p.bind(x, y)
268
+
269
+ >>> # MP backend (distributed execution)
270
+ >>> x = mp.random.uniform(shape=(10,)) # InterpObject with backend handle
271
+ >>> y = mp.random.uniform(shape=(10,))
272
+ >>> z = x + y # InterpObject.__add__ → add_p.bind(x, y)
273
+ """
274
+
275
+ def __init__(
276
+ self,
277
+ runtime_obj: Any,
278
+ obj_type: BaseType,
279
+ interpreter: Interpreter | None = None,
280
+ ):
281
+ """Initialize InterpObject.
282
+
283
+ Args:
284
+ runtime_obj: Backend-specific runtime object (ciphertext, array, handle, etc.)
285
+ obj_type: Type of the object (BaseType from edsl.typing)
286
+ interpreter: Interpreter context (if None, uses default interpreter)
287
+ """
288
+ self._runtime_obj = runtime_obj
289
+ self._type = obj_type
290
+ self._context = interpreter # InterpObject holds its Interpreter (Context)
291
+
292
+ @property
293
+ def type(self) -> BaseType:
294
+ return self._type
295
+
296
+ @property
297
+ def runtime_obj(self) -> Any:
298
+ """Get the underlying runtime object (backend-specific)."""
299
+ return self._runtime_obj
300
+
301
+ def __repr__(self) -> str:
302
+ runtime_repr = repr(self._runtime_obj)
303
+ # Truncate long representations
304
+ if len(runtime_repr) > 50:
305
+ runtime_repr = runtime_repr[:47] + "..."
306
+ return f"InterpObject({runtime_repr}, type={self.type})"
307
+
308
+
309
+ class Interpreter(AbstractInterpreter):
310
+ """Execution context for eager execution.
311
+
312
+ Inherits from Context and implements bind_primitive() by executing immediately.
313
+
314
+ Responsibilities:
315
+ 1. Execute primitives on InterpObject immediately
316
+ 2. Delegate to backend-specific executors
317
+ 3. Execute Graph IR (via GraphInterpreter)
318
+
319
+ Example:
320
+ >>> interp = Interpreter()
321
+ >>> x = InterpObject(np.array([1, 2, 3]), Tensor[f32, (3,)])
322
+ >>> y = InterpObject(np.array([4, 5, 6]), Tensor[f32, (3,)])
323
+ >>> z = x + y # InterpObject.__add__ → add_p.bind(x, y)
324
+ """
325
+
326
+ def __init__(
327
+ self,
328
+ executor: concurrent.futures.Executor | None = None,
329
+ name: str = "Interpreter",
330
+ tracer: ExecutionTracer | None = None,
331
+ trace_pid: int | None = None,
332
+ store: ObjectStore | None = None,
333
+ root_dir: str | pathlib.Path | None = None,
334
+ handlers: dict[str, Callable[..., Any]] | None = None,
335
+ ) -> None:
336
+ # Persistence Root
337
+ self.root_dir = (
338
+ pathlib.Path(root_dir)
339
+ if root_dir
340
+ else pathlib.Path(os.environ.get("MPLANG_DATA_ROOT", ".mpl"))
341
+ )
342
+
343
+ # Initialize Context base class (for state management)
344
+ super().__init__()
345
+
346
+ # Instance-level handler registry (overrides global registry)
347
+ self.handlers: dict[str, Callable] = handlers or {}
348
+ self.tracer = tracer
349
+
350
+ # GraphValue -> InterpObject cache
351
+ # Maps a GraphValue (IR node) to its computed InterpObject (Runtime result).
352
+ # This serves two purposes:
353
+ # 1. Caching: Avoid re-evaluating the same graph node multiple times.
354
+ # 2. MIMO Optimization: When one output of a multi-output op is computed,
355
+ # all sibling outputs are cached here to avoid re-execution.
356
+ self._execution_cache: dict[Any, InterpObject] = {}
357
+ self.executor = executor
358
+ self.async_ops: set[str] = set()
359
+ self.name = name
360
+ self.trace_pid = trace_pid
361
+ self.store: ObjectStore | None = store
362
+
363
+ def shutdown(self) -> None:
364
+ """Shutdown the interpreter and release resources.
365
+
366
+ This method is idempotent and safe to call multiple times.
367
+ It performs the following cleanup:
368
+ 1. Shuts down the internal executor (if any).
369
+ 2. Stops the execution tracer (if any).
370
+ 3. Shuts down any attached dialect states (e.g., stopping drivers).
371
+ """
372
+ # 1. Shutdown Executor
373
+ if self.executor:
374
+ self.executor.shutdown(wait=True)
375
+ self.executor = None
376
+
377
+ # 2. Stop Tracer
378
+ if self.tracer:
379
+ self.tracer.stop()
380
+ # Don't clear self.tracer, as we might want to read stats later
381
+
382
+ # 3. Shutdown Dialect States
383
+ # Iterate over all attached states (e.g., drivers, cluster managers)
384
+ # and shut them down if they support it.
385
+ for state in self._states.values():
386
+ if hasattr(state, "shutdown") and callable(state.shutdown):
387
+ state.shutdown()
388
+
389
+ # =========================================================================
390
+ # Dialect State Management
391
+ # =========================================================================
392
+ def get_dialect_state(self, dialect: str) -> DialectState | None:
393
+ """Get the state object for a specific dialect.
394
+
395
+ This is a convenience wrapper around get_state("dialect.{dialect}").
396
+
397
+ Args:
398
+ dialect: Name of the dialect (e.g., "simp", "bfv", "spu")
399
+
400
+ Returns:
401
+ The dialect state object, or None if not set.
402
+
403
+ Example:
404
+ simp_state = interpreter.get_dialect_state("simp")
405
+ if simp_state is not None:
406
+ simp_state.submit(rank, graph, inputs)
407
+ """
408
+ state = self.get_state(f"dialect.{dialect}")
409
+ # Type assertion: dialect states are always DialectState or None
410
+ return cast(DialectState | None, state)
411
+
412
+ def set_dialect_state(self, dialect: str, state: DialectState) -> None:
413
+ """Set the state object for a specific dialect.
414
+
415
+ This is a convenience wrapper around set_state("dialect.{dialect}", state).
416
+
417
+ Args:
418
+ dialect: Name of the dialect (e.g., "simp", "bfv", "spu")
419
+ state: The dialect state object (should implement DialectState protocol)
420
+
421
+ Example:
422
+ interpreter.set_dialect_state("simp", cluster.connect())
423
+ """
424
+ self.set_state(f"dialect.{dialect}", state)
425
+
426
+ def bind_primitive(
427
+ self, primitive: Primitive, args: tuple[Any, ...], kwargs: dict[str, Any]
428
+ ) -> InterpObject | list[InterpObject] | Any:
429
+ """Execute primitive by tracing and interpreting.
430
+
431
+ Implements the unified trace → interpret flow:
432
+ 1. All InterpObject arguments already registered via lift()
433
+ 2. Create a Tracer and push it as context
434
+ 3. Call primitive.bind() to build Graph IR (uses obj id in value names)
435
+ 4. Execute the graph via evaluate_graph() (resolves inputs via registry)
436
+
437
+ Args:
438
+ primitive: The primitive to execute
439
+ args: Positional arguments (already lifted by Primitive.bind)
440
+ kwargs: Keyword arguments (already lifted by Primitive.bind)
441
+
442
+ Returns:
443
+ Execution result (InterpObject or list of InterpObject or mixed with immediates)
444
+ """
445
+ from mplang.v1.utils.func_utils import var_demorph, var_morph
446
+ from mplang.v2.edsl.tracer import Tracer
447
+
448
+ # Create tracer and build graph
449
+ # Note: primitive.bind() internally calls Tracer.lift() with is_param=False,
450
+ # so all args become captures (not params). This is correct because we're
451
+ # tracing a primitive execution, not a user function with explicit parameters.
452
+ with Tracer() as ctx:
453
+ # Finalize graph by setting outputs
454
+ result_traced = primitive.bind(*args, **kwargs)
455
+
456
+ # Separate outputs into variables (Objects) and immediates (constants)
457
+ out_vars, out_imms, morph_struct = var_morph(
458
+ result_traced, lambda x: isinstance(x, Object)
459
+ )
460
+
461
+ if out_vars:
462
+ graph = ctx.finalize(out_vars)
463
+ else:
464
+ # All outputs are immediates, no graph outputs
465
+ graph = ctx.graph
466
+ graph.outputs = []
467
+
468
+ # Build inputs list for interpret
469
+ # _captured_vars contains all inputs (no params in this context)
470
+ inputs_list = [
471
+ obj.runtime_obj if isinstance(obj, InterpObject) else obj
472
+ for obj, _ in ctx._captured_vars.values()
473
+ ]
474
+
475
+ # Execute graph (may have 0 outputs if all were immediates)
476
+ if graph.outputs:
477
+ result_runtime_list = self.evaluate_graph(graph, inputs_list)
478
+ else:
479
+ result_runtime_list = []
480
+
481
+ # Wrap runtime results as InterpObjects
482
+ interp_results = [
483
+ InterpObject(rt_val, tr_obj.type, self)
484
+ for rt_val, tr_obj in zip(result_runtime_list, out_vars, strict=True)
485
+ ]
486
+
487
+ # Reconstruct the output tree: merge InterpObjects and immediates
488
+ return var_demorph(interp_results, out_imms, morph_struct)
489
+
490
+ def lift(self, obj: Any) -> InterpObject | Any:
491
+ """Lift an object to the Interpreter's native representation.
492
+
493
+ This is THE central method that manages the boundary between
494
+ InterpObject and TraceObject:
495
+
496
+ 1. **InterpObject → TraceObject** (during nested tracing):
497
+ - Register the InterpObject in self._objects for later resolution
498
+ - The InterpObject must belong to this Interpreter
499
+ - When the object flows into Tracer.lift() during bind_primitive,
500
+ it will be captured as input with a clean SSA name like "%arg0"
501
+
502
+ 2. **TraceObject → InterpObject** (evaluate traced computation):
503
+ - Extract the graph from the TraceObject's context (Tracer)
504
+ - Execute the graph via evaluate_graph() to get runtime result
505
+ - Wrap result as InterpObject and register it
506
+
507
+ 3. **Constants**: Pass through unchanged
508
+
509
+ Args:
510
+ obj: Object to lift (InterpObject, TraceObject, or constant)
511
+
512
+ Returns:
513
+ InterpObject (if Object input) or constant (pass-through)
514
+
515
+ Example:
516
+ >>> # InterpObject case
517
+ >>> x = InterpObject(np.array([1, 2]), Tensor[f32, (2,)])
518
+ >>> x_lifted = interp.lift(x) # registers in _objects, returns x
519
+ >>>
520
+ >>> # TraceObject case
521
+ >>> tracer = Tracer()
522
+ >>> push_context(tracer)
523
+ >>> z_trace = some_primitive.bind(x, y) # TraceObject
524
+ >>> pop_context()
525
+ >>> interp = Interpreter()
526
+ >>> z_interp = interp.lift(z_trace) # evaluate graph → InterpObject
527
+ """
528
+ from mplang.v2.edsl.tracer import TraceObject
529
+
530
+ if isinstance(obj, InterpObject):
531
+ # InterpObject must belong to this interpreter
532
+ if obj._context is not None and obj._context is not self:
533
+ raise ValueError(
534
+ f"InterpObject belongs to a different Interpreter. "
535
+ f"Object context: {obj._context}, Current interpreter: {self}"
536
+ )
537
+ return obj
538
+
539
+ elif isinstance(obj, TraceObject):
540
+ # Check execution cache
541
+ # If this value was computed as part of a previous execution (e.g. sibling output)
542
+ # we can return it immediately without re-execution.
543
+ graph_value = obj._graph_value
544
+ if graph_value in self._execution_cache:
545
+ return self._execution_cache[graph_value]
546
+
547
+ # First time seeing this Value.
548
+ # We need to execute the graph to compute it.
549
+ # MIMO Optimization:
550
+ # Instead of just asking for this single value, we ask for ALL outputs
551
+ # of the operation that produced this value. This ensures that if we
552
+ # later ask for a sibling output, it will be in the cache.
553
+
554
+ tracer = obj._context
555
+ graph = tracer.graph
556
+ defining_op = graph_value.defining_op
557
+
558
+ if defining_op is None:
559
+ # Value is likely a constant or input (no defining op in graph)
560
+ # Just execute graph for this single value
561
+ target_outputs = [graph_value]
562
+ else:
563
+ # Fetch all outputs of the defining op
564
+ target_outputs = defining_op.outputs
565
+
566
+ # Temporarily set graph outputs to the target outputs
567
+ # We must save/restore original outputs to avoid side effects
568
+ original_outputs = graph.outputs
569
+ graph.outputs = target_outputs
570
+
571
+ try:
572
+ # Resolve inputs from Tracer's captured vars
573
+ # _captured_vars preserves insertion order which matches graph.inputs order
574
+ inputs_list = []
575
+ for captured_obj, _ in tracer._captured_vars.values():
576
+ # Recursively lift captured objects to ensure they are ready
577
+ lifted = self.lift(captured_obj)
578
+ if isinstance(lifted, InterpObject):
579
+ inputs_list.append(lifted.runtime_obj)
580
+ else:
581
+ inputs_list.append(lifted)
582
+
583
+ # Execute graph
584
+ results_runtime = self.evaluate_graph(graph, inputs_list)
585
+
586
+ # Cache all results
587
+ for val, res in zip(target_outputs, results_runtime, strict=True):
588
+ # Wrap as InterpObject and cache
589
+ # Note: We use obj.type for the requested value, but for siblings
590
+ # we should ideally use their types. However, we don't have TraceObjects
591
+ # for siblings here, only GraphValues.
592
+ # InterpObject needs a type. GraphValue has a type.
593
+ self._execution_cache[val] = InterpObject(res, val.type, self)
594
+
595
+ finally:
596
+ # Restore original outputs
597
+ graph.outputs = original_outputs
598
+
599
+ # Now the result for our requested object should be in the cache
600
+ if graph_value not in self._execution_cache:
601
+ raise RuntimeError(
602
+ f"Failed to compute value for {obj} even after graph execution"
603
+ )
604
+
605
+ return self._execution_cache[graph_value]
606
+
607
+ else:
608
+ # Constants: pass through unchanged
609
+ return obj
610
+
611
+ def evaluate_graph(
612
+ self, graph: Graph, inputs: list[Any], job_id: str | None = None
613
+ ) -> list[Any]:
614
+ """Execute a Graph IR with runtime data.
615
+
616
+ Can be overridden by subclasses to implement remote execution or compilation.
617
+
618
+ Args:
619
+ graph: Finalized Graph IR to execute
620
+ inputs: Runtime objects corresponding to graph.inputs (positional)
621
+ job_id: Optional unique ID for this execution job (for profiling/tracing).
622
+
623
+ Returns:
624
+ List of runtime execution results corresponding to graph.outputs.
625
+ """
626
+ if self.executor:
627
+ return self._evaluate_graph_async(graph, inputs, job_id)
628
+ else:
629
+ return self._evaluate_graph_sync(graph, inputs, job_id)
630
+
631
+ def _evaluate_graph_sync(
632
+ self, graph: Graph, inputs: list[Any], job_id: str | None = None
633
+ ) -> list[Any]:
634
+ """Synchronous execution (Baseline)."""
635
+ # Local environment: Value -> Runtime Object
636
+ env = dict(zip(graph.inputs, inputs, strict=True))
637
+
638
+ for op in graph.operations:
639
+ # Resolve inputs
640
+ try:
641
+ args = [env[val] for val in op.inputs]
642
+ except KeyError as e:
643
+ missing_keys = [str(k) for k in op.inputs if k not in env]
644
+ # Limit available keys output to avoid flooding logs if env is huge
645
+ available_keys = [str(k) for k in list(env.keys())[:20]]
646
+ if len(env) > 20:
647
+ available_keys.append("...")
648
+
649
+ raise RuntimeError(
650
+ f"Failed to resolve inputs for op '{op.opcode}'.\n"
651
+ f"Missing values: {missing_keys}\n"
652
+ f"Available values (partial): {available_keys}"
653
+ ) from e
654
+
655
+ # Dispatch
656
+ # 1. Check instance-level handlers
657
+ handler = self.handlers.get(op.opcode)
658
+ # 2. Check global registry
659
+ if not handler:
660
+ handler = get_impl(op.opcode)
661
+
662
+ if handler:
663
+ # Pass interpreter to support recursive execution (HOFs)
664
+ # Pass op to access attributes and regions
665
+ # Pass args as runtime values
666
+ results = handler(self, op, *args)
667
+ else:
668
+ raise NotImplementedError(
669
+ f"No implementation registered for opcode: {op.opcode}"
670
+ )
671
+
672
+ # Update environment with outputs
673
+ # Handler should return a single value or a tuple/list of values
674
+ if len(op.outputs) == 0:
675
+ pass # Void operation
676
+ elif len(op.outputs) == 1:
677
+ env[op.outputs[0]] = results
678
+ else:
679
+ if len(results) != len(op.outputs):
680
+ raise RuntimeError(
681
+ f"Op {op.opcode} returned {len(results)} values, expected {len(op.outputs)}"
682
+ )
683
+ for out_val, res in zip(op.outputs, results, strict=True):
684
+ env[out_val] = res
685
+
686
+ # Return outputs
687
+ if self.tracer and job_id:
688
+ self.tracer.save_trace(job_id=job_id, rank=self.trace_pid)
689
+
690
+ return [env[out] for out in graph.outputs]
691
+
692
+ def _evaluate_graph_async(
693
+ self, graph: Graph, inputs: list[Any], job_id: str | None = None
694
+ ) -> list[Any]:
695
+ """Asynchronous execution with non-blocking DAG scheduling."""
696
+ # Tracer setup (if not provided, use a disabled stub)
697
+ tracer: ExecutionTracer | _NullTracer
698
+ if self.tracer:
699
+ tracer = self.tracer
700
+ tracer.total_ops += len(graph.operations)
701
+ else:
702
+ # No tracer provided - use minimal stub (no trace_dir needed)
703
+ tracer = _NullTracer()
704
+
705
+ active_tasks = 0
706
+
707
+ # 1. Setup State
708
+ # Value -> Runtime Object (initially inputs)
709
+ env = dict(zip(graph.inputs, inputs, strict=True))
710
+
711
+ # Op -> Pending Input Count
712
+ pending_counts = {}
713
+ # Value -> list[Op] (Consumers)
714
+ value_to_consumers: dict[Any, list[Any]] = collections.defaultdict(list)
715
+ # Value -> Remaining Consumers Count (for GC)
716
+ remaining_consumers: dict[Any, int] = collections.defaultdict(int)
717
+
718
+ # 2. Build Dependency Graph
719
+ for op in graph.operations:
720
+ count = 0
721
+ for val in op.inputs:
722
+ if val not in env: # If not already resolved (input or constant)
723
+ value_to_consumers[val].append(op)
724
+ remaining_consumers[val] += 1
725
+ count += 1
726
+ pending_counts[op] = count
727
+
728
+ # Mark graph outputs as having an extra consumer (the user)
729
+ # so they are not GC'd before return
730
+ for out in graph.outputs:
731
+ remaining_consumers[out] += 1
732
+
733
+ # 3. Synchronization
734
+ lock = threading.Lock()
735
+ ready_queue: queue.Queue[Any] = queue.Queue()
736
+ remaining_ops = len(graph.operations)
737
+
738
+ # Error propagation
739
+ error_occurred = False
740
+
741
+ # 4. Execution Helper
742
+ def on_op_done(op: Any, result: Any, error: Exception | None = None) -> None:
743
+ nonlocal remaining_ops, error_occurred, active_tasks
744
+
745
+ if error:
746
+ with lock:
747
+ if not error_occurred:
748
+ error_occurred = True
749
+ ready_queue.put(error)
750
+ return
751
+
752
+ with lock:
753
+ if op.opcode in self.async_ops and self.executor:
754
+ active_tasks -= 1
755
+ # profiler.sample(active_tasks, ready_queue.qsize())
756
+
757
+ if error_occurred:
758
+ return
759
+
760
+ # Store results
761
+ if len(op.outputs) == 1:
762
+ env[op.outputs[0]] = result
763
+ else:
764
+ for out_val, res in zip(op.outputs, result, strict=True):
765
+ env[out_val] = res
766
+
767
+ # Trigger consumers
768
+ for out_val in op.outputs:
769
+ if out_val in value_to_consumers:
770
+ for consumer_op in value_to_consumers[out_val]:
771
+ pending_counts[consumer_op] -= 1
772
+ if pending_counts[consumer_op] == 0:
773
+ tracer.log_schedule(
774
+ consumer_op, namespace=self.trace_pid
775
+ )
776
+ ready_queue.put(consumer_op)
777
+
778
+ # GC Inputs
779
+ for val in op.inputs:
780
+ if val in remaining_consumers:
781
+ remaining_consumers[val] -= 1
782
+ if remaining_consumers[val] == 0:
783
+ env.pop(val, None)
784
+
785
+ remaining_ops -= 1
786
+ if remaining_ops == 0:
787
+ ready_queue.put(None) # Sentinel
788
+
789
+ def execute_op(op: Any) -> None:
790
+ nonlocal active_tasks
791
+ # Extract args from env (must be ready)
792
+ args = [env[val] for val in op.inputs]
793
+
794
+ handler = self.handlers.get(op.opcode)
795
+ if not handler:
796
+ handler = get_impl(op.opcode)
797
+
798
+ if not handler:
799
+ raise NotImplementedError(
800
+ f"No implementation registered for opcode: {op.opcode}"
801
+ )
802
+
803
+ if op.opcode in self.async_ops and self.executor:
804
+ with lock:
805
+ active_tasks += 1
806
+ # profiler.sample(active_tasks, ready_queue.qsize())
807
+
808
+ # Submit to executor
809
+ def task() -> Any:
810
+ start_ts = tracer.log_start(
811
+ op, pid=self.trace_pid, namespace=self.trace_pid
812
+ )
813
+ res = handler(self, op, *args)
814
+ tracer.log_end(op, start_ts, pid=self.trace_pid)
815
+ return res
816
+
817
+ def callback(fut: Any) -> None:
818
+ try:
819
+ res = fut.result()
820
+ on_op_done(op, res)
821
+ except Exception as e:
822
+ on_op_done(op, None, error=e)
823
+
824
+ fut = self.executor.submit(task)
825
+ fut.add_done_callback(callback)
826
+ else:
827
+ # Sync execution (run immediately)
828
+ try:
829
+ start_ts = tracer.log_start(
830
+ op, pid=self.trace_pid, namespace=self.trace_pid
831
+ )
832
+ res = handler(self, op, *args)
833
+ tracer.log_end(op, start_ts, pid=self.trace_pid)
834
+ on_op_done(op, res)
835
+ except Exception as e:
836
+ on_op_done(op, None, error=e)
837
+
838
+ # 5. Initial Submission
839
+ # Submit all ops with 0 pending inputs
840
+ initial_ops = [op for op, count in pending_counts.items() if count == 0]
841
+ if not initial_ops and remaining_ops > 0:
842
+ # Cycle detected or empty graph?
843
+ pass
844
+
845
+ for op in initial_ops:
846
+ tracer.log_schedule(op, namespace=self.trace_pid)
847
+ ready_queue.put(op)
848
+
849
+ # Handle empty graph case
850
+ if remaining_ops == 0:
851
+ ready_queue.put(None)
852
+
853
+ # 6. Main Loop
854
+ while True:
855
+ item = ready_queue.get()
856
+ if item is None:
857
+ break
858
+ if isinstance(item, Exception):
859
+ raise item
860
+
861
+ # It's an op
862
+ execute_op(item)
863
+
864
+ # 7. Return outputs
865
+ if not self.tracer:
866
+ tracer.stop()
867
+
868
+ if self.tracer and job_id:
869
+ self.tracer.save_trace(job_id=job_id, rank=self.trace_pid)
870
+
871
+ return [env[out] for out in graph.outputs]