mplang-nightly 0.1.dev158__py3-none-any.whl → 0.1.dev268__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mplang/__init__.py +21 -45
- mplang/py.typed +13 -0
- mplang/v1/__init__.py +157 -0
- mplang/v1/_device.py +602 -0
- mplang/{analysis → v1/analysis}/__init__.py +1 -1
- mplang/{analysis → v1/analysis}/diagram.py +5 -7
- mplang/v1/core/__init__.py +157 -0
- mplang/{core → v1/core}/cluster.py +30 -14
- mplang/{core → v1/core}/comm.py +5 -1
- mplang/{core → v1/core}/context_mgr.py +1 -1
- mplang/{core/dtype.py → v1/core/dtypes.py} +44 -2
- mplang/{core → v1/core}/expr/__init__.py +7 -7
- mplang/{core → v1/core}/expr/ast.py +13 -14
- mplang/{core → v1/core}/expr/evaluator.py +65 -24
- mplang/{core → v1/core}/expr/printer.py +24 -18
- mplang/{core → v1/core}/expr/transformer.py +3 -3
- mplang/{core → v1/core}/expr/utils.py +2 -2
- mplang/{core → v1/core}/expr/visitor.py +1 -1
- mplang/{core → v1/core}/expr/walk.py +1 -1
- mplang/{core → v1/core}/interp.py +6 -6
- mplang/{core → v1/core}/mpir.py +23 -16
- mplang/{core → v1/core}/mpobject.py +6 -6
- mplang/{core → v1/core}/mptype.py +13 -10
- mplang/{core → v1/core}/pfunc.py +4 -4
- mplang/{core → v1/core}/primitive.py +106 -201
- mplang/{core → v1/core}/table.py +36 -8
- mplang/{core → v1/core}/tensor.py +1 -1
- mplang/{core → v1/core}/tracer.py +9 -9
- mplang/{api.py → v1/host.py} +38 -6
- mplang/v1/kernels/__init__.py +41 -0
- mplang/{kernels → v1/kernels}/base.py +1 -1
- mplang/v1/kernels/basic.py +240 -0
- mplang/{kernels → v1/kernels}/context.py +42 -27
- mplang/{kernels → v1/kernels}/crypto.py +44 -37
- mplang/v1/kernels/fhe.py +858 -0
- mplang/{kernels → v1/kernels}/mock_tee.py +12 -13
- mplang/{kernels → v1/kernels}/phe.py +263 -57
- mplang/{kernels → v1/kernels}/spu.py +137 -48
- mplang/{kernels → v1/kernels}/sql_duckdb.py +12 -15
- mplang/{kernels → v1/kernels}/stablehlo.py +30 -23
- mplang/v1/kernels/value.py +626 -0
- mplang/{ops → v1/ops}/__init__.py +5 -16
- mplang/{ops → v1/ops}/base.py +2 -5
- mplang/{ops/builtin.py → v1/ops/basic.py} +34 -26
- mplang/v1/ops/crypto.py +262 -0
- mplang/v1/ops/fhe.py +272 -0
- mplang/{ops → v1/ops}/jax_cc.py +33 -68
- mplang/v1/ops/nnx_cc.py +168 -0
- mplang/{ops → v1/ops}/phe.py +16 -4
- mplang/{ops → v1/ops}/spu.py +3 -5
- mplang/v1/ops/sql_cc.py +303 -0
- mplang/{ops → v1/ops}/tee.py +9 -24
- mplang/{protos → v1/protos}/v1alpha1/mpir_pb2.pyi +71 -21
- mplang/v1/protos/v1alpha1/value_pb2.py +34 -0
- mplang/v1/protos/v1alpha1/value_pb2.pyi +169 -0
- mplang/{runtime → v1/runtime}/__init__.py +2 -2
- mplang/v1/runtime/channel.py +230 -0
- mplang/{runtime → v1/runtime}/cli.py +35 -20
- mplang/{runtime → v1/runtime}/client.py +19 -8
- mplang/{runtime → v1/runtime}/communicator.py +59 -15
- mplang/{runtime → v1/runtime}/data_providers.py +80 -19
- mplang/{runtime → v1/runtime}/driver.py +30 -12
- mplang/v1/runtime/link_comm.py +196 -0
- mplang/{runtime → v1/runtime}/server.py +58 -42
- mplang/{runtime → v1/runtime}/session.py +57 -71
- mplang/{runtime → v1/runtime}/simulation.py +55 -28
- mplang/v1/simp/api.py +353 -0
- mplang/{simp → v1/simp}/mpi.py +8 -9
- mplang/{simp/__init__.py → v1/simp/party.py} +19 -145
- mplang/{simp → v1/simp}/random.py +21 -22
- mplang/v1/simp/smpc.py +238 -0
- mplang/v1/utils/table_utils.py +185 -0
- mplang/v2/__init__.py +424 -0
- mplang/v2/backends/__init__.py +57 -0
- mplang/v2/backends/bfv_impl.py +705 -0
- mplang/v2/backends/channel.py +217 -0
- mplang/v2/backends/crypto_impl.py +723 -0
- mplang/v2/backends/field_impl.py +454 -0
- mplang/v2/backends/func_impl.py +107 -0
- mplang/v2/backends/phe_impl.py +148 -0
- mplang/v2/backends/simp_design.md +136 -0
- mplang/v2/backends/simp_driver/__init__.py +41 -0
- mplang/v2/backends/simp_driver/http.py +168 -0
- mplang/v2/backends/simp_driver/mem.py +280 -0
- mplang/v2/backends/simp_driver/ops.py +135 -0
- mplang/v2/backends/simp_driver/state.py +60 -0
- mplang/v2/backends/simp_driver/values.py +52 -0
- mplang/v2/backends/simp_worker/__init__.py +29 -0
- mplang/v2/backends/simp_worker/http.py +354 -0
- mplang/v2/backends/simp_worker/mem.py +102 -0
- mplang/v2/backends/simp_worker/ops.py +167 -0
- mplang/v2/backends/simp_worker/state.py +49 -0
- mplang/v2/backends/spu_impl.py +275 -0
- mplang/v2/backends/spu_state.py +187 -0
- mplang/v2/backends/store_impl.py +62 -0
- mplang/v2/backends/table_impl.py +838 -0
- mplang/v2/backends/tee_impl.py +215 -0
- mplang/v2/backends/tensor_impl.py +519 -0
- mplang/v2/cli.py +603 -0
- mplang/v2/cli_guide.md +122 -0
- mplang/v2/dialects/__init__.py +36 -0
- mplang/v2/dialects/bfv.py +665 -0
- mplang/v2/dialects/crypto.py +689 -0
- mplang/v2/dialects/dtypes.py +378 -0
- mplang/v2/dialects/field.py +210 -0
- mplang/v2/dialects/func.py +135 -0
- mplang/v2/dialects/phe.py +723 -0
- mplang/v2/dialects/simp.py +944 -0
- mplang/v2/dialects/spu.py +349 -0
- mplang/v2/dialects/store.py +63 -0
- mplang/v2/dialects/table.py +407 -0
- mplang/v2/dialects/tee.py +346 -0
- mplang/v2/dialects/tensor.py +1175 -0
- mplang/v2/edsl/README.md +279 -0
- mplang/v2/edsl/__init__.py +99 -0
- mplang/v2/edsl/context.py +311 -0
- mplang/v2/edsl/graph.py +463 -0
- mplang/v2/edsl/jit.py +62 -0
- mplang/v2/edsl/object.py +53 -0
- mplang/v2/edsl/primitive.py +284 -0
- mplang/v2/edsl/printer.py +119 -0
- mplang/v2/edsl/registry.py +207 -0
- mplang/v2/edsl/serde.py +375 -0
- mplang/v2/edsl/tracer.py +614 -0
- mplang/v2/edsl/typing.py +816 -0
- mplang/v2/kernels/Makefile +30 -0
- mplang/v2/kernels/__init__.py +23 -0
- mplang/v2/kernels/gf128.cpp +148 -0
- mplang/v2/kernels/ldpc.cpp +82 -0
- mplang/v2/kernels/okvs.cpp +283 -0
- mplang/v2/kernels/okvs_opt.cpp +291 -0
- mplang/v2/kernels/py_kernels.py +398 -0
- mplang/v2/libs/collective.py +330 -0
- mplang/v2/libs/device/__init__.py +51 -0
- mplang/v2/libs/device/api.py +813 -0
- mplang/v2/libs/device/cluster.py +352 -0
- mplang/v2/libs/ml/__init__.py +23 -0
- mplang/v2/libs/ml/sgb.py +1861 -0
- mplang/v2/libs/mpc/__init__.py +41 -0
- mplang/v2/libs/mpc/_utils.py +99 -0
- mplang/v2/libs/mpc/analytics/__init__.py +35 -0
- mplang/v2/libs/mpc/analytics/aggregation.py +372 -0
- mplang/v2/libs/mpc/analytics/groupby.md +99 -0
- mplang/v2/libs/mpc/analytics/groupby.py +331 -0
- mplang/v2/libs/mpc/analytics/permutation.py +386 -0
- mplang/v2/libs/mpc/common/constants.py +39 -0
- mplang/v2/libs/mpc/ot/__init__.py +32 -0
- mplang/v2/libs/mpc/ot/base.py +222 -0
- mplang/v2/libs/mpc/ot/extension.py +477 -0
- mplang/v2/libs/mpc/ot/silent.py +217 -0
- mplang/v2/libs/mpc/psi/__init__.py +40 -0
- mplang/v2/libs/mpc/psi/cuckoo.py +228 -0
- mplang/v2/libs/mpc/psi/okvs.py +49 -0
- mplang/v2/libs/mpc/psi/okvs_gct.py +79 -0
- mplang/v2/libs/mpc/psi/oprf.py +310 -0
- mplang/v2/libs/mpc/psi/rr22.py +344 -0
- mplang/v2/libs/mpc/psi/unbalanced.py +200 -0
- mplang/v2/libs/mpc/vole/__init__.py +31 -0
- mplang/v2/libs/mpc/vole/gilboa.py +327 -0
- mplang/v2/libs/mpc/vole/ldpc.py +383 -0
- mplang/v2/libs/mpc/vole/silver.py +336 -0
- mplang/v2/runtime/__init__.py +15 -0
- mplang/v2/runtime/dialect_state.py +41 -0
- mplang/v2/runtime/interpreter.py +871 -0
- mplang/v2/runtime/object_store.py +194 -0
- mplang/v2/runtime/value.py +141 -0
- {mplang_nightly-0.1.dev158.dist-info → mplang_nightly-0.1.dev268.dist-info}/METADATA +24 -17
- mplang_nightly-0.1.dev268.dist-info/RECORD +180 -0
- {mplang_nightly-0.1.dev158.dist-info → mplang_nightly-0.1.dev268.dist-info}/WHEEL +1 -1
- mplang/core/__init__.py +0 -92
- mplang/device.py +0 -340
- mplang/kernels/builtin.py +0 -207
- mplang/ops/crypto.py +0 -109
- mplang/ops/ibis_cc.py +0 -139
- mplang/ops/sql.py +0 -61
- mplang/protos/v1alpha1/mpir_pb2_grpc.py +0 -3
- mplang/runtime/link_comm.py +0 -131
- mplang/simp/smpc.py +0 -201
- mplang/utils/table_utils.py +0 -73
- mplang_nightly-0.1.dev158.dist-info/RECORD +0 -77
- /mplang/{core → v1/core}/mask.py +0 -0
- /mplang/{protos → v1/protos}/v1alpha1/mpir_pb2.py +0 -0
- /mplang/{runtime → v1/runtime}/exceptions.py +0 -0
- /mplang/{runtime → v1/runtime}/http_api.md +0 -0
- /mplang/{kernels → v1/simp}/__init__.py +0 -0
- /mplang/{utils → v1/utils}/__init__.py +0 -0
- /mplang/{utils → v1/utils}/crypto.py +0 -0
- /mplang/{utils → v1/utils}/func_utils.py +0 -0
- /mplang/{utils → v1/utils}/spu_utils.py +0 -0
- {mplang_nightly-0.1.dev158.dist-info → mplang_nightly-0.1.dev268.dist-info}/entry_points.txt +0 -0
- {mplang_nightly-0.1.dev158.dist-info → mplang_nightly-0.1.dev268.dist-info}/licenses/LICENSE +0 -0
|
@@ -22,22 +22,30 @@ import logging
|
|
|
22
22
|
import re
|
|
23
23
|
from typing import Any
|
|
24
24
|
|
|
25
|
-
|
|
26
|
-
|
|
25
|
+
from fastapi import (
|
|
26
|
+
FastAPI,
|
|
27
|
+
HTTPException,
|
|
28
|
+
Request,
|
|
29
|
+
)
|
|
27
30
|
from fastapi.responses import JSONResponse
|
|
28
31
|
from pydantic import BaseModel
|
|
29
32
|
|
|
30
|
-
from mplang.core
|
|
31
|
-
from mplang.core.
|
|
32
|
-
from mplang.
|
|
33
|
-
from mplang.kernels.
|
|
34
|
-
from mplang.protos.v1alpha1 import mpir_pb2
|
|
35
|
-
from mplang.runtime.data_providers import
|
|
36
|
-
|
|
37
|
-
|
|
33
|
+
from mplang.v1.core import IrReader, TableType, TensorType
|
|
34
|
+
from mplang.v1.core.cluster import ClusterSpec
|
|
35
|
+
from mplang.v1.kernels.base import KernelContext
|
|
36
|
+
from mplang.v1.kernels.value import Value, decode_value, encode_value
|
|
37
|
+
from mplang.v1.protos.v1alpha1 import mpir_pb2
|
|
38
|
+
from mplang.v1.runtime.data_providers import (
|
|
39
|
+
DataProvider,
|
|
40
|
+
ResolvedURI,
|
|
41
|
+
register_provider,
|
|
42
|
+
)
|
|
43
|
+
from mplang.v1.runtime.exceptions import InvalidRequestError, ResourceNotFound
|
|
44
|
+
from mplang.v1.runtime.session import (
|
|
38
45
|
Computation,
|
|
39
46
|
Session,
|
|
40
47
|
Symbol,
|
|
48
|
+
create_session_from_spec,
|
|
41
49
|
)
|
|
42
50
|
|
|
43
51
|
logger = logging.getLogger(__name__)
|
|
@@ -112,19 +120,11 @@ class _SymbolsProvider(DataProvider):
|
|
|
112
120
|
ctx: KernelContext,
|
|
113
121
|
) -> None: # type: ignore[override]
|
|
114
122
|
name = self._symbol_name(uri)
|
|
115
|
-
|
|
116
|
-
data_b64 = base64.b64encode(pickle.dumps(value)).decode("utf-8")
|
|
117
|
-
except Exception as e: # pragma: no cover - defensive
|
|
123
|
+
if not isinstance(value, Value):
|
|
118
124
|
raise InvalidRequestError(
|
|
119
|
-
f"
|
|
120
|
-
)
|
|
121
|
-
|
|
122
|
-
obj = pickle.loads(base64.b64decode(data_b64))
|
|
123
|
-
except Exception as e: # pragma: no cover - defensive
|
|
124
|
-
raise InvalidRequestError(
|
|
125
|
-
f"Failed to decode value for symbols:// write: {e!s}"
|
|
126
|
-
) from e
|
|
127
|
-
_global_symbols[name] = Symbol(name=name, mptype={}, data=obj)
|
|
125
|
+
f"symbols:// write expects Value instance, got {type(value)}"
|
|
126
|
+
)
|
|
127
|
+
_global_symbols[name] = Symbol(name=name, mptype={}, data=value)
|
|
128
128
|
|
|
129
129
|
|
|
130
130
|
# Register symbols provider explicitly for server runtime
|
|
@@ -208,17 +208,18 @@ class ComputationResponse(BaseModel):
|
|
|
208
208
|
|
|
209
209
|
class CreateSymbolRequest(BaseModel):
|
|
210
210
|
mptype: dict
|
|
211
|
-
data: str # Base64 encoded data
|
|
211
|
+
data: str # Base64 encoded Value data
|
|
212
212
|
|
|
213
213
|
|
|
214
214
|
class SymbolResponse(BaseModel):
|
|
215
215
|
name: str
|
|
216
216
|
mptype: dict
|
|
217
|
-
data: str
|
|
217
|
+
data: str # Base64 encoded Value data
|
|
218
218
|
|
|
219
219
|
|
|
220
220
|
class CommSendRequest(BaseModel):
|
|
221
|
-
data: str # Base64 encoded data
|
|
221
|
+
data: str # Base64 encoded binary data
|
|
222
|
+
is_raw_bytes: bool = False # True for SPU channel raw bytes
|
|
222
223
|
|
|
223
224
|
|
|
224
225
|
# Response Models for enhanced status
|
|
@@ -233,7 +234,7 @@ class ComputationListResponse(BaseModel):
|
|
|
233
234
|
class GlobalSymbolResponse(BaseModel):
|
|
234
235
|
name: str
|
|
235
236
|
mptype: dict
|
|
236
|
-
data: str
|
|
237
|
+
data: str # Base64 encoded Value data
|
|
237
238
|
|
|
238
239
|
|
|
239
240
|
@app.get("/health")
|
|
@@ -266,15 +267,12 @@ def list_session_computations(session_name: str) -> ComputationListResponse:
|
|
|
266
267
|
def create_session(session_name: str, request: CreateSessionRequest) -> SessionResponse:
|
|
267
268
|
validate_name(session_name, "session")
|
|
268
269
|
# Delegate cluster spec parsing & session construction to resource layer
|
|
269
|
-
from mplang.core.cluster import ClusterSpec # local import to avoid cycles
|
|
270
270
|
|
|
271
271
|
if session_name in _sessions:
|
|
272
272
|
sess = _sessions[session_name]
|
|
273
273
|
else:
|
|
274
274
|
spec = ClusterSpec.from_dict(request.cluster_spec)
|
|
275
|
-
|
|
276
|
-
raise InvalidRequestError("No SPU device found in cluster_spec for session")
|
|
277
|
-
sess = Session(name=session_name, rank=request.rank, cluster_spec=spec)
|
|
275
|
+
sess = create_session_from_spec(name=session_name, rank=request.rank, spec=spec)
|
|
278
276
|
_sessions[session_name] = sess
|
|
279
277
|
return SessionResponse(name=sess.name)
|
|
280
278
|
|
|
@@ -314,7 +312,7 @@ def create_and_execute_computation(
|
|
|
314
312
|
f"Invalid base64 or protobuf for mpprogram: {e!s}"
|
|
315
313
|
) from e
|
|
316
314
|
|
|
317
|
-
reader =
|
|
315
|
+
reader = IrReader()
|
|
318
316
|
expr = reader.loads(graph_proto)
|
|
319
317
|
|
|
320
318
|
if expr is None:
|
|
@@ -359,7 +357,7 @@ def create_session_symbol(
|
|
|
359
357
|
if not sess:
|
|
360
358
|
raise ResourceNotFound(f"Session '{session_name}' not found.")
|
|
361
359
|
try:
|
|
362
|
-
obj =
|
|
360
|
+
obj = decode_value(base64.b64decode(request.data))
|
|
363
361
|
except Exception as e:
|
|
364
362
|
raise InvalidRequestError(f"Invalid symbol data: {e!s}") from e
|
|
365
363
|
symbol = Symbol(name=symbol_name, mptype=request.mptype, data=obj)
|
|
@@ -368,7 +366,7 @@ def create_session_symbol(
|
|
|
368
366
|
return SymbolResponse(
|
|
369
367
|
name=symbol.name,
|
|
370
368
|
mptype=symbol.mptype,
|
|
371
|
-
data=
|
|
369
|
+
data=base64.b64encode(encode_value(symbol.data)).decode("utf-8"),
|
|
372
370
|
)
|
|
373
371
|
|
|
374
372
|
|
|
@@ -388,13 +386,16 @@ def get_session_symbol(session_name: str, symbol_name: str) -> SymbolResponse:
|
|
|
388
386
|
status_code=404, detail=f"Symbol {symbol_name} not found"
|
|
389
387
|
)
|
|
390
388
|
|
|
391
|
-
|
|
392
|
-
|
|
389
|
+
# symbol data is None means this party does not participate the computation
|
|
390
|
+
# that produced the symbol.
|
|
391
|
+
if symbol.data is None:
|
|
392
|
+
raise ResourceNotFound(f"Symbol '{symbol_name}' has no data on this party")
|
|
393
393
|
|
|
394
|
+
# Serialize using Value envelope
|
|
394
395
|
return SymbolResponse(
|
|
395
396
|
name=symbol.name,
|
|
396
397
|
mptype=symbol.mptype,
|
|
397
|
-
data=
|
|
398
|
+
data=base64.b64encode(encode_value(symbol.data)).decode("utf-8"),
|
|
398
399
|
)
|
|
399
400
|
except ValueError as e:
|
|
400
401
|
raise HTTPException(status_code=404, detail=str(e)) from e
|
|
@@ -430,12 +431,16 @@ def create_global_symbol(
|
|
|
430
431
|
) -> GlobalSymbolResponse:
|
|
431
432
|
validate_name(symbol_name, "symbol")
|
|
432
433
|
try:
|
|
433
|
-
obj =
|
|
434
|
+
obj = decode_value(base64.b64decode(request.data))
|
|
434
435
|
except Exception as e:
|
|
435
436
|
raise InvalidRequestError(f"Invalid global symbol data: {e!s}") from e
|
|
436
437
|
sym = Symbol(name=symbol_name, mptype=request.mptype, data=obj)
|
|
437
438
|
_global_symbols[symbol_name] = sym
|
|
438
|
-
return GlobalSymbolResponse(
|
|
439
|
+
return GlobalSymbolResponse(
|
|
440
|
+
name=sym.name,
|
|
441
|
+
mptype=sym.mptype,
|
|
442
|
+
data=base64.b64encode(encode_value(sym.data)).decode("utf-8"),
|
|
443
|
+
)
|
|
439
444
|
|
|
440
445
|
|
|
441
446
|
@app.get("/api/v1/symbols/{symbol_name}", response_model=GlobalSymbolResponse)
|
|
@@ -443,9 +448,12 @@ def get_global_symbol(symbol_name: str) -> GlobalSymbolResponse: # route handle
|
|
|
443
448
|
sym = _global_symbols.get(symbol_name)
|
|
444
449
|
if not sym:
|
|
445
450
|
raise ResourceNotFound(f"Global symbol '{symbol_name}' not found")
|
|
446
|
-
|
|
447
|
-
|
|
448
|
-
|
|
451
|
+
# Serialize using Value envelope
|
|
452
|
+
return GlobalSymbolResponse(
|
|
453
|
+
name=sym.name,
|
|
454
|
+
mptype=sym.mptype,
|
|
455
|
+
data=base64.b64encode(encode_value(sym.data)).decode("utf-8"),
|
|
456
|
+
)
|
|
449
457
|
|
|
450
458
|
|
|
451
459
|
@app.get("/api/v1/symbols")
|
|
@@ -480,6 +488,14 @@ def comm_send(
|
|
|
480
488
|
# The receiver rank should be the rank of the server hosting this endpoint
|
|
481
489
|
# We don't need to validate to_rank since the request is coming to this server
|
|
482
490
|
|
|
491
|
+
# For raw bytes (SPU channel), pass through as dict with flag
|
|
492
|
+
# For normal data, pass the base64 string directly
|
|
493
|
+
data_payload: str | dict[str, object]
|
|
494
|
+
if request.is_raw_bytes:
|
|
495
|
+
data_payload = {"data": request.data, "is_raw_bytes": True}
|
|
496
|
+
else:
|
|
497
|
+
data_payload = request.data
|
|
498
|
+
|
|
483
499
|
# Use the proper onSent mechanism from CommunicatorBase
|
|
484
|
-
sess.communicator.onSent(from_rank, key,
|
|
500
|
+
sess.communicator.onSent(from_rank, key, data_payload)
|
|
485
501
|
return {"status": "ok"}
|
|
@@ -25,48 +25,28 @@ Process-wide registries (sessions, global symbols) live in the server layer
|
|
|
25
25
|
|
|
26
26
|
from __future__ import annotations
|
|
27
27
|
|
|
28
|
-
import logging
|
|
29
28
|
import time
|
|
30
29
|
from dataclasses import dataclass, field
|
|
31
30
|
from functools import cached_property
|
|
32
31
|
from typing import TYPE_CHECKING, Any, cast
|
|
33
|
-
from urllib.parse import urlparse
|
|
34
32
|
|
|
35
33
|
import spu.libspu as libspu
|
|
36
34
|
|
|
37
|
-
from mplang.core.
|
|
38
|
-
from mplang.core.
|
|
39
|
-
from mplang.core.
|
|
40
|
-
from mplang.
|
|
41
|
-
from mplang.
|
|
42
|
-
from mplang.
|
|
43
|
-
from mplang.
|
|
44
|
-
from mplang.
|
|
45
|
-
from mplang.
|
|
35
|
+
from mplang.v1.core.cluster import ClusterSpec
|
|
36
|
+
from mplang.v1.core.comm import ICommunicator
|
|
37
|
+
from mplang.v1.core.expr.ast import Expr
|
|
38
|
+
from mplang.v1.core.expr.evaluator import IEvaluator, create_evaluator
|
|
39
|
+
from mplang.v1.core.mask import Mask
|
|
40
|
+
from mplang.v1.kernels.context import RuntimeContext
|
|
41
|
+
from mplang.v1.kernels.spu import PFunction # type: ignore
|
|
42
|
+
from mplang.v1.kernels.value import Value
|
|
43
|
+
from mplang.v1.runtime.communicator import HttpCommunicator
|
|
44
|
+
from mplang.v1.runtime.exceptions import ResourceNotFound
|
|
45
|
+
from mplang.v1.runtime.link_comm import LinkCommunicator
|
|
46
|
+
from mplang.v1.utils.spu_utils import parse_field, parse_protocol
|
|
46
47
|
|
|
47
48
|
if TYPE_CHECKING: # pragma: no cover - import only for type checking
|
|
48
|
-
from mplang.core.cluster import ClusterSpec, Node, RuntimeInfo
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
class LinkCommFactory:
|
|
52
|
-
"""Factory for creating and caching link communicators."""
|
|
53
|
-
|
|
54
|
-
def __init__(self) -> None:
|
|
55
|
-
self._cache: dict[tuple[int, tuple[str, ...]], LinkCommunicator] = {}
|
|
56
|
-
|
|
57
|
-
def create_link(self, rel_rank: int, addrs: list[str]) -> LinkCommunicator:
|
|
58
|
-
key = (rel_rank, tuple(addrs))
|
|
59
|
-
link = self._cache.get(key)
|
|
60
|
-
if link is not None:
|
|
61
|
-
return link
|
|
62
|
-
logging.info(f"LinkCommunicator created: rel_rank={rel_rank} addrs={addrs}")
|
|
63
|
-
link = LinkCommunicator(rel_rank, addrs)
|
|
64
|
-
self._cache[key] = link
|
|
65
|
-
return link
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
# Shared link factory (module-local, not global registry of sessions)
|
|
69
|
-
g_link_factory = LinkCommFactory()
|
|
49
|
+
from mplang.v1.core.cluster import ClusterSpec, Node, RuntimeInfo
|
|
70
50
|
|
|
71
51
|
|
|
72
52
|
@dataclass
|
|
@@ -95,19 +75,25 @@ class SessionState:
|
|
|
95
75
|
class Session:
|
|
96
76
|
"""Represents the per-rank execution context.
|
|
97
77
|
|
|
98
|
-
Immutable config: name, rank, cluster_spec.
|
|
78
|
+
Immutable config: name, rank, cluster_spec, communicator.
|
|
99
79
|
Derived: node, runtime_info, endpoints, spu_device, spu_mask, protocol/field, is_spu_party.
|
|
100
80
|
Mutable: state (runtime object, symbols, computations, seeded flag).
|
|
81
|
+
|
|
82
|
+
Note: communicator is assumed to be initialized with cluster spec info (e.g. endpoints).
|
|
101
83
|
"""
|
|
102
84
|
|
|
103
|
-
def __init__(
|
|
85
|
+
def __init__(
|
|
86
|
+
self,
|
|
87
|
+
name: str,
|
|
88
|
+
rank: int,
|
|
89
|
+
cluster_spec: ClusterSpec,
|
|
90
|
+
communicator: ICommunicator,
|
|
91
|
+
):
|
|
104
92
|
self.name = name
|
|
105
93
|
self.rank = rank
|
|
106
94
|
self.cluster_spec = cluster_spec
|
|
107
95
|
self.state = SessionState()
|
|
108
|
-
self.communicator =
|
|
109
|
-
session_name=name, rank=rank, endpoints=self.endpoints
|
|
110
|
-
)
|
|
96
|
+
self.communicator = communicator
|
|
111
97
|
|
|
112
98
|
# --- Derived topology ---
|
|
113
99
|
@cached_property
|
|
@@ -118,18 +104,9 @@ class Session:
|
|
|
118
104
|
def runtime_info(self) -> RuntimeInfo:
|
|
119
105
|
return self.node.runtime_info
|
|
120
106
|
|
|
121
|
-
@
|
|
107
|
+
@property
|
|
122
108
|
def endpoints(self) -> list[str]:
|
|
123
|
-
|
|
124
|
-
for n in sorted(
|
|
125
|
-
self.cluster_spec.nodes.values(),
|
|
126
|
-
key=lambda x: x.rank, # type: ignore[attr-defined]
|
|
127
|
-
):
|
|
128
|
-
ep = n.endpoint
|
|
129
|
-
if not ep.startswith(("http://", "https://")):
|
|
130
|
-
ep = f"http://{ep}"
|
|
131
|
-
eps.append(ep)
|
|
132
|
-
return eps
|
|
109
|
+
return self.cluster_spec.endpoints
|
|
133
110
|
|
|
134
111
|
@cached_property
|
|
135
112
|
def spu_device(self): # type: ignore
|
|
@@ -184,22 +161,19 @@ class Session:
|
|
|
184
161
|
return
|
|
185
162
|
|
|
186
163
|
link_ctx = None
|
|
187
|
-
# TODO(jint): reuse same port for mplang and spu.
|
|
188
|
-
SPU_PORT_OFFSET = 100
|
|
189
164
|
|
|
190
165
|
if self.is_spu_party:
|
|
191
|
-
#
|
|
192
|
-
|
|
193
|
-
|
|
194
|
-
|
|
195
|
-
|
|
196
|
-
|
|
197
|
-
|
|
198
|
-
|
|
199
|
-
|
|
200
|
-
|
|
201
|
-
|
|
202
|
-
link_ctx = g_link_factory.create_link(rel_index, spu_addrs)
|
|
166
|
+
# Use Channels mode to reuse existing HttpCommunicator
|
|
167
|
+
# This eliminates the need for separate BRPC ports (SPU_PORT_OFFSET)
|
|
168
|
+
from mplang.v1.core.comm import CommunicatorBase
|
|
169
|
+
|
|
170
|
+
# Type assertion: ICommunicator is actually CommunicatorBase
|
|
171
|
+
comm = cast(CommunicatorBase, self.communicator)
|
|
172
|
+
link_ctx = LinkCommunicator(
|
|
173
|
+
rank=self.rank,
|
|
174
|
+
comm=comm,
|
|
175
|
+
spu_mask=self.spu_mask,
|
|
176
|
+
)
|
|
203
177
|
|
|
204
178
|
spu_config = libspu.RuntimeConfig(
|
|
205
179
|
protocol=parse_protocol(self.spu_protocol),
|
|
@@ -271,14 +245,26 @@ class Session:
|
|
|
271
245
|
f"Expected {len(output_names)} results, got {len(results)}"
|
|
272
246
|
)
|
|
273
247
|
for name, val in zip(output_names, results, strict=True):
|
|
248
|
+
# In pure SIMP model, all nodes should have the same symbol table.
|
|
249
|
+
# Non-participating nodes get None values.
|
|
250
|
+
if val is not None and not isinstance(val, Value):
|
|
251
|
+
raise TypeError(
|
|
252
|
+
"Session executions must produce kernel Value outputs; "
|
|
253
|
+
f"got {type(val).__name__} for symbol '{name}'"
|
|
254
|
+
)
|
|
274
255
|
self.add_symbol(Symbol(name=name, mptype={}, data=val))
|
|
275
256
|
|
|
276
|
-
# --- Convenience constructor ---
|
|
277
|
-
@classmethod
|
|
278
|
-
def from_cluster_spec_dict(cls, name: str, rank: int, spec_dict: dict) -> Session:
|
|
279
|
-
from mplang.core.cluster import ClusterSpec # local import to avoid cycles
|
|
280
257
|
|
|
281
|
-
|
|
282
|
-
|
|
283
|
-
|
|
284
|
-
|
|
258
|
+
# --- Convenience constructor use HttpCommunicator---
|
|
259
|
+
def create_session_from_spec(name: str, rank: int, spec: ClusterSpec) -> Session:
|
|
260
|
+
if len(spec.get_devices_by_kind("SPU")) == 0:
|
|
261
|
+
raise RuntimeError("No SPU device found in cluster_spec")
|
|
262
|
+
|
|
263
|
+
# Create HttpCommunicator for the session
|
|
264
|
+
communicator = HttpCommunicator(
|
|
265
|
+
session_name=name,
|
|
266
|
+
rank=rank,
|
|
267
|
+
endpoints=spec.endpoints,
|
|
268
|
+
)
|
|
269
|
+
|
|
270
|
+
return Session(name=name, rank=rank, cluster_spec=spec, communicator=communicator)
|
|
@@ -18,25 +18,32 @@ import concurrent.futures
|
|
|
18
18
|
import faulthandler
|
|
19
19
|
import logging
|
|
20
20
|
import sys
|
|
21
|
+
import threading
|
|
21
22
|
import traceback
|
|
22
23
|
from collections.abc import Sequence
|
|
23
24
|
from typing import Any, cast
|
|
24
25
|
|
|
25
26
|
import spu.libspu as libspu
|
|
26
27
|
|
|
27
|
-
from mplang.core
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
|
|
28
|
+
from mplang.v1.core import (
|
|
29
|
+
ClusterSpec,
|
|
30
|
+
CollectiveMixin,
|
|
31
|
+
CommunicatorBase,
|
|
32
|
+
InterpContext,
|
|
33
|
+
InterpVar,
|
|
34
|
+
IrReader,
|
|
35
|
+
IrWriter,
|
|
36
|
+
Mask,
|
|
37
|
+
MPObject,
|
|
38
|
+
MPType,
|
|
39
|
+
PFunction, # for spu.seed_env kernel seeding
|
|
40
|
+
TensorLike,
|
|
41
|
+
)
|
|
42
|
+
from mplang.v1.core.expr.ast import Expr
|
|
43
|
+
from mplang.v1.core.expr.evaluator import IEvaluator, create_evaluator
|
|
44
|
+
from mplang.v1.kernels.context import RuntimeContext
|
|
45
|
+
from mplang.v1.runtime.link_comm import LinkCommunicator
|
|
46
|
+
from mplang.v1.utils.spu_utils import parse_field, parse_protocol
|
|
40
47
|
|
|
41
48
|
|
|
42
49
|
class ThreadCommunicator(CommunicatorBase, CollectiveMixin):
|
|
@@ -73,8 +80,8 @@ class SimVar(InterpVar):
|
|
|
73
80
|
|
|
74
81
|
@property
|
|
75
82
|
def values(self) -> list[Any]:
|
|
76
|
-
"""
|
|
77
|
-
return self._values
|
|
83
|
+
"""Converted values across all ranks for user inspection."""
|
|
84
|
+
return [v.to_numpy() if hasattr(v, "to_numpy") else v for v in self._values]
|
|
78
85
|
|
|
79
86
|
def __repr__(self) -> str:
|
|
80
87
|
return f"SimVar({self.mptype})"
|
|
@@ -123,16 +130,37 @@ class Simulator(InterpContext):
|
|
|
123
130
|
comm.set_peers(self._comms)
|
|
124
131
|
|
|
125
132
|
# Prepare link contexts for SPU parties (store for evaluator-time initialization)
|
|
126
|
-
|
|
133
|
+
# Use Channels mode to reuse ThreadCommunicator instead of separate mem_link
|
|
127
134
|
self._spu_link_ctxs: list[LinkCommunicator | None] = [None] * world_size
|
|
128
|
-
|
|
129
|
-
|
|
130
|
-
|
|
135
|
+
|
|
136
|
+
# Create LinkCommunicators in parallel to avoid deadlock
|
|
137
|
+
# (create_with_channels does handshake via TestSend/TestRecv)
|
|
138
|
+
exceptions: dict[int, Exception] = {}
|
|
139
|
+
|
|
140
|
+
def create_link(g_rank: int) -> None:
|
|
141
|
+
try:
|
|
142
|
+
self._spu_link_ctxs[g_rank] = LinkCommunicator(
|
|
143
|
+
rank=g_rank,
|
|
144
|
+
comm=self._comms[g_rank],
|
|
145
|
+
spu_mask=spu_mask,
|
|
146
|
+
)
|
|
147
|
+
except Exception as e:
|
|
148
|
+
exceptions[g_rank] = e
|
|
149
|
+
|
|
150
|
+
threads = [
|
|
151
|
+
threading.Thread(target=create_link, args=(g_rank,)) for g_rank in spu_mask
|
|
131
152
|
]
|
|
132
|
-
for
|
|
133
|
-
|
|
134
|
-
|
|
135
|
-
|
|
153
|
+
for t in threads:
|
|
154
|
+
t.start()
|
|
155
|
+
for t in threads:
|
|
156
|
+
t.join()
|
|
157
|
+
|
|
158
|
+
# Check for exceptions during link creation
|
|
159
|
+
if exceptions:
|
|
160
|
+
first_exc = next(iter(exceptions.values()))
|
|
161
|
+
raise RuntimeError(
|
|
162
|
+
f"Failed to create SPU link contexts for ranks {list(exceptions.keys())}"
|
|
163
|
+
) from first_exc
|
|
136
164
|
|
|
137
165
|
self._spu_runtime_cfg = libspu.RuntimeConfig(
|
|
138
166
|
protocol=spu_protocol, field=spu_field
|
|
@@ -187,10 +215,10 @@ class Simulator(InterpContext):
|
|
|
187
215
|
This exposes potential MPIR serialization bugs by forcing expressions
|
|
188
216
|
to go through the full serialize->deserialize cycle.
|
|
189
217
|
"""
|
|
190
|
-
writer =
|
|
218
|
+
writer = IrWriter()
|
|
191
219
|
graph_proto = writer.dumps(expr)
|
|
192
220
|
|
|
193
|
-
reader =
|
|
221
|
+
reader = IrReader()
|
|
194
222
|
deserialized_expr = reader.loads(graph_proto)
|
|
195
223
|
|
|
196
224
|
if deserialized_expr is None:
|
|
@@ -202,8 +230,7 @@ class Simulator(InterpContext):
|
|
|
202
230
|
def fetch(self, obj: MPObject) -> list[TensorLike]:
|
|
203
231
|
if not isinstance(obj, SimVar):
|
|
204
232
|
raise ValueError(f"Expected SimVar, got {type(obj)}")
|
|
205
|
-
|
|
206
|
-
return list(obj.values)
|
|
233
|
+
return [v.to_numpy() if hasattr(v, "to_numpy") else v for v in obj._values]
|
|
207
234
|
|
|
208
235
|
# override
|
|
209
236
|
def evaluate(self, expr: Expr, bindings: dict[str, MPObject]) -> Sequence[MPObject]:
|
|
@@ -213,7 +240,7 @@ class Simulator(InterpContext):
|
|
|
213
240
|
raise ValueError(f"Variable {name} not in this context, got {var.ctx}.")
|
|
214
241
|
|
|
215
242
|
pts_env = [
|
|
216
|
-
{name: cast(SimVar, var).
|
|
243
|
+
{name: cast(SimVar, var)._values[rank] for name, var in bindings.items()}
|
|
217
244
|
for rank in range(self.world_size())
|
|
218
245
|
]
|
|
219
246
|
|