mplang-nightly 0.1.dev268__py3-none-any.whl → 0.1.dev270__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mplang/__init__.py +391 -17
- mplang/{v2/backends → backends}/__init__.py +9 -7
- mplang/{v2/backends → backends}/bfv_impl.py +6 -6
- mplang/{v2/backends → backends}/crypto_impl.py +6 -6
- mplang/{v2/backends → backends}/field_impl.py +5 -5
- mplang/{v2/backends → backends}/func_impl.py +4 -4
- mplang/{v2/backends → backends}/phe_impl.py +3 -3
- mplang/{v2/backends → backends}/simp_design.md +1 -1
- mplang/{v2/backends → backends}/simp_driver/__init__.py +5 -5
- mplang/{v2/backends → backends}/simp_driver/http.py +8 -8
- mplang/{v2/backends → backends}/simp_driver/mem.py +9 -9
- mplang/{v2/backends → backends}/simp_driver/ops.py +4 -4
- mplang/{v2/backends → backends}/simp_driver/state.py +2 -2
- mplang/{v2/backends → backends}/simp_driver/values.py +2 -2
- mplang/{v2/backends → backends}/simp_worker/__init__.py +3 -3
- mplang/{v2/backends → backends}/simp_worker/http.py +10 -10
- mplang/{v2/backends → backends}/simp_worker/mem.py +1 -1
- mplang/{v2/backends → backends}/simp_worker/ops.py +5 -5
- mplang/{v2/backends → backends}/simp_worker/state.py +2 -4
- mplang/{v2/backends → backends}/spu_impl.py +8 -8
- mplang/{v2/backends → backends}/spu_state.py +4 -4
- mplang/{v2/backends → backends}/store_impl.py +3 -3
- mplang/{v2/backends → backends}/table_impl.py +8 -8
- mplang/{v2/backends → backends}/tee_impl.py +6 -6
- mplang/{v2/backends → backends}/tensor_impl.py +6 -6
- mplang/{v2/cli.py → cli.py} +9 -9
- mplang/{v2/cli_guide.md → cli_guide.md} +12 -12
- mplang/{v2/dialects → dialects}/__init__.py +5 -5
- mplang/{v2/dialects → dialects}/bfv.py +6 -6
- mplang/{v2/dialects → dialects}/crypto.py +5 -5
- mplang/{v2/dialects → dialects}/dtypes.py +2 -2
- mplang/{v2/dialects → dialects}/field.py +3 -3
- mplang/{v2/dialects → dialects}/func.py +2 -2
- mplang/{v2/dialects → dialects}/phe.py +6 -6
- mplang/{v2/dialects → dialects}/simp.py +6 -6
- mplang/{v2/dialects → dialects}/spu.py +7 -7
- mplang/{v2/dialects → dialects}/store.py +2 -2
- mplang/{v2/dialects → dialects}/table.py +3 -3
- mplang/{v2/dialects → dialects}/tee.py +6 -6
- mplang/{v2/dialects → dialects}/tensor.py +5 -5
- mplang/{v2/edsl → edsl}/__init__.py +3 -3
- mplang/{v2/edsl → edsl}/context.py +6 -6
- mplang/{v2/edsl → edsl}/graph.py +5 -5
- mplang/{v2/edsl → edsl}/jit.py +2 -2
- mplang/{v2/edsl → edsl}/object.py +1 -1
- mplang/{v2/edsl → edsl}/primitive.py +5 -5
- mplang/{v2/edsl → edsl}/printer.py +1 -1
- mplang/{v2/edsl → edsl}/serde.py +1 -1
- mplang/{v2/edsl → edsl}/tracer.py +7 -7
- mplang/{v2/edsl → edsl}/typing.py +1 -1
- mplang/{v2/kernels → kernels}/ldpc.cpp +13 -13
- mplang/{v2/kernels → kernels}/okvs.cpp +4 -4
- mplang/{v2/kernels → kernels}/okvs_opt.cpp +46 -31
- mplang/{v2/kernels → kernels}/py_kernels.py +1 -1
- mplang/{v2/libs → libs}/collective.py +5 -5
- mplang/{v2/libs → libs}/device/__init__.py +1 -1
- mplang/{v2/libs → libs}/device/api.py +12 -12
- mplang/{v2/libs → libs}/ml/__init__.py +1 -1
- mplang/{v2/libs → libs}/ml/sgb.py +4 -4
- mplang/{v2/libs → libs}/mpc/__init__.py +3 -3
- mplang/{v2/libs → libs}/mpc/_utils.py +2 -2
- mplang/{v2/libs → libs}/mpc/analytics/aggregation.py +1 -1
- mplang/{v2/libs → libs}/mpc/analytics/groupby.py +2 -2
- mplang/{v2/libs → libs}/mpc/analytics/permutation.py +3 -3
- mplang/{v2/libs → libs}/mpc/ot/base.py +3 -3
- mplang/{v2/libs → libs}/mpc/ot/extension.py +2 -2
- mplang/{v2/libs → libs}/mpc/ot/silent.py +4 -4
- mplang/{v2/libs → libs}/mpc/psi/cuckoo.py +3 -3
- mplang/{v2/libs → libs}/mpc/psi/okvs.py +1 -1
- mplang/{v2/libs → libs}/mpc/psi/okvs_gct.py +19 -13
- mplang/{v2/libs → libs}/mpc/psi/oprf.py +3 -3
- mplang/libs/mpc/psi/rr22.py +303 -0
- mplang/{v2/libs → libs}/mpc/psi/unbalanced.py +4 -4
- mplang/{v2/libs → libs}/mpc/vole/gilboa.py +3 -3
- mplang/{v2/libs → libs}/mpc/vole/ldpc.py +2 -2
- mplang/{v2/libs → libs}/mpc/vole/silver.py +6 -6
- mplang/{v2/runtime → runtime}/interpreter.py +11 -11
- mplang/{v2/runtime → runtime}/value.py +2 -2
- mplang/{v1/runtime → utils}/__init__.py +18 -15
- mplang/{v1/utils → utils}/func_utils.py +1 -1
- {mplang_nightly-0.1.dev268.dist-info → mplang_nightly-0.1.dev270.dist-info}/METADATA +2 -2
- mplang_nightly-0.1.dev270.dist-info/RECORD +102 -0
- mplang/v1/__init__.py +0 -157
- mplang/v1/_device.py +0 -602
- mplang/v1/analysis/__init__.py +0 -37
- mplang/v1/analysis/diagram.py +0 -567
- mplang/v1/core/__init__.py +0 -157
- mplang/v1/core/cluster.py +0 -343
- mplang/v1/core/comm.py +0 -281
- mplang/v1/core/context_mgr.py +0 -50
- mplang/v1/core/dtypes.py +0 -335
- mplang/v1/core/expr/__init__.py +0 -80
- mplang/v1/core/expr/ast.py +0 -542
- mplang/v1/core/expr/evaluator.py +0 -581
- mplang/v1/core/expr/printer.py +0 -285
- mplang/v1/core/expr/transformer.py +0 -141
- mplang/v1/core/expr/utils.py +0 -78
- mplang/v1/core/expr/visitor.py +0 -85
- mplang/v1/core/expr/walk.py +0 -387
- mplang/v1/core/interp.py +0 -160
- mplang/v1/core/mask.py +0 -325
- mplang/v1/core/mpir.py +0 -965
- mplang/v1/core/mpobject.py +0 -117
- mplang/v1/core/mptype.py +0 -407
- mplang/v1/core/pfunc.py +0 -130
- mplang/v1/core/primitive.py +0 -877
- mplang/v1/core/table.py +0 -218
- mplang/v1/core/tensor.py +0 -75
- mplang/v1/core/tracer.py +0 -383
- mplang/v1/host.py +0 -130
- mplang/v1/kernels/__init__.py +0 -41
- mplang/v1/kernels/base.py +0 -125
- mplang/v1/kernels/basic.py +0 -240
- mplang/v1/kernels/context.py +0 -369
- mplang/v1/kernels/crypto.py +0 -122
- mplang/v1/kernels/fhe.py +0 -858
- mplang/v1/kernels/mock_tee.py +0 -72
- mplang/v1/kernels/phe.py +0 -1864
- mplang/v1/kernels/spu.py +0 -341
- mplang/v1/kernels/sql_duckdb.py +0 -44
- mplang/v1/kernels/stablehlo.py +0 -90
- mplang/v1/kernels/value.py +0 -626
- mplang/v1/ops/__init__.py +0 -35
- mplang/v1/ops/base.py +0 -424
- mplang/v1/ops/basic.py +0 -294
- mplang/v1/ops/crypto.py +0 -262
- mplang/v1/ops/fhe.py +0 -272
- mplang/v1/ops/jax_cc.py +0 -147
- mplang/v1/ops/nnx_cc.py +0 -168
- mplang/v1/ops/phe.py +0 -216
- mplang/v1/ops/spu.py +0 -151
- mplang/v1/ops/sql_cc.py +0 -303
- mplang/v1/ops/tee.py +0 -36
- mplang/v1/protos/v1alpha1/mpir_pb2.py +0 -63
- mplang/v1/protos/v1alpha1/mpir_pb2.pyi +0 -557
- mplang/v1/protos/v1alpha1/value_pb2.py +0 -34
- mplang/v1/protos/v1alpha1/value_pb2.pyi +0 -169
- mplang/v1/runtime/channel.py +0 -230
- mplang/v1/runtime/cli.py +0 -451
- mplang/v1/runtime/client.py +0 -456
- mplang/v1/runtime/communicator.py +0 -131
- mplang/v1/runtime/data_providers.py +0 -303
- mplang/v1/runtime/driver.py +0 -324
- mplang/v1/runtime/exceptions.py +0 -27
- mplang/v1/runtime/http_api.md +0 -56
- mplang/v1/runtime/link_comm.py +0 -196
- mplang/v1/runtime/server.py +0 -501
- mplang/v1/runtime/session.py +0 -270
- mplang/v1/runtime/simulation.py +0 -324
- mplang/v1/simp/__init__.py +0 -13
- mplang/v1/simp/api.py +0 -353
- mplang/v1/simp/mpi.py +0 -131
- mplang/v1/simp/party.py +0 -225
- mplang/v1/simp/random.py +0 -120
- mplang/v1/simp/smpc.py +0 -238
- mplang/v1/utils/__init__.py +0 -13
- mplang/v1/utils/crypto.py +0 -32
- mplang/v1/utils/spu_utils.py +0 -130
- mplang/v1/utils/table_utils.py +0 -185
- mplang/v2/__init__.py +0 -424
- mplang/v2/libs/mpc/psi/rr22.py +0 -344
- mplang_nightly-0.1.dev268.dist-info/RECORD +0 -180
- /mplang/{v2/backends → backends}/channel.py +0 -0
- /mplang/{v2/edsl → edsl}/README.md +0 -0
- /mplang/{v2/edsl → edsl}/registry.py +0 -0
- /mplang/{v2/kernels → kernels}/Makefile +0 -0
- /mplang/{v2/kernels → kernels}/__init__.py +0 -0
- /mplang/{v2/kernels → kernels}/gf128.cpp +0 -0
- /mplang/{v2/libs → libs}/device/cluster.py +0 -0
- /mplang/{v2/libs → libs}/mpc/analytics/__init__.py +0 -0
- /mplang/{v2/libs → libs}/mpc/analytics/groupby.md +0 -0
- /mplang/{v2/libs → libs}/mpc/common/constants.py +0 -0
- /mplang/{v2/libs → libs}/mpc/ot/__init__.py +0 -0
- /mplang/{v2/libs → libs}/mpc/psi/__init__.py +0 -0
- /mplang/{v2/libs → libs}/mpc/vole/__init__.py +0 -0
- /mplang/{v2/runtime → runtime}/__init__.py +0 -0
- /mplang/{v2/runtime → runtime}/dialect_state.py +0 -0
- /mplang/{v2/runtime → runtime}/object_store.py +0 -0
- {mplang_nightly-0.1.dev268.dist-info → mplang_nightly-0.1.dev270.dist-info}/WHEEL +0 -0
- {mplang_nightly-0.1.dev268.dist-info → mplang_nightly-0.1.dev270.dist-info}/entry_points.txt +0 -0
- {mplang_nightly-0.1.dev268.dist-info → mplang_nightly-0.1.dev270.dist-info}/licenses/LICENSE +0 -0
|
@@ -1,303 +0,0 @@
|
|
|
1
|
-
# Copyright 2025 Ant Group Co., Ltd.
|
|
2
|
-
#
|
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
-
# you may not use this file except in compliance with the License.
|
|
5
|
-
# You may obtain a copy of the License at
|
|
6
|
-
#
|
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
-
#
|
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
-
# See the License for the specific language governing permissions and
|
|
13
|
-
# limitations under the License.
|
|
14
|
-
|
|
15
|
-
from __future__ import annotations
|
|
16
|
-
|
|
17
|
-
import pathlib
|
|
18
|
-
import struct
|
|
19
|
-
from dataclasses import dataclass
|
|
20
|
-
from typing import Any
|
|
21
|
-
from urllib.parse import ParseResult, urlparse
|
|
22
|
-
|
|
23
|
-
import numpy as np
|
|
24
|
-
|
|
25
|
-
from mplang.v1.core import TableLike, TableType, TensorType
|
|
26
|
-
from mplang.v1.kernels.base import KernelContext
|
|
27
|
-
from mplang.v1.kernels.value import (
|
|
28
|
-
TableValue,
|
|
29
|
-
TensorValue,
|
|
30
|
-
Value,
|
|
31
|
-
decode_value,
|
|
32
|
-
encode_value,
|
|
33
|
-
)
|
|
34
|
-
from mplang.v1.utils import table_utils
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
@dataclass(frozen=True)
|
|
38
|
-
class ResolvedURI:
|
|
39
|
-
"""Result of resolving a resource path into a normalized form.
|
|
40
|
-
|
|
41
|
-
Attributes:
|
|
42
|
-
scheme: The URI scheme (e.g., 'file', 's3', 'mem', 'var', 'secret').
|
|
43
|
-
raw: The original path string as provided by the user.
|
|
44
|
-
parsed: The ParseResult if a scheme was present; otherwise None.
|
|
45
|
-
local_path: For file paths: concrete filesystem path (absolute or as given).
|
|
46
|
-
"""
|
|
47
|
-
|
|
48
|
-
scheme: str
|
|
49
|
-
raw: str
|
|
50
|
-
parsed: ParseResult | None
|
|
51
|
-
local_path: str | None
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
def resolve_uri(path: str) -> ResolvedURI:
|
|
55
|
-
"""Resolve a user-provided resource location into a normalized URI form.
|
|
56
|
-
|
|
57
|
-
This helper accepts plain filesystem paths and RFC 3986 style URIs. A path
|
|
58
|
-
is treated as ``file`` when ``urlparse(path).scheme`` is empty. Detection
|
|
59
|
-
no longer depends on the presence of the literal substring ``"://"`` so
|
|
60
|
-
that forms like ``mem:foo`` (no slashes) are still recognized as a URI.
|
|
61
|
-
|
|
62
|
-
Captured fields
|
|
63
|
-
- ``scheme``: Lower-cased scheme (``file`` when absent)
|
|
64
|
-
- ``raw``: Original input
|
|
65
|
-
- ``parsed``: ``ParseResult`` when a scheme was provided, else ``None``
|
|
66
|
-
- ``local_path``: Filesystem path for ``file`` scheme, else ``None``
|
|
67
|
-
|
|
68
|
-
Supported (pluggable) schemes out-of-the-box:
|
|
69
|
-
* ``file`` (default)
|
|
70
|
-
* ``mem``
|
|
71
|
-
* ``s3`` (stub)
|
|
72
|
-
* ``secret`` (stub)
|
|
73
|
-
* ``symbols`` (registered server-side)
|
|
74
|
-
|
|
75
|
-
Examples
|
|
76
|
-
>>> resolve_uri("data/train.npy").scheme
|
|
77
|
-
'file'
|
|
78
|
-
>>> resolve_uri("mem:dataset1").scheme
|
|
79
|
-
'mem'
|
|
80
|
-
>>> resolve_uri("mem://dataset1").scheme # both forms acceptable
|
|
81
|
-
'mem'
|
|
82
|
-
>>> resolve_uri("symbols://shared_model").scheme
|
|
83
|
-
'symbols'
|
|
84
|
-
>>> resolve_uri("file:///tmp/x.npy").local_path
|
|
85
|
-
'/tmp/x.npy'
|
|
86
|
-
"""
|
|
87
|
-
|
|
88
|
-
pr = urlparse(path)
|
|
89
|
-
if not pr.scheme:
|
|
90
|
-
return ResolvedURI("file", path, None, path)
|
|
91
|
-
|
|
92
|
-
scheme = pr.scheme.lower()
|
|
93
|
-
local_path: str | None = None
|
|
94
|
-
if scheme == "file":
|
|
95
|
-
local_path = pr.path
|
|
96
|
-
if pr.netloc and not local_path.startswith("/"):
|
|
97
|
-
local_path = f"//{pr.netloc}/{pr.path}"
|
|
98
|
-
return ResolvedURI(scheme, path, pr, local_path)
|
|
99
|
-
|
|
100
|
-
|
|
101
|
-
class DataProvider:
|
|
102
|
-
"""Abstract base for data providers.
|
|
103
|
-
|
|
104
|
-
Minimal contract: read/write by URI and type spec. Providers may ignore the
|
|
105
|
-
type spec but SHOULD validate when feasible.
|
|
106
|
-
"""
|
|
107
|
-
|
|
108
|
-
def read(
|
|
109
|
-
self, uri: ResolvedURI, out_spec: TensorType | TableType, *, ctx: KernelContext
|
|
110
|
-
) -> Any:
|
|
111
|
-
raise NotImplementedError
|
|
112
|
-
|
|
113
|
-
def write(self, uri: ResolvedURI, value: Any, *, ctx: KernelContext) -> None:
|
|
114
|
-
raise NotImplementedError
|
|
115
|
-
|
|
116
|
-
|
|
117
|
-
_REGISTRY: dict[str, DataProvider] = {}
|
|
118
|
-
|
|
119
|
-
|
|
120
|
-
def register_provider(
|
|
121
|
-
scheme: str, provider: DataProvider, *, replace: bool = False, quiet: bool = False
|
|
122
|
-
) -> None:
|
|
123
|
-
"""Register a provider implementation.
|
|
124
|
-
|
|
125
|
-
Args:
|
|
126
|
-
scheme: URI scheme handled (case-insensitive)
|
|
127
|
-
provider: Implementation
|
|
128
|
-
replace: If False and scheme exists -> ValueError
|
|
129
|
-
quiet: If True, suppress duplicate log messages when replacing
|
|
130
|
-
"""
|
|
131
|
-
import logging
|
|
132
|
-
|
|
133
|
-
key = scheme.lower()
|
|
134
|
-
if not replace and key in _REGISTRY:
|
|
135
|
-
raise ValueError(f"provider already registered for scheme: {scheme}")
|
|
136
|
-
if replace and key in _REGISTRY and not quiet:
|
|
137
|
-
logging.info(f"Replacing existing provider for scheme '{scheme}'")
|
|
138
|
-
_REGISTRY[key] = provider
|
|
139
|
-
|
|
140
|
-
|
|
141
|
-
def get_provider(scheme: str) -> DataProvider | None:
|
|
142
|
-
return _REGISTRY.get(scheme.lower())
|
|
143
|
-
|
|
144
|
-
|
|
145
|
-
# ---------------- Default Providers ----------------
|
|
146
|
-
MAGIC_MPLANG = b"MPLG"
|
|
147
|
-
MAGIC_PARQUET = b"PAR1"
|
|
148
|
-
MAGIC_ORC = b"ORC"
|
|
149
|
-
MAGIC_NUMPY = b"\x93NUMPY"
|
|
150
|
-
VERSION = 0x01
|
|
151
|
-
|
|
152
|
-
|
|
153
|
-
class FileProvider(DataProvider):
|
|
154
|
-
"""Local filesystem provider.
|
|
155
|
-
|
|
156
|
-
For tables: CSV bytes via table_utils.
|
|
157
|
-
For tensors: NumPy .npy via np.load/np.save.
|
|
158
|
-
"""
|
|
159
|
-
|
|
160
|
-
def read(
|
|
161
|
-
self, uri: ResolvedURI, out_spec: TensorType | TableType, *, ctx: KernelContext
|
|
162
|
-
) -> Any:
|
|
163
|
-
path = pathlib.Path(uri.local_path or uri.raw)
|
|
164
|
-
# try load by magic
|
|
165
|
-
with path.open("rb") as f:
|
|
166
|
-
# this is the maximum length needed to detect all supported formats
|
|
167
|
-
# (numpy requires 6 bytes: '\x93NUMPY').
|
|
168
|
-
MAGIC_BYTES_LEN_MAX = 6
|
|
169
|
-
magic = f.read(MAGIC_BYTES_LEN_MAX)
|
|
170
|
-
f.seek(0)
|
|
171
|
-
if magic.startswith(MAGIC_MPLANG):
|
|
172
|
-
MPLANG_HEADER_LEN = len(MAGIC_MPLANG) + 1
|
|
173
|
-
header = f.read(MPLANG_HEADER_LEN)
|
|
174
|
-
_, version = struct.unpack(">4sB", header)
|
|
175
|
-
if version != VERSION:
|
|
176
|
-
raise ValueError(f"unsupported mplang version {version}")
|
|
177
|
-
payload = f.read()
|
|
178
|
-
return decode_value(payload)
|
|
179
|
-
elif magic.startswith(MAGIC_PARQUET):
|
|
180
|
-
if not isinstance(out_spec, TableType):
|
|
181
|
-
raise ValueError(
|
|
182
|
-
f"PARQUET files require TableType output spec, got {type(out_spec).__name__}"
|
|
183
|
-
)
|
|
184
|
-
return table_utils.read_table(
|
|
185
|
-
f, format="parquet", columns=list(out_spec.column_names())
|
|
186
|
-
)
|
|
187
|
-
elif magic.startswith(MAGIC_ORC):
|
|
188
|
-
if not isinstance(out_spec, TableType):
|
|
189
|
-
raise ValueError(
|
|
190
|
-
f"ORC files require TableType output spec, got {type(out_spec).__name__}"
|
|
191
|
-
)
|
|
192
|
-
return table_utils.read_table(
|
|
193
|
-
f, format="orc", columns=list(out_spec.column_names())
|
|
194
|
-
)
|
|
195
|
-
elif magic.startswith(MAGIC_NUMPY):
|
|
196
|
-
if not isinstance(out_spec, TensorType):
|
|
197
|
-
raise ValueError(
|
|
198
|
-
f"NumPy files require TensorType output spec, got {type(out_spec).__name__}"
|
|
199
|
-
)
|
|
200
|
-
return np.load(f)
|
|
201
|
-
|
|
202
|
-
# Fallback: open the file for CSV or NumPy loading.
|
|
203
|
-
if isinstance(out_spec, TableType):
|
|
204
|
-
return table_utils.read_table(
|
|
205
|
-
f, format="csv", columns=list(out_spec.column_names())
|
|
206
|
-
)
|
|
207
|
-
else:
|
|
208
|
-
return np.load(f)
|
|
209
|
-
|
|
210
|
-
def write(self, uri: ResolvedURI, value: Any, *, ctx: KernelContext) -> None:
|
|
211
|
-
import os
|
|
212
|
-
|
|
213
|
-
path = uri.local_path or uri.raw
|
|
214
|
-
dir_name = os.path.dirname(path)
|
|
215
|
-
if dir_name:
|
|
216
|
-
os.makedirs(dir_name, exist_ok=True)
|
|
217
|
-
|
|
218
|
-
if not isinstance(value, Value):
|
|
219
|
-
value = (
|
|
220
|
-
TableValue(value)
|
|
221
|
-
if isinstance(value, TableLike)
|
|
222
|
-
else TensorValue(value)
|
|
223
|
-
)
|
|
224
|
-
|
|
225
|
-
if isinstance(value, TableValue):
|
|
226
|
-
table_utils.write_table(value.to_arrow(), path, format="parquet")
|
|
227
|
-
elif isinstance(value, TensorValue):
|
|
228
|
-
with open(path, "wb") as f:
|
|
229
|
-
np.save(f, value.to_numpy())
|
|
230
|
-
else:
|
|
231
|
-
payload = encode_value(value)
|
|
232
|
-
with open(path, "wb") as f:
|
|
233
|
-
f.write(struct.pack(">4sB", MAGIC_MPLANG, VERSION))
|
|
234
|
-
f.write(payload)
|
|
235
|
-
|
|
236
|
-
|
|
237
|
-
class MemProvider(DataProvider):
|
|
238
|
-
"""In-memory per-runtime KV provider (per rank, per session/runtime)."""
|
|
239
|
-
|
|
240
|
-
STATE_KEY = "resource.providers.mem"
|
|
241
|
-
|
|
242
|
-
@staticmethod
|
|
243
|
-
def _store(ctx: KernelContext) -> dict[str, Any]:
|
|
244
|
-
# Use ensure_state so creation is atomic & centralized; enforce dict.
|
|
245
|
-
store = ctx.runtime.ensure_state(MemProvider.STATE_KEY, dict)
|
|
246
|
-
if not isinstance(store, dict): # pragma: no cover - defensive
|
|
247
|
-
raise TypeError(
|
|
248
|
-
f"runtime state key '{MemProvider.STATE_KEY}' expected dict, got {type(store).__name__}"
|
|
249
|
-
)
|
|
250
|
-
return store # type: ignore[return-value]
|
|
251
|
-
|
|
252
|
-
def read(
|
|
253
|
-
self, uri: ResolvedURI, out_spec: TensorType | TableType, *, ctx: KernelContext
|
|
254
|
-
) -> Any:
|
|
255
|
-
store = self._store(ctx)
|
|
256
|
-
key = uri.raw
|
|
257
|
-
if key not in store:
|
|
258
|
-
raise FileNotFoundError(f"mem resource not found: {key}")
|
|
259
|
-
return store[key]
|
|
260
|
-
|
|
261
|
-
def write(self, uri: ResolvedURI, value: Any, *, ctx: KernelContext) -> None:
|
|
262
|
-
store = self._store(ctx)
|
|
263
|
-
store[uri.raw] = value
|
|
264
|
-
|
|
265
|
-
|
|
266
|
-
class S3Provider(DataProvider):
|
|
267
|
-
"""Placeholder S3 provider. Install external plugin to enable."""
|
|
268
|
-
|
|
269
|
-
def read(
|
|
270
|
-
self, uri: ResolvedURI, out_spec: TensorType | TableType, *, ctx: KernelContext
|
|
271
|
-
) -> Any:
|
|
272
|
-
raise NotImplementedError(
|
|
273
|
-
"S3 provider not installed. Provide an external plugin via register_provider('s3', ...) ."
|
|
274
|
-
)
|
|
275
|
-
|
|
276
|
-
def write(self, uri: ResolvedURI, value: Any, *, ctx: KernelContext) -> None:
|
|
277
|
-
raise NotImplementedError(
|
|
278
|
-
"S3 provider not installed. Provide an external plugin via register_provider('s3', ...) ."
|
|
279
|
-
)
|
|
280
|
-
|
|
281
|
-
|
|
282
|
-
class SecretProvider(DataProvider):
|
|
283
|
-
"""Placeholder secret provider. Integrate with KMS/secret manager via plugin."""
|
|
284
|
-
|
|
285
|
-
def read(
|
|
286
|
-
self, uri: ResolvedURI, out_spec: TensorType | TableType, *, ctx: KernelContext
|
|
287
|
-
) -> Any:
|
|
288
|
-
raise NotImplementedError(
|
|
289
|
-
"secret provider not installed. Provide an external plugin via register_provider('secret', ...) ."
|
|
290
|
-
)
|
|
291
|
-
|
|
292
|
-
def write(self, uri: ResolvedURI, value: Any, *, ctx: KernelContext) -> None:
|
|
293
|
-
raise NotImplementedError(
|
|
294
|
-
"secret provider not installed. Provide an external plugin via register_provider('secret', ...) ."
|
|
295
|
-
)
|
|
296
|
-
|
|
297
|
-
|
|
298
|
-
# Register default providers
|
|
299
|
-
register_provider("file", FileProvider())
|
|
300
|
-
register_provider("mem", MemProvider())
|
|
301
|
-
# Stubs to signal missing providers explicitly (can be overridden by plugins)
|
|
302
|
-
register_provider("s3", S3Provider())
|
|
303
|
-
register_provider("secret", SecretProvider())
|
mplang/v1/runtime/driver.py
DELETED
|
@@ -1,324 +0,0 @@
|
|
|
1
|
-
# Copyright 2025 Ant Group Co., Ltd.
|
|
2
|
-
#
|
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
-
# you may not use this file except in compliance with the License.
|
|
5
|
-
# You may obtain a copy of the License at
|
|
6
|
-
#
|
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
-
#
|
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
-
# See the License for the specific language governing permissions and
|
|
13
|
-
# limitations under the License.
|
|
14
|
-
|
|
15
|
-
"""
|
|
16
|
-
HTTP-based driver implementation for distributed execution.
|
|
17
|
-
|
|
18
|
-
This module provides an HTTP-based driver, using REST APIs
|
|
19
|
-
for distributed multi-party computation coordination.
|
|
20
|
-
"""
|
|
21
|
-
|
|
22
|
-
from __future__ import annotations
|
|
23
|
-
|
|
24
|
-
import asyncio
|
|
25
|
-
import base64
|
|
26
|
-
import uuid
|
|
27
|
-
from collections.abc import Sequence
|
|
28
|
-
from typing import Any
|
|
29
|
-
|
|
30
|
-
import numpy as np
|
|
31
|
-
|
|
32
|
-
from mplang.v1.core import (
|
|
33
|
-
ClusterSpec,
|
|
34
|
-
InterpContext,
|
|
35
|
-
InterpVar,
|
|
36
|
-
IrWriter,
|
|
37
|
-
Mask,
|
|
38
|
-
MPObject,
|
|
39
|
-
MPType,
|
|
40
|
-
)
|
|
41
|
-
from mplang.v1.core.expr.ast import Expr
|
|
42
|
-
from mplang.v1.kernels.value import TableValue, TensorValue
|
|
43
|
-
from mplang.v1.runtime.client import HttpExecutorClient
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
def new_uuid() -> str:
|
|
47
|
-
"""Generates a short UUID using URL-safe Base64 encoding."""
|
|
48
|
-
u = uuid.uuid4()
|
|
49
|
-
# Get the 16 bytes of the UUID
|
|
50
|
-
uuid_bytes = u.bytes
|
|
51
|
-
# Encode using URL-safe Base64
|
|
52
|
-
encoded_bytes = base64.urlsafe_b64encode(uuid_bytes)
|
|
53
|
-
# Decode to UTF-8 string, remove padding, and take first 8 characters
|
|
54
|
-
encoded_string = encoded_bytes.decode("utf-8").rstrip("=")[:8]
|
|
55
|
-
return encoded_string
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
class DriverVar(InterpVar):
|
|
59
|
-
"""A variable that references a value in distributed HTTP executor nodes.
|
|
60
|
-
|
|
61
|
-
This represents a symbol stored on remote HTTP servers that can be
|
|
62
|
-
retrieved via REST API calls.
|
|
63
|
-
"""
|
|
64
|
-
|
|
65
|
-
def __init__(
|
|
66
|
-
self,
|
|
67
|
-
ctx: Driver,
|
|
68
|
-
symbol_name: str,
|
|
69
|
-
mptype: MPType,
|
|
70
|
-
) -> None:
|
|
71
|
-
super().__init__(ctx, mptype)
|
|
72
|
-
self.symbol_name = symbol_name
|
|
73
|
-
|
|
74
|
-
@property
|
|
75
|
-
def mptype(self) -> MPType:
|
|
76
|
-
"""The type of this variable."""
|
|
77
|
-
return self._mptype
|
|
78
|
-
|
|
79
|
-
def __repr__(self) -> str:
|
|
80
|
-
return f"HttpDriverVar(symbol_name={self.symbol_name}, mptype={self.mptype})"
|
|
81
|
-
|
|
82
|
-
|
|
83
|
-
class Driver(InterpContext):
|
|
84
|
-
"""Driver for distributed execution using HTTP-based services.
|
|
85
|
-
|
|
86
|
-
Args:
|
|
87
|
-
cluster_spec: The cluster specification defining the distributed environment.
|
|
88
|
-
trace_ranks: List of ranks to trace execution for debugging.
|
|
89
|
-
timeout: HTTP request timeout in seconds.
|
|
90
|
-
"""
|
|
91
|
-
|
|
92
|
-
def __init__(
|
|
93
|
-
self,
|
|
94
|
-
cluster_spec: ClusterSpec,
|
|
95
|
-
*,
|
|
96
|
-
trace_ranks: list[int] | None = None,
|
|
97
|
-
timeout: int = 120,
|
|
98
|
-
) -> None:
|
|
99
|
-
"""Initialize a driver with the given cluster specification.
|
|
100
|
-
|
|
101
|
-
Args:
|
|
102
|
-
cluster_spec: The cluster specification defining the distributed environment.
|
|
103
|
-
trace_ranks: List of ranks to trace execution for debugging.
|
|
104
|
-
timeout: HTTP request timeout in seconds.
|
|
105
|
-
"""
|
|
106
|
-
super().__init__(cluster_spec)
|
|
107
|
-
self._trace_ranks = trace_ranks or []
|
|
108
|
-
self.timeout = timeout
|
|
109
|
-
|
|
110
|
-
self._session_id: str | None = None
|
|
111
|
-
self._counter = 0
|
|
112
|
-
|
|
113
|
-
self.node_addrs = {
|
|
114
|
-
node_id: node.endpoint for node_id, node in cluster_spec.nodes.items()
|
|
115
|
-
}
|
|
116
|
-
|
|
117
|
-
# Get SPU configuration from cluster_spec
|
|
118
|
-
spu_devices = cluster_spec.get_devices_by_kind("SPU")
|
|
119
|
-
if not spu_devices:
|
|
120
|
-
raise ValueError("No SPU device found in the cluster specification")
|
|
121
|
-
if len(spu_devices) > 1:
|
|
122
|
-
raise ValueError("Multiple SPU devices found in the cluster specification")
|
|
123
|
-
spu_device = spu_devices[0]
|
|
124
|
-
|
|
125
|
-
# Store SPU configuration as strings for better readability
|
|
126
|
-
self.spu_protocol_str = spu_device.config["protocol"]
|
|
127
|
-
self.spu_field_str = spu_device.config["field"]
|
|
128
|
-
|
|
129
|
-
# Compute spu_mask from spu_device members
|
|
130
|
-
spu_mask = Mask.from_ranks([member.rank for member in spu_device.members])
|
|
131
|
-
self.spu_mask_int = int(spu_mask)
|
|
132
|
-
|
|
133
|
-
def _create_clients(self) -> dict[str, HttpExecutorClient]:
|
|
134
|
-
"""Create HTTP clients for all endpoints."""
|
|
135
|
-
clients = {}
|
|
136
|
-
for node_id, endpoint in self.node_addrs.items():
|
|
137
|
-
clients[node_id] = HttpExecutorClient(endpoint, self.timeout)
|
|
138
|
-
return clients
|
|
139
|
-
|
|
140
|
-
async def _close_clients(self, clients: dict[str, HttpExecutorClient]) -> None:
|
|
141
|
-
"""Close all provided HTTP clients."""
|
|
142
|
-
await asyncio.gather(*[client.close() for client in clients.values()])
|
|
143
|
-
|
|
144
|
-
def new_name(self, prefix: str = "var") -> str:
|
|
145
|
-
"""Generate a unique execution name."""
|
|
146
|
-
name = f"{prefix}_{self._counter}"
|
|
147
|
-
self._counter += 1
|
|
148
|
-
return name
|
|
149
|
-
|
|
150
|
-
async def _get_or_create_session(self) -> str:
|
|
151
|
-
"""Get existing session or create a new one across all HTTP servers."""
|
|
152
|
-
if self._session_id is None:
|
|
153
|
-
new_session_id = new_uuid()
|
|
154
|
-
# Create temporary clients for session creation
|
|
155
|
-
clients = self._create_clients()
|
|
156
|
-
try:
|
|
157
|
-
# Create session on all HTTP servers concurrently
|
|
158
|
-
tasks = []
|
|
159
|
-
for node_id, client in clients.items():
|
|
160
|
-
# Convert node_id to rank for the session creation
|
|
161
|
-
rank = list(self.node_addrs.keys()).index(node_id)
|
|
162
|
-
task = client.create_session(
|
|
163
|
-
name=new_session_id,
|
|
164
|
-
rank=rank,
|
|
165
|
-
cluster_spec=self.cluster_spec.to_dict(),
|
|
166
|
-
)
|
|
167
|
-
tasks.append(task)
|
|
168
|
-
|
|
169
|
-
try:
|
|
170
|
-
results = await asyncio.gather(*tasks)
|
|
171
|
-
for session_id in results:
|
|
172
|
-
assert session_id == new_session_id
|
|
173
|
-
self._session_id = new_session_id
|
|
174
|
-
except RuntimeError as e:
|
|
175
|
-
raise RuntimeError(
|
|
176
|
-
f"Failed to create session on one or more parties: {e}"
|
|
177
|
-
) from e
|
|
178
|
-
finally:
|
|
179
|
-
await self._close_clients(clients)
|
|
180
|
-
|
|
181
|
-
assert self._session_id is not None
|
|
182
|
-
return self._session_id
|
|
183
|
-
|
|
184
|
-
async def _evaluate(
|
|
185
|
-
self, expr: Expr, bindings: dict[str, MPObject]
|
|
186
|
-
) -> Sequence[MPObject]:
|
|
187
|
-
"""Async implementation to evaluate an expression."""
|
|
188
|
-
session_id = await self._get_or_create_session()
|
|
189
|
-
|
|
190
|
-
# Prepare input names from bindings
|
|
191
|
-
var_names = []
|
|
192
|
-
party_symbol_names = []
|
|
193
|
-
for name, var in bindings.items():
|
|
194
|
-
if var.ctx is not self:
|
|
195
|
-
raise ValueError(f"Variable {name} not in this context, got {var.ctx}.")
|
|
196
|
-
assert isinstance(var, DriverVar), (
|
|
197
|
-
f"Expected HttpDriverVar, got {type(var)}"
|
|
198
|
-
)
|
|
199
|
-
var_names.append(name)
|
|
200
|
-
party_symbol_names.append(var.symbol_name)
|
|
201
|
-
|
|
202
|
-
var_name_mapping = dict(zip(var_names, party_symbol_names, strict=True))
|
|
203
|
-
|
|
204
|
-
writer = IrWriter(var_name_mapping)
|
|
205
|
-
program_proto = writer.dumps(expr)
|
|
206
|
-
|
|
207
|
-
output_symbols = [self.new_name() for _ in range(expr.num_outputs)]
|
|
208
|
-
|
|
209
|
-
# Create temporary clients for computation execution
|
|
210
|
-
clients = self._create_clients()
|
|
211
|
-
try:
|
|
212
|
-
# Concurrently create and execute computation on all parties
|
|
213
|
-
tasks = []
|
|
214
|
-
computation_id = new_uuid()
|
|
215
|
-
for _rank, client in clients.items():
|
|
216
|
-
task = client.create_and_execute_computation(
|
|
217
|
-
session_id,
|
|
218
|
-
computation_id,
|
|
219
|
-
program_proto.SerializeToString(),
|
|
220
|
-
party_symbol_names,
|
|
221
|
-
output_symbols,
|
|
222
|
-
)
|
|
223
|
-
tasks.append(task)
|
|
224
|
-
|
|
225
|
-
try:
|
|
226
|
-
await asyncio.gather(*tasks)
|
|
227
|
-
except RuntimeError as e:
|
|
228
|
-
raise RuntimeError(
|
|
229
|
-
f"Failed to create and execute computation on one or more parties: {e}"
|
|
230
|
-
) from e
|
|
231
|
-
finally:
|
|
232
|
-
await self._close_clients(clients)
|
|
233
|
-
|
|
234
|
-
# Create HttpDriverVar objects for each output
|
|
235
|
-
driver_vars = []
|
|
236
|
-
for symbol_name, mptype in zip(output_symbols, expr.mptypes, strict=True):
|
|
237
|
-
driver_var = DriverVar(self, symbol_name, mptype)
|
|
238
|
-
driver_vars.append(driver_var)
|
|
239
|
-
|
|
240
|
-
return driver_vars
|
|
241
|
-
|
|
242
|
-
def evaluate(self, expr: Expr, bindings: dict[str, MPObject]) -> Sequence[MPObject]:
|
|
243
|
-
"""Evaluate an expression using distributed HTTP execution."""
|
|
244
|
-
return asyncio.run(self._evaluate(expr, bindings))
|
|
245
|
-
|
|
246
|
-
async def _fetch(self, obj: MPObject) -> list[Any]:
|
|
247
|
-
"""Async implementation to fetch results."""
|
|
248
|
-
if not isinstance(obj, DriverVar):
|
|
249
|
-
raise ValueError(f"Expected HttpDriverVar, got {type(obj)}")
|
|
250
|
-
|
|
251
|
-
session_id = await self._get_or_create_session()
|
|
252
|
-
symbol_full_name = obj.symbol_name
|
|
253
|
-
|
|
254
|
-
# Create temporary clients for fetching
|
|
255
|
-
clients = self._create_clients()
|
|
256
|
-
try:
|
|
257
|
-
# Concurrently fetch symbol from all parties
|
|
258
|
-
tasks = []
|
|
259
|
-
for _rank, client in clients.items():
|
|
260
|
-
task = client.get_symbol(session_id, symbol_full_name)
|
|
261
|
-
tasks.append(task)
|
|
262
|
-
|
|
263
|
-
try:
|
|
264
|
-
# The results will be in the same order as the clients (ranks)
|
|
265
|
-
results = await asyncio.gather(*tasks)
|
|
266
|
-
converted: list[Any] = []
|
|
267
|
-
for value in results:
|
|
268
|
-
if isinstance(value, TensorValue):
|
|
269
|
-
arr = value.to_numpy()
|
|
270
|
-
if isinstance(arr, np.ndarray) and arr.size == 1:
|
|
271
|
-
converted.append(arr.item())
|
|
272
|
-
else:
|
|
273
|
-
converted.append(arr)
|
|
274
|
-
elif isinstance(value, TableValue):
|
|
275
|
-
converted.append(value.to_pandas())
|
|
276
|
-
else:
|
|
277
|
-
converted.append(value)
|
|
278
|
-
return converted
|
|
279
|
-
except RuntimeError as e:
|
|
280
|
-
raise RuntimeError(
|
|
281
|
-
f"Failed to fetch symbol from one or more parties: {e}"
|
|
282
|
-
) from e
|
|
283
|
-
finally:
|
|
284
|
-
await self._close_clients(clients)
|
|
285
|
-
|
|
286
|
-
def fetch(self, obj: MPObject) -> list[Any]:
|
|
287
|
-
"""Fetch results from the distributed HTTP execution."""
|
|
288
|
-
return asyncio.run(self._fetch(obj))
|
|
289
|
-
|
|
290
|
-
async def _ping(self, node_id: str) -> bool:
|
|
291
|
-
"""Async implementation to ping a node.
|
|
292
|
-
|
|
293
|
-
Args:
|
|
294
|
-
node_id: The ID of the node to ping
|
|
295
|
-
|
|
296
|
-
Returns:
|
|
297
|
-
True if the node is healthy, False otherwise
|
|
298
|
-
"""
|
|
299
|
-
# Create a temporary client for the node
|
|
300
|
-
if node_id not in self.node_addrs:
|
|
301
|
-
raise ValueError(f"Node {node_id} not found in party addresses")
|
|
302
|
-
|
|
303
|
-
endpoint = self.node_addrs[node_id]
|
|
304
|
-
client = HttpExecutorClient(endpoint, self.timeout)
|
|
305
|
-
|
|
306
|
-
try:
|
|
307
|
-
# Perform health check
|
|
308
|
-
return await client.health_check()
|
|
309
|
-
except Exception:
|
|
310
|
-
# Any exception means the node is not healthy
|
|
311
|
-
return False
|
|
312
|
-
finally:
|
|
313
|
-
await client.close()
|
|
314
|
-
|
|
315
|
-
def ping(self, node_id: str) -> bool:
|
|
316
|
-
"""Ping a node to check if it's healthy.
|
|
317
|
-
|
|
318
|
-
Args:
|
|
319
|
-
node_id: The ID of the node to ping
|
|
320
|
-
|
|
321
|
-
Returns:
|
|
322
|
-
True if the node is healthy, False otherwise
|
|
323
|
-
"""
|
|
324
|
-
return asyncio.run(self._ping(node_id))
|
mplang/v1/runtime/exceptions.py
DELETED
|
@@ -1,27 +0,0 @@
|
|
|
1
|
-
# Copyright 2025 Ant Group Co., Ltd.
|
|
2
|
-
#
|
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
-
# you may not use this file except in compliance with the License.
|
|
5
|
-
# You may obtain a copy of the License at
|
|
6
|
-
#
|
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
-
#
|
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
-
# See the License for the specific language governing permissions and
|
|
13
|
-
# limitations under the License.
|
|
14
|
-
|
|
15
|
-
"""Custom exception types for the HTTP backend."""
|
|
16
|
-
|
|
17
|
-
|
|
18
|
-
class HttpBackendError(Exception):
|
|
19
|
-
"""Base exception for all HTTP backend errors."""
|
|
20
|
-
|
|
21
|
-
|
|
22
|
-
class ResourceNotFound(HttpBackendError):
|
|
23
|
-
"""Raised when a resource (session, computation, etc.) is not found."""
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
class InvalidRequestError(HttpBackendError, ValueError):
|
|
27
|
-
"""Raised for invalid requests, e.g., bad parameters or invalid state."""
|