mplang-nightly 0.1.dev158__py3-none-any.whl → 0.1.dev268__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (191) hide show
  1. mplang/__init__.py +21 -45
  2. mplang/py.typed +13 -0
  3. mplang/v1/__init__.py +157 -0
  4. mplang/v1/_device.py +602 -0
  5. mplang/{analysis → v1/analysis}/__init__.py +1 -1
  6. mplang/{analysis → v1/analysis}/diagram.py +5 -7
  7. mplang/v1/core/__init__.py +157 -0
  8. mplang/{core → v1/core}/cluster.py +30 -14
  9. mplang/{core → v1/core}/comm.py +5 -1
  10. mplang/{core → v1/core}/context_mgr.py +1 -1
  11. mplang/{core/dtype.py → v1/core/dtypes.py} +44 -2
  12. mplang/{core → v1/core}/expr/__init__.py +7 -7
  13. mplang/{core → v1/core}/expr/ast.py +13 -14
  14. mplang/{core → v1/core}/expr/evaluator.py +65 -24
  15. mplang/{core → v1/core}/expr/printer.py +24 -18
  16. mplang/{core → v1/core}/expr/transformer.py +3 -3
  17. mplang/{core → v1/core}/expr/utils.py +2 -2
  18. mplang/{core → v1/core}/expr/visitor.py +1 -1
  19. mplang/{core → v1/core}/expr/walk.py +1 -1
  20. mplang/{core → v1/core}/interp.py +6 -6
  21. mplang/{core → v1/core}/mpir.py +23 -16
  22. mplang/{core → v1/core}/mpobject.py +6 -6
  23. mplang/{core → v1/core}/mptype.py +13 -10
  24. mplang/{core → v1/core}/pfunc.py +4 -4
  25. mplang/{core → v1/core}/primitive.py +106 -201
  26. mplang/{core → v1/core}/table.py +36 -8
  27. mplang/{core → v1/core}/tensor.py +1 -1
  28. mplang/{core → v1/core}/tracer.py +9 -9
  29. mplang/{api.py → v1/host.py} +38 -6
  30. mplang/v1/kernels/__init__.py +41 -0
  31. mplang/{kernels → v1/kernels}/base.py +1 -1
  32. mplang/v1/kernels/basic.py +240 -0
  33. mplang/{kernels → v1/kernels}/context.py +42 -27
  34. mplang/{kernels → v1/kernels}/crypto.py +44 -37
  35. mplang/v1/kernels/fhe.py +858 -0
  36. mplang/{kernels → v1/kernels}/mock_tee.py +12 -13
  37. mplang/{kernels → v1/kernels}/phe.py +263 -57
  38. mplang/{kernels → v1/kernels}/spu.py +137 -48
  39. mplang/{kernels → v1/kernels}/sql_duckdb.py +12 -15
  40. mplang/{kernels → v1/kernels}/stablehlo.py +30 -23
  41. mplang/v1/kernels/value.py +626 -0
  42. mplang/{ops → v1/ops}/__init__.py +5 -16
  43. mplang/{ops → v1/ops}/base.py +2 -5
  44. mplang/{ops/builtin.py → v1/ops/basic.py} +34 -26
  45. mplang/v1/ops/crypto.py +262 -0
  46. mplang/v1/ops/fhe.py +272 -0
  47. mplang/{ops → v1/ops}/jax_cc.py +33 -68
  48. mplang/v1/ops/nnx_cc.py +168 -0
  49. mplang/{ops → v1/ops}/phe.py +16 -4
  50. mplang/{ops → v1/ops}/spu.py +3 -5
  51. mplang/v1/ops/sql_cc.py +303 -0
  52. mplang/{ops → v1/ops}/tee.py +9 -24
  53. mplang/{protos → v1/protos}/v1alpha1/mpir_pb2.pyi +71 -21
  54. mplang/v1/protos/v1alpha1/value_pb2.py +34 -0
  55. mplang/v1/protos/v1alpha1/value_pb2.pyi +169 -0
  56. mplang/{runtime → v1/runtime}/__init__.py +2 -2
  57. mplang/v1/runtime/channel.py +230 -0
  58. mplang/{runtime → v1/runtime}/cli.py +35 -20
  59. mplang/{runtime → v1/runtime}/client.py +19 -8
  60. mplang/{runtime → v1/runtime}/communicator.py +59 -15
  61. mplang/{runtime → v1/runtime}/data_providers.py +80 -19
  62. mplang/{runtime → v1/runtime}/driver.py +30 -12
  63. mplang/v1/runtime/link_comm.py +196 -0
  64. mplang/{runtime → v1/runtime}/server.py +58 -42
  65. mplang/{runtime → v1/runtime}/session.py +57 -71
  66. mplang/{runtime → v1/runtime}/simulation.py +55 -28
  67. mplang/v1/simp/api.py +353 -0
  68. mplang/{simp → v1/simp}/mpi.py +8 -9
  69. mplang/{simp/__init__.py → v1/simp/party.py} +19 -145
  70. mplang/{simp → v1/simp}/random.py +21 -22
  71. mplang/v1/simp/smpc.py +238 -0
  72. mplang/v1/utils/table_utils.py +185 -0
  73. mplang/v2/__init__.py +424 -0
  74. mplang/v2/backends/__init__.py +57 -0
  75. mplang/v2/backends/bfv_impl.py +705 -0
  76. mplang/v2/backends/channel.py +217 -0
  77. mplang/v2/backends/crypto_impl.py +723 -0
  78. mplang/v2/backends/field_impl.py +454 -0
  79. mplang/v2/backends/func_impl.py +107 -0
  80. mplang/v2/backends/phe_impl.py +148 -0
  81. mplang/v2/backends/simp_design.md +136 -0
  82. mplang/v2/backends/simp_driver/__init__.py +41 -0
  83. mplang/v2/backends/simp_driver/http.py +168 -0
  84. mplang/v2/backends/simp_driver/mem.py +280 -0
  85. mplang/v2/backends/simp_driver/ops.py +135 -0
  86. mplang/v2/backends/simp_driver/state.py +60 -0
  87. mplang/v2/backends/simp_driver/values.py +52 -0
  88. mplang/v2/backends/simp_worker/__init__.py +29 -0
  89. mplang/v2/backends/simp_worker/http.py +354 -0
  90. mplang/v2/backends/simp_worker/mem.py +102 -0
  91. mplang/v2/backends/simp_worker/ops.py +167 -0
  92. mplang/v2/backends/simp_worker/state.py +49 -0
  93. mplang/v2/backends/spu_impl.py +275 -0
  94. mplang/v2/backends/spu_state.py +187 -0
  95. mplang/v2/backends/store_impl.py +62 -0
  96. mplang/v2/backends/table_impl.py +838 -0
  97. mplang/v2/backends/tee_impl.py +215 -0
  98. mplang/v2/backends/tensor_impl.py +519 -0
  99. mplang/v2/cli.py +603 -0
  100. mplang/v2/cli_guide.md +122 -0
  101. mplang/v2/dialects/__init__.py +36 -0
  102. mplang/v2/dialects/bfv.py +665 -0
  103. mplang/v2/dialects/crypto.py +689 -0
  104. mplang/v2/dialects/dtypes.py +378 -0
  105. mplang/v2/dialects/field.py +210 -0
  106. mplang/v2/dialects/func.py +135 -0
  107. mplang/v2/dialects/phe.py +723 -0
  108. mplang/v2/dialects/simp.py +944 -0
  109. mplang/v2/dialects/spu.py +349 -0
  110. mplang/v2/dialects/store.py +63 -0
  111. mplang/v2/dialects/table.py +407 -0
  112. mplang/v2/dialects/tee.py +346 -0
  113. mplang/v2/dialects/tensor.py +1175 -0
  114. mplang/v2/edsl/README.md +279 -0
  115. mplang/v2/edsl/__init__.py +99 -0
  116. mplang/v2/edsl/context.py +311 -0
  117. mplang/v2/edsl/graph.py +463 -0
  118. mplang/v2/edsl/jit.py +62 -0
  119. mplang/v2/edsl/object.py +53 -0
  120. mplang/v2/edsl/primitive.py +284 -0
  121. mplang/v2/edsl/printer.py +119 -0
  122. mplang/v2/edsl/registry.py +207 -0
  123. mplang/v2/edsl/serde.py +375 -0
  124. mplang/v2/edsl/tracer.py +614 -0
  125. mplang/v2/edsl/typing.py +816 -0
  126. mplang/v2/kernels/Makefile +30 -0
  127. mplang/v2/kernels/__init__.py +23 -0
  128. mplang/v2/kernels/gf128.cpp +148 -0
  129. mplang/v2/kernels/ldpc.cpp +82 -0
  130. mplang/v2/kernels/okvs.cpp +283 -0
  131. mplang/v2/kernels/okvs_opt.cpp +291 -0
  132. mplang/v2/kernels/py_kernels.py +398 -0
  133. mplang/v2/libs/collective.py +330 -0
  134. mplang/v2/libs/device/__init__.py +51 -0
  135. mplang/v2/libs/device/api.py +813 -0
  136. mplang/v2/libs/device/cluster.py +352 -0
  137. mplang/v2/libs/ml/__init__.py +23 -0
  138. mplang/v2/libs/ml/sgb.py +1861 -0
  139. mplang/v2/libs/mpc/__init__.py +41 -0
  140. mplang/v2/libs/mpc/_utils.py +99 -0
  141. mplang/v2/libs/mpc/analytics/__init__.py +35 -0
  142. mplang/v2/libs/mpc/analytics/aggregation.py +372 -0
  143. mplang/v2/libs/mpc/analytics/groupby.md +99 -0
  144. mplang/v2/libs/mpc/analytics/groupby.py +331 -0
  145. mplang/v2/libs/mpc/analytics/permutation.py +386 -0
  146. mplang/v2/libs/mpc/common/constants.py +39 -0
  147. mplang/v2/libs/mpc/ot/__init__.py +32 -0
  148. mplang/v2/libs/mpc/ot/base.py +222 -0
  149. mplang/v2/libs/mpc/ot/extension.py +477 -0
  150. mplang/v2/libs/mpc/ot/silent.py +217 -0
  151. mplang/v2/libs/mpc/psi/__init__.py +40 -0
  152. mplang/v2/libs/mpc/psi/cuckoo.py +228 -0
  153. mplang/v2/libs/mpc/psi/okvs.py +49 -0
  154. mplang/v2/libs/mpc/psi/okvs_gct.py +79 -0
  155. mplang/v2/libs/mpc/psi/oprf.py +310 -0
  156. mplang/v2/libs/mpc/psi/rr22.py +344 -0
  157. mplang/v2/libs/mpc/psi/unbalanced.py +200 -0
  158. mplang/v2/libs/mpc/vole/__init__.py +31 -0
  159. mplang/v2/libs/mpc/vole/gilboa.py +327 -0
  160. mplang/v2/libs/mpc/vole/ldpc.py +383 -0
  161. mplang/v2/libs/mpc/vole/silver.py +336 -0
  162. mplang/v2/runtime/__init__.py +15 -0
  163. mplang/v2/runtime/dialect_state.py +41 -0
  164. mplang/v2/runtime/interpreter.py +871 -0
  165. mplang/v2/runtime/object_store.py +194 -0
  166. mplang/v2/runtime/value.py +141 -0
  167. {mplang_nightly-0.1.dev158.dist-info → mplang_nightly-0.1.dev268.dist-info}/METADATA +24 -17
  168. mplang_nightly-0.1.dev268.dist-info/RECORD +180 -0
  169. {mplang_nightly-0.1.dev158.dist-info → mplang_nightly-0.1.dev268.dist-info}/WHEEL +1 -1
  170. mplang/core/__init__.py +0 -92
  171. mplang/device.py +0 -340
  172. mplang/kernels/builtin.py +0 -207
  173. mplang/ops/crypto.py +0 -109
  174. mplang/ops/ibis_cc.py +0 -139
  175. mplang/ops/sql.py +0 -61
  176. mplang/protos/v1alpha1/mpir_pb2_grpc.py +0 -3
  177. mplang/runtime/link_comm.py +0 -131
  178. mplang/simp/smpc.py +0 -201
  179. mplang/utils/table_utils.py +0 -73
  180. mplang_nightly-0.1.dev158.dist-info/RECORD +0 -77
  181. /mplang/{core → v1/core}/mask.py +0 -0
  182. /mplang/{protos → v1/protos}/v1alpha1/mpir_pb2.py +0 -0
  183. /mplang/{runtime → v1/runtime}/exceptions.py +0 -0
  184. /mplang/{runtime → v1/runtime}/http_api.md +0 -0
  185. /mplang/{kernels → v1/simp}/__init__.py +0 -0
  186. /mplang/{utils → v1/utils}/__init__.py +0 -0
  187. /mplang/{utils → v1/utils}/crypto.py +0 -0
  188. /mplang/{utils → v1/utils}/func_utils.py +0 -0
  189. /mplang/{utils → v1/utils}/spu_utils.py +0 -0
  190. {mplang_nightly-0.1.dev158.dist-info → mplang_nightly-0.1.dev268.dist-info}/entry_points.txt +0 -0
  191. {mplang_nightly-0.1.dev158.dist-info → mplang_nightly-0.1.dev268.dist-info}/licenses/LICENSE +0 -0
@@ -22,22 +22,30 @@ import logging
22
22
  import re
23
23
  from typing import Any
24
24
 
25
- import cloudpickle as pickle
26
- from fastapi import FastAPI, HTTPException, Request
25
+ from fastapi import (
26
+ FastAPI,
27
+ HTTPException,
28
+ Request,
29
+ )
27
30
  from fastapi.responses import JSONResponse
28
31
  from pydantic import BaseModel
29
32
 
30
- from mplang.core.mpir import Reader
31
- from mplang.core.table import TableType
32
- from mplang.core.tensor import TensorType
33
- from mplang.kernels.base import KernelContext
34
- from mplang.protos.v1alpha1 import mpir_pb2
35
- from mplang.runtime.data_providers import DataProvider, ResolvedURI, register_provider
36
- from mplang.runtime.exceptions import InvalidRequestError, ResourceNotFound
37
- from mplang.runtime.session import (
33
+ from mplang.v1.core import IrReader, TableType, TensorType
34
+ from mplang.v1.core.cluster import ClusterSpec
35
+ from mplang.v1.kernels.base import KernelContext
36
+ from mplang.v1.kernels.value import Value, decode_value, encode_value
37
+ from mplang.v1.protos.v1alpha1 import mpir_pb2
38
+ from mplang.v1.runtime.data_providers import (
39
+ DataProvider,
40
+ ResolvedURI,
41
+ register_provider,
42
+ )
43
+ from mplang.v1.runtime.exceptions import InvalidRequestError, ResourceNotFound
44
+ from mplang.v1.runtime.session import (
38
45
  Computation,
39
46
  Session,
40
47
  Symbol,
48
+ create_session_from_spec,
41
49
  )
42
50
 
43
51
  logger = logging.getLogger(__name__)
@@ -112,19 +120,11 @@ class _SymbolsProvider(DataProvider):
112
120
  ctx: KernelContext,
113
121
  ) -> None: # type: ignore[override]
114
122
  name = self._symbol_name(uri)
115
- try:
116
- data_b64 = base64.b64encode(pickle.dumps(value)).decode("utf-8")
117
- except Exception as e: # pragma: no cover - defensive
123
+ if not isinstance(value, Value):
118
124
  raise InvalidRequestError(
119
- f"Failed to encode value for symbols:// write: {e!s}"
120
- ) from e
121
- try:
122
- obj = pickle.loads(base64.b64decode(data_b64))
123
- except Exception as e: # pragma: no cover - defensive
124
- raise InvalidRequestError(
125
- f"Failed to decode value for symbols:// write: {e!s}"
126
- ) from e
127
- _global_symbols[name] = Symbol(name=name, mptype={}, data=obj)
125
+ f"symbols:// write expects Value instance, got {type(value)}"
126
+ )
127
+ _global_symbols[name] = Symbol(name=name, mptype={}, data=value)
128
128
 
129
129
 
130
130
  # Register symbols provider explicitly for server runtime
@@ -208,17 +208,18 @@ class ComputationResponse(BaseModel):
208
208
 
209
209
  class CreateSymbolRequest(BaseModel):
210
210
  mptype: dict
211
- data: str # Base64 encoded data
211
+ data: str # Base64 encoded Value data
212
212
 
213
213
 
214
214
  class SymbolResponse(BaseModel):
215
215
  name: str
216
216
  mptype: dict
217
- data: str
217
+ data: str # Base64 encoded Value data
218
218
 
219
219
 
220
220
  class CommSendRequest(BaseModel):
221
- data: str # Base64 encoded data
221
+ data: str # Base64 encoded binary data
222
+ is_raw_bytes: bool = False # True for SPU channel raw bytes
222
223
 
223
224
 
224
225
  # Response Models for enhanced status
@@ -233,7 +234,7 @@ class ComputationListResponse(BaseModel):
233
234
  class GlobalSymbolResponse(BaseModel):
234
235
  name: str
235
236
  mptype: dict
236
- data: str
237
+ data: str # Base64 encoded Value data
237
238
 
238
239
 
239
240
  @app.get("/health")
@@ -266,15 +267,12 @@ def list_session_computations(session_name: str) -> ComputationListResponse:
266
267
  def create_session(session_name: str, request: CreateSessionRequest) -> SessionResponse:
267
268
  validate_name(session_name, "session")
268
269
  # Delegate cluster spec parsing & session construction to resource layer
269
- from mplang.core.cluster import ClusterSpec # local import to avoid cycles
270
270
 
271
271
  if session_name in _sessions:
272
272
  sess = _sessions[session_name]
273
273
  else:
274
274
  spec = ClusterSpec.from_dict(request.cluster_spec)
275
- if len(spec.get_devices_by_kind("SPU")) == 0:
276
- raise InvalidRequestError("No SPU device found in cluster_spec for session")
277
- sess = Session(name=session_name, rank=request.rank, cluster_spec=spec)
275
+ sess = create_session_from_spec(name=session_name, rank=request.rank, spec=spec)
278
276
  _sessions[session_name] = sess
279
277
  return SessionResponse(name=sess.name)
280
278
 
@@ -314,7 +312,7 @@ def create_and_execute_computation(
314
312
  f"Invalid base64 or protobuf for mpprogram: {e!s}"
315
313
  ) from e
316
314
 
317
- reader = Reader()
315
+ reader = IrReader()
318
316
  expr = reader.loads(graph_proto)
319
317
 
320
318
  if expr is None:
@@ -359,7 +357,7 @@ def create_session_symbol(
359
357
  if not sess:
360
358
  raise ResourceNotFound(f"Session '{session_name}' not found.")
361
359
  try:
362
- obj = pickle.loads(base64.b64decode(request.data))
360
+ obj = decode_value(base64.b64decode(request.data))
363
361
  except Exception as e:
364
362
  raise InvalidRequestError(f"Invalid symbol data: {e!s}") from e
365
363
  symbol = Symbol(name=symbol_name, mptype=request.mptype, data=obj)
@@ -368,7 +366,7 @@ def create_session_symbol(
368
366
  return SymbolResponse(
369
367
  name=symbol.name,
370
368
  mptype=symbol.mptype,
371
- data=request.data,
369
+ data=base64.b64encode(encode_value(symbol.data)).decode("utf-8"),
372
370
  )
373
371
 
374
372
 
@@ -388,13 +386,16 @@ def get_session_symbol(session_name: str, symbol_name: str) -> SymbolResponse:
388
386
  status_code=404, detail=f"Symbol {symbol_name} not found"
389
387
  )
390
388
 
391
- data_bytes = pickle.dumps(symbol.data)
392
- data_b64 = base64.b64encode(data_bytes).decode("utf-8")
389
+ # symbol data is None means this party does not participate the computation
390
+ # that produced the symbol.
391
+ if symbol.data is None:
392
+ raise ResourceNotFound(f"Symbol '{symbol_name}' has no data on this party")
393
393
 
394
+ # Serialize using Value envelope
394
395
  return SymbolResponse(
395
396
  name=symbol.name,
396
397
  mptype=symbol.mptype,
397
- data=data_b64,
398
+ data=base64.b64encode(encode_value(symbol.data)).decode("utf-8"),
398
399
  )
399
400
  except ValueError as e:
400
401
  raise HTTPException(status_code=404, detail=str(e)) from e
@@ -430,12 +431,16 @@ def create_global_symbol(
430
431
  ) -> GlobalSymbolResponse:
431
432
  validate_name(symbol_name, "symbol")
432
433
  try:
433
- obj = pickle.loads(base64.b64decode(request.data))
434
+ obj = decode_value(base64.b64decode(request.data))
434
435
  except Exception as e:
435
436
  raise InvalidRequestError(f"Invalid global symbol data: {e!s}") from e
436
437
  sym = Symbol(name=symbol_name, mptype=request.mptype, data=obj)
437
438
  _global_symbols[symbol_name] = sym
438
- return GlobalSymbolResponse(name=sym.name, mptype=sym.mptype, data=request.data)
439
+ return GlobalSymbolResponse(
440
+ name=sym.name,
441
+ mptype=sym.mptype,
442
+ data=base64.b64encode(encode_value(sym.data)).decode("utf-8"),
443
+ )
439
444
 
440
445
 
441
446
  @app.get("/api/v1/symbols/{symbol_name}", response_model=GlobalSymbolResponse)
@@ -443,9 +448,12 @@ def get_global_symbol(symbol_name: str) -> GlobalSymbolResponse: # route handle
443
448
  sym = _global_symbols.get(symbol_name)
444
449
  if not sym:
445
450
  raise ResourceNotFound(f"Global symbol '{symbol_name}' not found")
446
- data_bytes = pickle.dumps(sym.data)
447
- data_b64 = base64.b64encode(data_bytes).decode("utf-8")
448
- return GlobalSymbolResponse(name=sym.name, mptype=sym.mptype, data=data_b64)
451
+ # Serialize using Value envelope
452
+ return GlobalSymbolResponse(
453
+ name=sym.name,
454
+ mptype=sym.mptype,
455
+ data=base64.b64encode(encode_value(sym.data)).decode("utf-8"),
456
+ )
449
457
 
450
458
 
451
459
  @app.get("/api/v1/symbols")
@@ -480,6 +488,14 @@ def comm_send(
480
488
  # The receiver rank should be the rank of the server hosting this endpoint
481
489
  # We don't need to validate to_rank since the request is coming to this server
482
490
 
491
+ # For raw bytes (SPU channel), pass through as dict with flag
492
+ # For normal data, pass the base64 string directly
493
+ data_payload: str | dict[str, object]
494
+ if request.is_raw_bytes:
495
+ data_payload = {"data": request.data, "is_raw_bytes": True}
496
+ else:
497
+ data_payload = request.data
498
+
483
499
  # Use the proper onSent mechanism from CommunicatorBase
484
- sess.communicator.onSent(from_rank, key, request.data)
500
+ sess.communicator.onSent(from_rank, key, data_payload)
485
501
  return {"status": "ok"}
@@ -25,48 +25,28 @@ Process-wide registries (sessions, global symbols) live in the server layer
25
25
 
26
26
  from __future__ import annotations
27
27
 
28
- import logging
29
28
  import time
30
29
  from dataclasses import dataclass, field
31
30
  from functools import cached_property
32
31
  from typing import TYPE_CHECKING, Any, cast
33
- from urllib.parse import urlparse
34
32
 
35
33
  import spu.libspu as libspu
36
34
 
37
- from mplang.core.expr.ast import Expr
38
- from mplang.core.expr.evaluator import IEvaluator, create_evaluator
39
- from mplang.core.mask import Mask
40
- from mplang.kernels.context import RuntimeContext
41
- from mplang.kernels.spu import PFunction # type: ignore
42
- from mplang.runtime.communicator import HttpCommunicator
43
- from mplang.runtime.exceptions import ResourceNotFound
44
- from mplang.runtime.link_comm import LinkCommunicator
45
- from mplang.utils.spu_utils import parse_field, parse_protocol
35
+ from mplang.v1.core.cluster import ClusterSpec
36
+ from mplang.v1.core.comm import ICommunicator
37
+ from mplang.v1.core.expr.ast import Expr
38
+ from mplang.v1.core.expr.evaluator import IEvaluator, create_evaluator
39
+ from mplang.v1.core.mask import Mask
40
+ from mplang.v1.kernels.context import RuntimeContext
41
+ from mplang.v1.kernels.spu import PFunction # type: ignore
42
+ from mplang.v1.kernels.value import Value
43
+ from mplang.v1.runtime.communicator import HttpCommunicator
44
+ from mplang.v1.runtime.exceptions import ResourceNotFound
45
+ from mplang.v1.runtime.link_comm import LinkCommunicator
46
+ from mplang.v1.utils.spu_utils import parse_field, parse_protocol
46
47
 
47
48
  if TYPE_CHECKING: # pragma: no cover - import only for type checking
48
- from mplang.core.cluster import ClusterSpec, Node, RuntimeInfo
49
-
50
-
51
- class LinkCommFactory:
52
- """Factory for creating and caching link communicators."""
53
-
54
- def __init__(self) -> None:
55
- self._cache: dict[tuple[int, tuple[str, ...]], LinkCommunicator] = {}
56
-
57
- def create_link(self, rel_rank: int, addrs: list[str]) -> LinkCommunicator:
58
- key = (rel_rank, tuple(addrs))
59
- link = self._cache.get(key)
60
- if link is not None:
61
- return link
62
- logging.info(f"LinkCommunicator created: rel_rank={rel_rank} addrs={addrs}")
63
- link = LinkCommunicator(rel_rank, addrs)
64
- self._cache[key] = link
65
- return link
66
-
67
-
68
- # Shared link factory (module-local, not global registry of sessions)
69
- g_link_factory = LinkCommFactory()
49
+ from mplang.v1.core.cluster import ClusterSpec, Node, RuntimeInfo
70
50
 
71
51
 
72
52
  @dataclass
@@ -95,19 +75,25 @@ class SessionState:
95
75
  class Session:
96
76
  """Represents the per-rank execution context.
97
77
 
98
- Immutable config: name, rank, cluster_spec.
78
+ Immutable config: name, rank, cluster_spec, communicator.
99
79
  Derived: node, runtime_info, endpoints, spu_device, spu_mask, protocol/field, is_spu_party.
100
80
  Mutable: state (runtime object, symbols, computations, seeded flag).
81
+
82
+ Note: communicator is assumed to be initialized with cluster spec info (e.g. endpoints).
101
83
  """
102
84
 
103
- def __init__(self, name: str, rank: int, cluster_spec: ClusterSpec):
85
+ def __init__(
86
+ self,
87
+ name: str,
88
+ rank: int,
89
+ cluster_spec: ClusterSpec,
90
+ communicator: ICommunicator,
91
+ ):
104
92
  self.name = name
105
93
  self.rank = rank
106
94
  self.cluster_spec = cluster_spec
107
95
  self.state = SessionState()
108
- self.communicator = HttpCommunicator(
109
- session_name=name, rank=rank, endpoints=self.endpoints
110
- )
96
+ self.communicator = communicator
111
97
 
112
98
  # --- Derived topology ---
113
99
  @cached_property
@@ -118,18 +104,9 @@ class Session:
118
104
  def runtime_info(self) -> RuntimeInfo:
119
105
  return self.node.runtime_info
120
106
 
121
- @cached_property
107
+ @property
122
108
  def endpoints(self) -> list[str]:
123
- eps: list[str] = []
124
- for n in sorted(
125
- self.cluster_spec.nodes.values(),
126
- key=lambda x: x.rank, # type: ignore[attr-defined]
127
- ):
128
- ep = n.endpoint
129
- if not ep.startswith(("http://", "https://")):
130
- ep = f"http://{ep}"
131
- eps.append(ep)
132
- return eps
109
+ return self.cluster_spec.endpoints
133
110
 
134
111
  @cached_property
135
112
  def spu_device(self): # type: ignore
@@ -184,22 +161,19 @@ class Session:
184
161
  return
185
162
 
186
163
  link_ctx = None
187
- # TODO(jint): reuse same port for mplang and spu.
188
- SPU_PORT_OFFSET = 100
189
164
 
190
165
  if self.is_spu_party:
191
- # Build SPU address list across all endpoints for ranks in mask
192
- spu_addrs: list[str] = []
193
- for r, addr in enumerate(self.communicator.endpoints):
194
- if r in self.spu_mask:
195
- if "//" not in addr:
196
- addr = f"//{addr}"
197
- parsed = urlparse(addr)
198
- assert isinstance(parsed.port, int)
199
- new_addr = f"{parsed.hostname}:{parsed.port + SPU_PORT_OFFSET}"
200
- spu_addrs.append(new_addr)
201
- rel_index = sum(1 for r in range(self.rank) if r in self.spu_mask)
202
- link_ctx = g_link_factory.create_link(rel_index, spu_addrs)
166
+ # Use Channels mode to reuse existing HttpCommunicator
167
+ # This eliminates the need for separate BRPC ports (SPU_PORT_OFFSET)
168
+ from mplang.v1.core.comm import CommunicatorBase
169
+
170
+ # Type assertion: ICommunicator is actually CommunicatorBase
171
+ comm = cast(CommunicatorBase, self.communicator)
172
+ link_ctx = LinkCommunicator(
173
+ rank=self.rank,
174
+ comm=comm,
175
+ spu_mask=self.spu_mask,
176
+ )
203
177
 
204
178
  spu_config = libspu.RuntimeConfig(
205
179
  protocol=parse_protocol(self.spu_protocol),
@@ -271,14 +245,26 @@ class Session:
271
245
  f"Expected {len(output_names)} results, got {len(results)}"
272
246
  )
273
247
  for name, val in zip(output_names, results, strict=True):
248
+ # In pure SIMP model, all nodes should have the same symbol table.
249
+ # Non-participating nodes get None values.
250
+ if val is not None and not isinstance(val, Value):
251
+ raise TypeError(
252
+ "Session executions must produce kernel Value outputs; "
253
+ f"got {type(val).__name__} for symbol '{name}'"
254
+ )
274
255
  self.add_symbol(Symbol(name=name, mptype={}, data=val))
275
256
 
276
- # --- Convenience constructor ---
277
- @classmethod
278
- def from_cluster_spec_dict(cls, name: str, rank: int, spec_dict: dict) -> Session:
279
- from mplang.core.cluster import ClusterSpec # local import to avoid cycles
280
257
 
281
- spec = ClusterSpec.from_dict(spec_dict)
282
- if len(spec.get_devices_by_kind("SPU")) == 0:
283
- raise RuntimeError("No SPU device found in cluster_spec")
284
- return cls(name=name, rank=rank, cluster_spec=spec)
258
+ # --- Convenience constructor use HttpCommunicator---
259
+ def create_session_from_spec(name: str, rank: int, spec: ClusterSpec) -> Session:
260
+ if len(spec.get_devices_by_kind("SPU")) == 0:
261
+ raise RuntimeError("No SPU device found in cluster_spec")
262
+
263
+ # Create HttpCommunicator for the session
264
+ communicator = HttpCommunicator(
265
+ session_name=name,
266
+ rank=rank,
267
+ endpoints=spec.endpoints,
268
+ )
269
+
270
+ return Session(name=name, rank=rank, cluster_spec=spec, communicator=communicator)
@@ -18,25 +18,32 @@ import concurrent.futures
18
18
  import faulthandler
19
19
  import logging
20
20
  import sys
21
+ import threading
21
22
  import traceback
22
23
  from collections.abc import Sequence
23
24
  from typing import Any, cast
24
25
 
25
26
  import spu.libspu as libspu
26
27
 
27
- from mplang.core.cluster import ClusterSpec
28
- from mplang.core.comm import CollectiveMixin, CommunicatorBase
29
- from mplang.core.expr.ast import Expr
30
- from mplang.core.expr.evaluator import IEvaluator, create_evaluator
31
- from mplang.core.interp import InterpContext, InterpVar
32
- from mplang.core.mask import Mask
33
- from mplang.core.mpir import Reader, Writer
34
- from mplang.core.mpobject import MPObject
35
- from mplang.core.mptype import MPType, TensorLike
36
- from mplang.core.pfunc import PFunction # for spu.seed_env kernel seeding
37
- from mplang.kernels.context import RuntimeContext
38
- from mplang.runtime.link_comm import LinkCommunicator
39
- from mplang.utils.spu_utils import parse_field, parse_protocol
28
+ from mplang.v1.core import (
29
+ ClusterSpec,
30
+ CollectiveMixin,
31
+ CommunicatorBase,
32
+ InterpContext,
33
+ InterpVar,
34
+ IrReader,
35
+ IrWriter,
36
+ Mask,
37
+ MPObject,
38
+ MPType,
39
+ PFunction, # for spu.seed_env kernel seeding
40
+ TensorLike,
41
+ )
42
+ from mplang.v1.core.expr.ast import Expr
43
+ from mplang.v1.core.expr.evaluator import IEvaluator, create_evaluator
44
+ from mplang.v1.kernels.context import RuntimeContext
45
+ from mplang.v1.runtime.link_comm import LinkCommunicator
46
+ from mplang.v1.utils.spu_utils import parse_field, parse_protocol
40
47
 
41
48
 
42
49
  class ThreadCommunicator(CommunicatorBase, CollectiveMixin):
@@ -73,8 +80,8 @@ class SimVar(InterpVar):
73
80
 
74
81
  @property
75
82
  def values(self) -> list[Any]:
76
- """The values of this variable across all ranks."""
77
- return self._values
83
+ """Converted values across all ranks for user inspection."""
84
+ return [v.to_numpy() if hasattr(v, "to_numpy") else v for v in self._values]
78
85
 
79
86
  def __repr__(self) -> str:
80
87
  return f"SimVar({self.mptype})"
@@ -123,16 +130,37 @@ class Simulator(InterpContext):
123
130
  comm.set_peers(self._comms)
124
131
 
125
132
  # Prepare link contexts for SPU parties (store for evaluator-time initialization)
126
- spu_addrs = [f"P{spu_rank}" for spu_rank in spu_mask]
133
+ # Use Channels mode to reuse ThreadCommunicator instead of separate mem_link
127
134
  self._spu_link_ctxs: list[LinkCommunicator | None] = [None] * world_size
128
- link_ctx_list = [
129
- LinkCommunicator(idx, spu_addrs, mem_link=True)
130
- for idx in range(spu_mask.num_parties())
135
+
136
+ # Create LinkCommunicators in parallel to avoid deadlock
137
+ # (create_with_channels does handshake via TestSend/TestRecv)
138
+ exceptions: dict[int, Exception] = {}
139
+
140
+ def create_link(g_rank: int) -> None:
141
+ try:
142
+ self._spu_link_ctxs[g_rank] = LinkCommunicator(
143
+ rank=g_rank,
144
+ comm=self._comms[g_rank],
145
+ spu_mask=spu_mask,
146
+ )
147
+ except Exception as e:
148
+ exceptions[g_rank] = e
149
+
150
+ threads = [
151
+ threading.Thread(target=create_link, args=(g_rank,)) for g_rank in spu_mask
131
152
  ]
132
- for g_rank in range(world_size):
133
- if g_rank in spu_mask:
134
- rel = Mask(spu_mask).global_to_relative_rank(g_rank)
135
- self._spu_link_ctxs[g_rank] = link_ctx_list[rel]
153
+ for t in threads:
154
+ t.start()
155
+ for t in threads:
156
+ t.join()
157
+
158
+ # Check for exceptions during link creation
159
+ if exceptions:
160
+ first_exc = next(iter(exceptions.values()))
161
+ raise RuntimeError(
162
+ f"Failed to create SPU link contexts for ranks {list(exceptions.keys())}"
163
+ ) from first_exc
136
164
 
137
165
  self._spu_runtime_cfg = libspu.RuntimeConfig(
138
166
  protocol=spu_protocol, field=spu_field
@@ -187,10 +215,10 @@ class Simulator(InterpContext):
187
215
  This exposes potential MPIR serialization bugs by forcing expressions
188
216
  to go through the full serialize->deserialize cycle.
189
217
  """
190
- writer = Writer()
218
+ writer = IrWriter()
191
219
  graph_proto = writer.dumps(expr)
192
220
 
193
- reader = Reader()
221
+ reader = IrReader()
194
222
  deserialized_expr = reader.loads(graph_proto)
195
223
 
196
224
  if deserialized_expr is None:
@@ -202,8 +230,7 @@ class Simulator(InterpContext):
202
230
  def fetch(self, obj: MPObject) -> list[TensorLike]:
203
231
  if not isinstance(obj, SimVar):
204
232
  raise ValueError(f"Expected SimVar, got {type(obj)}")
205
-
206
- return list(obj.values)
233
+ return [v.to_numpy() if hasattr(v, "to_numpy") else v for v in obj._values]
207
234
 
208
235
  # override
209
236
  def evaluate(self, expr: Expr, bindings: dict[str, MPObject]) -> Sequence[MPObject]:
@@ -213,7 +240,7 @@ class Simulator(InterpContext):
213
240
  raise ValueError(f"Variable {name} not in this context, got {var.ctx}.")
214
241
 
215
242
  pts_env = [
216
- {name: cast(SimVar, var).values[rank] for name, var in bindings.items()}
243
+ {name: cast(SimVar, var)._values[rank] for name, var in bindings.items()}
217
244
  for rank in range(self.world_size())
218
245
  ]
219
246