mplang-nightly 0.1.dev192__py3-none-any.whl → 0.1.dev268__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mplang/__init__.py +21 -130
- mplang/py.typed +13 -0
- mplang/v1/__init__.py +157 -0
- mplang/v1/_device.py +602 -0
- mplang/{analysis → v1/analysis}/__init__.py +1 -1
- mplang/{analysis → v1/analysis}/diagram.py +4 -4
- mplang/{core → v1/core}/__init__.py +20 -14
- mplang/{core → v1/core}/cluster.py +6 -1
- mplang/{core → v1/core}/comm.py +1 -1
- mplang/{core → v1/core}/context_mgr.py +1 -1
- mplang/{core → v1/core}/dtypes.py +38 -0
- mplang/{core → v1/core}/expr/__init__.py +7 -7
- mplang/{core → v1/core}/expr/ast.py +11 -13
- mplang/{core → v1/core}/expr/evaluator.py +8 -8
- mplang/{core → v1/core}/expr/printer.py +6 -6
- mplang/{core → v1/core}/expr/transformer.py +2 -2
- mplang/{core → v1/core}/expr/utils.py +2 -2
- mplang/{core → v1/core}/expr/visitor.py +1 -1
- mplang/{core → v1/core}/expr/walk.py +1 -1
- mplang/{core → v1/core}/interp.py +6 -6
- mplang/{core → v1/core}/mpir.py +13 -11
- mplang/{core → v1/core}/mpobject.py +6 -6
- mplang/{core → v1/core}/mptype.py +13 -10
- mplang/{core → v1/core}/pfunc.py +2 -2
- mplang/{core → v1/core}/primitive.py +12 -12
- mplang/{core → v1/core}/table.py +36 -8
- mplang/{core → v1/core}/tensor.py +1 -1
- mplang/{core → v1/core}/tracer.py +9 -9
- mplang/{host.py → v1/host.py} +5 -5
- mplang/{kernels → v1/kernels}/__init__.py +1 -1
- mplang/{kernels → v1/kernels}/base.py +1 -1
- mplang/{kernels → v1/kernels}/basic.py +15 -15
- mplang/{kernels → v1/kernels}/context.py +19 -16
- mplang/{kernels → v1/kernels}/crypto.py +8 -10
- mplang/{kernels → v1/kernels}/fhe.py +9 -7
- mplang/{kernels → v1/kernels}/mock_tee.py +3 -3
- mplang/{kernels → v1/kernels}/phe.py +26 -18
- mplang/{kernels → v1/kernels}/spu.py +5 -5
- mplang/{kernels → v1/kernels}/sql_duckdb.py +5 -3
- mplang/{kernels → v1/kernels}/stablehlo.py +18 -17
- mplang/{kernels → v1/kernels}/value.py +2 -2
- mplang/{ops → v1/ops}/__init__.py +3 -3
- mplang/{ops → v1/ops}/base.py +1 -1
- mplang/{ops → v1/ops}/basic.py +6 -5
- mplang/v1/ops/crypto.py +262 -0
- mplang/{ops → v1/ops}/fhe.py +2 -2
- mplang/{ops → v1/ops}/jax_cc.py +26 -59
- mplang/v1/ops/nnx_cc.py +168 -0
- mplang/{ops → v1/ops}/phe.py +16 -3
- mplang/{ops → v1/ops}/spu.py +3 -3
- mplang/v1/ops/sql_cc.py +303 -0
- mplang/{ops → v1/ops}/tee.py +2 -2
- mplang/{runtime → v1/runtime}/__init__.py +2 -2
- mplang/v1/runtime/channel.py +230 -0
- mplang/{runtime → v1/runtime}/cli.py +3 -3
- mplang/{runtime → v1/runtime}/client.py +1 -1
- mplang/{runtime → v1/runtime}/communicator.py +39 -15
- mplang/{runtime → v1/runtime}/data_providers.py +80 -19
- mplang/{runtime → v1/runtime}/driver.py +4 -4
- mplang/v1/runtime/link_comm.py +196 -0
- mplang/{runtime → v1/runtime}/server.py +22 -9
- mplang/{runtime → v1/runtime}/session.py +24 -51
- mplang/{runtime → v1/runtime}/simulation.py +36 -14
- mplang/{simp → v1/simp}/api.py +72 -14
- mplang/{simp → v1/simp}/mpi.py +1 -1
- mplang/{simp → v1/simp}/party.py +5 -5
- mplang/{simp → v1/simp}/random.py +2 -2
- mplang/v1/simp/smpc.py +238 -0
- mplang/v1/utils/table_utils.py +185 -0
- mplang/v2/__init__.py +424 -0
- mplang/v2/backends/__init__.py +57 -0
- mplang/v2/backends/bfv_impl.py +705 -0
- mplang/v2/backends/channel.py +217 -0
- mplang/v2/backends/crypto_impl.py +723 -0
- mplang/v2/backends/field_impl.py +454 -0
- mplang/v2/backends/func_impl.py +107 -0
- mplang/v2/backends/phe_impl.py +148 -0
- mplang/v2/backends/simp_design.md +136 -0
- mplang/v2/backends/simp_driver/__init__.py +41 -0
- mplang/v2/backends/simp_driver/http.py +168 -0
- mplang/v2/backends/simp_driver/mem.py +280 -0
- mplang/v2/backends/simp_driver/ops.py +135 -0
- mplang/v2/backends/simp_driver/state.py +60 -0
- mplang/v2/backends/simp_driver/values.py +52 -0
- mplang/v2/backends/simp_worker/__init__.py +29 -0
- mplang/v2/backends/simp_worker/http.py +354 -0
- mplang/v2/backends/simp_worker/mem.py +102 -0
- mplang/v2/backends/simp_worker/ops.py +167 -0
- mplang/v2/backends/simp_worker/state.py +49 -0
- mplang/v2/backends/spu_impl.py +275 -0
- mplang/v2/backends/spu_state.py +187 -0
- mplang/v2/backends/store_impl.py +62 -0
- mplang/v2/backends/table_impl.py +838 -0
- mplang/v2/backends/tee_impl.py +215 -0
- mplang/v2/backends/tensor_impl.py +519 -0
- mplang/v2/cli.py +603 -0
- mplang/v2/cli_guide.md +122 -0
- mplang/v2/dialects/__init__.py +36 -0
- mplang/v2/dialects/bfv.py +665 -0
- mplang/v2/dialects/crypto.py +689 -0
- mplang/v2/dialects/dtypes.py +378 -0
- mplang/v2/dialects/field.py +210 -0
- mplang/v2/dialects/func.py +135 -0
- mplang/v2/dialects/phe.py +723 -0
- mplang/v2/dialects/simp.py +944 -0
- mplang/v2/dialects/spu.py +349 -0
- mplang/v2/dialects/store.py +63 -0
- mplang/v2/dialects/table.py +407 -0
- mplang/v2/dialects/tee.py +346 -0
- mplang/v2/dialects/tensor.py +1175 -0
- mplang/v2/edsl/README.md +279 -0
- mplang/v2/edsl/__init__.py +99 -0
- mplang/v2/edsl/context.py +311 -0
- mplang/v2/edsl/graph.py +463 -0
- mplang/v2/edsl/jit.py +62 -0
- mplang/v2/edsl/object.py +53 -0
- mplang/v2/edsl/primitive.py +284 -0
- mplang/v2/edsl/printer.py +119 -0
- mplang/v2/edsl/registry.py +207 -0
- mplang/v2/edsl/serde.py +375 -0
- mplang/v2/edsl/tracer.py +614 -0
- mplang/v2/edsl/typing.py +816 -0
- mplang/v2/kernels/Makefile +30 -0
- mplang/v2/kernels/__init__.py +23 -0
- mplang/v2/kernels/gf128.cpp +148 -0
- mplang/v2/kernels/ldpc.cpp +82 -0
- mplang/v2/kernels/okvs.cpp +283 -0
- mplang/v2/kernels/okvs_opt.cpp +291 -0
- mplang/v2/kernels/py_kernels.py +398 -0
- mplang/v2/libs/collective.py +330 -0
- mplang/v2/libs/device/__init__.py +51 -0
- mplang/v2/libs/device/api.py +813 -0
- mplang/v2/libs/device/cluster.py +352 -0
- mplang/v2/libs/ml/__init__.py +23 -0
- mplang/v2/libs/ml/sgb.py +1861 -0
- mplang/v2/libs/mpc/__init__.py +41 -0
- mplang/v2/libs/mpc/_utils.py +99 -0
- mplang/v2/libs/mpc/analytics/__init__.py +35 -0
- mplang/v2/libs/mpc/analytics/aggregation.py +372 -0
- mplang/v2/libs/mpc/analytics/groupby.md +99 -0
- mplang/v2/libs/mpc/analytics/groupby.py +331 -0
- mplang/v2/libs/mpc/analytics/permutation.py +386 -0
- mplang/v2/libs/mpc/common/constants.py +39 -0
- mplang/v2/libs/mpc/ot/__init__.py +32 -0
- mplang/v2/libs/mpc/ot/base.py +222 -0
- mplang/v2/libs/mpc/ot/extension.py +477 -0
- mplang/v2/libs/mpc/ot/silent.py +217 -0
- mplang/v2/libs/mpc/psi/__init__.py +40 -0
- mplang/v2/libs/mpc/psi/cuckoo.py +228 -0
- mplang/v2/libs/mpc/psi/okvs.py +49 -0
- mplang/v2/libs/mpc/psi/okvs_gct.py +79 -0
- mplang/v2/libs/mpc/psi/oprf.py +310 -0
- mplang/v2/libs/mpc/psi/rr22.py +344 -0
- mplang/v2/libs/mpc/psi/unbalanced.py +200 -0
- mplang/v2/libs/mpc/vole/__init__.py +31 -0
- mplang/v2/libs/mpc/vole/gilboa.py +327 -0
- mplang/v2/libs/mpc/vole/ldpc.py +383 -0
- mplang/v2/libs/mpc/vole/silver.py +336 -0
- mplang/v2/runtime/__init__.py +15 -0
- mplang/v2/runtime/dialect_state.py +41 -0
- mplang/v2/runtime/interpreter.py +871 -0
- mplang/v2/runtime/object_store.py +194 -0
- mplang/v2/runtime/value.py +141 -0
- {mplang_nightly-0.1.dev192.dist-info → mplang_nightly-0.1.dev268.dist-info}/METADATA +22 -16
- mplang_nightly-0.1.dev268.dist-info/RECORD +180 -0
- {mplang_nightly-0.1.dev192.dist-info → mplang_nightly-0.1.dev268.dist-info}/WHEEL +1 -1
- mplang/device.py +0 -327
- mplang/ops/crypto.py +0 -108
- mplang/ops/ibis_cc.py +0 -136
- mplang/ops/sql_cc.py +0 -62
- mplang/runtime/link_comm.py +0 -78
- mplang/simp/smpc.py +0 -201
- mplang/utils/table_utils.py +0 -85
- mplang_nightly-0.1.dev192.dist-info/RECORD +0 -83
- /mplang/{core → v1/core}/mask.py +0 -0
- /mplang/{protos → v1/protos}/v1alpha1/mpir_pb2.py +0 -0
- /mplang/{protos → v1/protos}/v1alpha1/mpir_pb2.pyi +0 -0
- /mplang/{protos → v1/protos}/v1alpha1/value_pb2.py +0 -0
- /mplang/{protos → v1/protos}/v1alpha1/value_pb2.pyi +0 -0
- /mplang/{runtime → v1/runtime}/exceptions.py +0 -0
- /mplang/{runtime → v1/runtime}/http_api.md +0 -0
- /mplang/{simp → v1/simp}/__init__.py +0 -0
- /mplang/{utils → v1/utils}/__init__.py +0 -0
- /mplang/{utils → v1/utils}/crypto.py +0 -0
- /mplang/{utils → v1/utils}/func_utils.py +0 -0
- /mplang/{utils → v1/utils}/spu_utils.py +0 -0
- {mplang_nightly-0.1.dev192.dist-info → mplang_nightly-0.1.dev268.dist-info}/entry_points.txt +0 -0
- {mplang_nightly-0.1.dev192.dist-info → mplang_nightly-0.1.dev268.dist-info}/licenses/LICENSE +0 -0
|
@@ -14,16 +14,24 @@
|
|
|
14
14
|
|
|
15
15
|
from __future__ import annotations
|
|
16
16
|
|
|
17
|
+
import pathlib
|
|
18
|
+
import struct
|
|
17
19
|
from dataclasses import dataclass
|
|
18
20
|
from typing import Any
|
|
19
21
|
from urllib.parse import ParseResult, urlparse
|
|
20
22
|
|
|
21
23
|
import numpy as np
|
|
22
|
-
import pandas as pd
|
|
23
24
|
|
|
24
|
-
from mplang.core import TableType, TensorType
|
|
25
|
-
from mplang.kernels.base import KernelContext
|
|
26
|
-
from mplang.
|
|
25
|
+
from mplang.v1.core import TableLike, TableType, TensorType
|
|
26
|
+
from mplang.v1.kernels.base import KernelContext
|
|
27
|
+
from mplang.v1.kernels.value import (
|
|
28
|
+
TableValue,
|
|
29
|
+
TensorValue,
|
|
30
|
+
Value,
|
|
31
|
+
decode_value,
|
|
32
|
+
encode_value,
|
|
33
|
+
)
|
|
34
|
+
from mplang.v1.utils import table_utils
|
|
27
35
|
|
|
28
36
|
|
|
29
37
|
@dataclass(frozen=True)
|
|
@@ -135,6 +143,11 @@ def get_provider(scheme: str) -> DataProvider | None:
|
|
|
135
143
|
|
|
136
144
|
|
|
137
145
|
# ---------------- Default Providers ----------------
|
|
146
|
+
MAGIC_MPLANG = b"MPLG"
|
|
147
|
+
MAGIC_PARQUET = b"PAR1"
|
|
148
|
+
MAGIC_ORC = b"ORC"
|
|
149
|
+
MAGIC_NUMPY = b"\x93NUMPY"
|
|
150
|
+
VERSION = 0x01
|
|
138
151
|
|
|
139
152
|
|
|
140
153
|
class FileProvider(DataProvider):
|
|
@@ -147,14 +160,52 @@ class FileProvider(DataProvider):
|
|
|
147
160
|
def read(
|
|
148
161
|
self, uri: ResolvedURI, out_spec: TensorType | TableType, *, ctx: KernelContext
|
|
149
162
|
) -> Any:
|
|
150
|
-
path = uri.local_path or uri.raw
|
|
151
|
-
|
|
152
|
-
|
|
153
|
-
|
|
154
|
-
#
|
|
155
|
-
|
|
156
|
-
|
|
157
|
-
|
|
163
|
+
path = pathlib.Path(uri.local_path or uri.raw)
|
|
164
|
+
# try load by magic
|
|
165
|
+
with path.open("rb") as f:
|
|
166
|
+
# this is the maximum length needed to detect all supported formats
|
|
167
|
+
# (numpy requires 6 bytes: '\x93NUMPY').
|
|
168
|
+
MAGIC_BYTES_LEN_MAX = 6
|
|
169
|
+
magic = f.read(MAGIC_BYTES_LEN_MAX)
|
|
170
|
+
f.seek(0)
|
|
171
|
+
if magic.startswith(MAGIC_MPLANG):
|
|
172
|
+
MPLANG_HEADER_LEN = len(MAGIC_MPLANG) + 1
|
|
173
|
+
header = f.read(MPLANG_HEADER_LEN)
|
|
174
|
+
_, version = struct.unpack(">4sB", header)
|
|
175
|
+
if version != VERSION:
|
|
176
|
+
raise ValueError(f"unsupported mplang version {version}")
|
|
177
|
+
payload = f.read()
|
|
178
|
+
return decode_value(payload)
|
|
179
|
+
elif magic.startswith(MAGIC_PARQUET):
|
|
180
|
+
if not isinstance(out_spec, TableType):
|
|
181
|
+
raise ValueError(
|
|
182
|
+
f"PARQUET files require TableType output spec, got {type(out_spec).__name__}"
|
|
183
|
+
)
|
|
184
|
+
return table_utils.read_table(
|
|
185
|
+
f, format="parquet", columns=list(out_spec.column_names())
|
|
186
|
+
)
|
|
187
|
+
elif magic.startswith(MAGIC_ORC):
|
|
188
|
+
if not isinstance(out_spec, TableType):
|
|
189
|
+
raise ValueError(
|
|
190
|
+
f"ORC files require TableType output spec, got {type(out_spec).__name__}"
|
|
191
|
+
)
|
|
192
|
+
return table_utils.read_table(
|
|
193
|
+
f, format="orc", columns=list(out_spec.column_names())
|
|
194
|
+
)
|
|
195
|
+
elif magic.startswith(MAGIC_NUMPY):
|
|
196
|
+
if not isinstance(out_spec, TensorType):
|
|
197
|
+
raise ValueError(
|
|
198
|
+
f"NumPy files require TensorType output spec, got {type(out_spec).__name__}"
|
|
199
|
+
)
|
|
200
|
+
return np.load(f)
|
|
201
|
+
|
|
202
|
+
# Fallback: open the file for CSV or NumPy loading.
|
|
203
|
+
if isinstance(out_spec, TableType):
|
|
204
|
+
return table_utils.read_table(
|
|
205
|
+
f, format="csv", columns=list(out_spec.column_names())
|
|
206
|
+
)
|
|
207
|
+
else:
|
|
208
|
+
return np.load(f)
|
|
158
209
|
|
|
159
210
|
def write(self, uri: ResolvedURI, value: Any, *, ctx: KernelContext) -> None:
|
|
160
211
|
import os
|
|
@@ -163,14 +214,24 @@ class FileProvider(DataProvider):
|
|
|
163
214
|
dir_name = os.path.dirname(path)
|
|
164
215
|
if dir_name:
|
|
165
216
|
os.makedirs(dir_name, exist_ok=True)
|
|
166
|
-
|
|
167
|
-
if
|
|
168
|
-
|
|
217
|
+
|
|
218
|
+
if not isinstance(value, Value):
|
|
219
|
+
value = (
|
|
220
|
+
TableValue(value)
|
|
221
|
+
if isinstance(value, TableLike)
|
|
222
|
+
else TensorValue(value)
|
|
223
|
+
)
|
|
224
|
+
|
|
225
|
+
if isinstance(value, TableValue):
|
|
226
|
+
table_utils.write_table(value.to_arrow(), path, format="parquet")
|
|
227
|
+
elif isinstance(value, TensorValue):
|
|
228
|
+
with open(path, "wb") as f:
|
|
229
|
+
np.save(f, value.to_numpy())
|
|
230
|
+
else:
|
|
231
|
+
payload = encode_value(value)
|
|
169
232
|
with open(path, "wb") as f:
|
|
170
|
-
f.write(
|
|
171
|
-
|
|
172
|
-
# Tensor-like via numpy
|
|
173
|
-
np.save(path, np.asarray(value))
|
|
233
|
+
f.write(struct.pack(">4sB", MAGIC_MPLANG, VERSION))
|
|
234
|
+
f.write(payload)
|
|
174
235
|
|
|
175
236
|
|
|
176
237
|
class MemProvider(DataProvider):
|
|
@@ -29,7 +29,7 @@ from typing import Any
|
|
|
29
29
|
|
|
30
30
|
import numpy as np
|
|
31
31
|
|
|
32
|
-
from mplang.core import (
|
|
32
|
+
from mplang.v1.core import (
|
|
33
33
|
ClusterSpec,
|
|
34
34
|
InterpContext,
|
|
35
35
|
InterpVar,
|
|
@@ -38,9 +38,9 @@ from mplang.core import (
|
|
|
38
38
|
MPObject,
|
|
39
39
|
MPType,
|
|
40
40
|
)
|
|
41
|
-
from mplang.core.expr.ast import Expr
|
|
42
|
-
from mplang.kernels.value import TableValue, TensorValue
|
|
43
|
-
from mplang.runtime.client import HttpExecutorClient
|
|
41
|
+
from mplang.v1.core.expr.ast import Expr
|
|
42
|
+
from mplang.v1.kernels.value import TableValue, TensorValue
|
|
43
|
+
from mplang.v1.runtime.client import HttpExecutorClient
|
|
44
44
|
|
|
45
45
|
|
|
46
46
|
def new_uuid() -> str:
|
|
@@ -0,0 +1,196 @@
|
|
|
1
|
+
# Copyright 2025 Ant Group Co., Ltd.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
from __future__ import annotations
|
|
16
|
+
|
|
17
|
+
import logging
|
|
18
|
+
from typing import TYPE_CHECKING
|
|
19
|
+
|
|
20
|
+
import spu.libspu as libspu
|
|
21
|
+
|
|
22
|
+
if TYPE_CHECKING:
|
|
23
|
+
from mplang.v1.core.comm import CommunicatorBase
|
|
24
|
+
from mplang.v1.core.mask import Mask
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
class LinkCommunicator:
|
|
28
|
+
"""Minimal wrapper for libspu link context.
|
|
29
|
+
|
|
30
|
+
Supports three modes:
|
|
31
|
+
1. BRPC: Production mode with separate BRPC ports (legacy)
|
|
32
|
+
2. Mem: In-memory links for testing (legacy)
|
|
33
|
+
3. Channels: Reuse MPLang communicator via IChannel bridge (NEW)
|
|
34
|
+
|
|
35
|
+
The mode is selected based on constructor arguments:
|
|
36
|
+
- If `comm` is provided: Channels mode (NEW)
|
|
37
|
+
- Elif `mem_link` is True: Mem mode
|
|
38
|
+
- Else: BRPC mode
|
|
39
|
+
"""
|
|
40
|
+
|
|
41
|
+
def __init__(
|
|
42
|
+
self,
|
|
43
|
+
rank: int,
|
|
44
|
+
addrs: list[str] | None = None,
|
|
45
|
+
*,
|
|
46
|
+
mem_link: bool = False,
|
|
47
|
+
comm: CommunicatorBase | None = None,
|
|
48
|
+
spu_mask: Mask | None = None,
|
|
49
|
+
):
|
|
50
|
+
"""Initialize link communicator for SPU.
|
|
51
|
+
|
|
52
|
+
Args:
|
|
53
|
+
rank: Global rank of this party
|
|
54
|
+
addrs: List of addresses for all SPU parties (required for BRPC/Mem mode)
|
|
55
|
+
mem_link: If True, use in-memory link (Mem mode)
|
|
56
|
+
comm: MPLang communicator to reuse (Channels mode, NEW)
|
|
57
|
+
spu_mask: SPU parties mask (required for Channels mode)
|
|
58
|
+
|
|
59
|
+
Raises:
|
|
60
|
+
ValueError: If arguments are invalid for the selected mode
|
|
61
|
+
"""
|
|
62
|
+
self._rank = rank
|
|
63
|
+
|
|
64
|
+
# Select initialization mode based on arguments
|
|
65
|
+
if comm is not None:
|
|
66
|
+
self._init_channels_mode(rank, comm, spu_mask)
|
|
67
|
+
elif mem_link:
|
|
68
|
+
self._init_mem_mode(rank, addrs)
|
|
69
|
+
else:
|
|
70
|
+
self._init_brpc_mode(rank, addrs)
|
|
71
|
+
|
|
72
|
+
def _init_channels_mode(
|
|
73
|
+
self, rank: int, comm: CommunicatorBase, spu_mask: Mask | None
|
|
74
|
+
) -> None:
|
|
75
|
+
"""Initialize Channels mode (reuse MPLang communicator).
|
|
76
|
+
|
|
77
|
+
Args:
|
|
78
|
+
rank: Global rank of this party
|
|
79
|
+
comm: MPLang communicator to reuse
|
|
80
|
+
spu_mask: SPU parties mask
|
|
81
|
+
|
|
82
|
+
Raises:
|
|
83
|
+
ValueError: If spu_mask is None or rank not in mask
|
|
84
|
+
"""
|
|
85
|
+
if spu_mask is None:
|
|
86
|
+
raise ValueError("spu_mask required when using comm")
|
|
87
|
+
if rank not in spu_mask:
|
|
88
|
+
raise ValueError(f"rank {rank} not in spu_mask {spu_mask}")
|
|
89
|
+
|
|
90
|
+
# Lazy import to avoid circular dependency
|
|
91
|
+
from mplang.v1.runtime.channel import BaseChannel
|
|
92
|
+
|
|
93
|
+
# Create channels to ALL SPU parties (including self)
|
|
94
|
+
# libspu expects world_size channels, with self channel being None
|
|
95
|
+
channels = []
|
|
96
|
+
rel_rank = spu_mask.global_to_relative_rank(rank)
|
|
97
|
+
|
|
98
|
+
for _, peer_rank in enumerate(spu_mask):
|
|
99
|
+
if peer_rank == rank:
|
|
100
|
+
# For self, use None (won't be accessed by SPU)
|
|
101
|
+
channel = None
|
|
102
|
+
else:
|
|
103
|
+
channel = BaseChannel(comm, rank, peer_rank)
|
|
104
|
+
channels.append(channel)
|
|
105
|
+
|
|
106
|
+
# Create link context with custom channels
|
|
107
|
+
desc = libspu.link.Desc() # type: ignore
|
|
108
|
+
desc.recv_timeout_ms = 100 * 1000 # 100 seconds
|
|
109
|
+
|
|
110
|
+
# Add party info to desc (required for world_size inference)
|
|
111
|
+
for idx, peer_rank in enumerate(spu_mask):
|
|
112
|
+
desc.add_party(f"P{idx}", f"dummy_{peer_rank}")
|
|
113
|
+
|
|
114
|
+
self.lctx = libspu.link.create_with_channels(desc, rel_rank, channels)
|
|
115
|
+
self._world_size = spu_mask.num_parties()
|
|
116
|
+
|
|
117
|
+
logging.info(
|
|
118
|
+
f"LinkCommunicator initialized with BaseChannel: "
|
|
119
|
+
f"rank={rank}, rel_rank={rel_rank}, spu_mask={spu_mask}, "
|
|
120
|
+
f"world_size={self._world_size}"
|
|
121
|
+
)
|
|
122
|
+
|
|
123
|
+
def _init_mem_mode(self, rank: int, addrs: list[str] | None) -> None:
|
|
124
|
+
"""Initialize Mem mode (in-memory links for testing).
|
|
125
|
+
|
|
126
|
+
Args:
|
|
127
|
+
rank: Global rank of this party
|
|
128
|
+
addrs: List of addresses for all SPU parties
|
|
129
|
+
|
|
130
|
+
Raises:
|
|
131
|
+
ValueError: If addrs is None
|
|
132
|
+
"""
|
|
133
|
+
if addrs is None:
|
|
134
|
+
raise ValueError("addrs required for Mem mode")
|
|
135
|
+
|
|
136
|
+
self._world_size = len(addrs)
|
|
137
|
+
|
|
138
|
+
desc = libspu.link.Desc() # type: ignore
|
|
139
|
+
desc.recv_timeout_ms = 100 * 1000 # 100 seconds
|
|
140
|
+
desc.http_max_payload_size = 32 * 1024 * 1024 # 32M
|
|
141
|
+
for rank_idx, addr in enumerate(addrs):
|
|
142
|
+
desc.add_party(f"P{rank_idx}", addr)
|
|
143
|
+
|
|
144
|
+
self.lctx = libspu.link.create_mem(desc, self._rank)
|
|
145
|
+
logging.info(
|
|
146
|
+
f"LinkCommunicator initialized with Mem: "
|
|
147
|
+
f"rank={self._rank}, world_size={self._world_size}, addrs={addrs}"
|
|
148
|
+
)
|
|
149
|
+
|
|
150
|
+
def _init_brpc_mode(self, rank: int, addrs: list[str] | None) -> None:
|
|
151
|
+
"""Initialize BRPC mode (production mode with separate BRPC ports).
|
|
152
|
+
|
|
153
|
+
Args:
|
|
154
|
+
rank: Global rank of this party
|
|
155
|
+
addrs: List of addresses for all SPU parties
|
|
156
|
+
|
|
157
|
+
Raises:
|
|
158
|
+
ValueError: If addrs is None
|
|
159
|
+
"""
|
|
160
|
+
if addrs is None:
|
|
161
|
+
raise ValueError("addrs required for BRPC mode")
|
|
162
|
+
|
|
163
|
+
self._world_size = len(addrs)
|
|
164
|
+
|
|
165
|
+
desc = libspu.link.Desc() # type: ignore
|
|
166
|
+
desc.recv_timeout_ms = 100 * 1000 # 100 seconds
|
|
167
|
+
desc.http_max_payload_size = 32 * 1024 * 1024 # 32M
|
|
168
|
+
for rank_idx, addr in enumerate(addrs):
|
|
169
|
+
desc.add_party(f"P{rank_idx}", addr)
|
|
170
|
+
|
|
171
|
+
self.lctx = libspu.link.create_brpc(desc, self._rank)
|
|
172
|
+
logging.info(
|
|
173
|
+
f"LinkCommunicator initialized with BRPC: "
|
|
174
|
+
f"rank={self._rank}, world_size={self._world_size}, addrs={addrs}"
|
|
175
|
+
)
|
|
176
|
+
|
|
177
|
+
@property
|
|
178
|
+
def rank(self) -> int:
|
|
179
|
+
"""Get rank from underlying link context."""
|
|
180
|
+
return self.lctx.rank # type: ignore[no-any-return]
|
|
181
|
+
|
|
182
|
+
@property
|
|
183
|
+
def world_size(self) -> int:
|
|
184
|
+
"""Get world size from underlying link context."""
|
|
185
|
+
return self.lctx.world_size # type: ignore[no-any-return]
|
|
186
|
+
|
|
187
|
+
def get_lctx(self) -> libspu.link.Context:
|
|
188
|
+
"""Get the underlying libspu link context.
|
|
189
|
+
|
|
190
|
+
This is the primary interface - SPU runtime uses this context directly.
|
|
191
|
+
All communication and serialization is handled by libspu internally.
|
|
192
|
+
|
|
193
|
+
Returns:
|
|
194
|
+
The underlying libspu.link.Context instance.
|
|
195
|
+
"""
|
|
196
|
+
return self.lctx
|
|
@@ -30,14 +30,18 @@ from fastapi import (
|
|
|
30
30
|
from fastapi.responses import JSONResponse
|
|
31
31
|
from pydantic import BaseModel
|
|
32
32
|
|
|
33
|
-
from mplang.core import IrReader, TableType, TensorType
|
|
34
|
-
from mplang.core.cluster import ClusterSpec
|
|
35
|
-
from mplang.kernels.base import KernelContext
|
|
36
|
-
from mplang.kernels.value import Value, decode_value, encode_value
|
|
37
|
-
from mplang.protos.v1alpha1 import mpir_pb2
|
|
38
|
-
from mplang.runtime.data_providers import
|
|
39
|
-
|
|
40
|
-
|
|
33
|
+
from mplang.v1.core import IrReader, TableType, TensorType
|
|
34
|
+
from mplang.v1.core.cluster import ClusterSpec
|
|
35
|
+
from mplang.v1.kernels.base import KernelContext
|
|
36
|
+
from mplang.v1.kernels.value import Value, decode_value, encode_value
|
|
37
|
+
from mplang.v1.protos.v1alpha1 import mpir_pb2
|
|
38
|
+
from mplang.v1.runtime.data_providers import (
|
|
39
|
+
DataProvider,
|
|
40
|
+
ResolvedURI,
|
|
41
|
+
register_provider,
|
|
42
|
+
)
|
|
43
|
+
from mplang.v1.runtime.exceptions import InvalidRequestError, ResourceNotFound
|
|
44
|
+
from mplang.v1.runtime.session import (
|
|
41
45
|
Computation,
|
|
42
46
|
Session,
|
|
43
47
|
Symbol,
|
|
@@ -215,6 +219,7 @@ class SymbolResponse(BaseModel):
|
|
|
215
219
|
|
|
216
220
|
class CommSendRequest(BaseModel):
|
|
217
221
|
data: str # Base64 encoded binary data
|
|
222
|
+
is_raw_bytes: bool = False # True for SPU channel raw bytes
|
|
218
223
|
|
|
219
224
|
|
|
220
225
|
# Response Models for enhanced status
|
|
@@ -483,6 +488,14 @@ def comm_send(
|
|
|
483
488
|
# The receiver rank should be the rank of the server hosting this endpoint
|
|
484
489
|
# We don't need to validate to_rank since the request is coming to this server
|
|
485
490
|
|
|
491
|
+
# For raw bytes (SPU channel), pass through as dict with flag
|
|
492
|
+
# For normal data, pass the base64 string directly
|
|
493
|
+
data_payload: str | dict[str, object]
|
|
494
|
+
if request.is_raw_bytes:
|
|
495
|
+
data_payload = {"data": request.data, "is_raw_bytes": True}
|
|
496
|
+
else:
|
|
497
|
+
data_payload = request.data
|
|
498
|
+
|
|
486
499
|
# Use the proper onSent mechanism from CommunicatorBase
|
|
487
|
-
sess.communicator.onSent(from_rank, key,
|
|
500
|
+
sess.communicator.onSent(from_rank, key, data_payload)
|
|
488
501
|
return {"status": "ok"}
|
|
@@ -25,51 +25,28 @@ Process-wide registries (sessions, global symbols) live in the server layer
|
|
|
25
25
|
|
|
26
26
|
from __future__ import annotations
|
|
27
27
|
|
|
28
|
-
import logging
|
|
29
28
|
import time
|
|
30
29
|
from dataclasses import dataclass, field
|
|
31
30
|
from functools import cached_property
|
|
32
31
|
from typing import TYPE_CHECKING, Any, cast
|
|
33
|
-
from urllib.parse import urlparse
|
|
34
32
|
|
|
35
33
|
import spu.libspu as libspu
|
|
36
34
|
|
|
37
|
-
from mplang.core.cluster import ClusterSpec
|
|
38
|
-
from mplang.core.comm import ICommunicator
|
|
39
|
-
from mplang.core.expr.ast import Expr
|
|
40
|
-
from mplang.core.expr.evaluator import IEvaluator, create_evaluator
|
|
41
|
-
from mplang.core.mask import Mask
|
|
42
|
-
from mplang.kernels.context import RuntimeContext
|
|
43
|
-
from mplang.kernels.spu import PFunction # type: ignore
|
|
44
|
-
from mplang.kernels.value import Value
|
|
45
|
-
from mplang.runtime.communicator import HttpCommunicator
|
|
46
|
-
from mplang.runtime.exceptions import ResourceNotFound
|
|
47
|
-
from mplang.runtime.link_comm import LinkCommunicator
|
|
48
|
-
from mplang.utils.spu_utils import parse_field, parse_protocol
|
|
35
|
+
from mplang.v1.core.cluster import ClusterSpec
|
|
36
|
+
from mplang.v1.core.comm import ICommunicator
|
|
37
|
+
from mplang.v1.core.expr.ast import Expr
|
|
38
|
+
from mplang.v1.core.expr.evaluator import IEvaluator, create_evaluator
|
|
39
|
+
from mplang.v1.core.mask import Mask
|
|
40
|
+
from mplang.v1.kernels.context import RuntimeContext
|
|
41
|
+
from mplang.v1.kernels.spu import PFunction # type: ignore
|
|
42
|
+
from mplang.v1.kernels.value import Value
|
|
43
|
+
from mplang.v1.runtime.communicator import HttpCommunicator
|
|
44
|
+
from mplang.v1.runtime.exceptions import ResourceNotFound
|
|
45
|
+
from mplang.v1.runtime.link_comm import LinkCommunicator
|
|
46
|
+
from mplang.v1.utils.spu_utils import parse_field, parse_protocol
|
|
49
47
|
|
|
50
48
|
if TYPE_CHECKING: # pragma: no cover - import only for type checking
|
|
51
|
-
from mplang.core.cluster import ClusterSpec, Node, RuntimeInfo
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
class LinkCommFactory:
|
|
55
|
-
"""Factory for creating and caching link communicators."""
|
|
56
|
-
|
|
57
|
-
def __init__(self) -> None:
|
|
58
|
-
self._cache: dict[tuple[int, tuple[str, ...]], LinkCommunicator] = {}
|
|
59
|
-
|
|
60
|
-
def create_link(self, rel_rank: int, addrs: list[str]) -> LinkCommunicator:
|
|
61
|
-
key = (rel_rank, tuple(addrs))
|
|
62
|
-
link = self._cache.get(key)
|
|
63
|
-
if link is not None:
|
|
64
|
-
return link
|
|
65
|
-
logging.info(f"LinkCommunicator created: rel_rank={rel_rank} addrs={addrs}")
|
|
66
|
-
link = LinkCommunicator(rel_rank, addrs)
|
|
67
|
-
self._cache[key] = link
|
|
68
|
-
return link
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
# Shared link factory (module-local, not global registry of sessions)
|
|
72
|
-
g_link_factory = LinkCommFactory()
|
|
49
|
+
from mplang.v1.core.cluster import ClusterSpec, Node, RuntimeInfo
|
|
73
50
|
|
|
74
51
|
|
|
75
52
|
@dataclass
|
|
@@ -184,23 +161,19 @@ class Session:
|
|
|
184
161
|
return
|
|
185
162
|
|
|
186
163
|
link_ctx = None
|
|
187
|
-
# TODO(jint): reuse same port for mplang and spu.
|
|
188
|
-
SPU_PORT_OFFSET = 100
|
|
189
164
|
|
|
190
165
|
if self.is_spu_party:
|
|
191
|
-
#
|
|
192
|
-
|
|
193
|
-
|
|
194
|
-
|
|
195
|
-
|
|
196
|
-
|
|
197
|
-
|
|
198
|
-
|
|
199
|
-
|
|
200
|
-
|
|
201
|
-
|
|
202
|
-
rel_index = sum(1 for r in range(self.rank) if r in self.spu_mask)
|
|
203
|
-
link_ctx = g_link_factory.create_link(rel_index, spu_addrs)
|
|
166
|
+
# Use Channels mode to reuse existing HttpCommunicator
|
|
167
|
+
# This eliminates the need for separate BRPC ports (SPU_PORT_OFFSET)
|
|
168
|
+
from mplang.v1.core.comm import CommunicatorBase
|
|
169
|
+
|
|
170
|
+
# Type assertion: ICommunicator is actually CommunicatorBase
|
|
171
|
+
comm = cast(CommunicatorBase, self.communicator)
|
|
172
|
+
link_ctx = LinkCommunicator(
|
|
173
|
+
rank=self.rank,
|
|
174
|
+
comm=comm,
|
|
175
|
+
spu_mask=self.spu_mask,
|
|
176
|
+
)
|
|
204
177
|
|
|
205
178
|
spu_config = libspu.RuntimeConfig(
|
|
206
179
|
protocol=parse_protocol(self.spu_protocol),
|
|
@@ -18,13 +18,14 @@ import concurrent.futures
|
|
|
18
18
|
import faulthandler
|
|
19
19
|
import logging
|
|
20
20
|
import sys
|
|
21
|
+
import threading
|
|
21
22
|
import traceback
|
|
22
23
|
from collections.abc import Sequence
|
|
23
24
|
from typing import Any, cast
|
|
24
25
|
|
|
25
26
|
import spu.libspu as libspu
|
|
26
27
|
|
|
27
|
-
from mplang.core import (
|
|
28
|
+
from mplang.v1.core import (
|
|
28
29
|
ClusterSpec,
|
|
29
30
|
CollectiveMixin,
|
|
30
31
|
CommunicatorBase,
|
|
@@ -38,11 +39,11 @@ from mplang.core import (
|
|
|
38
39
|
PFunction, # for spu.seed_env kernel seeding
|
|
39
40
|
TensorLike,
|
|
40
41
|
)
|
|
41
|
-
from mplang.core.expr.ast import Expr
|
|
42
|
-
from mplang.core.expr.evaluator import IEvaluator, create_evaluator
|
|
43
|
-
from mplang.kernels.context import RuntimeContext
|
|
44
|
-
from mplang.runtime.link_comm import LinkCommunicator
|
|
45
|
-
from mplang.utils.spu_utils import parse_field, parse_protocol
|
|
42
|
+
from mplang.v1.core.expr.ast import Expr
|
|
43
|
+
from mplang.v1.core.expr.evaluator import IEvaluator, create_evaluator
|
|
44
|
+
from mplang.v1.kernels.context import RuntimeContext
|
|
45
|
+
from mplang.v1.runtime.link_comm import LinkCommunicator
|
|
46
|
+
from mplang.v1.utils.spu_utils import parse_field, parse_protocol
|
|
46
47
|
|
|
47
48
|
|
|
48
49
|
class ThreadCommunicator(CommunicatorBase, CollectiveMixin):
|
|
@@ -129,16 +130,37 @@ class Simulator(InterpContext):
|
|
|
129
130
|
comm.set_peers(self._comms)
|
|
130
131
|
|
|
131
132
|
# Prepare link contexts for SPU parties (store for evaluator-time initialization)
|
|
132
|
-
|
|
133
|
+
# Use Channels mode to reuse ThreadCommunicator instead of separate mem_link
|
|
133
134
|
self._spu_link_ctxs: list[LinkCommunicator | None] = [None] * world_size
|
|
134
|
-
|
|
135
|
-
|
|
136
|
-
|
|
135
|
+
|
|
136
|
+
# Create LinkCommunicators in parallel to avoid deadlock
|
|
137
|
+
# (create_with_channels does handshake via TestSend/TestRecv)
|
|
138
|
+
exceptions: dict[int, Exception] = {}
|
|
139
|
+
|
|
140
|
+
def create_link(g_rank: int) -> None:
|
|
141
|
+
try:
|
|
142
|
+
self._spu_link_ctxs[g_rank] = LinkCommunicator(
|
|
143
|
+
rank=g_rank,
|
|
144
|
+
comm=self._comms[g_rank],
|
|
145
|
+
spu_mask=spu_mask,
|
|
146
|
+
)
|
|
147
|
+
except Exception as e:
|
|
148
|
+
exceptions[g_rank] = e
|
|
149
|
+
|
|
150
|
+
threads = [
|
|
151
|
+
threading.Thread(target=create_link, args=(g_rank,)) for g_rank in spu_mask
|
|
137
152
|
]
|
|
138
|
-
for
|
|
139
|
-
|
|
140
|
-
|
|
141
|
-
|
|
153
|
+
for t in threads:
|
|
154
|
+
t.start()
|
|
155
|
+
for t in threads:
|
|
156
|
+
t.join()
|
|
157
|
+
|
|
158
|
+
# Check for exceptions during link creation
|
|
159
|
+
if exceptions:
|
|
160
|
+
first_exc = next(iter(exceptions.values()))
|
|
161
|
+
raise RuntimeError(
|
|
162
|
+
f"Failed to create SPU link contexts for ranks {list(exceptions.keys())}"
|
|
163
|
+
) from first_exc
|
|
142
164
|
|
|
143
165
|
self._spu_runtime_cfg = libspu.RuntimeConfig(
|
|
144
166
|
protocol=spu_protocol, field=spu_field
|