mplang-nightly 0.1.dev158__py3-none-any.whl → 0.1.dev268__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mplang/__init__.py +21 -45
- mplang/py.typed +13 -0
- mplang/v1/__init__.py +157 -0
- mplang/v1/_device.py +602 -0
- mplang/{analysis → v1/analysis}/__init__.py +1 -1
- mplang/{analysis → v1/analysis}/diagram.py +5 -7
- mplang/v1/core/__init__.py +157 -0
- mplang/{core → v1/core}/cluster.py +30 -14
- mplang/{core → v1/core}/comm.py +5 -1
- mplang/{core → v1/core}/context_mgr.py +1 -1
- mplang/{core/dtype.py → v1/core/dtypes.py} +44 -2
- mplang/{core → v1/core}/expr/__init__.py +7 -7
- mplang/{core → v1/core}/expr/ast.py +13 -14
- mplang/{core → v1/core}/expr/evaluator.py +65 -24
- mplang/{core → v1/core}/expr/printer.py +24 -18
- mplang/{core → v1/core}/expr/transformer.py +3 -3
- mplang/{core → v1/core}/expr/utils.py +2 -2
- mplang/{core → v1/core}/expr/visitor.py +1 -1
- mplang/{core → v1/core}/expr/walk.py +1 -1
- mplang/{core → v1/core}/interp.py +6 -6
- mplang/{core → v1/core}/mpir.py +23 -16
- mplang/{core → v1/core}/mpobject.py +6 -6
- mplang/{core → v1/core}/mptype.py +13 -10
- mplang/{core → v1/core}/pfunc.py +4 -4
- mplang/{core → v1/core}/primitive.py +106 -201
- mplang/{core → v1/core}/table.py +36 -8
- mplang/{core → v1/core}/tensor.py +1 -1
- mplang/{core → v1/core}/tracer.py +9 -9
- mplang/{api.py → v1/host.py} +38 -6
- mplang/v1/kernels/__init__.py +41 -0
- mplang/{kernels → v1/kernels}/base.py +1 -1
- mplang/v1/kernels/basic.py +240 -0
- mplang/{kernels → v1/kernels}/context.py +42 -27
- mplang/{kernels → v1/kernels}/crypto.py +44 -37
- mplang/v1/kernels/fhe.py +858 -0
- mplang/{kernels → v1/kernels}/mock_tee.py +12 -13
- mplang/{kernels → v1/kernels}/phe.py +263 -57
- mplang/{kernels → v1/kernels}/spu.py +137 -48
- mplang/{kernels → v1/kernels}/sql_duckdb.py +12 -15
- mplang/{kernels → v1/kernels}/stablehlo.py +30 -23
- mplang/v1/kernels/value.py +626 -0
- mplang/{ops → v1/ops}/__init__.py +5 -16
- mplang/{ops → v1/ops}/base.py +2 -5
- mplang/{ops/builtin.py → v1/ops/basic.py} +34 -26
- mplang/v1/ops/crypto.py +262 -0
- mplang/v1/ops/fhe.py +272 -0
- mplang/{ops → v1/ops}/jax_cc.py +33 -68
- mplang/v1/ops/nnx_cc.py +168 -0
- mplang/{ops → v1/ops}/phe.py +16 -4
- mplang/{ops → v1/ops}/spu.py +3 -5
- mplang/v1/ops/sql_cc.py +303 -0
- mplang/{ops → v1/ops}/tee.py +9 -24
- mplang/{protos → v1/protos}/v1alpha1/mpir_pb2.pyi +71 -21
- mplang/v1/protos/v1alpha1/value_pb2.py +34 -0
- mplang/v1/protos/v1alpha1/value_pb2.pyi +169 -0
- mplang/{runtime → v1/runtime}/__init__.py +2 -2
- mplang/v1/runtime/channel.py +230 -0
- mplang/{runtime → v1/runtime}/cli.py +35 -20
- mplang/{runtime → v1/runtime}/client.py +19 -8
- mplang/{runtime → v1/runtime}/communicator.py +59 -15
- mplang/{runtime → v1/runtime}/data_providers.py +80 -19
- mplang/{runtime → v1/runtime}/driver.py +30 -12
- mplang/v1/runtime/link_comm.py +196 -0
- mplang/{runtime → v1/runtime}/server.py +58 -42
- mplang/{runtime → v1/runtime}/session.py +57 -71
- mplang/{runtime → v1/runtime}/simulation.py +55 -28
- mplang/v1/simp/api.py +353 -0
- mplang/{simp → v1/simp}/mpi.py +8 -9
- mplang/{simp/__init__.py → v1/simp/party.py} +19 -145
- mplang/{simp → v1/simp}/random.py +21 -22
- mplang/v1/simp/smpc.py +238 -0
- mplang/v1/utils/table_utils.py +185 -0
- mplang/v2/__init__.py +424 -0
- mplang/v2/backends/__init__.py +57 -0
- mplang/v2/backends/bfv_impl.py +705 -0
- mplang/v2/backends/channel.py +217 -0
- mplang/v2/backends/crypto_impl.py +723 -0
- mplang/v2/backends/field_impl.py +454 -0
- mplang/v2/backends/func_impl.py +107 -0
- mplang/v2/backends/phe_impl.py +148 -0
- mplang/v2/backends/simp_design.md +136 -0
- mplang/v2/backends/simp_driver/__init__.py +41 -0
- mplang/v2/backends/simp_driver/http.py +168 -0
- mplang/v2/backends/simp_driver/mem.py +280 -0
- mplang/v2/backends/simp_driver/ops.py +135 -0
- mplang/v2/backends/simp_driver/state.py +60 -0
- mplang/v2/backends/simp_driver/values.py +52 -0
- mplang/v2/backends/simp_worker/__init__.py +29 -0
- mplang/v2/backends/simp_worker/http.py +354 -0
- mplang/v2/backends/simp_worker/mem.py +102 -0
- mplang/v2/backends/simp_worker/ops.py +167 -0
- mplang/v2/backends/simp_worker/state.py +49 -0
- mplang/v2/backends/spu_impl.py +275 -0
- mplang/v2/backends/spu_state.py +187 -0
- mplang/v2/backends/store_impl.py +62 -0
- mplang/v2/backends/table_impl.py +838 -0
- mplang/v2/backends/tee_impl.py +215 -0
- mplang/v2/backends/tensor_impl.py +519 -0
- mplang/v2/cli.py +603 -0
- mplang/v2/cli_guide.md +122 -0
- mplang/v2/dialects/__init__.py +36 -0
- mplang/v2/dialects/bfv.py +665 -0
- mplang/v2/dialects/crypto.py +689 -0
- mplang/v2/dialects/dtypes.py +378 -0
- mplang/v2/dialects/field.py +210 -0
- mplang/v2/dialects/func.py +135 -0
- mplang/v2/dialects/phe.py +723 -0
- mplang/v2/dialects/simp.py +944 -0
- mplang/v2/dialects/spu.py +349 -0
- mplang/v2/dialects/store.py +63 -0
- mplang/v2/dialects/table.py +407 -0
- mplang/v2/dialects/tee.py +346 -0
- mplang/v2/dialects/tensor.py +1175 -0
- mplang/v2/edsl/README.md +279 -0
- mplang/v2/edsl/__init__.py +99 -0
- mplang/v2/edsl/context.py +311 -0
- mplang/v2/edsl/graph.py +463 -0
- mplang/v2/edsl/jit.py +62 -0
- mplang/v2/edsl/object.py +53 -0
- mplang/v2/edsl/primitive.py +284 -0
- mplang/v2/edsl/printer.py +119 -0
- mplang/v2/edsl/registry.py +207 -0
- mplang/v2/edsl/serde.py +375 -0
- mplang/v2/edsl/tracer.py +614 -0
- mplang/v2/edsl/typing.py +816 -0
- mplang/v2/kernels/Makefile +30 -0
- mplang/v2/kernels/__init__.py +23 -0
- mplang/v2/kernels/gf128.cpp +148 -0
- mplang/v2/kernels/ldpc.cpp +82 -0
- mplang/v2/kernels/okvs.cpp +283 -0
- mplang/v2/kernels/okvs_opt.cpp +291 -0
- mplang/v2/kernels/py_kernels.py +398 -0
- mplang/v2/libs/collective.py +330 -0
- mplang/v2/libs/device/__init__.py +51 -0
- mplang/v2/libs/device/api.py +813 -0
- mplang/v2/libs/device/cluster.py +352 -0
- mplang/v2/libs/ml/__init__.py +23 -0
- mplang/v2/libs/ml/sgb.py +1861 -0
- mplang/v2/libs/mpc/__init__.py +41 -0
- mplang/v2/libs/mpc/_utils.py +99 -0
- mplang/v2/libs/mpc/analytics/__init__.py +35 -0
- mplang/v2/libs/mpc/analytics/aggregation.py +372 -0
- mplang/v2/libs/mpc/analytics/groupby.md +99 -0
- mplang/v2/libs/mpc/analytics/groupby.py +331 -0
- mplang/v2/libs/mpc/analytics/permutation.py +386 -0
- mplang/v2/libs/mpc/common/constants.py +39 -0
- mplang/v2/libs/mpc/ot/__init__.py +32 -0
- mplang/v2/libs/mpc/ot/base.py +222 -0
- mplang/v2/libs/mpc/ot/extension.py +477 -0
- mplang/v2/libs/mpc/ot/silent.py +217 -0
- mplang/v2/libs/mpc/psi/__init__.py +40 -0
- mplang/v2/libs/mpc/psi/cuckoo.py +228 -0
- mplang/v2/libs/mpc/psi/okvs.py +49 -0
- mplang/v2/libs/mpc/psi/okvs_gct.py +79 -0
- mplang/v2/libs/mpc/psi/oprf.py +310 -0
- mplang/v2/libs/mpc/psi/rr22.py +344 -0
- mplang/v2/libs/mpc/psi/unbalanced.py +200 -0
- mplang/v2/libs/mpc/vole/__init__.py +31 -0
- mplang/v2/libs/mpc/vole/gilboa.py +327 -0
- mplang/v2/libs/mpc/vole/ldpc.py +383 -0
- mplang/v2/libs/mpc/vole/silver.py +336 -0
- mplang/v2/runtime/__init__.py +15 -0
- mplang/v2/runtime/dialect_state.py +41 -0
- mplang/v2/runtime/interpreter.py +871 -0
- mplang/v2/runtime/object_store.py +194 -0
- mplang/v2/runtime/value.py +141 -0
- {mplang_nightly-0.1.dev158.dist-info → mplang_nightly-0.1.dev268.dist-info}/METADATA +24 -17
- mplang_nightly-0.1.dev268.dist-info/RECORD +180 -0
- {mplang_nightly-0.1.dev158.dist-info → mplang_nightly-0.1.dev268.dist-info}/WHEEL +1 -1
- mplang/core/__init__.py +0 -92
- mplang/device.py +0 -340
- mplang/kernels/builtin.py +0 -207
- mplang/ops/crypto.py +0 -109
- mplang/ops/ibis_cc.py +0 -139
- mplang/ops/sql.py +0 -61
- mplang/protos/v1alpha1/mpir_pb2_grpc.py +0 -3
- mplang/runtime/link_comm.py +0 -131
- mplang/simp/smpc.py +0 -201
- mplang/utils/table_utils.py +0 -73
- mplang_nightly-0.1.dev158.dist-info/RECORD +0 -77
- /mplang/{core → v1/core}/mask.py +0 -0
- /mplang/{protos → v1/protos}/v1alpha1/mpir_pb2.py +0 -0
- /mplang/{runtime → v1/runtime}/exceptions.py +0 -0
- /mplang/{runtime → v1/runtime}/http_api.md +0 -0
- /mplang/{kernels → v1/simp}/__init__.py +0 -0
- /mplang/{utils → v1/utils}/__init__.py +0 -0
- /mplang/{utils → v1/utils}/crypto.py +0 -0
- /mplang/{utils → v1/utils}/func_utils.py +0 -0
- /mplang/{utils → v1/utils}/spu_utils.py +0 -0
- {mplang_nightly-0.1.dev158.dist-info → mplang_nightly-0.1.dev268.dist-info}/entry_points.txt +0 -0
- {mplang_nightly-0.1.dev158.dist-info → mplang_nightly-0.1.dev268.dist-info}/licenses/LICENSE +0 -0
|
@@ -21,20 +21,33 @@ import base64
|
|
|
21
21
|
import logging
|
|
22
22
|
from typing import Any
|
|
23
23
|
|
|
24
|
-
import cloudpickle as pickle
|
|
25
24
|
import httpx
|
|
26
25
|
|
|
27
|
-
from mplang.core.comm import CommunicatorBase
|
|
26
|
+
from mplang.v1.core.comm import CommunicatorBase
|
|
27
|
+
from mplang.v1.kernels.value import Value, decode_value, encode_value
|
|
28
28
|
|
|
29
29
|
|
|
30
30
|
class HttpCommunicator(CommunicatorBase):
|
|
31
31
|
def __init__(self, session_name: str, rank: int, endpoints: list[str]):
|
|
32
|
+
# Validate endpoints
|
|
33
|
+
if not endpoints:
|
|
34
|
+
raise ValueError("endpoints cannot be empty")
|
|
35
|
+
|
|
36
|
+
if not all(endpoint for endpoint in endpoints):
|
|
37
|
+
raise ValueError("endpoints cannot contain empty elements")
|
|
38
|
+
|
|
32
39
|
super().__init__(rank, len(endpoints))
|
|
33
40
|
self.session_name = session_name
|
|
34
|
-
|
|
41
|
+
# Ensure all endpoints have protocol prefix
|
|
42
|
+
self.endpoints = [
|
|
43
|
+
endpoint
|
|
44
|
+
if endpoint.startswith(("http://", "https://"))
|
|
45
|
+
else f"http://{endpoint}"
|
|
46
|
+
for endpoint in endpoints
|
|
47
|
+
]
|
|
35
48
|
self._counter = 0
|
|
36
49
|
logging.info(
|
|
37
|
-
f"HttpCommunicator initialized: session={session_name}, rank={rank}, endpoints={endpoints}"
|
|
50
|
+
f"HttpCommunicator initialized: session={session_name}, rank={rank}, endpoints={self.endpoints}"
|
|
38
51
|
)
|
|
39
52
|
|
|
40
53
|
# override
|
|
@@ -44,7 +57,12 @@ class HttpCommunicator(CommunicatorBase):
|
|
|
44
57
|
return str(res)
|
|
45
58
|
|
|
46
59
|
def send(self, to: int, key: str, data: Any) -> None:
|
|
47
|
-
"""Sends data to a peer party by PUTing to its /comm/{key}/from/{from_rank} endpoint.
|
|
60
|
+
"""Sends data to a peer party by PUTing to its /comm/{key}/from/{from_rank} endpoint.
|
|
61
|
+
|
|
62
|
+
Supports two modes:
|
|
63
|
+
- SPU channel (key starts with "spu:"): sends raw bytes directly
|
|
64
|
+
- Normal channel: wraps data in Value envelope
|
|
65
|
+
"""
|
|
48
66
|
target_endpoint = self.endpoints[to]
|
|
49
67
|
url = f"{target_endpoint}/sessions/{self.session_name}/comm/{key}/from/{self._rank}"
|
|
50
68
|
logging.debug(
|
|
@@ -52,13 +70,20 @@ class HttpCommunicator(CommunicatorBase):
|
|
|
52
70
|
)
|
|
53
71
|
|
|
54
72
|
try:
|
|
55
|
-
#
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
|
|
73
|
+
# SPU channel mode: send raw bytes directly
|
|
74
|
+
if key.startswith("spu:") and isinstance(data, bytes):
|
|
75
|
+
data_b64 = base64.b64encode(data).decode("utf-8")
|
|
76
|
+
request_data = {"data": data_b64, "is_raw_bytes": True}
|
|
77
|
+
# Normal mode: serialize using Value envelope
|
|
78
|
+
elif isinstance(data, Value):
|
|
79
|
+
data_bytes = encode_value(data)
|
|
80
|
+
data_b64 = base64.b64encode(data_bytes).decode("utf-8")
|
|
81
|
+
request_data = {"data": data_b64}
|
|
82
|
+
else:
|
|
83
|
+
raise TypeError(
|
|
84
|
+
f"Communicator requires Value instance, got {type(data).__name__}. "
|
|
85
|
+
"Wrap data in TensorValue or custom Value subclass."
|
|
86
|
+
)
|
|
62
87
|
|
|
63
88
|
response = httpx.put(url, json=request_data, timeout=60)
|
|
64
89
|
logging.debug(f"Send response: status={response.status_code}")
|
|
@@ -72,14 +97,33 @@ class HttpCommunicator(CommunicatorBase):
|
|
|
72
97
|
raise OSError(f"Failed to send data to rank {to}") from e
|
|
73
98
|
|
|
74
99
|
def recv(self, frm: int, key: str) -> Any:
|
|
75
|
-
"""Wait until the key is set, returns the value.
|
|
100
|
+
"""Wait until the key is set, returns the value.
|
|
101
|
+
|
|
102
|
+
Supports two modes:
|
|
103
|
+
- SPU channel (key starts with "spu:"): returns raw bytes
|
|
104
|
+
- Normal channel: returns deserialized Value
|
|
105
|
+
"""
|
|
76
106
|
logging.debug(
|
|
77
107
|
f"Waiting to receive: from_rank={frm}, to_rank={self._rank}, key={key}"
|
|
78
108
|
)
|
|
79
|
-
|
|
109
|
+
received_data = super().recv(frm, key)
|
|
110
|
+
|
|
111
|
+
# Check if this is raw bytes (SPU channel)
|
|
112
|
+
if isinstance(received_data, dict) and received_data.get("is_raw_bytes"):
|
|
113
|
+
data_bytes = base64.b64decode(received_data["data"])
|
|
114
|
+
logging.debug(
|
|
115
|
+
f"Received raw bytes: from_rank={frm}, to_rank={self._rank}, key={key}, size={len(data_bytes)}"
|
|
116
|
+
)
|
|
117
|
+
return data_bytes
|
|
80
118
|
|
|
119
|
+
# Normal mode: deserialize Value envelope
|
|
120
|
+
data_b64 = (
|
|
121
|
+
received_data
|
|
122
|
+
if isinstance(received_data, str)
|
|
123
|
+
else received_data.get("data")
|
|
124
|
+
)
|
|
81
125
|
data_bytes = base64.b64decode(data_b64)
|
|
82
|
-
result =
|
|
126
|
+
result = decode_value(data_bytes)
|
|
83
127
|
|
|
84
128
|
logging.debug(
|
|
85
129
|
f"Received data: from_rank={frm}, to_rank={self._rank}, key={key}"
|
|
@@ -14,17 +14,24 @@
|
|
|
14
14
|
|
|
15
15
|
from __future__ import annotations
|
|
16
16
|
|
|
17
|
+
import pathlib
|
|
18
|
+
import struct
|
|
17
19
|
from dataclasses import dataclass
|
|
18
20
|
from typing import Any
|
|
19
21
|
from urllib.parse import ParseResult, urlparse
|
|
20
22
|
|
|
21
23
|
import numpy as np
|
|
22
|
-
import pandas as pd
|
|
23
24
|
|
|
24
|
-
from mplang.core
|
|
25
|
-
from mplang.
|
|
26
|
-
from mplang.kernels.
|
|
27
|
-
|
|
25
|
+
from mplang.v1.core import TableLike, TableType, TensorType
|
|
26
|
+
from mplang.v1.kernels.base import KernelContext
|
|
27
|
+
from mplang.v1.kernels.value import (
|
|
28
|
+
TableValue,
|
|
29
|
+
TensorValue,
|
|
30
|
+
Value,
|
|
31
|
+
decode_value,
|
|
32
|
+
encode_value,
|
|
33
|
+
)
|
|
34
|
+
from mplang.v1.utils import table_utils
|
|
28
35
|
|
|
29
36
|
|
|
30
37
|
@dataclass(frozen=True)
|
|
@@ -136,6 +143,11 @@ def get_provider(scheme: str) -> DataProvider | None:
|
|
|
136
143
|
|
|
137
144
|
|
|
138
145
|
# ---------------- Default Providers ----------------
|
|
146
|
+
MAGIC_MPLANG = b"MPLG"
|
|
147
|
+
MAGIC_PARQUET = b"PAR1"
|
|
148
|
+
MAGIC_ORC = b"ORC"
|
|
149
|
+
MAGIC_NUMPY = b"\x93NUMPY"
|
|
150
|
+
VERSION = 0x01
|
|
139
151
|
|
|
140
152
|
|
|
141
153
|
class FileProvider(DataProvider):
|
|
@@ -148,13 +160,52 @@ class FileProvider(DataProvider):
|
|
|
148
160
|
def read(
|
|
149
161
|
self, uri: ResolvedURI, out_spec: TensorType | TableType, *, ctx: KernelContext
|
|
150
162
|
) -> Any:
|
|
151
|
-
path = uri.local_path or uri.raw
|
|
152
|
-
|
|
153
|
-
|
|
154
|
-
|
|
155
|
-
|
|
156
|
-
|
|
157
|
-
|
|
163
|
+
path = pathlib.Path(uri.local_path or uri.raw)
|
|
164
|
+
# try load by magic
|
|
165
|
+
with path.open("rb") as f:
|
|
166
|
+
# this is the maximum length needed to detect all supported formats
|
|
167
|
+
# (numpy requires 6 bytes: '\x93NUMPY').
|
|
168
|
+
MAGIC_BYTES_LEN_MAX = 6
|
|
169
|
+
magic = f.read(MAGIC_BYTES_LEN_MAX)
|
|
170
|
+
f.seek(0)
|
|
171
|
+
if magic.startswith(MAGIC_MPLANG):
|
|
172
|
+
MPLANG_HEADER_LEN = len(MAGIC_MPLANG) + 1
|
|
173
|
+
header = f.read(MPLANG_HEADER_LEN)
|
|
174
|
+
_, version = struct.unpack(">4sB", header)
|
|
175
|
+
if version != VERSION:
|
|
176
|
+
raise ValueError(f"unsupported mplang version {version}")
|
|
177
|
+
payload = f.read()
|
|
178
|
+
return decode_value(payload)
|
|
179
|
+
elif magic.startswith(MAGIC_PARQUET):
|
|
180
|
+
if not isinstance(out_spec, TableType):
|
|
181
|
+
raise ValueError(
|
|
182
|
+
f"PARQUET files require TableType output spec, got {type(out_spec).__name__}"
|
|
183
|
+
)
|
|
184
|
+
return table_utils.read_table(
|
|
185
|
+
f, format="parquet", columns=list(out_spec.column_names())
|
|
186
|
+
)
|
|
187
|
+
elif magic.startswith(MAGIC_ORC):
|
|
188
|
+
if not isinstance(out_spec, TableType):
|
|
189
|
+
raise ValueError(
|
|
190
|
+
f"ORC files require TableType output spec, got {type(out_spec).__name__}"
|
|
191
|
+
)
|
|
192
|
+
return table_utils.read_table(
|
|
193
|
+
f, format="orc", columns=list(out_spec.column_names())
|
|
194
|
+
)
|
|
195
|
+
elif magic.startswith(MAGIC_NUMPY):
|
|
196
|
+
if not isinstance(out_spec, TensorType):
|
|
197
|
+
raise ValueError(
|
|
198
|
+
f"NumPy files require TensorType output spec, got {type(out_spec).__name__}"
|
|
199
|
+
)
|
|
200
|
+
return np.load(f)
|
|
201
|
+
|
|
202
|
+
# Fallback: open the file for CSV or NumPy loading.
|
|
203
|
+
if isinstance(out_spec, TableType):
|
|
204
|
+
return table_utils.read_table(
|
|
205
|
+
f, format="csv", columns=list(out_spec.column_names())
|
|
206
|
+
)
|
|
207
|
+
else:
|
|
208
|
+
return np.load(f)
|
|
158
209
|
|
|
159
210
|
def write(self, uri: ResolvedURI, value: Any, *, ctx: KernelContext) -> None:
|
|
160
211
|
import os
|
|
@@ -163,14 +214,24 @@ class FileProvider(DataProvider):
|
|
|
163
214
|
dir_name = os.path.dirname(path)
|
|
164
215
|
if dir_name:
|
|
165
216
|
os.makedirs(dir_name, exist_ok=True)
|
|
166
|
-
|
|
167
|
-
if
|
|
168
|
-
|
|
217
|
+
|
|
218
|
+
if not isinstance(value, Value):
|
|
219
|
+
value = (
|
|
220
|
+
TableValue(value)
|
|
221
|
+
if isinstance(value, TableLike)
|
|
222
|
+
else TensorValue(value)
|
|
223
|
+
)
|
|
224
|
+
|
|
225
|
+
if isinstance(value, TableValue):
|
|
226
|
+
table_utils.write_table(value.to_arrow(), path, format="parquet")
|
|
227
|
+
elif isinstance(value, TensorValue):
|
|
228
|
+
with open(path, "wb") as f:
|
|
229
|
+
np.save(f, value.to_numpy())
|
|
230
|
+
else:
|
|
231
|
+
payload = encode_value(value)
|
|
169
232
|
with open(path, "wb") as f:
|
|
170
|
-
f.write(
|
|
171
|
-
|
|
172
|
-
# Tensor-like via numpy
|
|
173
|
-
np.save(path, np.asarray(value))
|
|
233
|
+
f.write(struct.pack(">4sB", MAGIC_MPLANG, VERSION))
|
|
234
|
+
f.write(payload)
|
|
174
235
|
|
|
175
236
|
|
|
176
237
|
class MemProvider(DataProvider):
|
|
@@ -15,8 +15,8 @@
|
|
|
15
15
|
"""
|
|
16
16
|
HTTP-based driver implementation for distributed execution.
|
|
17
17
|
|
|
18
|
-
This module provides an HTTP-based
|
|
19
|
-
|
|
18
|
+
This module provides an HTTP-based driver, using REST APIs
|
|
19
|
+
for distributed multi-party computation coordination.
|
|
20
20
|
"""
|
|
21
21
|
|
|
22
22
|
from __future__ import annotations
|
|
@@ -27,14 +27,20 @@ import uuid
|
|
|
27
27
|
from collections.abc import Sequence
|
|
28
28
|
from typing import Any
|
|
29
29
|
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
from mplang.core
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
|
|
30
|
+
import numpy as np
|
|
31
|
+
|
|
32
|
+
from mplang.v1.core import (
|
|
33
|
+
ClusterSpec,
|
|
34
|
+
InterpContext,
|
|
35
|
+
InterpVar,
|
|
36
|
+
IrWriter,
|
|
37
|
+
Mask,
|
|
38
|
+
MPObject,
|
|
39
|
+
MPType,
|
|
40
|
+
)
|
|
41
|
+
from mplang.v1.core.expr.ast import Expr
|
|
42
|
+
from mplang.v1.kernels.value import TableValue, TensorValue
|
|
43
|
+
from mplang.v1.runtime.client import HttpExecutorClient
|
|
38
44
|
|
|
39
45
|
|
|
40
46
|
def new_uuid() -> str:
|
|
@@ -195,7 +201,7 @@ class Driver(InterpContext):
|
|
|
195
201
|
|
|
196
202
|
var_name_mapping = dict(zip(var_names, party_symbol_names, strict=True))
|
|
197
203
|
|
|
198
|
-
writer =
|
|
204
|
+
writer = IrWriter(var_name_mapping)
|
|
199
205
|
program_proto = writer.dumps(expr)
|
|
200
206
|
|
|
201
207
|
output_symbols = [self.new_name() for _ in range(expr.num_outputs)]
|
|
@@ -257,7 +263,19 @@ class Driver(InterpContext):
|
|
|
257
263
|
try:
|
|
258
264
|
# The results will be in the same order as the clients (ranks)
|
|
259
265
|
results = await asyncio.gather(*tasks)
|
|
260
|
-
|
|
266
|
+
converted: list[Any] = []
|
|
267
|
+
for value in results:
|
|
268
|
+
if isinstance(value, TensorValue):
|
|
269
|
+
arr = value.to_numpy()
|
|
270
|
+
if isinstance(arr, np.ndarray) and arr.size == 1:
|
|
271
|
+
converted.append(arr.item())
|
|
272
|
+
else:
|
|
273
|
+
converted.append(arr)
|
|
274
|
+
elif isinstance(value, TableValue):
|
|
275
|
+
converted.append(value.to_pandas())
|
|
276
|
+
else:
|
|
277
|
+
converted.append(value)
|
|
278
|
+
return converted
|
|
261
279
|
except RuntimeError as e:
|
|
262
280
|
raise RuntimeError(
|
|
263
281
|
f"Failed to fetch symbol from one or more parties: {e}"
|
|
@@ -0,0 +1,196 @@
|
|
|
1
|
+
# Copyright 2025 Ant Group Co., Ltd.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
from __future__ import annotations
|
|
16
|
+
|
|
17
|
+
import logging
|
|
18
|
+
from typing import TYPE_CHECKING
|
|
19
|
+
|
|
20
|
+
import spu.libspu as libspu
|
|
21
|
+
|
|
22
|
+
if TYPE_CHECKING:
|
|
23
|
+
from mplang.v1.core.comm import CommunicatorBase
|
|
24
|
+
from mplang.v1.core.mask import Mask
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
class LinkCommunicator:
|
|
28
|
+
"""Minimal wrapper for libspu link context.
|
|
29
|
+
|
|
30
|
+
Supports three modes:
|
|
31
|
+
1. BRPC: Production mode with separate BRPC ports (legacy)
|
|
32
|
+
2. Mem: In-memory links for testing (legacy)
|
|
33
|
+
3. Channels: Reuse MPLang communicator via IChannel bridge (NEW)
|
|
34
|
+
|
|
35
|
+
The mode is selected based on constructor arguments:
|
|
36
|
+
- If `comm` is provided: Channels mode (NEW)
|
|
37
|
+
- Elif `mem_link` is True: Mem mode
|
|
38
|
+
- Else: BRPC mode
|
|
39
|
+
"""
|
|
40
|
+
|
|
41
|
+
def __init__(
|
|
42
|
+
self,
|
|
43
|
+
rank: int,
|
|
44
|
+
addrs: list[str] | None = None,
|
|
45
|
+
*,
|
|
46
|
+
mem_link: bool = False,
|
|
47
|
+
comm: CommunicatorBase | None = None,
|
|
48
|
+
spu_mask: Mask | None = None,
|
|
49
|
+
):
|
|
50
|
+
"""Initialize link communicator for SPU.
|
|
51
|
+
|
|
52
|
+
Args:
|
|
53
|
+
rank: Global rank of this party
|
|
54
|
+
addrs: List of addresses for all SPU parties (required for BRPC/Mem mode)
|
|
55
|
+
mem_link: If True, use in-memory link (Mem mode)
|
|
56
|
+
comm: MPLang communicator to reuse (Channels mode, NEW)
|
|
57
|
+
spu_mask: SPU parties mask (required for Channels mode)
|
|
58
|
+
|
|
59
|
+
Raises:
|
|
60
|
+
ValueError: If arguments are invalid for the selected mode
|
|
61
|
+
"""
|
|
62
|
+
self._rank = rank
|
|
63
|
+
|
|
64
|
+
# Select initialization mode based on arguments
|
|
65
|
+
if comm is not None:
|
|
66
|
+
self._init_channels_mode(rank, comm, spu_mask)
|
|
67
|
+
elif mem_link:
|
|
68
|
+
self._init_mem_mode(rank, addrs)
|
|
69
|
+
else:
|
|
70
|
+
self._init_brpc_mode(rank, addrs)
|
|
71
|
+
|
|
72
|
+
def _init_channels_mode(
|
|
73
|
+
self, rank: int, comm: CommunicatorBase, spu_mask: Mask | None
|
|
74
|
+
) -> None:
|
|
75
|
+
"""Initialize Channels mode (reuse MPLang communicator).
|
|
76
|
+
|
|
77
|
+
Args:
|
|
78
|
+
rank: Global rank of this party
|
|
79
|
+
comm: MPLang communicator to reuse
|
|
80
|
+
spu_mask: SPU parties mask
|
|
81
|
+
|
|
82
|
+
Raises:
|
|
83
|
+
ValueError: If spu_mask is None or rank not in mask
|
|
84
|
+
"""
|
|
85
|
+
if spu_mask is None:
|
|
86
|
+
raise ValueError("spu_mask required when using comm")
|
|
87
|
+
if rank not in spu_mask:
|
|
88
|
+
raise ValueError(f"rank {rank} not in spu_mask {spu_mask}")
|
|
89
|
+
|
|
90
|
+
# Lazy import to avoid circular dependency
|
|
91
|
+
from mplang.v1.runtime.channel import BaseChannel
|
|
92
|
+
|
|
93
|
+
# Create channels to ALL SPU parties (including self)
|
|
94
|
+
# libspu expects world_size channels, with self channel being None
|
|
95
|
+
channels = []
|
|
96
|
+
rel_rank = spu_mask.global_to_relative_rank(rank)
|
|
97
|
+
|
|
98
|
+
for _, peer_rank in enumerate(spu_mask):
|
|
99
|
+
if peer_rank == rank:
|
|
100
|
+
# For self, use None (won't be accessed by SPU)
|
|
101
|
+
channel = None
|
|
102
|
+
else:
|
|
103
|
+
channel = BaseChannel(comm, rank, peer_rank)
|
|
104
|
+
channels.append(channel)
|
|
105
|
+
|
|
106
|
+
# Create link context with custom channels
|
|
107
|
+
desc = libspu.link.Desc() # type: ignore
|
|
108
|
+
desc.recv_timeout_ms = 100 * 1000 # 100 seconds
|
|
109
|
+
|
|
110
|
+
# Add party info to desc (required for world_size inference)
|
|
111
|
+
for idx, peer_rank in enumerate(spu_mask):
|
|
112
|
+
desc.add_party(f"P{idx}", f"dummy_{peer_rank}")
|
|
113
|
+
|
|
114
|
+
self.lctx = libspu.link.create_with_channels(desc, rel_rank, channels)
|
|
115
|
+
self._world_size = spu_mask.num_parties()
|
|
116
|
+
|
|
117
|
+
logging.info(
|
|
118
|
+
f"LinkCommunicator initialized with BaseChannel: "
|
|
119
|
+
f"rank={rank}, rel_rank={rel_rank}, spu_mask={spu_mask}, "
|
|
120
|
+
f"world_size={self._world_size}"
|
|
121
|
+
)
|
|
122
|
+
|
|
123
|
+
def _init_mem_mode(self, rank: int, addrs: list[str] | None) -> None:
|
|
124
|
+
"""Initialize Mem mode (in-memory links for testing).
|
|
125
|
+
|
|
126
|
+
Args:
|
|
127
|
+
rank: Global rank of this party
|
|
128
|
+
addrs: List of addresses for all SPU parties
|
|
129
|
+
|
|
130
|
+
Raises:
|
|
131
|
+
ValueError: If addrs is None
|
|
132
|
+
"""
|
|
133
|
+
if addrs is None:
|
|
134
|
+
raise ValueError("addrs required for Mem mode")
|
|
135
|
+
|
|
136
|
+
self._world_size = len(addrs)
|
|
137
|
+
|
|
138
|
+
desc = libspu.link.Desc() # type: ignore
|
|
139
|
+
desc.recv_timeout_ms = 100 * 1000 # 100 seconds
|
|
140
|
+
desc.http_max_payload_size = 32 * 1024 * 1024 # 32M
|
|
141
|
+
for rank_idx, addr in enumerate(addrs):
|
|
142
|
+
desc.add_party(f"P{rank_idx}", addr)
|
|
143
|
+
|
|
144
|
+
self.lctx = libspu.link.create_mem(desc, self._rank)
|
|
145
|
+
logging.info(
|
|
146
|
+
f"LinkCommunicator initialized with Mem: "
|
|
147
|
+
f"rank={self._rank}, world_size={self._world_size}, addrs={addrs}"
|
|
148
|
+
)
|
|
149
|
+
|
|
150
|
+
def _init_brpc_mode(self, rank: int, addrs: list[str] | None) -> None:
|
|
151
|
+
"""Initialize BRPC mode (production mode with separate BRPC ports).
|
|
152
|
+
|
|
153
|
+
Args:
|
|
154
|
+
rank: Global rank of this party
|
|
155
|
+
addrs: List of addresses for all SPU parties
|
|
156
|
+
|
|
157
|
+
Raises:
|
|
158
|
+
ValueError: If addrs is None
|
|
159
|
+
"""
|
|
160
|
+
if addrs is None:
|
|
161
|
+
raise ValueError("addrs required for BRPC mode")
|
|
162
|
+
|
|
163
|
+
self._world_size = len(addrs)
|
|
164
|
+
|
|
165
|
+
desc = libspu.link.Desc() # type: ignore
|
|
166
|
+
desc.recv_timeout_ms = 100 * 1000 # 100 seconds
|
|
167
|
+
desc.http_max_payload_size = 32 * 1024 * 1024 # 32M
|
|
168
|
+
for rank_idx, addr in enumerate(addrs):
|
|
169
|
+
desc.add_party(f"P{rank_idx}", addr)
|
|
170
|
+
|
|
171
|
+
self.lctx = libspu.link.create_brpc(desc, self._rank)
|
|
172
|
+
logging.info(
|
|
173
|
+
f"LinkCommunicator initialized with BRPC: "
|
|
174
|
+
f"rank={self._rank}, world_size={self._world_size}, addrs={addrs}"
|
|
175
|
+
)
|
|
176
|
+
|
|
177
|
+
@property
|
|
178
|
+
def rank(self) -> int:
|
|
179
|
+
"""Get rank from underlying link context."""
|
|
180
|
+
return self.lctx.rank # type: ignore[no-any-return]
|
|
181
|
+
|
|
182
|
+
@property
|
|
183
|
+
def world_size(self) -> int:
|
|
184
|
+
"""Get world size from underlying link context."""
|
|
185
|
+
return self.lctx.world_size # type: ignore[no-any-return]
|
|
186
|
+
|
|
187
|
+
def get_lctx(self) -> libspu.link.Context:
|
|
188
|
+
"""Get the underlying libspu link context.
|
|
189
|
+
|
|
190
|
+
This is the primary interface - SPU runtime uses this context directly.
|
|
191
|
+
All communication and serialization is handled by libspu internally.
|
|
192
|
+
|
|
193
|
+
Returns:
|
|
194
|
+
The underlying libspu.link.Context instance.
|
|
195
|
+
"""
|
|
196
|
+
return self.lctx
|