mplang-nightly 0.1.dev330__py3-none-any.whl → 0.1.dev332__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mplang/backends/simp_driver/mem.py +71 -65
- mplang/backends/simp_worker/__init__.py +5 -0
- mplang/backends/simp_worker/comm_context.py +16 -6
- mplang/backends/simp_worker/http.py +35 -35
- mplang/backends/simp_worker/infra.py +108 -0
- mplang/backends/simp_worker/request.py +80 -0
- mplang/backends/spu_state.py +70 -20
- mplang/runtime/interpreter.py +63 -178
- {mplang_nightly-0.1.dev330.dist-info → mplang_nightly-0.1.dev332.dist-info}/METADATA +1 -1
- {mplang_nightly-0.1.dev330.dist-info → mplang_nightly-0.1.dev332.dist-info}/RECORD +13 -11
- {mplang_nightly-0.1.dev330.dist-info → mplang_nightly-0.1.dev332.dist-info}/WHEEL +0 -0
- {mplang_nightly-0.1.dev330.dist-info → mplang_nightly-0.1.dev332.dist-info}/entry_points.txt +0 -0
- {mplang_nightly-0.1.dev330.dist-info → mplang_nightly-0.1.dev332.dist-info}/licenses/LICENSE +0 -0
|
@@ -23,9 +23,10 @@ from collections.abc import Callable
|
|
|
23
23
|
from typing import TYPE_CHECKING, Any, cast
|
|
24
24
|
|
|
25
25
|
from mplang.backends.simp_driver.state import SimpDriver
|
|
26
|
-
from mplang.backends.simp_worker import WORKER_HANDLERS
|
|
27
|
-
from mplang.backends.simp_worker.
|
|
26
|
+
from mplang.backends.simp_worker import WORKER_HANDLERS
|
|
27
|
+
from mplang.backends.simp_worker.infra import DEFAULT_ASYNC_OPS, WorkerInfra
|
|
28
28
|
from mplang.backends.simp_worker.mem import LocalMesh
|
|
29
|
+
from mplang.backends.simp_worker.request import create_request_interpreter
|
|
29
30
|
from mplang.runtime.interpreter import ExecutionTracer, Interpreter
|
|
30
31
|
from mplang.runtime.object_store import FileSystemBackend, ObjectStore
|
|
31
32
|
|
|
@@ -37,10 +38,14 @@ if TYPE_CHECKING:
|
|
|
37
38
|
|
|
38
39
|
|
|
39
40
|
class MemCluster:
|
|
40
|
-
"""Orchestrator that creates and manages local worker
|
|
41
|
+
"""Orchestrator that creates and manages local worker infrastructure.
|
|
41
42
|
|
|
42
43
|
This class handles worker lifecycle management. It does NOT attach to
|
|
43
44
|
an Interpreter - instead, it creates a SimpMemDriver that can be attached.
|
|
45
|
+
|
|
46
|
+
Per-request Interpreters are created on-the-fly by ``create_request_interpreter``
|
|
47
|
+
to isolate mutable state (CommContext, SimpWorker, SPUState) across concurrent
|
|
48
|
+
requests.
|
|
44
49
|
"""
|
|
45
50
|
|
|
46
51
|
def __init__(
|
|
@@ -75,8 +80,9 @@ class MemCluster:
|
|
|
75
80
|
)
|
|
76
81
|
self.tracer.start()
|
|
77
82
|
|
|
78
|
-
# Create
|
|
79
|
-
self.
|
|
83
|
+
# Create shared WorkerInfra per rank (replaces per-rank Interpreters)
|
|
84
|
+
self._infras: list[WorkerInfra] = []
|
|
85
|
+
self._stores: list[ObjectStore] = []
|
|
80
86
|
for rank in range(world_size):
|
|
81
87
|
worker_root = cluster_root / f"node{rank}"
|
|
82
88
|
store = ObjectStore(
|
|
@@ -84,60 +90,50 @@ class MemCluster:
|
|
|
84
90
|
root_path=str(worker_root / "store"),
|
|
85
91
|
)
|
|
86
92
|
)
|
|
93
|
+
self._stores.append(store)
|
|
87
94
|
|
|
88
|
-
|
|
95
|
+
w_handlers: dict[str, Callable[..., Any]] = {**WORKER_HANDLERS} # type: ignore[dict-item]
|
|
96
|
+
infra = WorkerInfra(
|
|
89
97
|
rank=rank,
|
|
90
98
|
world_size=world_size,
|
|
91
99
|
communicator=self._mesh.comms[rank],
|
|
92
100
|
store=store,
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
w_handlers: dict[str, Callable[..., Any]] = {**WORKER_HANDLERS} # type: ignore[dict-item]
|
|
96
|
-
comm_ctx = CommContext(
|
|
97
|
-
self._mesh.comms[rank],
|
|
98
|
-
context_id="ctx",
|
|
99
|
-
my_rank=rank,
|
|
100
|
-
)
|
|
101
|
-
w_interp = Interpreter(
|
|
102
|
-
name=f"Worker-{rank}",
|
|
101
|
+
handlers=w_handlers,
|
|
103
102
|
tracer=self.tracer,
|
|
104
103
|
trace_pid=rank,
|
|
105
|
-
store=store,
|
|
106
104
|
root_dir=worker_root,
|
|
107
|
-
|
|
108
|
-
|
|
105
|
+
# async_ops has no effect when executor is None (MemCluster
|
|
106
|
+
# runs each request synchronously on the mesh executor thread).
|
|
107
|
+
# We still set it so that WorkerInfra carries the canonical set
|
|
108
|
+
# for introspection and consistency with the HTTP Worker path.
|
|
109
|
+
async_ops=DEFAULT_ASYNC_OPS,
|
|
109
110
|
)
|
|
110
|
-
|
|
111
|
-
|
|
112
|
-
w_interp.async_ops = {
|
|
113
|
-
"bfv.add",
|
|
114
|
-
"bfv.mul",
|
|
115
|
-
"bfv.rotate",
|
|
116
|
-
"bfv.batch_encode",
|
|
117
|
-
"bfv.relinearize",
|
|
118
|
-
"bfv.encrypt",
|
|
119
|
-
"bfv.decrypt",
|
|
120
|
-
"field.solve_okvs",
|
|
121
|
-
"field.decode_okvs",
|
|
122
|
-
"field.aes_expand",
|
|
123
|
-
"field.mul",
|
|
124
|
-
"simp.shuffle",
|
|
125
|
-
}
|
|
126
|
-
self._workers.append(w_interp)
|
|
111
|
+
self._infras.append(infra)
|
|
127
112
|
|
|
128
113
|
@property
|
|
129
114
|
def world_size(self) -> int:
|
|
130
115
|
return self._world_size
|
|
131
116
|
|
|
132
117
|
@property
|
|
133
|
-
def
|
|
134
|
-
return self.
|
|
118
|
+
def infras(self) -> list[WorkerInfra]:
|
|
119
|
+
return self._infras
|
|
120
|
+
|
|
121
|
+
@property
|
|
122
|
+
def workers(self) -> list[WorkerInfra]:
|
|
123
|
+
"""Backward-compatible alias for ``infras``.
|
|
124
|
+
|
|
125
|
+
.. note:: Returns ``WorkerInfra`` objects (not ``Interpreter``). Only
|
|
126
|
+
the ``.store`` attribute is guaranteed by this interface. Callers
|
|
127
|
+
needing full Interpreter access should use
|
|
128
|
+
``create_request_interpreter(infra, job_id)`` instead.
|
|
129
|
+
"""
|
|
130
|
+
return self._infras
|
|
135
131
|
|
|
136
132
|
def create_state(self) -> SimpMemDriver:
|
|
137
133
|
"""Create a SimpMemDriver that can be attached to a Driver Interpreter."""
|
|
138
134
|
return SimpMemDriver(
|
|
139
135
|
world_size=self._world_size,
|
|
140
|
-
|
|
136
|
+
infras=self._infras,
|
|
141
137
|
mesh=self._mesh,
|
|
142
138
|
)
|
|
143
139
|
|
|
@@ -158,11 +154,11 @@ class SimpMemDriver(SimpDriver):
|
|
|
158
154
|
def __init__(
|
|
159
155
|
self,
|
|
160
156
|
world_size: int,
|
|
161
|
-
|
|
157
|
+
infras: list[WorkerInfra],
|
|
162
158
|
mesh: Any, # LocalMesh from simp_worker.mem
|
|
163
159
|
) -> None:
|
|
164
160
|
self._world_size = world_size
|
|
165
|
-
self.
|
|
161
|
+
self._infras = infras
|
|
166
162
|
self._mesh = mesh
|
|
167
163
|
|
|
168
164
|
def shutdown(self) -> None:
|
|
@@ -174,9 +170,18 @@ class SimpMemDriver(SimpDriver):
|
|
|
174
170
|
return self._world_size
|
|
175
171
|
|
|
176
172
|
@property
|
|
177
|
-
def
|
|
178
|
-
"""Worker
|
|
179
|
-
return self.
|
|
173
|
+
def infras(self) -> list[WorkerInfra]:
|
|
174
|
+
"""Worker infrastructure (shared, immutable)."""
|
|
175
|
+
return self._infras
|
|
176
|
+
|
|
177
|
+
@property
|
|
178
|
+
def workers(self) -> list[WorkerInfra]:
|
|
179
|
+
"""Backward-compatible alias for ``infras``.
|
|
180
|
+
|
|
181
|
+
.. note:: Returns ``WorkerInfra`` objects (not ``Interpreter``). Only
|
|
182
|
+
the ``.store`` attribute is guaranteed by this interface.
|
|
183
|
+
"""
|
|
184
|
+
return self._infras
|
|
180
185
|
|
|
181
186
|
def submit(
|
|
182
187
|
self, rank: int, graph: Graph, inputs: list[Any], job_id: str | None = None
|
|
@@ -205,31 +210,32 @@ class SimpMemDriver(SimpDriver):
|
|
|
205
210
|
|
|
206
211
|
def fetch(self, rank: int, uri: str) -> Future[Any]:
|
|
207
212
|
"""Fetch directly from worker store."""
|
|
208
|
-
|
|
209
|
-
|
|
210
|
-
return self._mesh.executor.submit(lambda: worker_ctx.store.get(uri)) # type: ignore[no-any-return]
|
|
213
|
+
infra = self._infras[rank]
|
|
214
|
+
return self._mesh.executor.submit(lambda: infra.store.get(uri)) # type: ignore[no-any-return]
|
|
211
215
|
|
|
212
216
|
def _run_worker(
|
|
213
217
|
self, rank: int, graph: Graph, inputs: list[Any], job_id: str | None = None
|
|
214
218
|
) -> Any:
|
|
215
|
-
"""Execute on
|
|
216
|
-
|
|
217
|
-
|
|
218
|
-
|
|
219
|
-
|
|
220
|
-
|
|
221
|
-
|
|
222
|
-
|
|
223
|
-
|
|
224
|
-
|
|
225
|
-
|
|
226
|
-
|
|
227
|
-
|
|
228
|
-
|
|
229
|
-
|
|
230
|
-
|
|
231
|
-
|
|
232
|
-
|
|
219
|
+
"""Execute on a per-request Interpreter."""
|
|
220
|
+
infra = self._infras[rank]
|
|
221
|
+
request_interp = create_request_interpreter(infra, job_id or "anonymous")
|
|
222
|
+
try:
|
|
223
|
+
# Resolve URI inputs (None means rank has no data)
|
|
224
|
+
resolved_inputs = [
|
|
225
|
+
infra.store.get(inp) if inp is not None else None for inp in inputs
|
|
226
|
+
]
|
|
227
|
+
|
|
228
|
+
# Execute
|
|
229
|
+
results = request_interp.evaluate_graph(graph, resolved_inputs, job_id)
|
|
230
|
+
|
|
231
|
+
# Store results (results is always a list)
|
|
232
|
+
if not graph.outputs:
|
|
233
|
+
return None
|
|
234
|
+
return [
|
|
235
|
+
infra.store.put(res) if res is not None else None for res in results
|
|
236
|
+
]
|
|
237
|
+
finally:
|
|
238
|
+
request_interp.shutdown()
|
|
233
239
|
|
|
234
240
|
|
|
235
241
|
def make_simulator(
|
|
@@ -38,11 +38,14 @@ from mplang.backends.simp_worker.http import (
|
|
|
38
38
|
RecvTimeoutError,
|
|
39
39
|
SendTimeoutError,
|
|
40
40
|
)
|
|
41
|
+
from mplang.backends.simp_worker.infra import DEFAULT_ASYNC_OPS, WorkerInfra
|
|
41
42
|
from mplang.backends.simp_worker.mem import LocalMesh, ThreadCommunicator
|
|
42
43
|
from mplang.backends.simp_worker.ops import WORKER_HANDLERS
|
|
44
|
+
from mplang.backends.simp_worker.request import create_request_interpreter
|
|
43
45
|
from mplang.backends.simp_worker.state import SimpWorker
|
|
44
46
|
|
|
45
47
|
__all__ = [
|
|
48
|
+
"DEFAULT_ASYNC_OPS",
|
|
46
49
|
"WORKER_HANDLERS",
|
|
47
50
|
"CommConfig",
|
|
48
51
|
"CommStats",
|
|
@@ -56,6 +59,8 @@ __all__ = [
|
|
|
56
59
|
"SendTimeoutError",
|
|
57
60
|
"SimpWorker",
|
|
58
61
|
"ThreadCommunicator",
|
|
62
|
+
"WorkerInfra",
|
|
63
|
+
"create_request_interpreter",
|
|
59
64
|
"testall",
|
|
60
65
|
"testany",
|
|
61
66
|
"wait_all",
|
|
@@ -62,15 +62,25 @@ class CommContext:
|
|
|
62
62
|
def world_size(self) -> int:
|
|
63
63
|
return self._comm.world_size
|
|
64
64
|
|
|
65
|
-
def spawn(self) -> CommContext:
|
|
65
|
+
def spawn(self, suffix: str | None = None) -> CommContext:
|
|
66
66
|
"""Create a child context with independent counter namespace.
|
|
67
67
|
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
|
|
68
|
+
Args:
|
|
69
|
+
suffix: Explicit suffix for the child context ID. When provided,
|
|
70
|
+
``context_id = f"{parent_id}.{suffix}"``. This is used when
|
|
71
|
+
the caller needs a deterministic, content-based ID that is
|
|
72
|
+
stable across ranks regardless of call order (e.g. the async
|
|
73
|
+
DAG scheduler uses ``f"{graph_exec_key}.{op_idx}"``).
|
|
74
|
+
|
|
75
|
+
When *None* (default), an auto-incrementing counter is used:
|
|
76
|
+
``context_id = f"{parent_id}.{spawn_seq}"``. This is suitable
|
|
77
|
+
for SPMD code where all ranks call ``spawn()`` in the same
|
|
78
|
+
program-order position.
|
|
72
79
|
"""
|
|
73
|
-
|
|
80
|
+
if suffix is None:
|
|
81
|
+
child_id = f"{self._id}.{self._spawn_counter}"
|
|
82
|
+
else:
|
|
83
|
+
child_id = f"{self._id}.{self._spawn_counter}.{suffix}"
|
|
74
84
|
self._spawn_counter += 1
|
|
75
85
|
return CommContext(self._comm, child_id, self._rank)
|
|
76
86
|
|
|
@@ -61,10 +61,10 @@ from mplang.backends.simp_worker.base import (
|
|
|
61
61
|
wait_all,
|
|
62
62
|
wait_any,
|
|
63
63
|
)
|
|
64
|
-
from mplang.backends.simp_worker.
|
|
64
|
+
from mplang.backends.simp_worker.infra import WorkerInfra
|
|
65
65
|
from mplang.edsl import serde
|
|
66
66
|
from mplang.edsl.graph import Graph
|
|
67
|
-
from mplang.runtime.interpreter import ExecutionTracer
|
|
67
|
+
from mplang.runtime.interpreter import ExecutionTracer
|
|
68
68
|
from mplang.runtime.object_store import FileSystemBackend, ObjectStore
|
|
69
69
|
from mplang.utils.logging import get_logger
|
|
70
70
|
|
|
@@ -636,7 +636,7 @@ def register_routes(
|
|
|
636
636
|
*,
|
|
637
637
|
rank: int,
|
|
638
638
|
world_size: int,
|
|
639
|
-
|
|
639
|
+
infra: WorkerInfra,
|
|
640
640
|
comm: HttpCommunicator,
|
|
641
641
|
store: ObjectStore,
|
|
642
642
|
exec_pool: concurrent.futures.ThreadPoolExecutor,
|
|
@@ -648,36 +648,38 @@ def register_routes(
|
|
|
648
648
|
enabling reuse in custom server setups.
|
|
649
649
|
|
|
650
650
|
Note:
|
|
651
|
-
The *
|
|
652
|
-
|
|
653
|
-
any request is served, as the ``/fetch`` and ``/objects`` endpoints
|
|
654
|
-
rely on it.
|
|
651
|
+
The *infra* ``WorkerInfra`` contains all shared infrastructure needed
|
|
652
|
+
to create per-request Interpreters.
|
|
655
653
|
|
|
656
654
|
Args:
|
|
657
655
|
app: The FastAPI application to register routes on.
|
|
658
656
|
rank: This worker's rank.
|
|
659
657
|
world_size: Total number of workers.
|
|
660
|
-
|
|
658
|
+
infra: The WorkerInfra for creating per-request Interpreters.
|
|
661
659
|
comm: The HttpCommunicator for inter-worker communication.
|
|
662
660
|
store: The ObjectStore for data persistence.
|
|
663
661
|
exec_pool: Thread pool for executing graphs.
|
|
664
662
|
"""
|
|
665
|
-
from
|
|
663
|
+
from mplang.backends.simp_worker.request import create_request_interpreter
|
|
666
664
|
|
|
667
665
|
def _do_execute(graph: Graph, inputs: list[Any], job_id: str | None = None) -> Any:
|
|
668
|
-
"""Execute graph in worker thread."""
|
|
669
|
-
|
|
670
|
-
|
|
671
|
-
|
|
672
|
-
|
|
666
|
+
"""Execute graph in worker thread with per-request isolation."""
|
|
667
|
+
request_interp = create_request_interpreter(infra, job_id or "anonymous")
|
|
668
|
+
try:
|
|
669
|
+
# Resolve URI inputs (None means rank has no data)
|
|
670
|
+
resolved_inputs = [
|
|
671
|
+
store.get(inp) if inp is not None else None for inp in inputs
|
|
672
|
+
]
|
|
673
673
|
|
|
674
|
-
|
|
675
|
-
|
|
674
|
+
result = request_interp.evaluate_graph(graph, resolved_inputs, job_id)
|
|
675
|
+
comm.wait_pending_sends()
|
|
676
676
|
|
|
677
|
-
|
|
678
|
-
|
|
679
|
-
|
|
680
|
-
|
|
677
|
+
# Store results and return URIs (result is always a list)
|
|
678
|
+
if not graph.outputs:
|
|
679
|
+
return None
|
|
680
|
+
return [store.put(res) if res is not None else None for res in result]
|
|
681
|
+
finally:
|
|
682
|
+
request_interp.shutdown()
|
|
681
683
|
|
|
682
684
|
# Mutable closure state shared between endpoint handlers and exec_pool threads.
|
|
683
685
|
async_tasks: dict[str, AsyncTaskState] = {}
|
|
@@ -845,8 +847,7 @@ def register_routes(
|
|
|
845
847
|
"""Fetch data by URI (e.g. ``mem://abc123``, ``fs://ckpt/s100``)."""
|
|
846
848
|
logger.debug(f"Worker {rank} received fetch request for {req.uri}")
|
|
847
849
|
try:
|
|
848
|
-
|
|
849
|
-
val = state.store.get(req.uri)
|
|
850
|
+
val = store.get(req.uri)
|
|
850
851
|
return {"result": serde.dumps_b64(val)}
|
|
851
852
|
except Exception as e:
|
|
852
853
|
logger.error(f"Worker {rank} fetch failed: {e}")
|
|
@@ -856,8 +857,7 @@ def register_routes(
|
|
|
856
857
|
def list_objects() -> dict[str, list[str]]:
|
|
857
858
|
"""List all objects in the worker's store (transient + persistent)."""
|
|
858
859
|
try:
|
|
859
|
-
|
|
860
|
-
return {"objects": state.store.list_keys()}
|
|
860
|
+
return {"objects": store.list_keys()}
|
|
861
861
|
except Exception as e:
|
|
862
862
|
logger.error(f"Worker {rank} list_objects failed: {e}")
|
|
863
863
|
raise HTTPException(status_code=500, detail=str(e)) from e
|
|
@@ -917,22 +917,22 @@ def create_worker_app(
|
|
|
917
917
|
comm = HttpCommunicator(rank, world_size, endpoints, tracer=tracer)
|
|
918
918
|
if store is None:
|
|
919
919
|
store = ObjectStore(persistent=FileSystemBackend(root_path=str(root_dir)))
|
|
920
|
-
ctx = SimpWorker(rank, world_size, comm, store, spu_endpoints)
|
|
921
920
|
|
|
922
921
|
handlers: dict[str, Callable[..., Any]] = {**WORKER_HANDLERS} # type: ignore[dict-item]
|
|
923
922
|
|
|
924
|
-
from mplang.backends.simp_worker.
|
|
923
|
+
from mplang.backends.simp_worker.infra import DEFAULT_ASYNC_OPS, WorkerInfra
|
|
925
924
|
|
|
926
|
-
|
|
927
|
-
|
|
925
|
+
infra = WorkerInfra(
|
|
926
|
+
rank=rank,
|
|
927
|
+
world_size=world_size,
|
|
928
|
+
communicator=comm,
|
|
929
|
+
store=store,
|
|
930
|
+
handlers=handlers,
|
|
931
|
+
spu_endpoints=spu_endpoints,
|
|
928
932
|
tracer=tracer,
|
|
929
933
|
root_dir=root_dir,
|
|
930
|
-
|
|
931
|
-
store=store,
|
|
932
|
-
comm_ctx=comm_ctx,
|
|
934
|
+
async_ops=DEFAULT_ASYNC_OPS,
|
|
933
935
|
)
|
|
934
|
-
# Register SimpWorker context as 'simp' dialect state
|
|
935
|
-
worker.set_dialect_state("simp", ctx)
|
|
936
936
|
|
|
937
937
|
exec_pool = concurrent.futures.ThreadPoolExecutor(
|
|
938
938
|
max_workers=2, thread_name_prefix=f"exec_{rank}"
|
|
@@ -941,7 +941,7 @@ def create_worker_app(
|
|
|
941
941
|
@asynccontextmanager
|
|
942
942
|
async def lifespan(app: FastAPI) -> AsyncIterator[None]:
|
|
943
943
|
yield
|
|
944
|
-
await asyncio.to_thread(
|
|
944
|
+
await asyncio.to_thread(infra.shutdown)
|
|
945
945
|
await asyncio.to_thread(comm.shutdown)
|
|
946
946
|
await asyncio.to_thread(exec_pool.shutdown, wait=True)
|
|
947
947
|
|
|
@@ -951,7 +951,7 @@ def create_worker_app(
|
|
|
951
951
|
app,
|
|
952
952
|
rank=rank,
|
|
953
953
|
world_size=world_size,
|
|
954
|
-
|
|
954
|
+
infra=infra,
|
|
955
955
|
comm=comm,
|
|
956
956
|
store=store,
|
|
957
957
|
exec_pool=exec_pool,
|
|
@@ -0,0 +1,108 @@
|
|
|
1
|
+
# Copyright 2026 Ant Group Co., Ltd.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
"""WorkerInfra: shared infrastructure container for per-request Interpreter creation.
|
|
16
|
+
|
|
17
|
+
Created once at Worker startup. Passed to ``create_request_interpreter()`` for
|
|
18
|
+
each incoming request. All fields are either immutable or thread-safe.
|
|
19
|
+
"""
|
|
20
|
+
|
|
21
|
+
from __future__ import annotations
|
|
22
|
+
|
|
23
|
+
import concurrent.futures
|
|
24
|
+
import pathlib
|
|
25
|
+
import threading
|
|
26
|
+
from collections.abc import Callable
|
|
27
|
+
from dataclasses import dataclass, field
|
|
28
|
+
from typing import Any
|
|
29
|
+
|
|
30
|
+
from mplang.backends.simp_worker.base import CommunicatorProtocol
|
|
31
|
+
from mplang.runtime.interpreter import ExecutionTracer
|
|
32
|
+
from mplang.runtime.object_store import ObjectStore
|
|
33
|
+
|
|
34
|
+
# Opcodes eligible for async DAG scheduling when an Executor is present.
|
|
35
|
+
# Kept as a module-level constant so that both MemCluster and HTTP Worker
|
|
36
|
+
# share the same set without duplication.
|
|
37
|
+
DEFAULT_ASYNC_OPS: frozenset[str] = frozenset({
|
|
38
|
+
"bfv.add",
|
|
39
|
+
"bfv.mul",
|
|
40
|
+
"bfv.rotate",
|
|
41
|
+
"bfv.batch_encode",
|
|
42
|
+
"bfv.relinearize",
|
|
43
|
+
"bfv.encrypt",
|
|
44
|
+
"bfv.decrypt",
|
|
45
|
+
"field.solve_okvs",
|
|
46
|
+
"field.decode_okvs",
|
|
47
|
+
"field.aes_expand",
|
|
48
|
+
"field.mul",
|
|
49
|
+
"simp.shuffle",
|
|
50
|
+
})
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
@dataclass
|
|
54
|
+
class WorkerInfra:
|
|
55
|
+
"""Shared infrastructure for a Worker process.
|
|
56
|
+
|
|
57
|
+
Created once at Worker startup. Passed to each per-request Interpreter
|
|
58
|
+
factory. All fields are either immutable or thread-safe.
|
|
59
|
+
"""
|
|
60
|
+
|
|
61
|
+
rank: int
|
|
62
|
+
world_size: int
|
|
63
|
+
communicator: CommunicatorProtocol
|
|
64
|
+
store: ObjectStore
|
|
65
|
+
handlers: dict[str, Callable[..., Any]]
|
|
66
|
+
spu_endpoints: dict[int, str] | None = None
|
|
67
|
+
tracer: ExecutionTracer | None = None
|
|
68
|
+
trace_pid: int | None = None
|
|
69
|
+
root_dir: pathlib.Path | None = None
|
|
70
|
+
executor: concurrent.futures.Executor | None = None
|
|
71
|
+
async_ops: frozenset[str] = field(default_factory=frozenset)
|
|
72
|
+
|
|
73
|
+
# SPU template links (lazily populated, protected by lock).
|
|
74
|
+
# Cache keys are (local_rank, spu_world_size, protocol, field, link_mode).
|
|
75
|
+
# The number of distinct keys is bounded by the Cartesian product of SPU
|
|
76
|
+
# configurations actually used at runtime -- typically 1-3 entries per
|
|
77
|
+
# worker process (one per distinct SPU device declaration).
|
|
78
|
+
_spu_lock: threading.Lock = field(default_factory=threading.Lock, repr=False)
|
|
79
|
+
_spu_template_links: dict[tuple, Any] = field(default_factory=dict, repr=False)
|
|
80
|
+
|
|
81
|
+
def get_or_create_spu_link(
|
|
82
|
+
self,
|
|
83
|
+
cache_key: tuple,
|
|
84
|
+
create_fn: Callable[[], Any],
|
|
85
|
+
) -> Any:
|
|
86
|
+
"""Thread-safe lazy creation of template SPU links.
|
|
87
|
+
|
|
88
|
+
Args:
|
|
89
|
+
cache_key: Tuple identifying the SPU configuration.
|
|
90
|
+
create_fn: Factory function to create a new link context.
|
|
91
|
+
|
|
92
|
+
Returns:
|
|
93
|
+
A libspu.link.Context (template link) for the given configuration.
|
|
94
|
+
"""
|
|
95
|
+
with self._spu_lock:
|
|
96
|
+
if cache_key not in self._spu_template_links:
|
|
97
|
+
self._spu_template_links[cache_key] = create_fn()
|
|
98
|
+
return self._spu_template_links[cache_key]
|
|
99
|
+
|
|
100
|
+
def shutdown(self) -> None:
|
|
101
|
+
"""Release cached SPU template links.
|
|
102
|
+
|
|
103
|
+
Safe to call multiple times. Should be called during process
|
|
104
|
+
shutdown to eagerly close BRPC connections rather than waiting
|
|
105
|
+
for GC / process exit.
|
|
106
|
+
"""
|
|
107
|
+
with self._spu_lock:
|
|
108
|
+
self._spu_template_links.clear()
|
|
@@ -0,0 +1,80 @@
|
|
|
1
|
+
# Copyright 2026 Ant Group Co., Ltd.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
"""Per-request Interpreter factory.
|
|
16
|
+
|
|
17
|
+
Creates lightweight Interpreter instances for each incoming request,
|
|
18
|
+
providing full isolation of mutable state (CommContext, SimpWorker,
|
|
19
|
+
SPUState) while sharing immutable infrastructure (Communicator, Store,
|
|
20
|
+
handlers).
|
|
21
|
+
"""
|
|
22
|
+
|
|
23
|
+
from __future__ import annotations
|
|
24
|
+
|
|
25
|
+
from mplang.backends.simp_worker.comm_context import CommContext
|
|
26
|
+
from mplang.backends.simp_worker.infra import WorkerInfra
|
|
27
|
+
from mplang.backends.simp_worker.state import SimpWorker
|
|
28
|
+
from mplang.backends.spu_state import SPUState
|
|
29
|
+
from mplang.runtime.interpreter import Interpreter
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
def create_request_interpreter(
|
|
33
|
+
infra: WorkerInfra,
|
|
34
|
+
job_id: str,
|
|
35
|
+
) -> Interpreter:
|
|
36
|
+
"""Create a lightweight Interpreter for a single request.
|
|
37
|
+
|
|
38
|
+
Cost: ~2μs (dict/TLS allocation only).
|
|
39
|
+
SPU Runtime created on-demand via link.spawn() (~120μs).
|
|
40
|
+
|
|
41
|
+
Args:
|
|
42
|
+
infra: Shared WorkerInfra (process-lifetime).
|
|
43
|
+
job_id: Unique request identifier (used as CommContext context_id).
|
|
44
|
+
|
|
45
|
+
Returns:
|
|
46
|
+
Per-request Interpreter with isolated state.
|
|
47
|
+
"""
|
|
48
|
+
# Per-request CommContext with unique context_id
|
|
49
|
+
comm_ctx = CommContext(infra.communicator, context_id=job_id, my_rank=infra.rank)
|
|
50
|
+
|
|
51
|
+
# Per-request Interpreter (does not own shared executor/tracer)
|
|
52
|
+
interp = Interpreter(
|
|
53
|
+
name=f"Worker-{infra.rank}-{job_id}",
|
|
54
|
+
tracer=infra.tracer,
|
|
55
|
+
trace_pid=infra.trace_pid,
|
|
56
|
+
store=infra.store,
|
|
57
|
+
root_dir=infra.root_dir,
|
|
58
|
+
handlers=infra.handlers,
|
|
59
|
+
executor=infra.executor,
|
|
60
|
+
comm_ctx=comm_ctx,
|
|
61
|
+
owns_executor=False,
|
|
62
|
+
owns_tracer=False,
|
|
63
|
+
)
|
|
64
|
+
interp.async_ops = set(infra.async_ops)
|
|
65
|
+
|
|
66
|
+
# Per-request SimpWorker: isolates current_parties
|
|
67
|
+
worker_state = SimpWorker(
|
|
68
|
+
rank=infra.rank,
|
|
69
|
+
world_size=infra.world_size,
|
|
70
|
+
communicator=infra.communicator, # raw comm kept for SPU BaseChannel
|
|
71
|
+
store=infra.store,
|
|
72
|
+
spu_endpoints=infra.spu_endpoints,
|
|
73
|
+
)
|
|
74
|
+
interp.set_dialect_state("simp", worker_state)
|
|
75
|
+
|
|
76
|
+
# Per-request SPUState: will use link.spawn() for Runtime isolation
|
|
77
|
+
spu_state = SPUState(infra=infra)
|
|
78
|
+
interp.set_dialect_state("spu", spu_state)
|
|
79
|
+
|
|
80
|
+
return interp
|
mplang/backends/spu_state.py
CHANGED
|
@@ -28,6 +28,7 @@ import spu.libspu as libspu
|
|
|
28
28
|
from mplang.runtime.dialect_state import DialectState
|
|
29
29
|
|
|
30
30
|
if TYPE_CHECKING:
|
|
31
|
+
from mplang.backends.simp_worker.infra import WorkerInfra
|
|
31
32
|
from mplang.dialects import spu
|
|
32
33
|
|
|
33
34
|
|
|
@@ -37,18 +38,61 @@ class SPUState(DialectState):
|
|
|
37
38
|
Caches SPU Runtime and Io objects per (local_rank, world_size, config, link_mode)
|
|
38
39
|
to enable reuse across multiple SPU kernel executions.
|
|
39
40
|
|
|
41
|
+
When created with a ``WorkerInfra`` reference (per-request mode), template
|
|
42
|
+
links are obtained from the shared infra (thread-safe) and then spawned
|
|
43
|
+
via ``link.spawn()`` for per-request isolation.
|
|
44
|
+
|
|
40
45
|
This replaces the previous global `_SPU_RUNTIMES` cache with a properly
|
|
41
46
|
lifecycle-managed dialect state.
|
|
42
47
|
"""
|
|
43
48
|
|
|
44
49
|
dialect_name: str = "spu"
|
|
45
50
|
|
|
46
|
-
def __init__(self) -> None:
|
|
47
|
-
#
|
|
51
|
+
def __init__(self, infra: WorkerInfra | None = None) -> None:
|
|
52
|
+
# Optional shared infrastructure (for per-request isolation via link.spawn)
|
|
53
|
+
self._infra = infra
|
|
54
|
+
# Key: (local_rank, world_size, protocol, field, link_mode, spu_endpoints)
|
|
48
55
|
# Value: (Runtime, Io)
|
|
49
56
|
self._runtimes: dict[
|
|
50
|
-
tuple[int, int, str, str, str
|
|
57
|
+
tuple[int, int, str, str, str, tuple[str, ...] | None],
|
|
58
|
+
tuple[spu_api.Runtime, spu_api.Io],
|
|
51
59
|
] = {}
|
|
60
|
+
# Local template link cache (used when no WorkerInfra is provided)
|
|
61
|
+
self._template_links: dict[tuple, libspu.link.Context] = {}
|
|
62
|
+
|
|
63
|
+
def _get_template_link(
|
|
64
|
+
self,
|
|
65
|
+
cache_key: tuple,
|
|
66
|
+
local_rank: int,
|
|
67
|
+
spu_world_size: int,
|
|
68
|
+
communicator: object | None,
|
|
69
|
+
parties: list[int] | None,
|
|
70
|
+
spu_endpoints: list[str] | None,
|
|
71
|
+
) -> libspu.link.Context:
|
|
72
|
+
"""Get or create a template link for the given configuration.
|
|
73
|
+
|
|
74
|
+
With ``WorkerInfra``: uses infra's thread-safe shared cache.
|
|
75
|
+
Without: uses a local (non-thread-safe) cache on this SPUState.
|
|
76
|
+
"""
|
|
77
|
+
|
|
78
|
+
def _create() -> libspu.link.Context:
|
|
79
|
+
if spu_endpoints:
|
|
80
|
+
return self._create_brpc_link(local_rank, spu_endpoints)
|
|
81
|
+
elif communicator is not None:
|
|
82
|
+
if parties is None:
|
|
83
|
+
raise ValueError("parties required when using communicator")
|
|
84
|
+
return self._create_channels_link(
|
|
85
|
+
local_rank, spu_world_size, communicator, parties
|
|
86
|
+
)
|
|
87
|
+
else:
|
|
88
|
+
return self._create_mem_link(local_rank, spu_world_size)
|
|
89
|
+
|
|
90
|
+
if self._infra is not None:
|
|
91
|
+
return self._infra.get_or_create_spu_link(cache_key, _create)
|
|
92
|
+
|
|
93
|
+
if cache_key not in self._template_links:
|
|
94
|
+
self._template_links[cache_key] = _create()
|
|
95
|
+
return self._template_links[cache_key]
|
|
52
96
|
|
|
53
97
|
def get_or_create(
|
|
54
98
|
self,
|
|
@@ -61,13 +105,18 @@ class SPUState(DialectState):
|
|
|
61
105
|
) -> tuple[spu_api.Runtime, spu_api.Io]:
|
|
62
106
|
"""Get or create SPU Runtime and Io for the given configuration.
|
|
63
107
|
|
|
108
|
+
Link mode priority: spu_endpoints (BRPC) > communicator (Channels) > mem.
|
|
109
|
+
When ``spu_endpoints`` is provided it always takes precedence, even if
|
|
110
|
+
a ``communicator`` is also supplied.
|
|
111
|
+
|
|
64
112
|
Args:
|
|
65
113
|
local_rank: The local rank within the SPU device (0-indexed).
|
|
66
114
|
spu_world_size: The number of parties in the SPU device.
|
|
67
115
|
config: SPU configuration including protocol settings.
|
|
68
|
-
spu_endpoints: Optional list of BRPC endpoints.
|
|
116
|
+
spu_endpoints: Optional list of BRPC endpoints. Takes highest
|
|
117
|
+
priority when provided.
|
|
69
118
|
communicator: Optional v2 communicator (ThreadCommunicator/HttpCommunicator).
|
|
70
|
-
|
|
119
|
+
Used only when ``spu_endpoints`` is not provided.
|
|
71
120
|
parties: Optional list of global ranks for SPU parties.
|
|
72
121
|
Required when communicator is provided.
|
|
73
122
|
|
|
@@ -77,10 +126,10 @@ class SPUState(DialectState):
|
|
|
77
126
|
from mplang.backends.spu_impl import to_runtime_config
|
|
78
127
|
|
|
79
128
|
# Determine link mode
|
|
80
|
-
if
|
|
81
|
-
link_mode = "channels"
|
|
82
|
-
elif spu_endpoints:
|
|
129
|
+
if spu_endpoints:
|
|
83
130
|
link_mode = "brpc"
|
|
131
|
+
elif communicator is not None:
|
|
132
|
+
link_mode = "channels"
|
|
84
133
|
else:
|
|
85
134
|
link_mode = "mem"
|
|
86
135
|
|
|
@@ -90,22 +139,22 @@ class SPUState(DialectState):
|
|
|
90
139
|
config.protocol,
|
|
91
140
|
config.field,
|
|
92
141
|
link_mode,
|
|
142
|
+
tuple(spu_endpoints) if spu_endpoints else None,
|
|
93
143
|
)
|
|
94
144
|
|
|
95
145
|
if cache_key in self._runtimes:
|
|
96
146
|
return self._runtimes[cache_key]
|
|
97
147
|
|
|
98
|
-
#
|
|
99
|
-
|
|
100
|
-
|
|
101
|
-
|
|
102
|
-
|
|
103
|
-
|
|
104
|
-
|
|
105
|
-
|
|
106
|
-
|
|
107
|
-
|
|
108
|
-
link = self._create_mem_link(local_rank, spu_world_size)
|
|
148
|
+
# Unified path: get-or-create template link, then spawn for isolation
|
|
149
|
+
template_link = self._get_template_link(
|
|
150
|
+
cache_key,
|
|
151
|
+
local_rank,
|
|
152
|
+
spu_world_size,
|
|
153
|
+
communicator,
|
|
154
|
+
parties,
|
|
155
|
+
spu_endpoints,
|
|
156
|
+
)
|
|
157
|
+
link = template_link.spawn()
|
|
109
158
|
|
|
110
159
|
# Create Runtime and Io
|
|
111
160
|
runtime_config = to_runtime_config(config)
|
|
@@ -183,5 +232,6 @@ class SPUState(DialectState):
|
|
|
183
232
|
return libspu.link.create_brpc(desc, local_rank)
|
|
184
233
|
|
|
185
234
|
def shutdown(self) -> None:
|
|
186
|
-
"""Clear all cached runtimes."""
|
|
235
|
+
"""Clear all cached runtimes and template links."""
|
|
187
236
|
self._runtimes.clear()
|
|
237
|
+
self._template_links.clear()
|
mplang/runtime/interpreter.py
CHANGED
|
@@ -24,7 +24,6 @@ from __future__ import annotations
|
|
|
24
24
|
|
|
25
25
|
import collections
|
|
26
26
|
import concurrent.futures
|
|
27
|
-
import contextlib
|
|
28
27
|
import hashlib
|
|
29
28
|
import json
|
|
30
29
|
import os
|
|
@@ -32,7 +31,7 @@ import pathlib
|
|
|
32
31
|
import queue
|
|
33
32
|
import threading
|
|
34
33
|
import time
|
|
35
|
-
from collections.abc import Callable
|
|
34
|
+
from collections.abc import Callable
|
|
36
35
|
from typing import TYPE_CHECKING, Any, cast
|
|
37
36
|
|
|
38
37
|
from mplang.edsl.context import AbstractInterpreter
|
|
@@ -427,6 +426,9 @@ class Interpreter(AbstractInterpreter):
|
|
|
427
426
|
root_dir: str | pathlib.Path | None = None,
|
|
428
427
|
handlers: dict[str, Callable[..., Any]] | None = None,
|
|
429
428
|
comm_ctx: CommContext | None = None,
|
|
429
|
+
*,
|
|
430
|
+
owns_executor: bool = True,
|
|
431
|
+
owns_tracer: bool = True,
|
|
430
432
|
) -> None:
|
|
431
433
|
# Persistence Root
|
|
432
434
|
self.root_dir = (
|
|
@@ -457,26 +459,6 @@ class Interpreter(AbstractInterpreter):
|
|
|
457
459
|
# all sibling outputs are cached here to avoid re-execution.
|
|
458
460
|
self._execution_cache: dict[Any, InterpObject] = {}
|
|
459
461
|
|
|
460
|
-
# -----------------------------------------------------------------
|
|
461
|
-
# Graph-local op execution ids (for deterministic communication tags)
|
|
462
|
-
# -----------------------------------------------------------------
|
|
463
|
-
# We assign a monotonically increasing exec_id to each op execution
|
|
464
|
-
# within a graph namespace, and keep it deterministic across parties.
|
|
465
|
-
#
|
|
466
|
-
# IMPORTANT:
|
|
467
|
-
# - We intentionally make exec_id grow across repeated executions of the
|
|
468
|
-
# same region graph (e.g., while_loop iterations) to avoid tag/key reuse.
|
|
469
|
-
#
|
|
470
|
-
# Implementation:
|
|
471
|
-
# - Each evaluate_graph(graph, ...) reserves a contiguous exec_id range
|
|
472
|
-
# [base, base + len(graph.operations)).
|
|
473
|
-
# - Op exec_id = base + op_index_in_graph.
|
|
474
|
-
# - Reservation is persisted per graph_exec_key (structural hash).
|
|
475
|
-
# - We forbid concurrent execution of the same graph_hash to avoid
|
|
476
|
-
# message tag confusion when a backend uses only per-op tags.
|
|
477
|
-
self._exec_id_lock = threading.Lock()
|
|
478
|
-
self._graph_next_exec_base: dict[str, int] = {}
|
|
479
|
-
self._active_graph_exec_keys: set[str] = set()
|
|
480
462
|
self._tls = threading.local()
|
|
481
463
|
self.executor = executor
|
|
482
464
|
self.async_ops: set[str] = set()
|
|
@@ -484,45 +466,16 @@ class Interpreter(AbstractInterpreter):
|
|
|
484
466
|
self.trace_pid = trace_pid
|
|
485
467
|
self.store: ObjectStore | None = store
|
|
486
468
|
self.comm_ctx = comm_ctx
|
|
487
|
-
|
|
488
|
-
|
|
489
|
-
|
|
490
|
-
self
|
|
491
|
-
*,
|
|
492
|
-
graph_exec_key: str | None = None,
|
|
493
|
-
op_exec_id: int | None = None,
|
|
494
|
-
) -> Iterator[None]:
|
|
495
|
-
"""Temporarily set execution context in thread-local storage."""
|
|
496
|
-
|
|
497
|
-
prev_graph_key = getattr(self._tls, "current_graph_exec_key", None)
|
|
498
|
-
prev_exec_id = getattr(self._tls, "current_op_exec_id", None)
|
|
499
|
-
|
|
500
|
-
if graph_exec_key is not None:
|
|
501
|
-
self._tls.current_graph_exec_key = graph_exec_key
|
|
502
|
-
if op_exec_id is not None:
|
|
503
|
-
self._tls.current_op_exec_id = op_exec_id
|
|
504
|
-
|
|
505
|
-
try:
|
|
506
|
-
yield
|
|
507
|
-
finally:
|
|
508
|
-
if graph_exec_key is not None:
|
|
509
|
-
if prev_graph_key is None:
|
|
510
|
-
delattr(self._tls, "current_graph_exec_key")
|
|
511
|
-
else:
|
|
512
|
-
self._tls.current_graph_exec_key = prev_graph_key
|
|
513
|
-
|
|
514
|
-
if op_exec_id is not None:
|
|
515
|
-
if prev_exec_id is None:
|
|
516
|
-
delattr(self._tls, "current_op_exec_id")
|
|
517
|
-
else:
|
|
518
|
-
self._tls.current_op_exec_id = prev_exec_id
|
|
469
|
+
# Ownership flags: when False, shutdown() skips releasing these
|
|
470
|
+
# shared resources. Per-request Interpreters set these to False.
|
|
471
|
+
self._owns_executor = owns_executor
|
|
472
|
+
self._owns_tracer = owns_tracer
|
|
519
473
|
|
|
520
474
|
def _graph_exec_key(self, graph: Graph) -> str:
|
|
521
475
|
"""Return a deterministic, structural hash for a graph.
|
|
522
476
|
|
|
523
477
|
Used for:
|
|
524
|
-
-
|
|
525
|
-
- Communication tag disambiguation (worker ops may include this key)
|
|
478
|
+
- Communication tag disambiguation in async child CommContext IDs
|
|
526
479
|
|
|
527
480
|
Note: we cache on the Graph object assuming graphs are immutable during
|
|
528
481
|
execution (finalized graphs / regions).
|
|
@@ -657,20 +610,20 @@ class Interpreter(AbstractInterpreter):
|
|
|
657
610
|
|
|
658
611
|
This method is idempotent and safe to call multiple times.
|
|
659
612
|
It performs the following cleanup:
|
|
660
|
-
1. Shuts down the internal executor (if
|
|
661
|
-
2. Stops the execution tracer (if
|
|
613
|
+
1. Shuts down the internal executor (if owned).
|
|
614
|
+
2. Stops the execution tracer (if owned).
|
|
662
615
|
3. Shuts down any attached dialect states (e.g., stopping drivers).
|
|
663
616
|
"""
|
|
664
617
|
logger.info("Shutting down Interpreter '%s'", self.name)
|
|
665
618
|
|
|
666
|
-
# 1. Shutdown Executor
|
|
667
|
-
if self.executor:
|
|
619
|
+
# 1. Shutdown Executor (only if we own it)
|
|
620
|
+
if self.executor and self._owns_executor:
|
|
668
621
|
logger.debug("Shutting down executor")
|
|
669
622
|
self.executor.shutdown(wait=True)
|
|
670
623
|
self.executor = None
|
|
671
624
|
|
|
672
|
-
# 2. Stop Tracer
|
|
673
|
-
if self.tracer:
|
|
625
|
+
# 2. Stop Tracer (only if we own it)
|
|
626
|
+
if self.tracer and self._owns_tracer:
|
|
674
627
|
logger.debug("Stopping execution tracer")
|
|
675
628
|
self.tracer.stop()
|
|
676
629
|
# Don't clear self.tracer, as we might want to read stats later
|
|
@@ -923,70 +876,18 @@ class Interpreter(AbstractInterpreter):
|
|
|
923
876
|
Returns:
|
|
924
877
|
List of runtime execution results corresponding to graph.outputs.
|
|
925
878
|
"""
|
|
926
|
-
|
|
927
|
-
|
|
928
|
-
|
|
929
|
-
|
|
930
|
-
|
|
931
|
-
|
|
932
|
-
|
|
933
|
-
|
|
934
|
-
|
|
935
|
-
self.
|
|
936
|
-
|
|
937
|
-
|
|
938
|
-
with self._tls_exec_context(graph_exec_key=graph_exec_key):
|
|
939
|
-
logger.debug(
|
|
940
|
-
"Evaluating graph: %d inputs, %d ops, %d outputs (job_id=%s, async=%s, graph_key=%s)",
|
|
941
|
-
len(inputs),
|
|
942
|
-
len(graph.operations),
|
|
943
|
-
len(graph.outputs),
|
|
944
|
-
job_id,
|
|
945
|
-
self.executor is not None,
|
|
946
|
-
graph_exec_key,
|
|
947
|
-
)
|
|
948
|
-
if self.executor:
|
|
949
|
-
return self._evaluate_graph_async(graph, inputs, job_id)
|
|
950
|
-
else:
|
|
951
|
-
return self._evaluate_graph_sync(graph, inputs, job_id)
|
|
952
|
-
finally:
|
|
953
|
-
with self._exec_id_lock:
|
|
954
|
-
self._active_graph_exec_keys.discard(graph_exec_key)
|
|
955
|
-
|
|
956
|
-
def _reserve_op_exec_base(self, graph: Graph) -> int:
|
|
957
|
-
"""Reserve a contiguous exec_id range for a single evaluate_graph call.
|
|
958
|
-
|
|
959
|
-
Counter is namespaced by the current graph_exec_key.
|
|
960
|
-
"""
|
|
961
|
-
key = self.current_graph_exec_key()
|
|
962
|
-
with self._exec_id_lock:
|
|
963
|
-
base = self._graph_next_exec_base.get(key, 0)
|
|
964
|
-
self._graph_next_exec_base[key] = base + len(graph.operations)
|
|
965
|
-
return base
|
|
966
|
-
|
|
967
|
-
def current_graph_exec_key(self) -> str:
|
|
968
|
-
"""Return current graph execution key during evaluate_graph execution."""
|
|
969
|
-
|
|
970
|
-
key = getattr(self._tls, "current_graph_exec_key", None)
|
|
971
|
-
if key is None:
|
|
972
|
-
raise RuntimeError(
|
|
973
|
-
"current_graph_exec_key() called outside of evaluate_graph execution"
|
|
974
|
-
)
|
|
975
|
-
return cast(str, key)
|
|
976
|
-
|
|
977
|
-
def current_op_exec_id(self) -> int:
|
|
978
|
-
"""Return current op exec_id during graph execution.
|
|
979
|
-
|
|
980
|
-
Worker-side implementations can use this to build deterministic,
|
|
981
|
-
unique communication tags without coupling to any specific op.
|
|
982
|
-
"""
|
|
983
|
-
|
|
984
|
-
exec_id = getattr(self._tls, "current_op_exec_id", None)
|
|
985
|
-
if exec_id is None:
|
|
986
|
-
raise RuntimeError(
|
|
987
|
-
"current_op_exec_id() called outside of evaluate_graph execution"
|
|
988
|
-
)
|
|
989
|
-
return cast(int, exec_id)
|
|
879
|
+
logger.debug(
|
|
880
|
+
"Evaluating graph: %d inputs, %d ops, %d outputs (job_id=%s, async=%s)",
|
|
881
|
+
len(inputs),
|
|
882
|
+
len(graph.operations),
|
|
883
|
+
len(graph.outputs),
|
|
884
|
+
job_id,
|
|
885
|
+
self.executor is not None,
|
|
886
|
+
)
|
|
887
|
+
if self.executor:
|
|
888
|
+
return self._evaluate_graph_async(graph, inputs, job_id)
|
|
889
|
+
else:
|
|
890
|
+
return self._evaluate_graph_sync(graph, inputs, job_id)
|
|
990
891
|
|
|
991
892
|
def current_comm_ctx(self) -> CommContext:
|
|
992
893
|
"""Return the current CommContext for this execution.
|
|
@@ -1043,10 +944,7 @@ class Interpreter(AbstractInterpreter):
|
|
|
1043
944
|
# Build consumer counts for intermediate value GC
|
|
1044
945
|
remaining_consumers = self._build_value_gc_info(graph)
|
|
1045
946
|
|
|
1046
|
-
|
|
1047
|
-
|
|
1048
|
-
for op_index, op in enumerate(graph.operations):
|
|
1049
|
-
exec_id = op_exec_base + op_index
|
|
947
|
+
for _op_index, op in enumerate(graph.operations):
|
|
1050
948
|
# Resolve inputs
|
|
1051
949
|
try:
|
|
1052
950
|
args = [env[val] for val in op.inputs]
|
|
@@ -1070,16 +968,15 @@ class Interpreter(AbstractInterpreter):
|
|
|
1070
968
|
if not handler:
|
|
1071
969
|
handler = get_impl(op.opcode)
|
|
1072
970
|
|
|
1073
|
-
|
|
1074
|
-
|
|
1075
|
-
|
|
1076
|
-
|
|
1077
|
-
|
|
1078
|
-
|
|
1079
|
-
|
|
1080
|
-
|
|
1081
|
-
|
|
1082
|
-
)
|
|
971
|
+
if handler:
|
|
972
|
+
# Pass interpreter to support recursive execution (HOFs)
|
|
973
|
+
# Pass op to access attributes and regions
|
|
974
|
+
# Pass args as runtime values
|
|
975
|
+
results = handler(self, op, *args)
|
|
976
|
+
else:
|
|
977
|
+
raise NotImplementedError(
|
|
978
|
+
f"No implementation registered for opcode: {op.opcode}"
|
|
979
|
+
)
|
|
1083
980
|
|
|
1084
981
|
# Update environment with outputs
|
|
1085
982
|
# Handler should return a single value or a tuple/list of values
|
|
@@ -1120,8 +1017,7 @@ class Interpreter(AbstractInterpreter):
|
|
|
1120
1017
|
self, graph: Graph, inputs: list[Any], job_id: str | None = None
|
|
1121
1018
|
) -> list[Any]:
|
|
1122
1019
|
"""Asynchronous execution with non-blocking DAG scheduling."""
|
|
1123
|
-
graph_exec_key = self.
|
|
1124
|
-
op_exec_base = self._reserve_op_exec_base(graph)
|
|
1020
|
+
graph_exec_key = self._graph_exec_key(graph)
|
|
1125
1021
|
op_to_index = {op: i for i, op in enumerate(graph.operations)}
|
|
1126
1022
|
|
|
1127
1023
|
# CommContext for this execution scope
|
|
@@ -1219,8 +1115,6 @@ class Interpreter(AbstractInterpreter):
|
|
|
1219
1115
|
# Extract args from env (must be ready)
|
|
1220
1116
|
args = [env[val] for val in op.inputs]
|
|
1221
1117
|
|
|
1222
|
-
exec_id = op_exec_base + op_to_index[op]
|
|
1223
|
-
|
|
1224
1118
|
handler = self.handlers.get(op.opcode)
|
|
1225
1119
|
if not handler:
|
|
1226
1120
|
handler = get_impl(op.opcode)
|
|
@@ -1230,20 +1124,17 @@ class Interpreter(AbstractInterpreter):
|
|
|
1230
1124
|
f"No implementation registered for opcode: {op.opcode}"
|
|
1231
1125
|
)
|
|
1232
1126
|
|
|
1233
|
-
#
|
|
1234
|
-
#
|
|
1235
|
-
#
|
|
1236
|
-
#
|
|
1237
|
-
#
|
|
1127
|
+
# Derive a deterministic child CommContext ID for this op.
|
|
1128
|
+
# We include graph_exec_key (structural hash of the graph) to prevent
|
|
1129
|
+
# op_idx collisions between *different nested graphs* within the same
|
|
1130
|
+
# request. For example, two simp.pcall_static ops in the same parent
|
|
1131
|
+
# graph each trigger a sub-graph execution — without graph_exec_key,
|
|
1132
|
+
# their child ops at the same positional index would share the same
|
|
1133
|
+
# CommContext ID and corrupt each other's messages.
|
|
1134
|
+
# Cross-request isolation is already handled by root context_id (= job_id).
|
|
1238
1135
|
op_idx = op_to_index[op]
|
|
1239
1136
|
if root_comm_ctx is not None:
|
|
1240
|
-
|
|
1241
|
-
|
|
1242
|
-
child_ctx: CommContext | None = CommContext(
|
|
1243
|
-
root_comm_ctx._comm,
|
|
1244
|
-
f"{root_comm_ctx._id}.{graph_exec_key}.{op_idx}",
|
|
1245
|
-
root_comm_ctx._rank,
|
|
1246
|
-
)
|
|
1137
|
+
child_ctx = root_comm_ctx.spawn(suffix=f"{graph_exec_key}.{op_idx}")
|
|
1247
1138
|
else:
|
|
1248
1139
|
child_ctx = None
|
|
1249
1140
|
|
|
@@ -1254,17 +1145,14 @@ class Interpreter(AbstractInterpreter):
|
|
|
1254
1145
|
|
|
1255
1146
|
# Submit to executor
|
|
1256
1147
|
def task() -> Any:
|
|
1257
|
-
|
|
1258
|
-
|
|
1259
|
-
|
|
1260
|
-
|
|
1261
|
-
|
|
1262
|
-
|
|
1263
|
-
|
|
1264
|
-
|
|
1265
|
-
res = handler(self, op, *args)
|
|
1266
|
-
tracer.log_end(op, start_ts, pid=self.trace_pid)
|
|
1267
|
-
return res
|
|
1148
|
+
if child_ctx is not None:
|
|
1149
|
+
self._tls.comm_ctx = child_ctx
|
|
1150
|
+
start_ts = tracer.log_start(
|
|
1151
|
+
op, pid=self.trace_pid, namespace=self.trace_pid
|
|
1152
|
+
)
|
|
1153
|
+
res = handler(self, op, *args)
|
|
1154
|
+
tracer.log_end(op, start_ts, pid=self.trace_pid)
|
|
1155
|
+
return res
|
|
1268
1156
|
|
|
1269
1157
|
def callback(fut: Any) -> None:
|
|
1270
1158
|
try:
|
|
@@ -1278,17 +1166,14 @@ class Interpreter(AbstractInterpreter):
|
|
|
1278
1166
|
else:
|
|
1279
1167
|
# Sync execution (run immediately)
|
|
1280
1168
|
try:
|
|
1281
|
-
|
|
1282
|
-
|
|
1283
|
-
|
|
1284
|
-
|
|
1285
|
-
|
|
1286
|
-
|
|
1287
|
-
|
|
1288
|
-
|
|
1289
|
-
res = handler(self, op, *args)
|
|
1290
|
-
tracer.log_end(op, start_ts, pid=self.trace_pid)
|
|
1291
|
-
on_op_done(op, res)
|
|
1169
|
+
if child_ctx is not None:
|
|
1170
|
+
self._tls.comm_ctx = child_ctx
|
|
1171
|
+
start_ts = tracer.log_start(
|
|
1172
|
+
op, pid=self.trace_pid, namespace=self.trace_pid
|
|
1173
|
+
)
|
|
1174
|
+
res = handler(self, op, *args)
|
|
1175
|
+
tracer.log_end(op, start_ts, pid=self.trace_pid)
|
|
1176
|
+
on_op_done(op, res)
|
|
1292
1177
|
except Exception as e:
|
|
1293
1178
|
on_op_done(op, None, error=e)
|
|
1294
1179
|
|
|
@@ -11,7 +11,7 @@ mplang/backends/func_impl.py,sha256=vhkSJnvgSzvQWXqOE40O8npqOmtNqLxefzv9ETuQODQ,
|
|
|
11
11
|
mplang/backends/phe_impl.py,sha256=CefRpFrRJvbZm3X58OYOFAyNZXKDp081Q6CT8fhdreE,4118
|
|
12
12
|
mplang/backends/simp_design.md,sha256=u4YeDLKk5avsueXhzPtt4OUBToBmurOVK_BWTKTJ7w0,5246
|
|
13
13
|
mplang/backends/spu_impl.py,sha256=73JTdIFqWoKgVM5jR1lJm_PODlgb9BrQUwLSDcGWi20,11876
|
|
14
|
-
mplang/backends/spu_state.py,sha256=
|
|
14
|
+
mplang/backends/spu_state.py,sha256=XKquhcAND8afEUzzfJ3VFb1wQ3ptFuPptStxHt20JpA,8652
|
|
15
15
|
mplang/backends/store_impl.py,sha256=IeoMxt9Ma21aYv-dKgF7COinvO6uc09aw5gGTp-5Ctg,1599
|
|
16
16
|
mplang/backends/table_impl.py,sha256=g3fjXOdHUoEEE1w330C5n5FwKcM26cQSo0jPgxNQFsE,30104
|
|
17
17
|
mplang/backends/tee_impl.py,sha256=5yed94Lr55hgSDT6j3FjdTboqzhl4HpJ9lGk1YGZBXQ,7026
|
|
@@ -19,18 +19,20 @@ mplang/backends/tensor_impl.py,sha256=SkleS37V5_tW1vKzp8c4EoRyGN590lRJvyIkys1x9o
|
|
|
19
19
|
mplang/backends/util.py,sha256=yrftyGCxOHSfyTawopGhrErUrfgV_5xXqHgEjEY3f8A,3940
|
|
20
20
|
mplang/backends/simp_driver/__init__.py,sha256=pQkHGR8HagnPB4OqagPZY-Ul19y6HjIJ5dY6QXZrBH8,1243
|
|
21
21
|
mplang/backends/simp_driver/http.py,sha256=cPcX2f2SIXy6NtQDjSaCJGr8qKut-WcOSmAP2UQomrk,5674
|
|
22
|
-
mplang/backends/simp_driver/mem.py,sha256=
|
|
22
|
+
mplang/backends/simp_driver/mem.py,sha256=sb4TT61VEOf384PoMK2WAOCI0SN_iH03bVLo0e7DLBU,10213
|
|
23
23
|
mplang/backends/simp_driver/ops.py,sha256=WYObWDRCsiXH0UBWZX5vD5W98ZPkd88U_qBV8SE5rA8,4503
|
|
24
24
|
mplang/backends/simp_driver/state.py,sha256=dNmYMFN2D2BBdgs6C0YLaHrfaBRMgs05UNxMWw6tZIs,1713
|
|
25
25
|
mplang/backends/simp_driver/values.py,sha256=Lz1utNSIzH-dCzZAEjU6JRcxPsfKGfUJrYl6gIuMOGw,1509
|
|
26
|
-
mplang/backends/simp_worker/__init__.py,sha256=
|
|
26
|
+
mplang/backends/simp_worker/__init__.py,sha256=MGfnuvkP6mh9FzSOpLCZ96bgqDydE8GNYCtBhidPiWk,1922
|
|
27
27
|
mplang/backends/simp_worker/base.py,sha256=sRfxktJjCuYgD7tASWMNL_gWQiAkqJkK9N053GWxhM8,12350
|
|
28
28
|
mplang/backends/simp_worker/collective_algorithms.py,sha256=NFUpTu2X1Kxl5nI8ef5l1fhD2t3AffeNFDSo3hOzQYo,4667
|
|
29
29
|
mplang/backends/simp_worker/collectives.py,sha256=5T29LYy9EJePLoKgC47XQezlUaZa7YCfgsQXT7XGIEY,5970
|
|
30
|
-
mplang/backends/simp_worker/comm_context.py,sha256=
|
|
31
|
-
mplang/backends/simp_worker/http.py,sha256=
|
|
30
|
+
mplang/backends/simp_worker/comm_context.py,sha256=EmDNYOpCP5p8nCL1nJcj9Ztua8Kb_SC4cUpuVKHtZOI,3910
|
|
31
|
+
mplang/backends/simp_worker/http.py,sha256=7SEnQE8a-vbNTQEd4xIo2jhI8AsaEqsYZ8IcINsg8qY,35129
|
|
32
|
+
mplang/backends/simp_worker/infra.py,sha256=9f-MpuNbL3vUdkOwGg8EW6ZJPyA41JJTTnTqhvtD4A8,3848
|
|
32
33
|
mplang/backends/simp_worker/mem.py,sha256=2AWTvv1awtGt-GaEpfnGLE9QoeaCCdfz1BgEKklVorA,6171
|
|
33
34
|
mplang/backends/simp_worker/ops.py,sha256=SlZ9bJGkvoTwgXjvWvyrNsinf3B-zMm04XVL8IEKYMs,5628
|
|
35
|
+
mplang/backends/simp_worker/request.py,sha256=J_qPh0Crj8DWvrEEzAMt7ie-F4AvEl_1AaYlfWf3TtQ,2771
|
|
34
36
|
mplang/backends/simp_worker/state.py,sha256=nIu0ybvdYqRqp0TkoSneUF2u31evDHucCRduVBaDals,1445
|
|
35
37
|
mplang/dialects/__init__.py,sha256=CYMmkeQVU0Znr9n3_5clZKb16u7acJ5jl5Zjbx4Tn1U,1478
|
|
36
38
|
mplang/dialects/_jax_utils.py,sha256=LJhFQNvwUB7Kq4YJTfP5MAE4Q7ooNperHn4x2TsKW4s,9056
|
|
@@ -101,7 +103,7 @@ mplang/libs/mpc/vole/ldpc.py,sha256=gOmIbyOjkGE5lewyatl3p6FizNNH8LZ_1oOhp_-TOck,
|
|
|
101
103
|
mplang/libs/mpc/vole/silver.py,sha256=EIxhpFIVNBemgeIZzCu5Cz_4wysxRm9b1Xfu0xiweVQ,12218
|
|
102
104
|
mplang/runtime/__init__.py,sha256=VdUwJ3kDaI46FvGw7iMGwcsjt0HTGmmRmaBwj99xKIw,620
|
|
103
105
|
mplang/runtime/dialect_state.py,sha256=HxO1i4kSOujS2tQzAF9-WmI3nChSaGgupf2_07dHetY,1277
|
|
104
|
-
mplang/runtime/interpreter.py,sha256=
|
|
106
|
+
mplang/runtime/interpreter.py,sha256=boz3LDziXJbgYBDCVG_NwsXD9jGteHYHqn3lksTBVGs,46199
|
|
105
107
|
mplang/runtime/object_store.py,sha256=8Xqr87mkKuQIs-gVZ89Nk62o2GjJDk8meQ4TT66i4aQ,17916
|
|
106
108
|
mplang/runtime/value.py,sha256=EqlhSgxLTJi_FF3ppyKjMe4eHS6-ROx-zK1YesG1U4o,4311
|
|
107
109
|
mplang/tool/__init__.py,sha256=9K-T50W_vClUlyERcVx5xGZaeyv0Ts63SaQX6AZtjIs,1341
|
|
@@ -109,8 +111,8 @@ mplang/tool/program.py,sha256=W3H8bpPirnoJ4ZrmyPYuMCPadJis20o__n_1MKqCsWU,11058
|
|
|
109
111
|
mplang/utils/__init__.py,sha256=Hwrwti2nfPxWUXV8DN6T1QaqXH_Jsd27k8UMSdBGUns,1073
|
|
110
112
|
mplang/utils/func_utils.py,sha256=Jdn_60jN3jcSE_oAqSMTLQjiE8CLyPpY-H1HmPIL5mw,5354
|
|
111
113
|
mplang/utils/logging.py,sha256=9dMhwprVbx1WMGJrgoQbWmV50vyYuLU4NSPnetcl1Go,7237
|
|
112
|
-
mplang_nightly-0.1.
|
|
113
|
-
mplang_nightly-0.1.
|
|
114
|
-
mplang_nightly-0.1.
|
|
115
|
-
mplang_nightly-0.1.
|
|
116
|
-
mplang_nightly-0.1.
|
|
114
|
+
mplang_nightly-0.1.dev332.dist-info/METADATA,sha256=H4LyHhbLgPXLlYVKdf1sNzDai9gAqU_JMw20NA92NEo,16783
|
|
115
|
+
mplang_nightly-0.1.dev332.dist-info/WHEEL,sha256=QccIxa26bgl1E6uMy58deGWi-0aeIkkangHcxk2kWfw,87
|
|
116
|
+
mplang_nightly-0.1.dev332.dist-info/entry_points.txt,sha256=mG1oJT-GAjQR834a62_QIWb7litzWPPyVnwFqm-rWuY,55
|
|
117
|
+
mplang_nightly-0.1.dev332.dist-info/licenses/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
|
|
118
|
+
mplang_nightly-0.1.dev332.dist-info/RECORD,,
|
|
File without changes
|
{mplang_nightly-0.1.dev330.dist-info → mplang_nightly-0.1.dev332.dist-info}/entry_points.txt
RENAMED
|
File without changes
|
{mplang_nightly-0.1.dev330.dist-info → mplang_nightly-0.1.dev332.dist-info}/licenses/LICENSE
RENAMED
|
File without changes
|