mplang-nightly 0.1.dev152__py3-none-any.whl → 0.1.dev153__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
mplang/core/cluster.py CHANGED
@@ -25,23 +25,28 @@ from typing import Any
25
25
 
26
26
  @dataclass(frozen=True)
27
27
  class RuntimeInfo:
28
- """
29
- Structured representation of a Physical Node's runtime capabilities.
28
+ """Per-physical-node runtime configuration.
29
+
30
+ ``op_bindings`` is a per-node override map (logical_op -> kernel_id) merged
31
+ into that node's ``RuntimeContext``. Unknown future / auxiliary fields are
32
+ preserved in ``extra``.
30
33
  """
31
34
 
32
35
  version: str
33
36
  platform: str
34
- backends: list[str]
37
+ # Per-node partial override dispatch table (merged over project defaults).
38
+ op_bindings: dict[str, str] = field(default_factory=dict)
35
39
 
36
- # A catch-all for any other custom or future properties
40
+ # A catch-all for any other custom or future properties (must not collide
41
+ # with reserved keys: version, platform, op_bindings).
37
42
  extra: dict[str, Any] = field(default_factory=dict)
38
43
 
39
44
  def to_dict(self) -> dict[str, Any]:
40
- """Convert RuntimeInfo to a dictionary."""
45
+ """Convert RuntimeInfo to a dictionary (stable field names)."""
41
46
  result = {
42
47
  "version": self.version,
43
48
  "platform": self.platform,
44
- "backends": self.backends,
49
+ "op_bindings": self.op_bindings,
45
50
  }
46
51
  result.update(self.extra)
47
52
  return result
@@ -175,7 +180,8 @@ class ClusterSpec:
175
180
 
176
181
  # 2. Parse Physical Nodes, using the list index as the rank
177
182
  nodes_map: dict[str, Node] = {}
178
- known_runtime_fields = {"version", "platform", "backends"}
183
+ # Reserved runtime info keys we recognize explicitly.
184
+ known_runtime_fields = {"version", "platform", "op_bindings"}
179
185
  for i, node_cfg in enumerate(config["nodes"]):
180
186
  if "rank" in node_cfg:
181
187
  # Optionally, we can log a warning that the explicit 'rank' is ignored.
@@ -187,11 +193,12 @@ class ClusterSpec:
187
193
  for k, v in runtime_info_cfg.items()
188
194
  if k not in known_runtime_fields
189
195
  }
190
-
196
+ # Gracefully ignore legacy 'backends' if present (treated as extra)
197
+ # for backward compatibility.
191
198
  runtime_info = RuntimeInfo(
192
199
  version=runtime_info_cfg.get("version", "N/A"),
193
200
  platform=runtime_info_cfg.get("platform", "N/A"),
194
- backends=runtime_info_cfg.get("backends", []),
201
+ op_bindings=runtime_info_cfg.get("op_bindings", {}) or {},
195
202
  extra=extra_runtime_info,
196
203
  )
197
204
 
@@ -227,32 +234,96 @@ class ClusterSpec:
227
234
  return cls(nodes=nodes_map, devices=devices_map)
228
235
 
229
236
  @classmethod
230
- def simple(cls, world_size: int) -> ClusterSpec:
231
- """Creates a simple cluster spec for simulation with the given number of parties."""
232
- nodes = {
233
- f"node{i}": Node(
237
+ def simple(
238
+ cls,
239
+ world_size: int,
240
+ *,
241
+ endpoints: list[str] | None = None,
242
+ spu_protocol: str = "SEMI2K",
243
+ spu_field: str = "FM128",
244
+ runtime_version: str = "simulated",
245
+ runtime_platform: str = "simulated",
246
+ op_bindings: list[dict[str, str]] | None = None,
247
+ enable_local_device: bool = True,
248
+ enable_spu_device: bool = True,
249
+ ) -> ClusterSpec:
250
+ """Convenience constructor used heavily in tests.
251
+
252
+ Parameters
253
+ ----------
254
+ world_size:
255
+ Number of parties (physical nodes).
256
+ endpoints:
257
+ Optional explicit endpoint list of length ``world_size``. Each element may
258
+ include scheme (``http://``) or not; stored verbatim. If not provided we
259
+ synthesize ``localhost:{5000 + i}`` (5000 is a fixed default; pass explicit
260
+ endpoints for control). Deprecated ``base_port`` legacy kwarg can adjust it.
261
+ spu_protocol / spu_field:
262
+ SPU device config values.
263
+ runtime_version / runtime_platform:
264
+ Populated into each node's ``RuntimeInfo``.
265
+ op_bindings:
266
+ Optional list of length ``world_size`` supplying per-node op_bindings
267
+ override dicts (defaults to empty dicts).
268
+ enable_local_device:
269
+ If True (default), create one ``local_{rank}`` device per node.
270
+ enable_spu_device:
271
+ If True (default) create a shared SPU device named ``SP0``.
272
+ """
273
+ base_port = 5000
274
+
275
+ if endpoints is not None and len(endpoints) != world_size:
276
+ raise ValueError(
277
+ "len(endpoints) must equal world_size when provided: "
278
+ f"{len(endpoints)} != {world_size}"
279
+ )
280
+
281
+ if op_bindings is not None and len(op_bindings) != world_size:
282
+ raise ValueError(
283
+ "len(op_bindings) must equal world_size when provided: "
284
+ f"{len(op_bindings)} != {world_size}"
285
+ )
286
+
287
+ if not enable_local_device and not enable_spu_device:
288
+ raise ValueError(
289
+ "At least one of enable_local_device or enable_spu_device must be True"
290
+ )
291
+
292
+ nodes: dict[str, Node] = {}
293
+ for i in range(world_size):
294
+ ep = endpoints[i] if endpoints is not None else f"localhost:{base_port + i}"
295
+ node_op_bindings = op_bindings[i] if op_bindings is not None else {}
296
+ nodes[f"node{i}"] = Node(
234
297
  name=f"node{i}",
235
298
  rank=i,
236
- endpoint=f"localhost:{5000 + i}",
299
+ endpoint=ep,
237
300
  runtime_info=RuntimeInfo(
238
- version="simulated",
239
- platform="simulated",
240
- backends=["__all__"],
301
+ version=runtime_version,
302
+ platform=runtime_platform,
303
+ op_bindings=node_op_bindings,
241
304
  ),
242
305
  )
243
- for i in range(world_size)
244
- }
245
306
 
246
- devices = {
247
- "SP0": Device(
307
+ devices: dict[str, Device] = {}
308
+ # Optional per-node local devices
309
+ if enable_local_device:
310
+ for i in range(world_size):
311
+ devices[f"local_{i}"] = Device(
312
+ name=f"local_{i}",
313
+ kind="local",
314
+ members=[nodes[f"node{i}"]],
315
+ )
316
+
317
+ # Shared SPU device
318
+ if enable_spu_device:
319
+ devices["SP0"] = Device(
248
320
  name="SP0",
249
321
  kind="SPU",
250
322
  members=list(nodes.values()),
251
323
  config={
252
- "protocol": "SEMI2K",
253
- "field": "FM128",
324
+ "protocol": spu_protocol,
325
+ "field": spu_field,
254
326
  },
255
327
  )
256
- }
257
328
 
258
329
  return cls(nodes=nodes, devices=devices)
mplang/runtime/client.py CHANGED
@@ -81,21 +81,14 @@ class HttpExecutorClient:
81
81
  self,
82
82
  name: str,
83
83
  rank: int,
84
- endpoints: list[str],
85
- *,
86
- spu_mask: int = 0,
87
- spu_protocol: str = "SEMI2K",
88
- spu_field: str = "FM64",
84
+ cluster_spec: dict,
89
85
  ) -> str:
90
86
  """Create a new session.
91
87
 
92
88
  Args:
93
89
  name: Session name/ID.
94
- rank: The rank of this party in the session.
95
- endpoints: List of endpoint URLs for all parties, indexed by rank.
96
- spu_mask: SPU mask for the session, 0 means no SPU.
97
- spu_protocol: SPU protocol for the session (e.g., "SEMI2K", "ABY3").
98
- spu_field: SPU field for the session (e.g., "FM64", "FM128").
90
+ rank: This party's rank.
91
+ cluster_spec: Full cluster specification dict (ClusterSpec.to_dict()).
99
92
 
100
93
  Returns:
101
94
  The session name/ID
@@ -104,14 +97,7 @@ class HttpExecutorClient:
104
97
  RuntimeError: If session creation fails
105
98
  """
106
99
  url = f"/sessions/{name}"
107
-
108
- payload: dict[str, Any] = {
109
- "rank": rank,
110
- "endpoints": endpoints,
111
- "spu_mask": spu_mask,
112
- "spu_protocol": spu_protocol,
113
- "spu_field": spu_field,
114
- }
100
+ payload: dict[str, Any] = {"rank": rank, "cluster_spec": cluster_spec}
115
101
 
116
102
  try:
117
103
  response = await self._client.put(url, json=payload)
mplang/runtime/driver.py CHANGED
@@ -145,8 +145,6 @@ class Driver(InterpContext):
145
145
  """Get existing session or create a new one across all HTTP servers."""
146
146
  if self._session_id is None:
147
147
  new_session_id = new_uuid()
148
- endpoints_list = list(self.node_addrs.values())
149
-
150
148
  # Create temporary clients for session creation
151
149
  clients = self._create_clients()
152
150
  try:
@@ -158,10 +156,7 @@ class Driver(InterpContext):
158
156
  task = client.create_session(
159
157
  name=new_session_id,
160
158
  rank=rank,
161
- endpoints=endpoints_list,
162
- spu_mask=self.spu_mask_int,
163
- spu_protocol=self.spu_protocol_str,
164
- spu_field=self.spu_field_str,
159
+ cluster_spec=self.cluster_spec.to_dict(),
165
160
  )
166
161
  tasks.append(task)
167
162
 
mplang/runtime/server.py CHANGED
@@ -32,14 +32,30 @@ from mplang.core.table import TableType
32
32
  from mplang.core.tensor import TensorType
33
33
  from mplang.kernels.base import KernelContext
34
34
  from mplang.protos.v1alpha1 import mpir_pb2
35
- from mplang.runtime import resource
36
35
  from mplang.runtime.data_providers import DataProvider, ResolvedURI, register_provider
37
36
  from mplang.runtime.exceptions import InvalidRequestError, ResourceNotFound
37
+ from mplang.runtime.session import (
38
+ Computation,
39
+ Session,
40
+ Symbol,
41
+ )
38
42
 
39
43
  logger = logging.getLogger(__name__)
40
44
 
41
45
  app = FastAPI()
42
46
 
47
+ # per-server global state
48
+ _sessions: dict[str, Session] = {}
49
+ _global_symbols: dict[str, Symbol] = {}
50
+
51
+
52
+ def register_session(session: Session) -> Session: # pragma: no cover - test helper
53
+ existing = _sessions.get(session.name)
54
+ if existing:
55
+ return existing
56
+ _sessions[session.name] = session
57
+ return session
58
+
43
59
 
44
60
  class _SymbolsProvider(DataProvider):
45
61
  """Server-local symbols provider backed by BackendRuntime.state."""
@@ -83,7 +99,7 @@ class _SymbolsProvider(DataProvider):
83
99
  ctx: KernelContext,
84
100
  ) -> Any: # type: ignore[override]
85
101
  name = self._symbol_name(uri)
86
- sym = resource.get_global_symbol(name)
102
+ sym = _global_symbols.get(name)
87
103
  if sym is None:
88
104
  raise ResourceNotFound(f"Global symbol '{name}' not found")
89
105
  return sym.data
@@ -102,8 +118,13 @@ class _SymbolsProvider(DataProvider):
102
118
  raise InvalidRequestError(
103
119
  f"Failed to encode value for symbols:// write: {e!s}"
104
120
  ) from e
105
-
106
- resource.create_global_symbol(name, {}, data_b64)
121
+ try:
122
+ obj = pickle.loads(base64.b64decode(data_b64))
123
+ except Exception as e: # pragma: no cover - defensive
124
+ raise InvalidRequestError(
125
+ f"Failed to decode value for symbols:// write: {e!s}"
126
+ ) from e
127
+ _global_symbols[name] = Symbol(name=name, mptype={}, data=obj)
107
128
 
108
129
 
109
130
  # Register symbols provider explicitly for server runtime
@@ -168,11 +189,7 @@ def validate_name(name: str, name_type: str) -> None:
168
189
  # Request/Response Models
169
190
  class CreateSessionRequest(BaseModel):
170
191
  rank: int
171
- endpoints: list[str]
172
- # SPU related
173
- spu_mask: int
174
- spu_protocol: str
175
- spu_field: str
192
+ cluster_spec: dict
176
193
 
177
194
 
178
195
  class SessionResponse(BaseModel):
@@ -229,7 +246,7 @@ async def health_check() -> dict[str, str]:
229
246
  @app.get("/sessions", response_model=SessionListResponse)
230
247
  def list_sessions() -> SessionListResponse:
231
248
  """List all session names."""
232
- return SessionListResponse(sessions=resource.list_all_sessions())
249
+ return SessionListResponse(sessions=list(_sessions.keys()))
233
250
 
234
251
 
235
252
  # List all computations in a session
@@ -238,39 +255,44 @@ def list_sessions() -> SessionListResponse:
238
255
  )
239
256
  def list_session_computations(session_name: str) -> ComputationListResponse:
240
257
  """List all computation names in a session."""
241
- session = resource.get_session(session_name)
242
- if not session:
258
+ sess = _sessions.get(session_name)
259
+ if not sess:
243
260
  raise ResourceNotFound(f"Session '{session_name}' not found")
244
- return ComputationListResponse(computations=list(session.computations.keys()))
261
+ return ComputationListResponse(computations=sess.list_computations())
245
262
 
246
263
 
247
264
  # Session endpoints
248
265
  @app.put("/sessions/{session_name}", response_model=SessionResponse)
249
266
  def create_session(session_name: str, request: CreateSessionRequest) -> SessionResponse:
250
267
  validate_name(session_name, "session")
251
- session = resource.create_session(
252
- name=session_name,
253
- rank=request.rank,
254
- endpoints=request.endpoints,
255
- spu_mask=request.spu_mask,
256
- spu_protocol=request.spu_protocol,
257
- spu_field=request.spu_field,
258
- )
259
- return SessionResponse(name=session.name)
268
+ # Delegate cluster spec parsing & session construction to resource layer
269
+ from mplang.core.cluster import ClusterSpec # local import to avoid cycles
270
+
271
+ if session_name in _sessions:
272
+ sess = _sessions[session_name]
273
+ else:
274
+ spec = ClusterSpec.from_dict(request.cluster_spec)
275
+ if len(spec.get_devices_by_kind("SPU")) == 0:
276
+ raise InvalidRequestError("No SPU device found in cluster_spec for session")
277
+ sess = Session(name=session_name, rank=request.rank, cluster_spec=spec)
278
+ _sessions[session_name] = sess
279
+ return SessionResponse(name=sess.name)
260
280
 
261
281
 
262
282
  @app.get("/sessions/{session_name}", response_model=SessionResponse)
263
283
  def get_session(session_name: str) -> SessionResponse:
264
- session = resource.get_session(session_name)
265
- if not session:
284
+ sess = _sessions.get(session_name)
285
+ if not sess:
266
286
  raise ResourceNotFound(f"Session '{session_name}' not found")
267
- return SessionResponse(name=session.name)
287
+ return SessionResponse(name=sess.name)
268
288
 
269
289
 
270
290
  @app.delete("/sessions/{session_name}")
271
291
  def delete_session(session_name: str) -> dict[str, str]:
272
292
  """Delete a session and all its associated resources."""
273
- if resource.delete_session(session_name):
293
+ if session_name in _sessions:
294
+ del _sessions[session_name]
295
+ logging.info(f"Session {session_name} deleted successfully")
274
296
  return {"message": f"Session '{session_name}' deleted successfully"}
275
297
  else:
276
298
  raise ResourceNotFound(f"Session '{session_name}' not found")
@@ -299,18 +321,25 @@ def create_and_execute_computation(
299
321
  raise InvalidRequestError("Failed to parse expression from protobuf")
300
322
 
301
323
  # Create the computation resource
302
- computation = resource.create_computation(session_name, computation_id, expr)
303
- # Execute with input/output names
304
- resource.execute_computation(
305
- session_name, computation.name, request.input_names, request.output_names
306
- )
307
- return ComputationResponse(name=computation.name)
324
+ sess = _sessions.get(session_name)
325
+ if not sess:
326
+ raise ResourceNotFound(f"Session '{session_name}' not found.")
327
+ comp = sess.get_computation(computation_id)
328
+ if not comp:
329
+ comp = Computation(name=computation_id, expr=expr)
330
+ sess.add_computation(comp)
331
+ sess.execute(comp, request.input_names, request.output_names)
332
+ return ComputationResponse(name=computation_id)
308
333
 
309
334
 
310
335
  @app.delete("/sessions/{session_name}/computations/{computation_id}")
311
336
  def delete_computation(session_name: str, computation_id: str) -> dict[str, str]:
312
337
  """Delete a specific computation."""
313
- if resource.delete_computation(session_name, computation_id):
338
+ sess = _sessions.get(session_name)
339
+ if sess and sess.delete_computation(computation_id):
340
+ logging.info(
341
+ f"Computation {computation_id} deleted from session {session_name}"
342
+ )
314
343
  return {"message": f"Computation '{computation_id}' deleted successfully"}
315
344
  else:
316
345
  raise ResourceNotFound(
@@ -326,9 +355,15 @@ def create_session_symbol(
326
355
  session_name: str, symbol_name: str, request: CreateSymbolRequest
327
356
  ) -> SymbolResponse:
328
357
  """Create a symbol in a session."""
329
- symbol = resource.create_symbol(
330
- session_name, symbol_name, request.mptype, request.data
331
- )
358
+ sess = _sessions.get(session_name)
359
+ if not sess:
360
+ raise ResourceNotFound(f"Session '{session_name}' not found.")
361
+ try:
362
+ obj = pickle.loads(base64.b64decode(request.data))
363
+ except Exception as e:
364
+ raise InvalidRequestError(f"Invalid symbol data: {e!s}") from e
365
+ symbol = Symbol(name=symbol_name, mptype=request.mptype, data=obj)
366
+ sess.add_symbol(symbol)
332
367
  # Return the base64 data back to client; server stores Python object
333
368
  return SymbolResponse(
334
369
  name=symbol.name,
@@ -346,8 +381,8 @@ def get_session_symbol(session_name: str, symbol_name: str) -> SymbolResponse:
346
381
  logger.debug(
347
382
  f"Looking for symbol: '{symbol_name}' in session: '{session_name}'"
348
383
  )
349
-
350
- symbol = resource.get_symbol(session_name, symbol_name)
384
+ sess = _sessions.get(session_name)
385
+ symbol = sess.get_symbol(symbol_name) if sess else None
351
386
  if not symbol:
352
387
  raise HTTPException(
353
388
  status_code=404, detail=f"Symbol {symbol_name} not found"
@@ -368,14 +403,19 @@ def get_session_symbol(session_name: str, symbol_name: str) -> SymbolResponse:
368
403
  @app.get("/sessions/{session_name}/symbols")
369
404
  def list_session_symbols(session_name: str) -> dict[str, list[str]]:
370
405
  """List all symbols in a session."""
371
- symbols = resource.list_symbols(session_name)
406
+ sess = _sessions.get(session_name)
407
+ if not sess:
408
+ raise ResourceNotFound(f"Session '{session_name}' not found.")
409
+ symbols = sess.list_symbols()
372
410
  return {"symbols": symbols}
373
411
 
374
412
 
375
413
  @app.delete("/sessions/{session_name}/symbols/{symbol_name}")
376
414
  def delete_symbol(session_name: str, symbol_name: str) -> dict[str, str]:
377
415
  """Delete a specific symbol."""
378
- if resource.delete_symbol(session_name, symbol_name):
416
+ sess = _sessions.get(session_name)
417
+ if sess and sess.delete_symbol(symbol_name):
418
+ logging.info(f"Symbol {symbol_name} deleted from session {session_name}")
379
419
  return {"message": f"Symbol '{symbol_name}' deleted successfully"}
380
420
  else:
381
421
  raise ResourceNotFound(
@@ -389,13 +429,18 @@ def create_global_symbol(
389
429
  symbol_name: str, request: CreateSymbolRequest
390
430
  ) -> GlobalSymbolResponse:
391
431
  validate_name(symbol_name, "symbol")
392
- sym = resource.create_global_symbol(symbol_name, request.mptype, request.data)
432
+ try:
433
+ obj = pickle.loads(base64.b64decode(request.data))
434
+ except Exception as e:
435
+ raise InvalidRequestError(f"Invalid global symbol data: {e!s}") from e
436
+ sym = Symbol(name=symbol_name, mptype=request.mptype, data=obj)
437
+ _global_symbols[symbol_name] = sym
393
438
  return GlobalSymbolResponse(name=sym.name, mptype=sym.mptype, data=request.data)
394
439
 
395
440
 
396
441
  @app.get("/api/v1/symbols/{symbol_name}", response_model=GlobalSymbolResponse)
397
- def get_global_symbol(symbol_name: str) -> GlobalSymbolResponse:
398
- sym = resource.get_global_symbol(symbol_name)
442
+ def get_global_symbol(symbol_name: str) -> GlobalSymbolResponse: # route handler
443
+ sym = _global_symbols.get(symbol_name)
399
444
  if not sym:
400
445
  raise ResourceNotFound(f"Global symbol '{symbol_name}' not found")
401
446
  data_bytes = pickle.dumps(sym.data)
@@ -405,12 +450,13 @@ def get_global_symbol(symbol_name: str) -> GlobalSymbolResponse:
405
450
 
406
451
  @app.get("/api/v1/symbols")
407
452
  def list_global_symbols() -> dict[str, list[str]]:
408
- return {"symbols": resource.list_global_symbols()}
453
+ return {"symbols": list(_global_symbols.keys())}
409
454
 
410
455
 
411
456
  @app.delete("/api/v1/symbols/{symbol_name}")
412
- def delete_global_symbol(symbol_name: str) -> dict[str, str]:
413
- if resource.delete_global_symbol(symbol_name):
457
+ def delete_global_symbol(symbol_name: str) -> dict[str, str]: # route handler
458
+ if symbol_name in _global_symbols:
459
+ del _global_symbols[symbol_name]
414
460
  return {"message": f"Global symbol '{symbol_name}' deleted successfully"}
415
461
  else:
416
462
  raise ResourceNotFound(f"Global symbol '{symbol_name}' not found")
@@ -426,8 +472,8 @@ def comm_send(
426
472
  Receive a message from another party and deliver it to the session's communicator.
427
473
  This endpoint runs on the receiver's server.
428
474
  """
429
- session = resource.get_session(session_name)
430
- if not session or not session.communicator:
475
+ sess = _sessions.get(session_name)
476
+ if not sess or not sess.communicator:
431
477
  logger.error(f"Session or communicator not found: session={session_name}")
432
478
  raise HTTPException(status_code=404, detail="Session or communicator not found")
433
479
 
@@ -435,5 +481,5 @@ def comm_send(
435
481
  # We don't need to validate to_rank since the request is coming to this server
436
482
 
437
483
  # Use the proper onSent mechanism from CommunicatorBase
438
- session.communicator.onSent(from_rank, key, request.data)
484
+ sess.communicator.onSent(from_rank, key, request.data)
439
485
  return {"status": "ok"}
@@ -0,0 +1,285 @@
1
+ # Copyright 2025 Ant Group Co., Ltd.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """Core Session model (pure, no global registries).
16
+
17
+ Contents:
18
+ * SessionState dataclass
19
+ * LinkCommFactory (SPU link reuse cache)
20
+ * Session (topology derivation, runtime init, SPU env seeding, local symbol/computation storage)
21
+
22
+ Process-wide registries (sessions, global symbols) live in the server layer
23
+ (`server.py`) so this module remains portable and easy to unit test.
24
+ """
25
+
26
+ from __future__ import annotations
27
+
28
+ import logging
29
+ import time
30
+ from dataclasses import dataclass, field
31
+ from functools import cached_property
32
+ from typing import TYPE_CHECKING, Any, cast
33
+ from urllib.parse import urlparse
34
+
35
+ import spu.libspu as libspu
36
+
37
+ from mplang.core.expr.ast import Expr
38
+ from mplang.core.expr.evaluator import IEvaluator, create_evaluator
39
+ from mplang.core.mask import Mask
40
+ from mplang.kernels.context import RuntimeContext
41
+ from mplang.kernels.spu import PFunction # type: ignore
42
+ from mplang.runtime.communicator import HttpCommunicator
43
+ from mplang.runtime.exceptions import ResourceNotFound
44
+ from mplang.runtime.link_comm import LinkCommunicator
45
+ from mplang.utils.spu_utils import parse_field, parse_protocol
46
+
47
+ if TYPE_CHECKING: # pragma: no cover - import only for type checking
48
+ from mplang.core.cluster import ClusterSpec, Node, RuntimeInfo
49
+
50
+
51
+ class LinkCommFactory:
52
+ """Factory for creating and caching link communicators."""
53
+
54
+ def __init__(self) -> None:
55
+ self._cache: dict[tuple[int, tuple[str, ...]], LinkCommunicator] = {}
56
+
57
+ def create_link(self, rel_rank: int, addrs: list[str]) -> LinkCommunicator:
58
+ key = (rel_rank, tuple(addrs))
59
+ link = self._cache.get(key)
60
+ if link is not None:
61
+ return link
62
+ logging.info(f"LinkCommunicator created: rel_rank={rel_rank} addrs={addrs}")
63
+ link = LinkCommunicator(rel_rank, addrs)
64
+ self._cache[key] = link
65
+ return link
66
+
67
+
68
+ # Shared link factory (module-local, not global registry of sessions)
69
+ g_link_factory = LinkCommFactory()
70
+
71
+
72
+ @dataclass
73
+ class Symbol:
74
+ name: str
75
+ mptype: Any
76
+ data: Any
77
+
78
+
79
+ @dataclass
80
+ class Computation:
81
+ name: str
82
+ expr: Expr
83
+
84
+
85
+ @dataclass
86
+ class SessionState:
87
+ runtime: RuntimeContext | None = None
88
+ computations: dict[str, Computation] = field(default_factory=dict)
89
+ symbols: dict[str, Symbol] = field(default_factory=dict)
90
+ spu_seeded: bool = False
91
+ created_ts: float = field(default_factory=time.time)
92
+ last_access_ts: float = field(default_factory=time.time)
93
+
94
+
95
+ class Session:
96
+ """Represents the per-rank execution context.
97
+
98
+ Immutable config: name, rank, cluster_spec.
99
+ Derived: node, runtime_info, endpoints, spu_device, spu_mask, protocol/field, is_spu_party.
100
+ Mutable: state (runtime object, symbols, computations, seeded flag).
101
+ """
102
+
103
+ def __init__(self, name: str, rank: int, cluster_spec: ClusterSpec):
104
+ self.name = name
105
+ self.rank = rank
106
+ self.cluster_spec = cluster_spec
107
+ self.state = SessionState()
108
+ self.communicator = HttpCommunicator(
109
+ session_name=name, rank=rank, endpoints=self.endpoints
110
+ )
111
+
112
+ # --- Derived topology ---
113
+ @cached_property
114
+ def node(self) -> Node:
115
+ return self.cluster_spec.get_node_by_rank(self.rank)
116
+
117
+ @property
118
+ def runtime_info(self) -> RuntimeInfo:
119
+ return self.node.runtime_info
120
+
121
+ @cached_property
122
+ def endpoints(self) -> list[str]:
123
+ eps: list[str] = []
124
+ for n in sorted(
125
+ self.cluster_spec.nodes.values(),
126
+ key=lambda x: x.rank, # type: ignore[attr-defined]
127
+ ):
128
+ ep = n.endpoint
129
+ if not ep.startswith(("http://", "https://")):
130
+ ep = f"http://{ep}"
131
+ eps.append(ep)
132
+ return eps
133
+
134
+ @cached_property
135
+ def spu_device(self): # type: ignore
136
+ devs = self.cluster_spec.get_devices_by_kind("SPU")
137
+ if len(devs) != 1:
138
+ raise RuntimeError(
139
+ f"Expected exactly one SPU device, got {len(devs)} (session={self.name})"
140
+ )
141
+ return devs[0]
142
+
143
+ @cached_property
144
+ def spu_mask(self) -> Mask:
145
+ return Mask.from_ranks([m.rank for m in self.spu_device.members])
146
+
147
+ @property
148
+ def spu_protocol(self) -> str:
149
+ return cast(str, self.spu_device.config.get("protocol", "SEMI2K"))
150
+
151
+ @property
152
+ def spu_field(self) -> str:
153
+ return cast(str, self.spu_device.config.get("field", "FM64"))
154
+
155
+ @property
156
+ def is_spu_party(self) -> bool:
157
+ return self.rank in self.spu_mask
158
+
159
+ # --- Runtime helpers ---
160
+ def ensure_runtime(self) -> RuntimeContext:
161
+ if self.state.runtime is None:
162
+ self.state.runtime = RuntimeContext(
163
+ rank=self.rank,
164
+ world_size=len(self.cluster_spec.nodes), # type: ignore[attr-defined]
165
+ initial_bindings=(
166
+ self.runtime_info.op_bindings if self.runtime_info else {}
167
+ ),
168
+ )
169
+ return self.state.runtime
170
+
171
+ def ensure_spu_env(self) -> None:
172
+ """Ensure SPU kernel env (config/world[/link]) registered on this runtime.
173
+
174
+ Previous logic only seeded SPU parties; non-participating ranks then raised
175
+ a hard error when the evaluator encountered SPU ops in the global program,
176
+ because the kernel pocket lacked config/world. For now we register the
177
+ config/world on ALL parties (idempotent) and only attach a link context for
178
+ participating SPU ranks. Non-parties will still error later if they try to
179
+ execute a link-dependent SPU kernel (which should be guarded by masks in the
180
+ IR), but they will no longer fail early with a misleading
181
+ "SPU kernel state not initialized" message.
182
+ """
183
+ if self.state.spu_seeded:
184
+ return
185
+
186
+ link_ctx = None
187
+ # Fixed port offset for SPU runtime link services (legacy value retained).
188
+ # TODO: make configurable if future deployments require dynamic offset.
189
+ SPU_PORT_OFFSET = 100
190
+
191
+ if self.is_spu_party:
192
+ # Build SPU address list across all endpoints for ranks in mask
193
+ spu_addrs: list[str] = []
194
+ for r, addr in enumerate(self.communicator.endpoints):
195
+ if r in self.spu_mask:
196
+ if "//" not in addr:
197
+ addr = f"//{addr}"
198
+ parsed = urlparse(addr)
199
+ assert isinstance(parsed.port, int)
200
+ new_addr = f"{parsed.hostname}:{parsed.port + SPU_PORT_OFFSET}"
201
+ spu_addrs.append(new_addr)
202
+ rel_index = sum(1 for r in range(self.rank) if r in self.spu_mask)
203
+ link_ctx = g_link_factory.create_link(rel_index, spu_addrs)
204
+
205
+ spu_config = libspu.RuntimeConfig(
206
+ protocol=parse_protocol(self.spu_protocol),
207
+ field=parse_field(self.spu_field),
208
+ fxp_fraction_bits=18,
209
+ )
210
+ seed_pfunc = PFunction(
211
+ fn_type="spu.seed_env",
212
+ ins_info=(),
213
+ outs_info=(),
214
+ config=spu_config,
215
+ world=self.spu_mask.num_parties(),
216
+ link=link_ctx,
217
+ )
218
+ self.ensure_runtime().run_kernel(seed_pfunc, [])
219
+ self.state.spu_seeded = True
220
+
221
+ # --- Computations & Symbols (instance-local) ---
222
+ def add_computation(self, computation: Computation) -> None:
223
+ self.state.computations[computation.name] = computation
224
+
225
+ def get_computation(self, name: str) -> Computation | None:
226
+ return self.state.computations.get(name)
227
+
228
+ def add_symbol(self, symbol: Symbol) -> None:
229
+ self.state.symbols[symbol.name] = symbol
230
+
231
+ def get_symbol(self, name: str) -> Symbol | None:
232
+ return self.state.symbols.get(name)
233
+
234
+ def list_symbols(self) -> list[str]: # pragma: no cover - trivial
235
+ return list(self.state.symbols.keys())
236
+
237
+ def delete_symbol(self, name: str) -> bool:
238
+ if name in self.state.symbols:
239
+ del self.state.symbols[name]
240
+ return True
241
+ return False
242
+
243
+ def list_computations(self) -> list[str]: # pragma: no cover - trivial
244
+ return list(self.state.computations.keys())
245
+
246
+ def delete_computation(self, name: str) -> bool:
247
+ if name in self.state.computations:
248
+ del self.state.computations[name]
249
+ return True
250
+ return False
251
+
252
+ # --- Execution ---
253
+ def execute(
254
+ self, computation: Computation, input_names: list[str], output_names: list[str]
255
+ ) -> None:
256
+ env: dict[str, Any] = {}
257
+ for in_name in input_names:
258
+ sym = self.get_symbol(in_name)
259
+ if sym is None:
260
+ raise ResourceNotFound(
261
+ f"Input symbol '{in_name}' not found in session '{self.name}'"
262
+ )
263
+ env[in_name] = sym.data
264
+ rt = self.ensure_runtime()
265
+ self.ensure_spu_env()
266
+ evaluator: IEvaluator = create_evaluator(
267
+ rank=self.rank, env=env, comm=self.communicator, runtime=rt
268
+ )
269
+ results = evaluator.evaluate(computation.expr)
270
+ if results and len(results) != len(output_names):
271
+ raise RuntimeError(
272
+ f"Expected {len(output_names)} results, got {len(results)}"
273
+ )
274
+ for name, val in zip(output_names, results, strict=True):
275
+ self.add_symbol(Symbol(name=name, mptype={}, data=val))
276
+
277
+ # --- Convenience constructor ---
278
+ @classmethod
279
+ def from_cluster_spec_dict(cls, name: str, rank: int, spec_dict: dict) -> Session:
280
+ from mplang.core.cluster import ClusterSpec # local import to avoid cycles
281
+
282
+ spec = ClusterSpec.from_dict(spec_dict)
283
+ if len(spec.get_devices_by_kind("SPU")) == 0:
284
+ raise RuntimeError("No SPU device found in cluster_spec")
285
+ return cls(name=name, rank=rank, cluster_spec=spec)
@@ -86,20 +86,17 @@ class Simulator(InterpContext):
86
86
  cluster_spec: ClusterSpec,
87
87
  *,
88
88
  trace_ranks: list[int] | None = None,
89
- op_bindings: dict[str, str] | None = None,
90
89
  ) -> None:
91
90
  """Initialize a simulator with the given cluster specification.
92
91
 
93
92
  Args:
94
93
  cluster_spec: The cluster specification defining the simulation environment.
95
94
  trace_ranks: List of ranks to trace execution for debugging.
96
- op_bindings: Optional op->kernel binding template applied to all
97
- RuntimeContexts. These are static dispatch overrides (merged
98
- with project defaults) and are orthogonal to the per-evaluate
99
- variable ``bindings`` dict passed into ``evaluate``.
95
+ Per-node op binding overrides should now be provided via
96
+ each node's `runtime_info.op_bindings` in the supplied
97
+ `cluster_spec`.
100
98
  """
101
99
  super().__init__(cluster_spec)
102
- self._op_bindings_template = op_bindings or {}
103
100
  self._trace_ranks = trace_ranks or []
104
101
 
105
102
  spu_devices = cluster_spec.get_devices_by_kind("SPU")
@@ -145,21 +142,22 @@ class Simulator(InterpContext):
145
142
 
146
143
  # Persistent per-rank RuntimeContext instances (reused across evaluates).
147
144
  # We no longer pre-create evaluators since each evaluate has different env bindings.
148
- self._runtimes: list[RuntimeContext] = [
149
- RuntimeContext(
145
+ # Build per-rank runtime contexts.
146
+ self._runtimes: list[RuntimeContext] = []
147
+ for rank in range(self.world_size()):
148
+ node = self.cluster_spec.get_node_by_rank(rank)
149
+ rt = RuntimeContext(
150
150
  rank=rank,
151
151
  world_size=self.world_size(),
152
- # Static op bindings template cloned into each runtime. These are kernel
153
- # dispatch mappings, not per-evaluate variable bindings.
154
- initial_bindings=self._op_bindings_template,
152
+ initial_bindings=node.runtime_info.op_bindings,
155
153
  )
156
- for rank in range(self.world_size())
157
- ]
154
+ self._runtimes.append(rt)
158
155
 
159
156
  @classmethod
160
157
  def simple(
161
158
  cls,
162
159
  world_size: int,
160
+ op_bindings: dict[str, str] | None = None,
163
161
  **kwargs: Any,
164
162
  ) -> Simulator:
165
163
  """Create a simple simulator with the given number of parties.
@@ -175,6 +173,10 @@ class Simulator(InterpContext):
175
173
  A Simulator instance with a simple cluster configuration.
176
174
  """
177
175
  cluster_spec = ClusterSpec.simple(world_size)
176
+ if op_bindings:
177
+ # Apply the same op_bindings to every node's runtime_info for convenience
178
+ for node in cluster_spec.nodes.values():
179
+ node.runtime_info.op_bindings.update(op_bindings)
178
180
  return cls(cluster_spec, **kwargs)
179
181
 
180
182
  def _do_evaluate(self, expr: Expr, evaluator_engine: IEvaluator) -> Any:
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: mplang-nightly
3
- Version: 0.1.dev152
3
+ Version: 0.1.dev153
4
4
  Summary: Multi-Party Programming Language
5
5
  Author-email: SecretFlow Team <secretflow-contact@service.alipay.com>
6
6
  License: Apache License
@@ -4,7 +4,7 @@ mplang/device.py,sha256=RmjnhzHxJkkNmtBKtYMEbpQYBZpuC43qlllkCOp-QD8,12548
4
4
  mplang/analysis/__init__.py,sha256=CTHFvRsi-nFngojqjn08UaR3RY9i7CJ7T2UdR95kCrk,1056
5
5
  mplang/analysis/diagram.py,sha256=ffwgD12gL1_KH1uJ_EYkjmIlDrfxYJJkWj-wHl09_Xk,19520
6
6
  mplang/core/__init__.py,sha256=lWxlEKfRwX7FNDzgyKZ1fiDMaCiqkyg0j5mKlZD_v7g,2244
7
- mplang/core/cluster.py,sha256=gqMJenvXUfHhE181Dd5JiUkD4nT07RLoicBnvsGmRkE,8598
7
+ mplang/core/cluster.py,sha256=IqXHLogetegUEEAzmD8cWRash-UID06Wo3OBeZFwatg,11800
8
8
  mplang/core/comm.py,sha256=MByyu3etlQh_TkP1vKCFLIAPPuJOpl9Kjs6hOj6m4Yc,8843
9
9
  mplang/core/context_mgr.py,sha256=R0QJAod-1nYduVoOknLfAsxZiy-RtmuQcp-07HABYZU,1541
10
10
  mplang/core/dtype.py,sha256=0rZqFaFikFu9RxtdO36JLEgFL-E-lo3hH10whwkTVVY,10213
@@ -51,16 +51,16 @@ mplang/protos/v1alpha1/mpir_pb2.pyi,sha256=GwXR4wPB_kB_36iYS9x-cGI9KDKFMq89KhdLh
51
51
  mplang/protos/v1alpha1/mpir_pb2_grpc.py,sha256=xYOs94SXiNYAlFodACnsXW5QovLsHY5tCk3p76RH5Zc,158
52
52
  mplang/runtime/__init__.py,sha256=IRPP3TtpFC4iSt7_uaq-S4dL7CwrXL0XBMeaBoEYLlg,948
53
53
  mplang/runtime/cli.py,sha256=WehDodeVB4AukSWx1LJxxtKUqGmLPY4qjayrPlOg3bE,14438
54
- mplang/runtime/client.py,sha256=w8sPuQzqaJI5uS_3JHu2mf0tLaFmZH3f6-SeUBfMLMY,15737
54
+ mplang/runtime/client.py,sha256=vkJUFSDcKIdbKiGUM5AosCKTZygl9g8uZFEjw2xwKig,15249
55
55
  mplang/runtime/communicator.py,sha256=Lek6_h_Wmr_W-_JpT-vMxL3CHxcVZdtf7jdaLGuxPgQ,3199
56
56
  mplang/runtime/data_providers.py,sha256=hH2butEOYNGq2rRZjVBDfXLxe3YUin2ftAF6htbTfLA,8226
57
- mplang/runtime/driver.py,sha256=Ok1jY301ctN1_KTb4jwSxOdB0lI_xhx9AwhtEGJ-VLQ,11300
57
+ mplang/runtime/driver.py,sha256=pq2EQFZK9tH90Idops_yeF6fj0cfFVD_5mFcmy4Hzco,11089
58
58
  mplang/runtime/exceptions.py,sha256=c18U0xK20dRmgZo0ogTf5vXlkix9y3VAFuzkHxaXPEk,981
59
59
  mplang/runtime/http_api.md,sha256=-re1DhEqMplAkv_wnqEU-PSs8tTzf4-Ml0Gq0f3Go6s,4883
60
60
  mplang/runtime/link_comm.py,sha256=uNqTCGZVwWeuHAb7yXXQf0DUsMXLa8leHCkrcZdzYMU,4559
61
- mplang/runtime/resource.py,sha256=xNke4UpNDjsjWcr09oXWNBXsMfSZFOwsKD7FWdCVPbc,11688
62
- mplang/runtime/server.py,sha256=LQ5uJi95tYrKmgHwZaxUQi-aiqwSsT3W4z7pZ9dQaUQ,14716
63
- mplang/runtime/simulation.py,sha256=_cmUsYL58mvc6msHZ2fDjFAEHHLdJ-TRzJV8BxOP_WA,11473
61
+ mplang/runtime/server.py,sha256=vYjuWTWhhSLHUpsO8FDnOQ8kFzPhE-fXDDyL8GHVPj4,16673
62
+ mplang/runtime/session.py,sha256=4TQ_RPRmriv0H0S6rl_GSabxS7XrwMkdZIdcnyE8bHw,10374
63
+ mplang/runtime/simulation.py,sha256=WyIs8ta3ZM5o3RB0Bcb0MUu6Yh88Iujr27KvZFqGxig,11497
64
64
  mplang/simp/__init__.py,sha256=xNXnA8-jZAANa2A1W39b3lYO7D02zdCXl0TpivkTGS4,11579
65
65
  mplang/simp/mpi.py,sha256=Wv_Q16TQ3rdLam6OzqXiefIGSMmagGkso09ycyOkHEs,4774
66
66
  mplang/simp/random.py,sha256=7PVgWNL1j7Sf3MqT5PRiWplUu-0dyhF3Ub566iqX86M,3898
@@ -70,8 +70,8 @@ mplang/utils/crypto.py,sha256=rvPomBFtznRHc3RPi6Aip9lsU8zW2oxBqGv1K3vn7Rs,1052
70
70
  mplang/utils/func_utils.py,sha256=vCJcZmu0bEbqhOQKdpttV2_MBllIcPSN0b8U4WjNGGo,5164
71
71
  mplang/utils/spu_utils.py,sha256=S3L9RBkBe2AvSuMSQQ12cBY5Y1NPthubvErSX_7nj1A,4158
72
72
  mplang/utils/table_utils.py,sha256=aC-IZOKkSmFkpr3NZchLM0Wt0GOn-rg_xHBHREWBwAU,2202
73
- mplang_nightly-0.1.dev152.dist-info/METADATA,sha256=5Mt-R98IyuopKWLvWdxbCio3Yo8KmmmYqs92FyETS5M,16547
74
- mplang_nightly-0.1.dev152.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
75
- mplang_nightly-0.1.dev152.dist-info/entry_points.txt,sha256=mG1oJT-GAjQR834a62_QIWb7litzWPPyVnwFqm-rWuY,55
76
- mplang_nightly-0.1.dev152.dist-info/licenses/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
77
- mplang_nightly-0.1.dev152.dist-info/RECORD,,
73
+ mplang_nightly-0.1.dev153.dist-info/METADATA,sha256=4dEwwbuB0n0oRHxO09vuMY2Al57Ol8O8KdXGlDpEZqo,16547
74
+ mplang_nightly-0.1.dev153.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
75
+ mplang_nightly-0.1.dev153.dist-info/entry_points.txt,sha256=mG1oJT-GAjQR834a62_QIWb7litzWPPyVnwFqm-rWuY,55
76
+ mplang_nightly-0.1.dev153.dist-info/licenses/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
77
+ mplang_nightly-0.1.dev153.dist-info/RECORD,,
@@ -1,365 +0,0 @@
1
- # Copyright 2025 Ant Group Co., Ltd.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
-
15
- """
16
- This module provides the resource management for the HTTP backend.
17
- It is a simplified, in-memory version of the original executor's resource manager.
18
- """
19
-
20
- import base64
21
- import logging
22
- from dataclasses import dataclass, field
23
- from typing import Any
24
- from urllib.parse import urlparse
25
-
26
- import cloudpickle as pickle
27
- import spu.libspu as libspu
28
-
29
- from mplang.core.expr.ast import Expr
30
- from mplang.core.expr.evaluator import IEvaluator, create_evaluator
31
- from mplang.core.mask import Mask
32
- from mplang.kernels.context import RuntimeContext
33
- from mplang.kernels.spu import PFunction # type: ignore
34
- from mplang.runtime.communicator import HttpCommunicator
35
- from mplang.runtime.exceptions import InvalidRequestError, ResourceNotFound
36
- from mplang.runtime.link_comm import LinkCommunicator
37
- from mplang.utils.spu_utils import parse_field, parse_protocol
38
-
39
-
40
- class LinkCommFactory:
41
- """Factory for creating and caching link communicators."""
42
-
43
- def __init__(self) -> None:
44
- self._cache: dict[tuple[int, tuple[str, ...]], LinkCommunicator] = {}
45
-
46
- def create_link(self, rank: int, addrs: list[str]) -> LinkCommunicator:
47
- key = (rank, tuple(addrs))
48
- val = self._cache.get(key, None)
49
- if val is not None:
50
- return val
51
-
52
- logging.info(f"LinkCommunicator created: {rank} {addrs}")
53
- new_link = LinkCommunicator(rank, addrs)
54
- self._cache[key] = new_link
55
- return new_link
56
-
57
-
58
- # Global link factory instance
59
- g_link_factory = LinkCommFactory()
60
-
61
-
62
- @dataclass
63
- class Symbol:
64
- name: str
65
- mptype: Any # More flexible type to handle dict or MPType
66
- data: Any # More flexible data type
67
-
68
-
69
- @dataclass
70
- class Computation:
71
- name: str
72
- expr: Expr # The computation expression
73
-
74
-
75
- @dataclass
76
- class Session:
77
- name: str
78
- communicator: HttpCommunicator
79
- computations: dict[str, Computation] = field(default_factory=dict)
80
- symbols: dict[str, Symbol] = field(default_factory=dict) # Session-level symbols
81
-
82
- # spu related
83
- spu_mask: int = -1
84
- spu_protocol: str = "SEMI2K"
85
- spu_field: str = "FM64"
86
-
87
-
88
- # Global session storage
89
- _sessions: dict[str, Session] = {}
90
-
91
-
92
- # Session Management
93
- def create_session(
94
- name: str,
95
- rank: int,
96
- endpoints: list[str],
97
- # SPU related
98
- spu_mask: int = 0,
99
- spu_protocol: str = "SEMI2K",
100
- spu_field: str = "FM64",
101
- ) -> Session:
102
- logging.info(f"Creating session: {name}, rank: {rank}, spu_mask: {spu_mask}")
103
- if name in _sessions:
104
- # Return existing session (idempotent operation)
105
- logging.info(f"Session {name} already exists, returning existing session")
106
- return _sessions[name]
107
- session = Session(
108
- name, HttpCommunicator(session_name=name, rank=rank, endpoints=endpoints)
109
- )
110
-
111
- session.spu_mask = spu_mask
112
- session.spu_protocol = spu_protocol
113
- session.spu_field = spu_field
114
-
115
- _sessions[name] = session
116
- logging.info(f"Session {name} created successfully")
117
- return session
118
-
119
-
120
- def get_session(name: str) -> Session | None:
121
- return _sessions.get(name)
122
-
123
-
124
- def delete_session(name: str) -> bool:
125
- """Delete a session and all associated resources.
126
-
127
- Returns:
128
- True if session was deleted, False if session was not found.
129
- """
130
- if name in _sessions:
131
- del _sessions[name]
132
- logging.info(f"Session {name} deleted successfully")
133
- return True
134
- return False
135
-
136
-
137
- # Global symbol management (process-wide, not per-session)
138
- _global_symbols: dict[str, Symbol] = {}
139
-
140
-
141
- def create_global_symbol(name: str, mptype: dict[str, Any], data_b64: str) -> Symbol:
142
- """Create or replace a global symbol.
143
-
144
- Args:
145
- name: Symbol identifier
146
- mptype: Metadata dict (shape/dtype, etc.)
147
- data_b64: Base64-encoded pickled data
148
- """
149
- try:
150
- raw = base64.b64decode(data_b64)
151
- data = pickle.loads(raw)
152
- except Exception as e: # pragma: no cover - defensive
153
- raise InvalidRequestError(f"Failed to decode symbol payload: {e}") from e
154
- sym = Symbol(name=name, mptype=mptype, data=data)
155
- _global_symbols[name] = sym
156
- return sym
157
-
158
-
159
- def get_global_symbol(name: str) -> Symbol:
160
- sym = _global_symbols.get(name)
161
- if sym is None:
162
- raise ResourceNotFound(f"Global symbol '{name}' not found")
163
- return sym
164
-
165
-
166
- def delete_global_symbol(name: str) -> bool:
167
- return _global_symbols.pop(name, None) is not None
168
-
169
-
170
- def list_global_symbols() -> list[str]: # pragma: no cover - trivial
171
- return sorted(_global_symbols.keys())
172
-
173
-
174
- # Computation Management
175
- def create_computation(
176
- session_name: str, computation_name: str, expr: Expr
177
- ) -> Computation:
178
- """Creates a computation resource within a session."""
179
- session = get_session(session_name)
180
- if not session:
181
- raise ResourceNotFound(f"Session '{session_name}' not found.")
182
- computation = Computation(computation_name, expr)
183
- session.computations[computation_name] = computation
184
- logging.info(f"Computation {computation_name} created for session {session_name}")
185
- return computation
186
-
187
-
188
- def get_computation(session_name: str, comp_name: str) -> Computation | None:
189
- session = get_session(session_name)
190
- if session:
191
- return session.computations.get(comp_name)
192
- return None
193
-
194
-
195
- def delete_computation(session_name: str, comp_name: str) -> bool:
196
- """Delete a computation from a session.
197
-
198
- Returns:
199
- True if computation was deleted, False if not found.
200
- """
201
- session = get_session(session_name)
202
- if not session:
203
- return False
204
-
205
- if comp_name in session.computations:
206
- del session.computations[comp_name]
207
- logging.info(f"Computation {comp_name} deleted from session {session_name}")
208
- return True
209
- return False
210
-
211
-
212
- def execute_computation(
213
- session_name: str, comp_name: str, input_names: list[str], output_names: list[str]
214
- ) -> None:
215
- """Execute a computation using the Evaluator."""
216
- session = get_session(session_name)
217
- if not session:
218
- raise ResourceNotFound(f"Session '{session_name}' not found.")
219
-
220
- computation = get_computation(session_name, comp_name)
221
- if not computation:
222
- raise ResourceNotFound(
223
- f"Computation '{comp_name}' not found in session '{session_name}'."
224
- )
225
-
226
- if not session.communicator:
227
- raise InvalidRequestError(
228
- f"Communicator not initialized for session '{session_name}'."
229
- )
230
-
231
- # Get rank from session communicator
232
- rank = session.communicator.rank
233
-
234
- # Prepare input bindings from session symbols
235
- bindings = {}
236
- for input_name in input_names:
237
- symbol = get_symbol(session_name, input_name)
238
- if not symbol:
239
- raise ResourceNotFound(
240
- f"Input symbol '{input_name}' not found in session '{session_name}'"
241
- )
242
- bindings[input_name] = symbol.data
243
-
244
- spu_mask = (
245
- Mask(session.spu_mask)
246
- if session.spu_mask != -1
247
- else Mask.all(session.communicator.world_size)
248
- )
249
-
250
- # Build evaluator
251
- # Explicit per-rank backend runtime (deglobalized)
252
- runtime = RuntimeContext(rank=rank, world_size=session.communicator.world_size)
253
- evaluator: IEvaluator = create_evaluator(
254
- rank=rank, env=bindings, comm=session.communicator, runtime=runtime
255
- )
256
-
257
- # Initialize SPU runtime state for flat kernels (once per evaluator invocation)
258
- if rank in spu_mask:
259
- # Build SPU address list (only once per rank; consistent ordering of participating ranks)
260
- spu_addrs: list[str] = []
261
- for r, addr in enumerate(session.communicator.endpoints):
262
- if r in spu_mask:
263
- if "://" not in addr:
264
- addr = f"//{addr}"
265
- parsed = urlparse(addr)
266
- assert isinstance(parsed.port, int)
267
- new_addr = f"{parsed.hostname}:{parsed.port + 100}"
268
- spu_addrs.append(new_addr)
269
- # Determine this rank's relative index among participating ranks
270
- rel_index = sum(1 for r in range(rank) if r in spu_mask)
271
- link_ctx = g_link_factory.create_link(rel_index, spu_addrs)
272
- else:
273
- link_ctx = None
274
- # Always seed config/world; provide per-rank link (may be None if not participating)
275
- spu_config = libspu.RuntimeConfig(
276
- protocol=parse_protocol(session.spu_protocol),
277
- field=parse_field(session.spu_field),
278
- fxp_fraction_bits=18,
279
- )
280
- # Seed SPU env via backend kernel (inside evaluator's kernel context)
281
- seed_pfunc = PFunction(
282
- fn_type="spu.seed_env",
283
- ins_info=(),
284
- outs_info=(),
285
- config=spu_config,
286
- world=spu_mask.num_parties(),
287
- link=link_ctx,
288
- )
289
- # Run seeding kernel with evaluator (no inputs, no outputs)
290
- evaluator.runtime.run_kernel(seed_pfunc, [])
291
-
292
- results = evaluator.evaluate(computation.expr)
293
-
294
- # Store results in session symbols using output_names
295
- if results:
296
- if len(results) != len(output_names):
297
- raise RuntimeError(
298
- f"Expected {len(output_names)} results, got {len(results)}"
299
- )
300
- for name, val in zip(output_names, results, strict=True):
301
- session.symbols[name] = Symbol(name=name, mptype={}, data=val)
302
-
303
-
304
- # Symbol Management
305
- def create_symbol(session_name: str, name: str, mptype: Any, data: Any) -> Symbol:
306
- """Create a symbol in a session's symbol table.
307
-
308
- The `data` is expected to be a base64-encoded pickled Python object.
309
- """
310
- session = get_session(session_name)
311
- if not session:
312
- raise ResourceNotFound(f"Session '{session_name}' not found.")
313
-
314
- # Deserialize base64-encoded data to Python object
315
- try:
316
- data_bytes = base64.b64decode(data)
317
- obj = pickle.loads(data_bytes)
318
- except Exception as e:
319
- raise InvalidRequestError(f"Invalid symbol data encoding: {e!s}") from e
320
-
321
- symbol = Symbol(name, mptype, obj)
322
- session.symbols[name] = symbol
323
- return symbol
324
-
325
-
326
- def get_symbol(session_name: str, name: str) -> Symbol | None:
327
- """Get a symbol from a session's symbol table (session-level only)."""
328
- session = get_session(session_name)
329
- if not session:
330
- return None
331
-
332
- # Only session-level symbols are supported now
333
- return session.symbols.get(name)
334
-
335
-
336
- def list_symbols(session_name: str) -> list[str]:
337
- """List all symbols in a session's symbol table."""
338
- session = get_session(session_name)
339
- if not session:
340
- raise ResourceNotFound(f"Session '{session_name}' not found.")
341
-
342
- # Only session-level symbols are supported now
343
- return list(session.symbols.keys())
344
-
345
-
346
- def delete_symbol(session_name: str, symbol_name: str) -> bool:
347
- """Delete a symbol from a session.
348
-
349
- Returns:
350
- True if symbol was deleted, False if not found.
351
- """
352
- session = get_session(session_name)
353
- if not session:
354
- return False
355
-
356
- if symbol_name in session.symbols:
357
- del session.symbols[symbol_name]
358
- logging.info(f"Symbol {symbol_name} deleted from session {session_name}")
359
- return True
360
- return False
361
-
362
-
363
- def list_all_sessions() -> list[str]:
364
- """List all session names."""
365
- return list(_sessions.keys())