mplang-nightly 0.1.dev152__py3-none-any.whl → 0.1.dev153__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mplang/core/cluster.py +95 -24
- mplang/runtime/client.py +4 -18
- mplang/runtime/driver.py +1 -6
- mplang/runtime/server.py +95 -49
- mplang/runtime/session.py +285 -0
- mplang/runtime/simulation.py +15 -13
- {mplang_nightly-0.1.dev152.dist-info → mplang_nightly-0.1.dev153.dist-info}/METADATA +1 -1
- {mplang_nightly-0.1.dev152.dist-info → mplang_nightly-0.1.dev153.dist-info}/RECORD +11 -11
- mplang/runtime/resource.py +0 -365
- {mplang_nightly-0.1.dev152.dist-info → mplang_nightly-0.1.dev153.dist-info}/WHEEL +0 -0
- {mplang_nightly-0.1.dev152.dist-info → mplang_nightly-0.1.dev153.dist-info}/entry_points.txt +0 -0
- {mplang_nightly-0.1.dev152.dist-info → mplang_nightly-0.1.dev153.dist-info}/licenses/LICENSE +0 -0
mplang/core/cluster.py
CHANGED
@@ -25,23 +25,28 @@ from typing import Any
|
|
25
25
|
|
26
26
|
@dataclass(frozen=True)
|
27
27
|
class RuntimeInfo:
|
28
|
-
"""
|
29
|
-
|
28
|
+
"""Per-physical-node runtime configuration.
|
29
|
+
|
30
|
+
``op_bindings`` is a per-node override map (logical_op -> kernel_id) merged
|
31
|
+
into that node's ``RuntimeContext``. Unknown future / auxiliary fields are
|
32
|
+
preserved in ``extra``.
|
30
33
|
"""
|
31
34
|
|
32
35
|
version: str
|
33
36
|
platform: str
|
34
|
-
|
37
|
+
# Per-node partial override dispatch table (merged over project defaults).
|
38
|
+
op_bindings: dict[str, str] = field(default_factory=dict)
|
35
39
|
|
36
|
-
# A catch-all for any other custom or future properties
|
40
|
+
# A catch-all for any other custom or future properties (must not collide
|
41
|
+
# with reserved keys: version, platform, op_bindings).
|
37
42
|
extra: dict[str, Any] = field(default_factory=dict)
|
38
43
|
|
39
44
|
def to_dict(self) -> dict[str, Any]:
|
40
|
-
"""Convert RuntimeInfo to a dictionary."""
|
45
|
+
"""Convert RuntimeInfo to a dictionary (stable field names)."""
|
41
46
|
result = {
|
42
47
|
"version": self.version,
|
43
48
|
"platform": self.platform,
|
44
|
-
"
|
49
|
+
"op_bindings": self.op_bindings,
|
45
50
|
}
|
46
51
|
result.update(self.extra)
|
47
52
|
return result
|
@@ -175,7 +180,8 @@ class ClusterSpec:
|
|
175
180
|
|
176
181
|
# 2. Parse Physical Nodes, using the list index as the rank
|
177
182
|
nodes_map: dict[str, Node] = {}
|
178
|
-
|
183
|
+
# Reserved runtime info keys we recognize explicitly.
|
184
|
+
known_runtime_fields = {"version", "platform", "op_bindings"}
|
179
185
|
for i, node_cfg in enumerate(config["nodes"]):
|
180
186
|
if "rank" in node_cfg:
|
181
187
|
# Optionally, we can log a warning that the explicit 'rank' is ignored.
|
@@ -187,11 +193,12 @@ class ClusterSpec:
|
|
187
193
|
for k, v in runtime_info_cfg.items()
|
188
194
|
if k not in known_runtime_fields
|
189
195
|
}
|
190
|
-
|
196
|
+
# Gracefully ignore legacy 'backends' if present (treated as extra)
|
197
|
+
# for backward compatibility.
|
191
198
|
runtime_info = RuntimeInfo(
|
192
199
|
version=runtime_info_cfg.get("version", "N/A"),
|
193
200
|
platform=runtime_info_cfg.get("platform", "N/A"),
|
194
|
-
|
201
|
+
op_bindings=runtime_info_cfg.get("op_bindings", {}) or {},
|
195
202
|
extra=extra_runtime_info,
|
196
203
|
)
|
197
204
|
|
@@ -227,32 +234,96 @@ class ClusterSpec:
|
|
227
234
|
return cls(nodes=nodes_map, devices=devices_map)
|
228
235
|
|
229
236
|
@classmethod
|
230
|
-
def simple(
|
231
|
-
|
232
|
-
|
233
|
-
|
237
|
+
def simple(
|
238
|
+
cls,
|
239
|
+
world_size: int,
|
240
|
+
*,
|
241
|
+
endpoints: list[str] | None = None,
|
242
|
+
spu_protocol: str = "SEMI2K",
|
243
|
+
spu_field: str = "FM128",
|
244
|
+
runtime_version: str = "simulated",
|
245
|
+
runtime_platform: str = "simulated",
|
246
|
+
op_bindings: list[dict[str, str]] | None = None,
|
247
|
+
enable_local_device: bool = True,
|
248
|
+
enable_spu_device: bool = True,
|
249
|
+
) -> ClusterSpec:
|
250
|
+
"""Convenience constructor used heavily in tests.
|
251
|
+
|
252
|
+
Parameters
|
253
|
+
----------
|
254
|
+
world_size:
|
255
|
+
Number of parties (physical nodes).
|
256
|
+
endpoints:
|
257
|
+
Optional explicit endpoint list of length ``world_size``. Each element may
|
258
|
+
include scheme (``http://``) or not; stored verbatim. If not provided we
|
259
|
+
synthesize ``localhost:{5000 + i}`` (5000 is a fixed default; pass explicit
|
260
|
+
endpoints for control). Deprecated ``base_port`` legacy kwarg can adjust it.
|
261
|
+
spu_protocol / spu_field:
|
262
|
+
SPU device config values.
|
263
|
+
runtime_version / runtime_platform:
|
264
|
+
Populated into each node's ``RuntimeInfo``.
|
265
|
+
op_bindings:
|
266
|
+
Optional list of length ``world_size`` supplying per-node op_bindings
|
267
|
+
override dicts (defaults to empty dicts).
|
268
|
+
enable_local_device:
|
269
|
+
If True (default), create one ``local_{rank}`` device per node.
|
270
|
+
enable_spu_device:
|
271
|
+
If True (default) create a shared SPU device named ``SP0``.
|
272
|
+
"""
|
273
|
+
base_port = 5000
|
274
|
+
|
275
|
+
if endpoints is not None and len(endpoints) != world_size:
|
276
|
+
raise ValueError(
|
277
|
+
"len(endpoints) must equal world_size when provided: "
|
278
|
+
f"{len(endpoints)} != {world_size}"
|
279
|
+
)
|
280
|
+
|
281
|
+
if op_bindings is not None and len(op_bindings) != world_size:
|
282
|
+
raise ValueError(
|
283
|
+
"len(op_bindings) must equal world_size when provided: "
|
284
|
+
f"{len(op_bindings)} != {world_size}"
|
285
|
+
)
|
286
|
+
|
287
|
+
if not enable_local_device and not enable_spu_device:
|
288
|
+
raise ValueError(
|
289
|
+
"At least one of enable_local_device or enable_spu_device must be True"
|
290
|
+
)
|
291
|
+
|
292
|
+
nodes: dict[str, Node] = {}
|
293
|
+
for i in range(world_size):
|
294
|
+
ep = endpoints[i] if endpoints is not None else f"localhost:{base_port + i}"
|
295
|
+
node_op_bindings = op_bindings[i] if op_bindings is not None else {}
|
296
|
+
nodes[f"node{i}"] = Node(
|
234
297
|
name=f"node{i}",
|
235
298
|
rank=i,
|
236
|
-
endpoint=
|
299
|
+
endpoint=ep,
|
237
300
|
runtime_info=RuntimeInfo(
|
238
|
-
version=
|
239
|
-
platform=
|
240
|
-
|
301
|
+
version=runtime_version,
|
302
|
+
platform=runtime_platform,
|
303
|
+
op_bindings=node_op_bindings,
|
241
304
|
),
|
242
305
|
)
|
243
|
-
for i in range(world_size)
|
244
|
-
}
|
245
306
|
|
246
|
-
devices = {
|
247
|
-
|
307
|
+
devices: dict[str, Device] = {}
|
308
|
+
# Optional per-node local devices
|
309
|
+
if enable_local_device:
|
310
|
+
for i in range(world_size):
|
311
|
+
devices[f"local_{i}"] = Device(
|
312
|
+
name=f"local_{i}",
|
313
|
+
kind="local",
|
314
|
+
members=[nodes[f"node{i}"]],
|
315
|
+
)
|
316
|
+
|
317
|
+
# Shared SPU device
|
318
|
+
if enable_spu_device:
|
319
|
+
devices["SP0"] = Device(
|
248
320
|
name="SP0",
|
249
321
|
kind="SPU",
|
250
322
|
members=list(nodes.values()),
|
251
323
|
config={
|
252
|
-
"protocol":
|
253
|
-
"field":
|
324
|
+
"protocol": spu_protocol,
|
325
|
+
"field": spu_field,
|
254
326
|
},
|
255
327
|
)
|
256
|
-
}
|
257
328
|
|
258
329
|
return cls(nodes=nodes, devices=devices)
|
mplang/runtime/client.py
CHANGED
@@ -81,21 +81,14 @@ class HttpExecutorClient:
|
|
81
81
|
self,
|
82
82
|
name: str,
|
83
83
|
rank: int,
|
84
|
-
|
85
|
-
*,
|
86
|
-
spu_mask: int = 0,
|
87
|
-
spu_protocol: str = "SEMI2K",
|
88
|
-
spu_field: str = "FM64",
|
84
|
+
cluster_spec: dict,
|
89
85
|
) -> str:
|
90
86
|
"""Create a new session.
|
91
87
|
|
92
88
|
Args:
|
93
89
|
name: Session name/ID.
|
94
|
-
rank:
|
95
|
-
|
96
|
-
spu_mask: SPU mask for the session, 0 means no SPU.
|
97
|
-
spu_protocol: SPU protocol for the session (e.g., "SEMI2K", "ABY3").
|
98
|
-
spu_field: SPU field for the session (e.g., "FM64", "FM128").
|
90
|
+
rank: This party's rank.
|
91
|
+
cluster_spec: Full cluster specification dict (ClusterSpec.to_dict()).
|
99
92
|
|
100
93
|
Returns:
|
101
94
|
The session name/ID
|
@@ -104,14 +97,7 @@ class HttpExecutorClient:
|
|
104
97
|
RuntimeError: If session creation fails
|
105
98
|
"""
|
106
99
|
url = f"/sessions/{name}"
|
107
|
-
|
108
|
-
payload: dict[str, Any] = {
|
109
|
-
"rank": rank,
|
110
|
-
"endpoints": endpoints,
|
111
|
-
"spu_mask": spu_mask,
|
112
|
-
"spu_protocol": spu_protocol,
|
113
|
-
"spu_field": spu_field,
|
114
|
-
}
|
100
|
+
payload: dict[str, Any] = {"rank": rank, "cluster_spec": cluster_spec}
|
115
101
|
|
116
102
|
try:
|
117
103
|
response = await self._client.put(url, json=payload)
|
mplang/runtime/driver.py
CHANGED
@@ -145,8 +145,6 @@ class Driver(InterpContext):
|
|
145
145
|
"""Get existing session or create a new one across all HTTP servers."""
|
146
146
|
if self._session_id is None:
|
147
147
|
new_session_id = new_uuid()
|
148
|
-
endpoints_list = list(self.node_addrs.values())
|
149
|
-
|
150
148
|
# Create temporary clients for session creation
|
151
149
|
clients = self._create_clients()
|
152
150
|
try:
|
@@ -158,10 +156,7 @@ class Driver(InterpContext):
|
|
158
156
|
task = client.create_session(
|
159
157
|
name=new_session_id,
|
160
158
|
rank=rank,
|
161
|
-
|
162
|
-
spu_mask=self.spu_mask_int,
|
163
|
-
spu_protocol=self.spu_protocol_str,
|
164
|
-
spu_field=self.spu_field_str,
|
159
|
+
cluster_spec=self.cluster_spec.to_dict(),
|
165
160
|
)
|
166
161
|
tasks.append(task)
|
167
162
|
|
mplang/runtime/server.py
CHANGED
@@ -32,14 +32,30 @@ from mplang.core.table import TableType
|
|
32
32
|
from mplang.core.tensor import TensorType
|
33
33
|
from mplang.kernels.base import KernelContext
|
34
34
|
from mplang.protos.v1alpha1 import mpir_pb2
|
35
|
-
from mplang.runtime import resource
|
36
35
|
from mplang.runtime.data_providers import DataProvider, ResolvedURI, register_provider
|
37
36
|
from mplang.runtime.exceptions import InvalidRequestError, ResourceNotFound
|
37
|
+
from mplang.runtime.session import (
|
38
|
+
Computation,
|
39
|
+
Session,
|
40
|
+
Symbol,
|
41
|
+
)
|
38
42
|
|
39
43
|
logger = logging.getLogger(__name__)
|
40
44
|
|
41
45
|
app = FastAPI()
|
42
46
|
|
47
|
+
# per-server global state
|
48
|
+
_sessions: dict[str, Session] = {}
|
49
|
+
_global_symbols: dict[str, Symbol] = {}
|
50
|
+
|
51
|
+
|
52
|
+
def register_session(session: Session) -> Session: # pragma: no cover - test helper
|
53
|
+
existing = _sessions.get(session.name)
|
54
|
+
if existing:
|
55
|
+
return existing
|
56
|
+
_sessions[session.name] = session
|
57
|
+
return session
|
58
|
+
|
43
59
|
|
44
60
|
class _SymbolsProvider(DataProvider):
|
45
61
|
"""Server-local symbols provider backed by BackendRuntime.state."""
|
@@ -83,7 +99,7 @@ class _SymbolsProvider(DataProvider):
|
|
83
99
|
ctx: KernelContext,
|
84
100
|
) -> Any: # type: ignore[override]
|
85
101
|
name = self._symbol_name(uri)
|
86
|
-
sym =
|
102
|
+
sym = _global_symbols.get(name)
|
87
103
|
if sym is None:
|
88
104
|
raise ResourceNotFound(f"Global symbol '{name}' not found")
|
89
105
|
return sym.data
|
@@ -102,8 +118,13 @@ class _SymbolsProvider(DataProvider):
|
|
102
118
|
raise InvalidRequestError(
|
103
119
|
f"Failed to encode value for symbols:// write: {e!s}"
|
104
120
|
) from e
|
105
|
-
|
106
|
-
|
121
|
+
try:
|
122
|
+
obj = pickle.loads(base64.b64decode(data_b64))
|
123
|
+
except Exception as e: # pragma: no cover - defensive
|
124
|
+
raise InvalidRequestError(
|
125
|
+
f"Failed to decode value for symbols:// write: {e!s}"
|
126
|
+
) from e
|
127
|
+
_global_symbols[name] = Symbol(name=name, mptype={}, data=obj)
|
107
128
|
|
108
129
|
|
109
130
|
# Register symbols provider explicitly for server runtime
|
@@ -168,11 +189,7 @@ def validate_name(name: str, name_type: str) -> None:
|
|
168
189
|
# Request/Response Models
|
169
190
|
class CreateSessionRequest(BaseModel):
|
170
191
|
rank: int
|
171
|
-
|
172
|
-
# SPU related
|
173
|
-
spu_mask: int
|
174
|
-
spu_protocol: str
|
175
|
-
spu_field: str
|
192
|
+
cluster_spec: dict
|
176
193
|
|
177
194
|
|
178
195
|
class SessionResponse(BaseModel):
|
@@ -229,7 +246,7 @@ async def health_check() -> dict[str, str]:
|
|
229
246
|
@app.get("/sessions", response_model=SessionListResponse)
|
230
247
|
def list_sessions() -> SessionListResponse:
|
231
248
|
"""List all session names."""
|
232
|
-
return SessionListResponse(sessions=
|
249
|
+
return SessionListResponse(sessions=list(_sessions.keys()))
|
233
250
|
|
234
251
|
|
235
252
|
# List all computations in a session
|
@@ -238,39 +255,44 @@ def list_sessions() -> SessionListResponse:
|
|
238
255
|
)
|
239
256
|
def list_session_computations(session_name: str) -> ComputationListResponse:
|
240
257
|
"""List all computation names in a session."""
|
241
|
-
|
242
|
-
if not
|
258
|
+
sess = _sessions.get(session_name)
|
259
|
+
if not sess:
|
243
260
|
raise ResourceNotFound(f"Session '{session_name}' not found")
|
244
|
-
return ComputationListResponse(computations=
|
261
|
+
return ComputationListResponse(computations=sess.list_computations())
|
245
262
|
|
246
263
|
|
247
264
|
# Session endpoints
|
248
265
|
@app.put("/sessions/{session_name}", response_model=SessionResponse)
|
249
266
|
def create_session(session_name: str, request: CreateSessionRequest) -> SessionResponse:
|
250
267
|
validate_name(session_name, "session")
|
251
|
-
session
|
252
|
-
|
253
|
-
|
254
|
-
|
255
|
-
|
256
|
-
|
257
|
-
|
258
|
-
|
259
|
-
|
268
|
+
# Delegate cluster spec parsing & session construction to resource layer
|
269
|
+
from mplang.core.cluster import ClusterSpec # local import to avoid cycles
|
270
|
+
|
271
|
+
if session_name in _sessions:
|
272
|
+
sess = _sessions[session_name]
|
273
|
+
else:
|
274
|
+
spec = ClusterSpec.from_dict(request.cluster_spec)
|
275
|
+
if len(spec.get_devices_by_kind("SPU")) == 0:
|
276
|
+
raise InvalidRequestError("No SPU device found in cluster_spec for session")
|
277
|
+
sess = Session(name=session_name, rank=request.rank, cluster_spec=spec)
|
278
|
+
_sessions[session_name] = sess
|
279
|
+
return SessionResponse(name=sess.name)
|
260
280
|
|
261
281
|
|
262
282
|
@app.get("/sessions/{session_name}", response_model=SessionResponse)
|
263
283
|
def get_session(session_name: str) -> SessionResponse:
|
264
|
-
|
265
|
-
if not
|
284
|
+
sess = _sessions.get(session_name)
|
285
|
+
if not sess:
|
266
286
|
raise ResourceNotFound(f"Session '{session_name}' not found")
|
267
|
-
return SessionResponse(name=
|
287
|
+
return SessionResponse(name=sess.name)
|
268
288
|
|
269
289
|
|
270
290
|
@app.delete("/sessions/{session_name}")
|
271
291
|
def delete_session(session_name: str) -> dict[str, str]:
|
272
292
|
"""Delete a session and all its associated resources."""
|
273
|
-
if
|
293
|
+
if session_name in _sessions:
|
294
|
+
del _sessions[session_name]
|
295
|
+
logging.info(f"Session {session_name} deleted successfully")
|
274
296
|
return {"message": f"Session '{session_name}' deleted successfully"}
|
275
297
|
else:
|
276
298
|
raise ResourceNotFound(f"Session '{session_name}' not found")
|
@@ -299,18 +321,25 @@ def create_and_execute_computation(
|
|
299
321
|
raise InvalidRequestError("Failed to parse expression from protobuf")
|
300
322
|
|
301
323
|
# Create the computation resource
|
302
|
-
|
303
|
-
|
304
|
-
|
305
|
-
|
306
|
-
|
307
|
-
|
324
|
+
sess = _sessions.get(session_name)
|
325
|
+
if not sess:
|
326
|
+
raise ResourceNotFound(f"Session '{session_name}' not found.")
|
327
|
+
comp = sess.get_computation(computation_id)
|
328
|
+
if not comp:
|
329
|
+
comp = Computation(name=computation_id, expr=expr)
|
330
|
+
sess.add_computation(comp)
|
331
|
+
sess.execute(comp, request.input_names, request.output_names)
|
332
|
+
return ComputationResponse(name=computation_id)
|
308
333
|
|
309
334
|
|
310
335
|
@app.delete("/sessions/{session_name}/computations/{computation_id}")
|
311
336
|
def delete_computation(session_name: str, computation_id: str) -> dict[str, str]:
|
312
337
|
"""Delete a specific computation."""
|
313
|
-
|
338
|
+
sess = _sessions.get(session_name)
|
339
|
+
if sess and sess.delete_computation(computation_id):
|
340
|
+
logging.info(
|
341
|
+
f"Computation {computation_id} deleted from session {session_name}"
|
342
|
+
)
|
314
343
|
return {"message": f"Computation '{computation_id}' deleted successfully"}
|
315
344
|
else:
|
316
345
|
raise ResourceNotFound(
|
@@ -326,9 +355,15 @@ def create_session_symbol(
|
|
326
355
|
session_name: str, symbol_name: str, request: CreateSymbolRequest
|
327
356
|
) -> SymbolResponse:
|
328
357
|
"""Create a symbol in a session."""
|
329
|
-
|
330
|
-
|
331
|
-
|
358
|
+
sess = _sessions.get(session_name)
|
359
|
+
if not sess:
|
360
|
+
raise ResourceNotFound(f"Session '{session_name}' not found.")
|
361
|
+
try:
|
362
|
+
obj = pickle.loads(base64.b64decode(request.data))
|
363
|
+
except Exception as e:
|
364
|
+
raise InvalidRequestError(f"Invalid symbol data: {e!s}") from e
|
365
|
+
symbol = Symbol(name=symbol_name, mptype=request.mptype, data=obj)
|
366
|
+
sess.add_symbol(symbol)
|
332
367
|
# Return the base64 data back to client; server stores Python object
|
333
368
|
return SymbolResponse(
|
334
369
|
name=symbol.name,
|
@@ -346,8 +381,8 @@ def get_session_symbol(session_name: str, symbol_name: str) -> SymbolResponse:
|
|
346
381
|
logger.debug(
|
347
382
|
f"Looking for symbol: '{symbol_name}' in session: '{session_name}'"
|
348
383
|
)
|
349
|
-
|
350
|
-
symbol =
|
384
|
+
sess = _sessions.get(session_name)
|
385
|
+
symbol = sess.get_symbol(symbol_name) if sess else None
|
351
386
|
if not symbol:
|
352
387
|
raise HTTPException(
|
353
388
|
status_code=404, detail=f"Symbol {symbol_name} not found"
|
@@ -368,14 +403,19 @@ def get_session_symbol(session_name: str, symbol_name: str) -> SymbolResponse:
|
|
368
403
|
@app.get("/sessions/{session_name}/symbols")
|
369
404
|
def list_session_symbols(session_name: str) -> dict[str, list[str]]:
|
370
405
|
"""List all symbols in a session."""
|
371
|
-
|
406
|
+
sess = _sessions.get(session_name)
|
407
|
+
if not sess:
|
408
|
+
raise ResourceNotFound(f"Session '{session_name}' not found.")
|
409
|
+
symbols = sess.list_symbols()
|
372
410
|
return {"symbols": symbols}
|
373
411
|
|
374
412
|
|
375
413
|
@app.delete("/sessions/{session_name}/symbols/{symbol_name}")
|
376
414
|
def delete_symbol(session_name: str, symbol_name: str) -> dict[str, str]:
|
377
415
|
"""Delete a specific symbol."""
|
378
|
-
|
416
|
+
sess = _sessions.get(session_name)
|
417
|
+
if sess and sess.delete_symbol(symbol_name):
|
418
|
+
logging.info(f"Symbol {symbol_name} deleted from session {session_name}")
|
379
419
|
return {"message": f"Symbol '{symbol_name}' deleted successfully"}
|
380
420
|
else:
|
381
421
|
raise ResourceNotFound(
|
@@ -389,13 +429,18 @@ def create_global_symbol(
|
|
389
429
|
symbol_name: str, request: CreateSymbolRequest
|
390
430
|
) -> GlobalSymbolResponse:
|
391
431
|
validate_name(symbol_name, "symbol")
|
392
|
-
|
432
|
+
try:
|
433
|
+
obj = pickle.loads(base64.b64decode(request.data))
|
434
|
+
except Exception as e:
|
435
|
+
raise InvalidRequestError(f"Invalid global symbol data: {e!s}") from e
|
436
|
+
sym = Symbol(name=symbol_name, mptype=request.mptype, data=obj)
|
437
|
+
_global_symbols[symbol_name] = sym
|
393
438
|
return GlobalSymbolResponse(name=sym.name, mptype=sym.mptype, data=request.data)
|
394
439
|
|
395
440
|
|
396
441
|
@app.get("/api/v1/symbols/{symbol_name}", response_model=GlobalSymbolResponse)
|
397
|
-
def get_global_symbol(symbol_name: str) -> GlobalSymbolResponse:
|
398
|
-
sym =
|
442
|
+
def get_global_symbol(symbol_name: str) -> GlobalSymbolResponse: # route handler
|
443
|
+
sym = _global_symbols.get(symbol_name)
|
399
444
|
if not sym:
|
400
445
|
raise ResourceNotFound(f"Global symbol '{symbol_name}' not found")
|
401
446
|
data_bytes = pickle.dumps(sym.data)
|
@@ -405,12 +450,13 @@ def get_global_symbol(symbol_name: str) -> GlobalSymbolResponse:
|
|
405
450
|
|
406
451
|
@app.get("/api/v1/symbols")
|
407
452
|
def list_global_symbols() -> dict[str, list[str]]:
|
408
|
-
return {"symbols":
|
453
|
+
return {"symbols": list(_global_symbols.keys())}
|
409
454
|
|
410
455
|
|
411
456
|
@app.delete("/api/v1/symbols/{symbol_name}")
|
412
|
-
def delete_global_symbol(symbol_name: str) -> dict[str, str]:
|
413
|
-
if
|
457
|
+
def delete_global_symbol(symbol_name: str) -> dict[str, str]: # route handler
|
458
|
+
if symbol_name in _global_symbols:
|
459
|
+
del _global_symbols[symbol_name]
|
414
460
|
return {"message": f"Global symbol '{symbol_name}' deleted successfully"}
|
415
461
|
else:
|
416
462
|
raise ResourceNotFound(f"Global symbol '{symbol_name}' not found")
|
@@ -426,8 +472,8 @@ def comm_send(
|
|
426
472
|
Receive a message from another party and deliver it to the session's communicator.
|
427
473
|
This endpoint runs on the receiver's server.
|
428
474
|
"""
|
429
|
-
|
430
|
-
if not
|
475
|
+
sess = _sessions.get(session_name)
|
476
|
+
if not sess or not sess.communicator:
|
431
477
|
logger.error(f"Session or communicator not found: session={session_name}")
|
432
478
|
raise HTTPException(status_code=404, detail="Session or communicator not found")
|
433
479
|
|
@@ -435,5 +481,5 @@ def comm_send(
|
|
435
481
|
# We don't need to validate to_rank since the request is coming to this server
|
436
482
|
|
437
483
|
# Use the proper onSent mechanism from CommunicatorBase
|
438
|
-
|
484
|
+
sess.communicator.onSent(from_rank, key, request.data)
|
439
485
|
return {"status": "ok"}
|
@@ -0,0 +1,285 @@
|
|
1
|
+
# Copyright 2025 Ant Group Co., Ltd.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
"""Core Session model (pure, no global registries).
|
16
|
+
|
17
|
+
Contents:
|
18
|
+
* SessionState dataclass
|
19
|
+
* LinkCommFactory (SPU link reuse cache)
|
20
|
+
* Session (topology derivation, runtime init, SPU env seeding, local symbol/computation storage)
|
21
|
+
|
22
|
+
Process-wide registries (sessions, global symbols) live in the server layer
|
23
|
+
(`server.py`) so this module remains portable and easy to unit test.
|
24
|
+
"""
|
25
|
+
|
26
|
+
from __future__ import annotations
|
27
|
+
|
28
|
+
import logging
|
29
|
+
import time
|
30
|
+
from dataclasses import dataclass, field
|
31
|
+
from functools import cached_property
|
32
|
+
from typing import TYPE_CHECKING, Any, cast
|
33
|
+
from urllib.parse import urlparse
|
34
|
+
|
35
|
+
import spu.libspu as libspu
|
36
|
+
|
37
|
+
from mplang.core.expr.ast import Expr
|
38
|
+
from mplang.core.expr.evaluator import IEvaluator, create_evaluator
|
39
|
+
from mplang.core.mask import Mask
|
40
|
+
from mplang.kernels.context import RuntimeContext
|
41
|
+
from mplang.kernels.spu import PFunction # type: ignore
|
42
|
+
from mplang.runtime.communicator import HttpCommunicator
|
43
|
+
from mplang.runtime.exceptions import ResourceNotFound
|
44
|
+
from mplang.runtime.link_comm import LinkCommunicator
|
45
|
+
from mplang.utils.spu_utils import parse_field, parse_protocol
|
46
|
+
|
47
|
+
if TYPE_CHECKING: # pragma: no cover - import only for type checking
|
48
|
+
from mplang.core.cluster import ClusterSpec, Node, RuntimeInfo
|
49
|
+
|
50
|
+
|
51
|
+
class LinkCommFactory:
|
52
|
+
"""Factory for creating and caching link communicators."""
|
53
|
+
|
54
|
+
def __init__(self) -> None:
|
55
|
+
self._cache: dict[tuple[int, tuple[str, ...]], LinkCommunicator] = {}
|
56
|
+
|
57
|
+
def create_link(self, rel_rank: int, addrs: list[str]) -> LinkCommunicator:
|
58
|
+
key = (rel_rank, tuple(addrs))
|
59
|
+
link = self._cache.get(key)
|
60
|
+
if link is not None:
|
61
|
+
return link
|
62
|
+
logging.info(f"LinkCommunicator created: rel_rank={rel_rank} addrs={addrs}")
|
63
|
+
link = LinkCommunicator(rel_rank, addrs)
|
64
|
+
self._cache[key] = link
|
65
|
+
return link
|
66
|
+
|
67
|
+
|
68
|
+
# Shared link factory (module-local, not global registry of sessions)
|
69
|
+
g_link_factory = LinkCommFactory()
|
70
|
+
|
71
|
+
|
72
|
+
@dataclass
|
73
|
+
class Symbol:
|
74
|
+
name: str
|
75
|
+
mptype: Any
|
76
|
+
data: Any
|
77
|
+
|
78
|
+
|
79
|
+
@dataclass
|
80
|
+
class Computation:
|
81
|
+
name: str
|
82
|
+
expr: Expr
|
83
|
+
|
84
|
+
|
85
|
+
@dataclass
|
86
|
+
class SessionState:
|
87
|
+
runtime: RuntimeContext | None = None
|
88
|
+
computations: dict[str, Computation] = field(default_factory=dict)
|
89
|
+
symbols: dict[str, Symbol] = field(default_factory=dict)
|
90
|
+
spu_seeded: bool = False
|
91
|
+
created_ts: float = field(default_factory=time.time)
|
92
|
+
last_access_ts: float = field(default_factory=time.time)
|
93
|
+
|
94
|
+
|
95
|
+
class Session:
|
96
|
+
"""Represents the per-rank execution context.
|
97
|
+
|
98
|
+
Immutable config: name, rank, cluster_spec.
|
99
|
+
Derived: node, runtime_info, endpoints, spu_device, spu_mask, protocol/field, is_spu_party.
|
100
|
+
Mutable: state (runtime object, symbols, computations, seeded flag).
|
101
|
+
"""
|
102
|
+
|
103
|
+
def __init__(self, name: str, rank: int, cluster_spec: ClusterSpec):
|
104
|
+
self.name = name
|
105
|
+
self.rank = rank
|
106
|
+
self.cluster_spec = cluster_spec
|
107
|
+
self.state = SessionState()
|
108
|
+
self.communicator = HttpCommunicator(
|
109
|
+
session_name=name, rank=rank, endpoints=self.endpoints
|
110
|
+
)
|
111
|
+
|
112
|
+
# --- Derived topology ---
|
113
|
+
@cached_property
|
114
|
+
def node(self) -> Node:
|
115
|
+
return self.cluster_spec.get_node_by_rank(self.rank)
|
116
|
+
|
117
|
+
@property
|
118
|
+
def runtime_info(self) -> RuntimeInfo:
|
119
|
+
return self.node.runtime_info
|
120
|
+
|
121
|
+
@cached_property
|
122
|
+
def endpoints(self) -> list[str]:
|
123
|
+
eps: list[str] = []
|
124
|
+
for n in sorted(
|
125
|
+
self.cluster_spec.nodes.values(),
|
126
|
+
key=lambda x: x.rank, # type: ignore[attr-defined]
|
127
|
+
):
|
128
|
+
ep = n.endpoint
|
129
|
+
if not ep.startswith(("http://", "https://")):
|
130
|
+
ep = f"http://{ep}"
|
131
|
+
eps.append(ep)
|
132
|
+
return eps
|
133
|
+
|
134
|
+
@cached_property
|
135
|
+
def spu_device(self): # type: ignore
|
136
|
+
devs = self.cluster_spec.get_devices_by_kind("SPU")
|
137
|
+
if len(devs) != 1:
|
138
|
+
raise RuntimeError(
|
139
|
+
f"Expected exactly one SPU device, got {len(devs)} (session={self.name})"
|
140
|
+
)
|
141
|
+
return devs[0]
|
142
|
+
|
143
|
+
@cached_property
|
144
|
+
def spu_mask(self) -> Mask:
|
145
|
+
return Mask.from_ranks([m.rank for m in self.spu_device.members])
|
146
|
+
|
147
|
+
@property
|
148
|
+
def spu_protocol(self) -> str:
|
149
|
+
return cast(str, self.spu_device.config.get("protocol", "SEMI2K"))
|
150
|
+
|
151
|
+
@property
|
152
|
+
def spu_field(self) -> str:
|
153
|
+
return cast(str, self.spu_device.config.get("field", "FM64"))
|
154
|
+
|
155
|
+
@property
|
156
|
+
def is_spu_party(self) -> bool:
|
157
|
+
return self.rank in self.spu_mask
|
158
|
+
|
159
|
+
# --- Runtime helpers ---
|
160
|
+
def ensure_runtime(self) -> RuntimeContext:
|
161
|
+
if self.state.runtime is None:
|
162
|
+
self.state.runtime = RuntimeContext(
|
163
|
+
rank=self.rank,
|
164
|
+
world_size=len(self.cluster_spec.nodes), # type: ignore[attr-defined]
|
165
|
+
initial_bindings=(
|
166
|
+
self.runtime_info.op_bindings if self.runtime_info else {}
|
167
|
+
),
|
168
|
+
)
|
169
|
+
return self.state.runtime
|
170
|
+
|
171
|
+
def ensure_spu_env(self) -> None:
|
172
|
+
"""Ensure SPU kernel env (config/world[/link]) registered on this runtime.
|
173
|
+
|
174
|
+
Previous logic only seeded SPU parties; non-participating ranks then raised
|
175
|
+
a hard error when the evaluator encountered SPU ops in the global program,
|
176
|
+
because the kernel pocket lacked config/world. For now we register the
|
177
|
+
config/world on ALL parties (idempotent) and only attach a link context for
|
178
|
+
participating SPU ranks. Non-parties will still error later if they try to
|
179
|
+
execute a link-dependent SPU kernel (which should be guarded by masks in the
|
180
|
+
IR), but they will no longer fail early with a misleading
|
181
|
+
"SPU kernel state not initialized" message.
|
182
|
+
"""
|
183
|
+
if self.state.spu_seeded:
|
184
|
+
return
|
185
|
+
|
186
|
+
link_ctx = None
|
187
|
+
# Fixed port offset for SPU runtime link services (legacy value retained).
|
188
|
+
# TODO: make configurable if future deployments require dynamic offset.
|
189
|
+
SPU_PORT_OFFSET = 100
|
190
|
+
|
191
|
+
if self.is_spu_party:
|
192
|
+
# Build SPU address list across all endpoints for ranks in mask
|
193
|
+
spu_addrs: list[str] = []
|
194
|
+
for r, addr in enumerate(self.communicator.endpoints):
|
195
|
+
if r in self.spu_mask:
|
196
|
+
if "//" not in addr:
|
197
|
+
addr = f"//{addr}"
|
198
|
+
parsed = urlparse(addr)
|
199
|
+
assert isinstance(parsed.port, int)
|
200
|
+
new_addr = f"{parsed.hostname}:{parsed.port + SPU_PORT_OFFSET}"
|
201
|
+
spu_addrs.append(new_addr)
|
202
|
+
rel_index = sum(1 for r in range(self.rank) if r in self.spu_mask)
|
203
|
+
link_ctx = g_link_factory.create_link(rel_index, spu_addrs)
|
204
|
+
|
205
|
+
spu_config = libspu.RuntimeConfig(
|
206
|
+
protocol=parse_protocol(self.spu_protocol),
|
207
|
+
field=parse_field(self.spu_field),
|
208
|
+
fxp_fraction_bits=18,
|
209
|
+
)
|
210
|
+
seed_pfunc = PFunction(
|
211
|
+
fn_type="spu.seed_env",
|
212
|
+
ins_info=(),
|
213
|
+
outs_info=(),
|
214
|
+
config=spu_config,
|
215
|
+
world=self.spu_mask.num_parties(),
|
216
|
+
link=link_ctx,
|
217
|
+
)
|
218
|
+
self.ensure_runtime().run_kernel(seed_pfunc, [])
|
219
|
+
self.state.spu_seeded = True
|
220
|
+
|
221
|
+
# --- Computations & Symbols (instance-local) ---
|
222
|
+
def add_computation(self, computation: Computation) -> None:
|
223
|
+
self.state.computations[computation.name] = computation
|
224
|
+
|
225
|
+
def get_computation(self, name: str) -> Computation | None:
|
226
|
+
return self.state.computations.get(name)
|
227
|
+
|
228
|
+
def add_symbol(self, symbol: Symbol) -> None:
|
229
|
+
self.state.symbols[symbol.name] = symbol
|
230
|
+
|
231
|
+
def get_symbol(self, name: str) -> Symbol | None:
|
232
|
+
return self.state.symbols.get(name)
|
233
|
+
|
234
|
+
def list_symbols(self) -> list[str]: # pragma: no cover - trivial
|
235
|
+
return list(self.state.symbols.keys())
|
236
|
+
|
237
|
+
def delete_symbol(self, name: str) -> bool:
|
238
|
+
if name in self.state.symbols:
|
239
|
+
del self.state.symbols[name]
|
240
|
+
return True
|
241
|
+
return False
|
242
|
+
|
243
|
+
def list_computations(self) -> list[str]: # pragma: no cover - trivial
|
244
|
+
return list(self.state.computations.keys())
|
245
|
+
|
246
|
+
def delete_computation(self, name: str) -> bool:
|
247
|
+
if name in self.state.computations:
|
248
|
+
del self.state.computations[name]
|
249
|
+
return True
|
250
|
+
return False
|
251
|
+
|
252
|
+
# --- Execution ---
|
253
|
+
def execute(
|
254
|
+
self, computation: Computation, input_names: list[str], output_names: list[str]
|
255
|
+
) -> None:
|
256
|
+
env: dict[str, Any] = {}
|
257
|
+
for in_name in input_names:
|
258
|
+
sym = self.get_symbol(in_name)
|
259
|
+
if sym is None:
|
260
|
+
raise ResourceNotFound(
|
261
|
+
f"Input symbol '{in_name}' not found in session '{self.name}'"
|
262
|
+
)
|
263
|
+
env[in_name] = sym.data
|
264
|
+
rt = self.ensure_runtime()
|
265
|
+
self.ensure_spu_env()
|
266
|
+
evaluator: IEvaluator = create_evaluator(
|
267
|
+
rank=self.rank, env=env, comm=self.communicator, runtime=rt
|
268
|
+
)
|
269
|
+
results = evaluator.evaluate(computation.expr)
|
270
|
+
if results and len(results) != len(output_names):
|
271
|
+
raise RuntimeError(
|
272
|
+
f"Expected {len(output_names)} results, got {len(results)}"
|
273
|
+
)
|
274
|
+
for name, val in zip(output_names, results, strict=True):
|
275
|
+
self.add_symbol(Symbol(name=name, mptype={}, data=val))
|
276
|
+
|
277
|
+
# --- Convenience constructor ---
|
278
|
+
@classmethod
|
279
|
+
def from_cluster_spec_dict(cls, name: str, rank: int, spec_dict: dict) -> Session:
|
280
|
+
from mplang.core.cluster import ClusterSpec # local import to avoid cycles
|
281
|
+
|
282
|
+
spec = ClusterSpec.from_dict(spec_dict)
|
283
|
+
if len(spec.get_devices_by_kind("SPU")) == 0:
|
284
|
+
raise RuntimeError("No SPU device found in cluster_spec")
|
285
|
+
return cls(name=name, rank=rank, cluster_spec=spec)
|
mplang/runtime/simulation.py
CHANGED
@@ -86,20 +86,17 @@ class Simulator(InterpContext):
|
|
86
86
|
cluster_spec: ClusterSpec,
|
87
87
|
*,
|
88
88
|
trace_ranks: list[int] | None = None,
|
89
|
-
op_bindings: dict[str, str] | None = None,
|
90
89
|
) -> None:
|
91
90
|
"""Initialize a simulator with the given cluster specification.
|
92
91
|
|
93
92
|
Args:
|
94
93
|
cluster_spec: The cluster specification defining the simulation environment.
|
95
94
|
trace_ranks: List of ranks to trace execution for debugging.
|
96
|
-
|
97
|
-
|
98
|
-
|
99
|
-
variable ``bindings`` dict passed into ``evaluate``.
|
95
|
+
Per-node op binding overrides should now be provided via
|
96
|
+
each node's `runtime_info.op_bindings` in the supplied
|
97
|
+
`cluster_spec`.
|
100
98
|
"""
|
101
99
|
super().__init__(cluster_spec)
|
102
|
-
self._op_bindings_template = op_bindings or {}
|
103
100
|
self._trace_ranks = trace_ranks or []
|
104
101
|
|
105
102
|
spu_devices = cluster_spec.get_devices_by_kind("SPU")
|
@@ -145,21 +142,22 @@ class Simulator(InterpContext):
|
|
145
142
|
|
146
143
|
# Persistent per-rank RuntimeContext instances (reused across evaluates).
|
147
144
|
# We no longer pre-create evaluators since each evaluate has different env bindings.
|
148
|
-
|
149
|
-
|
145
|
+
# Build per-rank runtime contexts.
|
146
|
+
self._runtimes: list[RuntimeContext] = []
|
147
|
+
for rank in range(self.world_size()):
|
148
|
+
node = self.cluster_spec.get_node_by_rank(rank)
|
149
|
+
rt = RuntimeContext(
|
150
150
|
rank=rank,
|
151
151
|
world_size=self.world_size(),
|
152
|
-
|
153
|
-
# dispatch mappings, not per-evaluate variable bindings.
|
154
|
-
initial_bindings=self._op_bindings_template,
|
152
|
+
initial_bindings=node.runtime_info.op_bindings,
|
155
153
|
)
|
156
|
-
|
157
|
-
]
|
154
|
+
self._runtimes.append(rt)
|
158
155
|
|
159
156
|
@classmethod
|
160
157
|
def simple(
|
161
158
|
cls,
|
162
159
|
world_size: int,
|
160
|
+
op_bindings: dict[str, str] | None = None,
|
163
161
|
**kwargs: Any,
|
164
162
|
) -> Simulator:
|
165
163
|
"""Create a simple simulator with the given number of parties.
|
@@ -175,6 +173,10 @@ class Simulator(InterpContext):
|
|
175
173
|
A Simulator instance with a simple cluster configuration.
|
176
174
|
"""
|
177
175
|
cluster_spec = ClusterSpec.simple(world_size)
|
176
|
+
if op_bindings:
|
177
|
+
# Apply the same op_bindings to every node's runtime_info for convenience
|
178
|
+
for node in cluster_spec.nodes.values():
|
179
|
+
node.runtime_info.op_bindings.update(op_bindings)
|
178
180
|
return cls(cluster_spec, **kwargs)
|
179
181
|
|
180
182
|
def _do_evaluate(self, expr: Expr, evaluator_engine: IEvaluator) -> Any:
|
@@ -4,7 +4,7 @@ mplang/device.py,sha256=RmjnhzHxJkkNmtBKtYMEbpQYBZpuC43qlllkCOp-QD8,12548
|
|
4
4
|
mplang/analysis/__init__.py,sha256=CTHFvRsi-nFngojqjn08UaR3RY9i7CJ7T2UdR95kCrk,1056
|
5
5
|
mplang/analysis/diagram.py,sha256=ffwgD12gL1_KH1uJ_EYkjmIlDrfxYJJkWj-wHl09_Xk,19520
|
6
6
|
mplang/core/__init__.py,sha256=lWxlEKfRwX7FNDzgyKZ1fiDMaCiqkyg0j5mKlZD_v7g,2244
|
7
|
-
mplang/core/cluster.py,sha256=
|
7
|
+
mplang/core/cluster.py,sha256=IqXHLogetegUEEAzmD8cWRash-UID06Wo3OBeZFwatg,11800
|
8
8
|
mplang/core/comm.py,sha256=MByyu3etlQh_TkP1vKCFLIAPPuJOpl9Kjs6hOj6m4Yc,8843
|
9
9
|
mplang/core/context_mgr.py,sha256=R0QJAod-1nYduVoOknLfAsxZiy-RtmuQcp-07HABYZU,1541
|
10
10
|
mplang/core/dtype.py,sha256=0rZqFaFikFu9RxtdO36JLEgFL-E-lo3hH10whwkTVVY,10213
|
@@ -51,16 +51,16 @@ mplang/protos/v1alpha1/mpir_pb2.pyi,sha256=GwXR4wPB_kB_36iYS9x-cGI9KDKFMq89KhdLh
|
|
51
51
|
mplang/protos/v1alpha1/mpir_pb2_grpc.py,sha256=xYOs94SXiNYAlFodACnsXW5QovLsHY5tCk3p76RH5Zc,158
|
52
52
|
mplang/runtime/__init__.py,sha256=IRPP3TtpFC4iSt7_uaq-S4dL7CwrXL0XBMeaBoEYLlg,948
|
53
53
|
mplang/runtime/cli.py,sha256=WehDodeVB4AukSWx1LJxxtKUqGmLPY4qjayrPlOg3bE,14438
|
54
|
-
mplang/runtime/client.py,sha256=
|
54
|
+
mplang/runtime/client.py,sha256=vkJUFSDcKIdbKiGUM5AosCKTZygl9g8uZFEjw2xwKig,15249
|
55
55
|
mplang/runtime/communicator.py,sha256=Lek6_h_Wmr_W-_JpT-vMxL3CHxcVZdtf7jdaLGuxPgQ,3199
|
56
56
|
mplang/runtime/data_providers.py,sha256=hH2butEOYNGq2rRZjVBDfXLxe3YUin2ftAF6htbTfLA,8226
|
57
|
-
mplang/runtime/driver.py,sha256=
|
57
|
+
mplang/runtime/driver.py,sha256=pq2EQFZK9tH90Idops_yeF6fj0cfFVD_5mFcmy4Hzco,11089
|
58
58
|
mplang/runtime/exceptions.py,sha256=c18U0xK20dRmgZo0ogTf5vXlkix9y3VAFuzkHxaXPEk,981
|
59
59
|
mplang/runtime/http_api.md,sha256=-re1DhEqMplAkv_wnqEU-PSs8tTzf4-Ml0Gq0f3Go6s,4883
|
60
60
|
mplang/runtime/link_comm.py,sha256=uNqTCGZVwWeuHAb7yXXQf0DUsMXLa8leHCkrcZdzYMU,4559
|
61
|
-
mplang/runtime/
|
62
|
-
mplang/runtime/
|
63
|
-
mplang/runtime/simulation.py,sha256=
|
61
|
+
mplang/runtime/server.py,sha256=vYjuWTWhhSLHUpsO8FDnOQ8kFzPhE-fXDDyL8GHVPj4,16673
|
62
|
+
mplang/runtime/session.py,sha256=4TQ_RPRmriv0H0S6rl_GSabxS7XrwMkdZIdcnyE8bHw,10374
|
63
|
+
mplang/runtime/simulation.py,sha256=WyIs8ta3ZM5o3RB0Bcb0MUu6Yh88Iujr27KvZFqGxig,11497
|
64
64
|
mplang/simp/__init__.py,sha256=xNXnA8-jZAANa2A1W39b3lYO7D02zdCXl0TpivkTGS4,11579
|
65
65
|
mplang/simp/mpi.py,sha256=Wv_Q16TQ3rdLam6OzqXiefIGSMmagGkso09ycyOkHEs,4774
|
66
66
|
mplang/simp/random.py,sha256=7PVgWNL1j7Sf3MqT5PRiWplUu-0dyhF3Ub566iqX86M,3898
|
@@ -70,8 +70,8 @@ mplang/utils/crypto.py,sha256=rvPomBFtznRHc3RPi6Aip9lsU8zW2oxBqGv1K3vn7Rs,1052
|
|
70
70
|
mplang/utils/func_utils.py,sha256=vCJcZmu0bEbqhOQKdpttV2_MBllIcPSN0b8U4WjNGGo,5164
|
71
71
|
mplang/utils/spu_utils.py,sha256=S3L9RBkBe2AvSuMSQQ12cBY5Y1NPthubvErSX_7nj1A,4158
|
72
72
|
mplang/utils/table_utils.py,sha256=aC-IZOKkSmFkpr3NZchLM0Wt0GOn-rg_xHBHREWBwAU,2202
|
73
|
-
mplang_nightly-0.1.
|
74
|
-
mplang_nightly-0.1.
|
75
|
-
mplang_nightly-0.1.
|
76
|
-
mplang_nightly-0.1.
|
77
|
-
mplang_nightly-0.1.
|
73
|
+
mplang_nightly-0.1.dev153.dist-info/METADATA,sha256=4dEwwbuB0n0oRHxO09vuMY2Al57Ol8O8KdXGlDpEZqo,16547
|
74
|
+
mplang_nightly-0.1.dev153.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
75
|
+
mplang_nightly-0.1.dev153.dist-info/entry_points.txt,sha256=mG1oJT-GAjQR834a62_QIWb7litzWPPyVnwFqm-rWuY,55
|
76
|
+
mplang_nightly-0.1.dev153.dist-info/licenses/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
|
77
|
+
mplang_nightly-0.1.dev153.dist-info/RECORD,,
|
mplang/runtime/resource.py
DELETED
@@ -1,365 +0,0 @@
|
|
1
|
-
# Copyright 2025 Ant Group Co., Ltd.
|
2
|
-
#
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
-
# you may not use this file except in compliance with the License.
|
5
|
-
# You may obtain a copy of the License at
|
6
|
-
#
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
-
#
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
-
# See the License for the specific language governing permissions and
|
13
|
-
# limitations under the License.
|
14
|
-
|
15
|
-
"""
|
16
|
-
This module provides the resource management for the HTTP backend.
|
17
|
-
It is a simplified, in-memory version of the original executor's resource manager.
|
18
|
-
"""
|
19
|
-
|
20
|
-
import base64
|
21
|
-
import logging
|
22
|
-
from dataclasses import dataclass, field
|
23
|
-
from typing import Any
|
24
|
-
from urllib.parse import urlparse
|
25
|
-
|
26
|
-
import cloudpickle as pickle
|
27
|
-
import spu.libspu as libspu
|
28
|
-
|
29
|
-
from mplang.core.expr.ast import Expr
|
30
|
-
from mplang.core.expr.evaluator import IEvaluator, create_evaluator
|
31
|
-
from mplang.core.mask import Mask
|
32
|
-
from mplang.kernels.context import RuntimeContext
|
33
|
-
from mplang.kernels.spu import PFunction # type: ignore
|
34
|
-
from mplang.runtime.communicator import HttpCommunicator
|
35
|
-
from mplang.runtime.exceptions import InvalidRequestError, ResourceNotFound
|
36
|
-
from mplang.runtime.link_comm import LinkCommunicator
|
37
|
-
from mplang.utils.spu_utils import parse_field, parse_protocol
|
38
|
-
|
39
|
-
|
40
|
-
class LinkCommFactory:
|
41
|
-
"""Factory for creating and caching link communicators."""
|
42
|
-
|
43
|
-
def __init__(self) -> None:
|
44
|
-
self._cache: dict[tuple[int, tuple[str, ...]], LinkCommunicator] = {}
|
45
|
-
|
46
|
-
def create_link(self, rank: int, addrs: list[str]) -> LinkCommunicator:
|
47
|
-
key = (rank, tuple(addrs))
|
48
|
-
val = self._cache.get(key, None)
|
49
|
-
if val is not None:
|
50
|
-
return val
|
51
|
-
|
52
|
-
logging.info(f"LinkCommunicator created: {rank} {addrs}")
|
53
|
-
new_link = LinkCommunicator(rank, addrs)
|
54
|
-
self._cache[key] = new_link
|
55
|
-
return new_link
|
56
|
-
|
57
|
-
|
58
|
-
# Global link factory instance
|
59
|
-
g_link_factory = LinkCommFactory()
|
60
|
-
|
61
|
-
|
62
|
-
@dataclass
|
63
|
-
class Symbol:
|
64
|
-
name: str
|
65
|
-
mptype: Any # More flexible type to handle dict or MPType
|
66
|
-
data: Any # More flexible data type
|
67
|
-
|
68
|
-
|
69
|
-
@dataclass
|
70
|
-
class Computation:
|
71
|
-
name: str
|
72
|
-
expr: Expr # The computation expression
|
73
|
-
|
74
|
-
|
75
|
-
@dataclass
|
76
|
-
class Session:
|
77
|
-
name: str
|
78
|
-
communicator: HttpCommunicator
|
79
|
-
computations: dict[str, Computation] = field(default_factory=dict)
|
80
|
-
symbols: dict[str, Symbol] = field(default_factory=dict) # Session-level symbols
|
81
|
-
|
82
|
-
# spu related
|
83
|
-
spu_mask: int = -1
|
84
|
-
spu_protocol: str = "SEMI2K"
|
85
|
-
spu_field: str = "FM64"
|
86
|
-
|
87
|
-
|
88
|
-
# Global session storage
|
89
|
-
_sessions: dict[str, Session] = {}
|
90
|
-
|
91
|
-
|
92
|
-
# Session Management
|
93
|
-
def create_session(
|
94
|
-
name: str,
|
95
|
-
rank: int,
|
96
|
-
endpoints: list[str],
|
97
|
-
# SPU related
|
98
|
-
spu_mask: int = 0,
|
99
|
-
spu_protocol: str = "SEMI2K",
|
100
|
-
spu_field: str = "FM64",
|
101
|
-
) -> Session:
|
102
|
-
logging.info(f"Creating session: {name}, rank: {rank}, spu_mask: {spu_mask}")
|
103
|
-
if name in _sessions:
|
104
|
-
# Return existing session (idempotent operation)
|
105
|
-
logging.info(f"Session {name} already exists, returning existing session")
|
106
|
-
return _sessions[name]
|
107
|
-
session = Session(
|
108
|
-
name, HttpCommunicator(session_name=name, rank=rank, endpoints=endpoints)
|
109
|
-
)
|
110
|
-
|
111
|
-
session.spu_mask = spu_mask
|
112
|
-
session.spu_protocol = spu_protocol
|
113
|
-
session.spu_field = spu_field
|
114
|
-
|
115
|
-
_sessions[name] = session
|
116
|
-
logging.info(f"Session {name} created successfully")
|
117
|
-
return session
|
118
|
-
|
119
|
-
|
120
|
-
def get_session(name: str) -> Session | None:
|
121
|
-
return _sessions.get(name)
|
122
|
-
|
123
|
-
|
124
|
-
def delete_session(name: str) -> bool:
|
125
|
-
"""Delete a session and all associated resources.
|
126
|
-
|
127
|
-
Returns:
|
128
|
-
True if session was deleted, False if session was not found.
|
129
|
-
"""
|
130
|
-
if name in _sessions:
|
131
|
-
del _sessions[name]
|
132
|
-
logging.info(f"Session {name} deleted successfully")
|
133
|
-
return True
|
134
|
-
return False
|
135
|
-
|
136
|
-
|
137
|
-
# Global symbol management (process-wide, not per-session)
|
138
|
-
_global_symbols: dict[str, Symbol] = {}
|
139
|
-
|
140
|
-
|
141
|
-
def create_global_symbol(name: str, mptype: dict[str, Any], data_b64: str) -> Symbol:
|
142
|
-
"""Create or replace a global symbol.
|
143
|
-
|
144
|
-
Args:
|
145
|
-
name: Symbol identifier
|
146
|
-
mptype: Metadata dict (shape/dtype, etc.)
|
147
|
-
data_b64: Base64-encoded pickled data
|
148
|
-
"""
|
149
|
-
try:
|
150
|
-
raw = base64.b64decode(data_b64)
|
151
|
-
data = pickle.loads(raw)
|
152
|
-
except Exception as e: # pragma: no cover - defensive
|
153
|
-
raise InvalidRequestError(f"Failed to decode symbol payload: {e}") from e
|
154
|
-
sym = Symbol(name=name, mptype=mptype, data=data)
|
155
|
-
_global_symbols[name] = sym
|
156
|
-
return sym
|
157
|
-
|
158
|
-
|
159
|
-
def get_global_symbol(name: str) -> Symbol:
|
160
|
-
sym = _global_symbols.get(name)
|
161
|
-
if sym is None:
|
162
|
-
raise ResourceNotFound(f"Global symbol '{name}' not found")
|
163
|
-
return sym
|
164
|
-
|
165
|
-
|
166
|
-
def delete_global_symbol(name: str) -> bool:
|
167
|
-
return _global_symbols.pop(name, None) is not None
|
168
|
-
|
169
|
-
|
170
|
-
def list_global_symbols() -> list[str]: # pragma: no cover - trivial
|
171
|
-
return sorted(_global_symbols.keys())
|
172
|
-
|
173
|
-
|
174
|
-
# Computation Management
|
175
|
-
def create_computation(
|
176
|
-
session_name: str, computation_name: str, expr: Expr
|
177
|
-
) -> Computation:
|
178
|
-
"""Creates a computation resource within a session."""
|
179
|
-
session = get_session(session_name)
|
180
|
-
if not session:
|
181
|
-
raise ResourceNotFound(f"Session '{session_name}' not found.")
|
182
|
-
computation = Computation(computation_name, expr)
|
183
|
-
session.computations[computation_name] = computation
|
184
|
-
logging.info(f"Computation {computation_name} created for session {session_name}")
|
185
|
-
return computation
|
186
|
-
|
187
|
-
|
188
|
-
def get_computation(session_name: str, comp_name: str) -> Computation | None:
|
189
|
-
session = get_session(session_name)
|
190
|
-
if session:
|
191
|
-
return session.computations.get(comp_name)
|
192
|
-
return None
|
193
|
-
|
194
|
-
|
195
|
-
def delete_computation(session_name: str, comp_name: str) -> bool:
|
196
|
-
"""Delete a computation from a session.
|
197
|
-
|
198
|
-
Returns:
|
199
|
-
True if computation was deleted, False if not found.
|
200
|
-
"""
|
201
|
-
session = get_session(session_name)
|
202
|
-
if not session:
|
203
|
-
return False
|
204
|
-
|
205
|
-
if comp_name in session.computations:
|
206
|
-
del session.computations[comp_name]
|
207
|
-
logging.info(f"Computation {comp_name} deleted from session {session_name}")
|
208
|
-
return True
|
209
|
-
return False
|
210
|
-
|
211
|
-
|
212
|
-
def execute_computation(
|
213
|
-
session_name: str, comp_name: str, input_names: list[str], output_names: list[str]
|
214
|
-
) -> None:
|
215
|
-
"""Execute a computation using the Evaluator."""
|
216
|
-
session = get_session(session_name)
|
217
|
-
if not session:
|
218
|
-
raise ResourceNotFound(f"Session '{session_name}' not found.")
|
219
|
-
|
220
|
-
computation = get_computation(session_name, comp_name)
|
221
|
-
if not computation:
|
222
|
-
raise ResourceNotFound(
|
223
|
-
f"Computation '{comp_name}' not found in session '{session_name}'."
|
224
|
-
)
|
225
|
-
|
226
|
-
if not session.communicator:
|
227
|
-
raise InvalidRequestError(
|
228
|
-
f"Communicator not initialized for session '{session_name}'."
|
229
|
-
)
|
230
|
-
|
231
|
-
# Get rank from session communicator
|
232
|
-
rank = session.communicator.rank
|
233
|
-
|
234
|
-
# Prepare input bindings from session symbols
|
235
|
-
bindings = {}
|
236
|
-
for input_name in input_names:
|
237
|
-
symbol = get_symbol(session_name, input_name)
|
238
|
-
if not symbol:
|
239
|
-
raise ResourceNotFound(
|
240
|
-
f"Input symbol '{input_name}' not found in session '{session_name}'"
|
241
|
-
)
|
242
|
-
bindings[input_name] = symbol.data
|
243
|
-
|
244
|
-
spu_mask = (
|
245
|
-
Mask(session.spu_mask)
|
246
|
-
if session.spu_mask != -1
|
247
|
-
else Mask.all(session.communicator.world_size)
|
248
|
-
)
|
249
|
-
|
250
|
-
# Build evaluator
|
251
|
-
# Explicit per-rank backend runtime (deglobalized)
|
252
|
-
runtime = RuntimeContext(rank=rank, world_size=session.communicator.world_size)
|
253
|
-
evaluator: IEvaluator = create_evaluator(
|
254
|
-
rank=rank, env=bindings, comm=session.communicator, runtime=runtime
|
255
|
-
)
|
256
|
-
|
257
|
-
# Initialize SPU runtime state for flat kernels (once per evaluator invocation)
|
258
|
-
if rank in spu_mask:
|
259
|
-
# Build SPU address list (only once per rank; consistent ordering of participating ranks)
|
260
|
-
spu_addrs: list[str] = []
|
261
|
-
for r, addr in enumerate(session.communicator.endpoints):
|
262
|
-
if r in spu_mask:
|
263
|
-
if "://" not in addr:
|
264
|
-
addr = f"//{addr}"
|
265
|
-
parsed = urlparse(addr)
|
266
|
-
assert isinstance(parsed.port, int)
|
267
|
-
new_addr = f"{parsed.hostname}:{parsed.port + 100}"
|
268
|
-
spu_addrs.append(new_addr)
|
269
|
-
# Determine this rank's relative index among participating ranks
|
270
|
-
rel_index = sum(1 for r in range(rank) if r in spu_mask)
|
271
|
-
link_ctx = g_link_factory.create_link(rel_index, spu_addrs)
|
272
|
-
else:
|
273
|
-
link_ctx = None
|
274
|
-
# Always seed config/world; provide per-rank link (may be None if not participating)
|
275
|
-
spu_config = libspu.RuntimeConfig(
|
276
|
-
protocol=parse_protocol(session.spu_protocol),
|
277
|
-
field=parse_field(session.spu_field),
|
278
|
-
fxp_fraction_bits=18,
|
279
|
-
)
|
280
|
-
# Seed SPU env via backend kernel (inside evaluator's kernel context)
|
281
|
-
seed_pfunc = PFunction(
|
282
|
-
fn_type="spu.seed_env",
|
283
|
-
ins_info=(),
|
284
|
-
outs_info=(),
|
285
|
-
config=spu_config,
|
286
|
-
world=spu_mask.num_parties(),
|
287
|
-
link=link_ctx,
|
288
|
-
)
|
289
|
-
# Run seeding kernel with evaluator (no inputs, no outputs)
|
290
|
-
evaluator.runtime.run_kernel(seed_pfunc, [])
|
291
|
-
|
292
|
-
results = evaluator.evaluate(computation.expr)
|
293
|
-
|
294
|
-
# Store results in session symbols using output_names
|
295
|
-
if results:
|
296
|
-
if len(results) != len(output_names):
|
297
|
-
raise RuntimeError(
|
298
|
-
f"Expected {len(output_names)} results, got {len(results)}"
|
299
|
-
)
|
300
|
-
for name, val in zip(output_names, results, strict=True):
|
301
|
-
session.symbols[name] = Symbol(name=name, mptype={}, data=val)
|
302
|
-
|
303
|
-
|
304
|
-
# Symbol Management
|
305
|
-
def create_symbol(session_name: str, name: str, mptype: Any, data: Any) -> Symbol:
|
306
|
-
"""Create a symbol in a session's symbol table.
|
307
|
-
|
308
|
-
The `data` is expected to be a base64-encoded pickled Python object.
|
309
|
-
"""
|
310
|
-
session = get_session(session_name)
|
311
|
-
if not session:
|
312
|
-
raise ResourceNotFound(f"Session '{session_name}' not found.")
|
313
|
-
|
314
|
-
# Deserialize base64-encoded data to Python object
|
315
|
-
try:
|
316
|
-
data_bytes = base64.b64decode(data)
|
317
|
-
obj = pickle.loads(data_bytes)
|
318
|
-
except Exception as e:
|
319
|
-
raise InvalidRequestError(f"Invalid symbol data encoding: {e!s}") from e
|
320
|
-
|
321
|
-
symbol = Symbol(name, mptype, obj)
|
322
|
-
session.symbols[name] = symbol
|
323
|
-
return symbol
|
324
|
-
|
325
|
-
|
326
|
-
def get_symbol(session_name: str, name: str) -> Symbol | None:
|
327
|
-
"""Get a symbol from a session's symbol table (session-level only)."""
|
328
|
-
session = get_session(session_name)
|
329
|
-
if not session:
|
330
|
-
return None
|
331
|
-
|
332
|
-
# Only session-level symbols are supported now
|
333
|
-
return session.symbols.get(name)
|
334
|
-
|
335
|
-
|
336
|
-
def list_symbols(session_name: str) -> list[str]:
|
337
|
-
"""List all symbols in a session's symbol table."""
|
338
|
-
session = get_session(session_name)
|
339
|
-
if not session:
|
340
|
-
raise ResourceNotFound(f"Session '{session_name}' not found.")
|
341
|
-
|
342
|
-
# Only session-level symbols are supported now
|
343
|
-
return list(session.symbols.keys())
|
344
|
-
|
345
|
-
|
346
|
-
def delete_symbol(session_name: str, symbol_name: str) -> bool:
|
347
|
-
"""Delete a symbol from a session.
|
348
|
-
|
349
|
-
Returns:
|
350
|
-
True if symbol was deleted, False if not found.
|
351
|
-
"""
|
352
|
-
session = get_session(session_name)
|
353
|
-
if not session:
|
354
|
-
return False
|
355
|
-
|
356
|
-
if symbol_name in session.symbols:
|
357
|
-
del session.symbols[symbol_name]
|
358
|
-
logging.info(f"Symbol {symbol_name} deleted from session {session_name}")
|
359
|
-
return True
|
360
|
-
return False
|
361
|
-
|
362
|
-
|
363
|
-
def list_all_sessions() -> list[str]:
|
364
|
-
"""List all session names."""
|
365
|
-
return list(_sessions.keys())
|
File without changes
|
{mplang_nightly-0.1.dev152.dist-info → mplang_nightly-0.1.dev153.dist-info}/entry_points.txt
RENAMED
File without changes
|
{mplang_nightly-0.1.dev152.dist-info → mplang_nightly-0.1.dev153.dist-info}/licenses/LICENSE
RENAMED
File without changes
|