mplang-nightly 0.1.dev158__py3-none-any.whl → 0.1.dev268__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mplang/__init__.py +21 -45
- mplang/py.typed +13 -0
- mplang/v1/__init__.py +157 -0
- mplang/v1/_device.py +602 -0
- mplang/{analysis → v1/analysis}/__init__.py +1 -1
- mplang/{analysis → v1/analysis}/diagram.py +5 -7
- mplang/v1/core/__init__.py +157 -0
- mplang/{core → v1/core}/cluster.py +30 -14
- mplang/{core → v1/core}/comm.py +5 -1
- mplang/{core → v1/core}/context_mgr.py +1 -1
- mplang/{core/dtype.py → v1/core/dtypes.py} +44 -2
- mplang/{core → v1/core}/expr/__init__.py +7 -7
- mplang/{core → v1/core}/expr/ast.py +13 -14
- mplang/{core → v1/core}/expr/evaluator.py +65 -24
- mplang/{core → v1/core}/expr/printer.py +24 -18
- mplang/{core → v1/core}/expr/transformer.py +3 -3
- mplang/{core → v1/core}/expr/utils.py +2 -2
- mplang/{core → v1/core}/expr/visitor.py +1 -1
- mplang/{core → v1/core}/expr/walk.py +1 -1
- mplang/{core → v1/core}/interp.py +6 -6
- mplang/{core → v1/core}/mpir.py +23 -16
- mplang/{core → v1/core}/mpobject.py +6 -6
- mplang/{core → v1/core}/mptype.py +13 -10
- mplang/{core → v1/core}/pfunc.py +4 -4
- mplang/{core → v1/core}/primitive.py +106 -201
- mplang/{core → v1/core}/table.py +36 -8
- mplang/{core → v1/core}/tensor.py +1 -1
- mplang/{core → v1/core}/tracer.py +9 -9
- mplang/{api.py → v1/host.py} +38 -6
- mplang/v1/kernels/__init__.py +41 -0
- mplang/{kernels → v1/kernels}/base.py +1 -1
- mplang/v1/kernels/basic.py +240 -0
- mplang/{kernels → v1/kernels}/context.py +42 -27
- mplang/{kernels → v1/kernels}/crypto.py +44 -37
- mplang/v1/kernels/fhe.py +858 -0
- mplang/{kernels → v1/kernels}/mock_tee.py +12 -13
- mplang/{kernels → v1/kernels}/phe.py +263 -57
- mplang/{kernels → v1/kernels}/spu.py +137 -48
- mplang/{kernels → v1/kernels}/sql_duckdb.py +12 -15
- mplang/{kernels → v1/kernels}/stablehlo.py +30 -23
- mplang/v1/kernels/value.py +626 -0
- mplang/{ops → v1/ops}/__init__.py +5 -16
- mplang/{ops → v1/ops}/base.py +2 -5
- mplang/{ops/builtin.py → v1/ops/basic.py} +34 -26
- mplang/v1/ops/crypto.py +262 -0
- mplang/v1/ops/fhe.py +272 -0
- mplang/{ops → v1/ops}/jax_cc.py +33 -68
- mplang/v1/ops/nnx_cc.py +168 -0
- mplang/{ops → v1/ops}/phe.py +16 -4
- mplang/{ops → v1/ops}/spu.py +3 -5
- mplang/v1/ops/sql_cc.py +303 -0
- mplang/{ops → v1/ops}/tee.py +9 -24
- mplang/{protos → v1/protos}/v1alpha1/mpir_pb2.pyi +71 -21
- mplang/v1/protos/v1alpha1/value_pb2.py +34 -0
- mplang/v1/protos/v1alpha1/value_pb2.pyi +169 -0
- mplang/{runtime → v1/runtime}/__init__.py +2 -2
- mplang/v1/runtime/channel.py +230 -0
- mplang/{runtime → v1/runtime}/cli.py +35 -20
- mplang/{runtime → v1/runtime}/client.py +19 -8
- mplang/{runtime → v1/runtime}/communicator.py +59 -15
- mplang/{runtime → v1/runtime}/data_providers.py +80 -19
- mplang/{runtime → v1/runtime}/driver.py +30 -12
- mplang/v1/runtime/link_comm.py +196 -0
- mplang/{runtime → v1/runtime}/server.py +58 -42
- mplang/{runtime → v1/runtime}/session.py +57 -71
- mplang/{runtime → v1/runtime}/simulation.py +55 -28
- mplang/v1/simp/api.py +353 -0
- mplang/{simp → v1/simp}/mpi.py +8 -9
- mplang/{simp/__init__.py → v1/simp/party.py} +19 -145
- mplang/{simp → v1/simp}/random.py +21 -22
- mplang/v1/simp/smpc.py +238 -0
- mplang/v1/utils/table_utils.py +185 -0
- mplang/v2/__init__.py +424 -0
- mplang/v2/backends/__init__.py +57 -0
- mplang/v2/backends/bfv_impl.py +705 -0
- mplang/v2/backends/channel.py +217 -0
- mplang/v2/backends/crypto_impl.py +723 -0
- mplang/v2/backends/field_impl.py +454 -0
- mplang/v2/backends/func_impl.py +107 -0
- mplang/v2/backends/phe_impl.py +148 -0
- mplang/v2/backends/simp_design.md +136 -0
- mplang/v2/backends/simp_driver/__init__.py +41 -0
- mplang/v2/backends/simp_driver/http.py +168 -0
- mplang/v2/backends/simp_driver/mem.py +280 -0
- mplang/v2/backends/simp_driver/ops.py +135 -0
- mplang/v2/backends/simp_driver/state.py +60 -0
- mplang/v2/backends/simp_driver/values.py +52 -0
- mplang/v2/backends/simp_worker/__init__.py +29 -0
- mplang/v2/backends/simp_worker/http.py +354 -0
- mplang/v2/backends/simp_worker/mem.py +102 -0
- mplang/v2/backends/simp_worker/ops.py +167 -0
- mplang/v2/backends/simp_worker/state.py +49 -0
- mplang/v2/backends/spu_impl.py +275 -0
- mplang/v2/backends/spu_state.py +187 -0
- mplang/v2/backends/store_impl.py +62 -0
- mplang/v2/backends/table_impl.py +838 -0
- mplang/v2/backends/tee_impl.py +215 -0
- mplang/v2/backends/tensor_impl.py +519 -0
- mplang/v2/cli.py +603 -0
- mplang/v2/cli_guide.md +122 -0
- mplang/v2/dialects/__init__.py +36 -0
- mplang/v2/dialects/bfv.py +665 -0
- mplang/v2/dialects/crypto.py +689 -0
- mplang/v2/dialects/dtypes.py +378 -0
- mplang/v2/dialects/field.py +210 -0
- mplang/v2/dialects/func.py +135 -0
- mplang/v2/dialects/phe.py +723 -0
- mplang/v2/dialects/simp.py +944 -0
- mplang/v2/dialects/spu.py +349 -0
- mplang/v2/dialects/store.py +63 -0
- mplang/v2/dialects/table.py +407 -0
- mplang/v2/dialects/tee.py +346 -0
- mplang/v2/dialects/tensor.py +1175 -0
- mplang/v2/edsl/README.md +279 -0
- mplang/v2/edsl/__init__.py +99 -0
- mplang/v2/edsl/context.py +311 -0
- mplang/v2/edsl/graph.py +463 -0
- mplang/v2/edsl/jit.py +62 -0
- mplang/v2/edsl/object.py +53 -0
- mplang/v2/edsl/primitive.py +284 -0
- mplang/v2/edsl/printer.py +119 -0
- mplang/v2/edsl/registry.py +207 -0
- mplang/v2/edsl/serde.py +375 -0
- mplang/v2/edsl/tracer.py +614 -0
- mplang/v2/edsl/typing.py +816 -0
- mplang/v2/kernels/Makefile +30 -0
- mplang/v2/kernels/__init__.py +23 -0
- mplang/v2/kernels/gf128.cpp +148 -0
- mplang/v2/kernels/ldpc.cpp +82 -0
- mplang/v2/kernels/okvs.cpp +283 -0
- mplang/v2/kernels/okvs_opt.cpp +291 -0
- mplang/v2/kernels/py_kernels.py +398 -0
- mplang/v2/libs/collective.py +330 -0
- mplang/v2/libs/device/__init__.py +51 -0
- mplang/v2/libs/device/api.py +813 -0
- mplang/v2/libs/device/cluster.py +352 -0
- mplang/v2/libs/ml/__init__.py +23 -0
- mplang/v2/libs/ml/sgb.py +1861 -0
- mplang/v2/libs/mpc/__init__.py +41 -0
- mplang/v2/libs/mpc/_utils.py +99 -0
- mplang/v2/libs/mpc/analytics/__init__.py +35 -0
- mplang/v2/libs/mpc/analytics/aggregation.py +372 -0
- mplang/v2/libs/mpc/analytics/groupby.md +99 -0
- mplang/v2/libs/mpc/analytics/groupby.py +331 -0
- mplang/v2/libs/mpc/analytics/permutation.py +386 -0
- mplang/v2/libs/mpc/common/constants.py +39 -0
- mplang/v2/libs/mpc/ot/__init__.py +32 -0
- mplang/v2/libs/mpc/ot/base.py +222 -0
- mplang/v2/libs/mpc/ot/extension.py +477 -0
- mplang/v2/libs/mpc/ot/silent.py +217 -0
- mplang/v2/libs/mpc/psi/__init__.py +40 -0
- mplang/v2/libs/mpc/psi/cuckoo.py +228 -0
- mplang/v2/libs/mpc/psi/okvs.py +49 -0
- mplang/v2/libs/mpc/psi/okvs_gct.py +79 -0
- mplang/v2/libs/mpc/psi/oprf.py +310 -0
- mplang/v2/libs/mpc/psi/rr22.py +344 -0
- mplang/v2/libs/mpc/psi/unbalanced.py +200 -0
- mplang/v2/libs/mpc/vole/__init__.py +31 -0
- mplang/v2/libs/mpc/vole/gilboa.py +327 -0
- mplang/v2/libs/mpc/vole/ldpc.py +383 -0
- mplang/v2/libs/mpc/vole/silver.py +336 -0
- mplang/v2/runtime/__init__.py +15 -0
- mplang/v2/runtime/dialect_state.py +41 -0
- mplang/v2/runtime/interpreter.py +871 -0
- mplang/v2/runtime/object_store.py +194 -0
- mplang/v2/runtime/value.py +141 -0
- {mplang_nightly-0.1.dev158.dist-info → mplang_nightly-0.1.dev268.dist-info}/METADATA +24 -17
- mplang_nightly-0.1.dev268.dist-info/RECORD +180 -0
- {mplang_nightly-0.1.dev158.dist-info → mplang_nightly-0.1.dev268.dist-info}/WHEEL +1 -1
- mplang/core/__init__.py +0 -92
- mplang/device.py +0 -340
- mplang/kernels/builtin.py +0 -207
- mplang/ops/crypto.py +0 -109
- mplang/ops/ibis_cc.py +0 -139
- mplang/ops/sql.py +0 -61
- mplang/protos/v1alpha1/mpir_pb2_grpc.py +0 -3
- mplang/runtime/link_comm.py +0 -131
- mplang/simp/smpc.py +0 -201
- mplang/utils/table_utils.py +0 -73
- mplang_nightly-0.1.dev158.dist-info/RECORD +0 -77
- /mplang/{core → v1/core}/mask.py +0 -0
- /mplang/{protos → v1/protos}/v1alpha1/mpir_pb2.py +0 -0
- /mplang/{runtime → v1/runtime}/exceptions.py +0 -0
- /mplang/{runtime → v1/runtime}/http_api.md +0 -0
- /mplang/{kernels → v1/simp}/__init__.py +0 -0
- /mplang/{utils → v1/utils}/__init__.py +0 -0
- /mplang/{utils → v1/utils}/crypto.py +0 -0
- /mplang/{utils → v1/utils}/func_utils.py +0 -0
- /mplang/{utils → v1/utils}/spu_utils.py +0 -0
- {mplang_nightly-0.1.dev158.dist-info → mplang_nightly-0.1.dev268.dist-info}/entry_points.txt +0 -0
- {mplang_nightly-0.1.dev158.dist-info → mplang_nightly-0.1.dev268.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,217 @@
|
|
|
1
|
+
# Copyright 2026 Ant Group Co., Ltd.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
"""SPU IChannel implementation for MPLang v2.
|
|
16
|
+
|
|
17
|
+
Bridges v2's simp_worker communicators (ThreadCommunicator/HttpCommunicator)
|
|
18
|
+
to libspu's IChannel interface, enabling SPU to reuse existing communication
|
|
19
|
+
infrastructure instead of creating separate BRPC connections.
|
|
20
|
+
"""
|
|
21
|
+
|
|
22
|
+
from __future__ import annotations
|
|
23
|
+
|
|
24
|
+
import logging
|
|
25
|
+
from typing import Protocol
|
|
26
|
+
|
|
27
|
+
import spu.libspu as libspu
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
class CommunicatorProtocol(Protocol):
|
|
31
|
+
"""Protocol for v2 communicators (duck typing).
|
|
32
|
+
|
|
33
|
+
Both ThreadCommunicator and HttpCommunicator implement this interface.
|
|
34
|
+
"""
|
|
35
|
+
|
|
36
|
+
def send(self, to: int, key: str, data: bytes) -> None: ...
|
|
37
|
+
def recv(self, frm: int, key: str) -> bytes: ...
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
class BaseChannel(libspu.link.IChannel):
|
|
41
|
+
"""Bridge v2 communicator to SPU IChannel interface.
|
|
42
|
+
|
|
43
|
+
Supports both ThreadCommunicator and HttpCommunicator via duck typing.
|
|
44
|
+
Each BaseChannel represents a channel to ONE peer rank.
|
|
45
|
+
|
|
46
|
+
Communication Protocol:
|
|
47
|
+
- SPU calls send(tag, bytes) -> comm.send(peer, "spu:tag", bytes)
|
|
48
|
+
- SPU calls recv(tag) -> bytes <- comm.recv(peer, "spu:tag")
|
|
49
|
+
|
|
50
|
+
Tag Namespace:
|
|
51
|
+
All tags are prefixed with "spu:" to avoid collision with other
|
|
52
|
+
traffic on the same communicator.
|
|
53
|
+
"""
|
|
54
|
+
|
|
55
|
+
def __init__(
|
|
56
|
+
self,
|
|
57
|
+
comm: CommunicatorProtocol,
|
|
58
|
+
local_rank: int,
|
|
59
|
+
peer_rank: int,
|
|
60
|
+
tag_prefix: str = "spu",
|
|
61
|
+
):
|
|
62
|
+
"""Initialize channel to a specific peer.
|
|
63
|
+
|
|
64
|
+
Args:
|
|
65
|
+
comm: v2 communicator (any object implementing send/recv)
|
|
66
|
+
local_rank: Global rank of this party
|
|
67
|
+
peer_rank: Global rank of the peer party
|
|
68
|
+
tag_prefix: Prefix for all tags (default: "spu")
|
|
69
|
+
"""
|
|
70
|
+
super().__init__()
|
|
71
|
+
self._comm = comm
|
|
72
|
+
self._local_rank = local_rank
|
|
73
|
+
self._peer_rank = peer_rank
|
|
74
|
+
self._tag_prefix = tag_prefix
|
|
75
|
+
|
|
76
|
+
logging.debug(
|
|
77
|
+
f"BaseChannel initialized: local_rank={local_rank}, "
|
|
78
|
+
f"peer_rank={peer_rank}, tag_prefix={tag_prefix}"
|
|
79
|
+
)
|
|
80
|
+
|
|
81
|
+
def _make_key(self, tag: str) -> str:
|
|
82
|
+
"""Create unique key for communicator.
|
|
83
|
+
|
|
84
|
+
Args:
|
|
85
|
+
tag: SPU-provided tag (e.g., "send_0")
|
|
86
|
+
|
|
87
|
+
Returns:
|
|
88
|
+
Prefixed key (e.g., "spu:send_0")
|
|
89
|
+
"""
|
|
90
|
+
return f"{self._tag_prefix}:{tag}"
|
|
91
|
+
|
|
92
|
+
def Send(self, tag: str, data: bytes) -> None:
|
|
93
|
+
"""Send bytes to peer.
|
|
94
|
+
|
|
95
|
+
Args:
|
|
96
|
+
tag: Message tag for matching send/recv pairs
|
|
97
|
+
data: Raw bytes to send
|
|
98
|
+
"""
|
|
99
|
+
key = self._make_key(tag)
|
|
100
|
+
logging.debug(
|
|
101
|
+
f"BaseChannel.Send: {self._local_rank} -> {self._peer_rank}, "
|
|
102
|
+
f"tag={tag}, key={key}, size={len(data)}"
|
|
103
|
+
)
|
|
104
|
+
|
|
105
|
+
# Send raw bytes directly
|
|
106
|
+
# v2 communicators accept Any, bytes is valid
|
|
107
|
+
self._comm.send(self._peer_rank, key, data)
|
|
108
|
+
|
|
109
|
+
def Recv(self, tag: str) -> bytes:
|
|
110
|
+
"""Receive bytes from peer (blocking).
|
|
111
|
+
|
|
112
|
+
Args:
|
|
113
|
+
tag: Message tag for matching send/recv pairs
|
|
114
|
+
|
|
115
|
+
Returns:
|
|
116
|
+
Raw bytes received
|
|
117
|
+
|
|
118
|
+
Raises:
|
|
119
|
+
TypeError: If received data is not bytes
|
|
120
|
+
"""
|
|
121
|
+
key = self._make_key(tag)
|
|
122
|
+
logging.debug(
|
|
123
|
+
f"BaseChannel.Recv: {self._local_rank} <- {self._peer_rank}, "
|
|
124
|
+
f"tag={tag}, key={key}"
|
|
125
|
+
)
|
|
126
|
+
|
|
127
|
+
# Receive data (should be bytes)
|
|
128
|
+
data = self._comm.recv(self._peer_rank, key)
|
|
129
|
+
|
|
130
|
+
# Validate data type
|
|
131
|
+
if not isinstance(data, bytes):
|
|
132
|
+
raise TypeError(
|
|
133
|
+
f"Expected bytes from communicator, got {type(data).__name__}. "
|
|
134
|
+
f"Communicator must support raw bytes transmission for SPU channels."
|
|
135
|
+
)
|
|
136
|
+
|
|
137
|
+
logging.debug(
|
|
138
|
+
f"BaseChannel.Recv complete: {self._local_rank} <- {self._peer_rank}, "
|
|
139
|
+
f"tag={tag}, size={len(data)}"
|
|
140
|
+
)
|
|
141
|
+
return data
|
|
142
|
+
|
|
143
|
+
def SendAsync(self, tag: str, data: bytes) -> None:
|
|
144
|
+
"""Async send.
|
|
145
|
+
|
|
146
|
+
For HttpCommunicator, underlying HTTP client is non-blocking.
|
|
147
|
+
For ThreadCommunicator, send is instant (memory transfer).
|
|
148
|
+
|
|
149
|
+
Args:
|
|
150
|
+
tag: Message tag
|
|
151
|
+
data: Raw bytes to send
|
|
152
|
+
"""
|
|
153
|
+
self.Send(tag, data)
|
|
154
|
+
|
|
155
|
+
def SendAsyncThrottled(self, tag: str, data: bytes) -> None:
|
|
156
|
+
"""Throttled async send.
|
|
157
|
+
|
|
158
|
+
Currently maps to regular SendAsync.
|
|
159
|
+
|
|
160
|
+
Args:
|
|
161
|
+
tag: Message tag
|
|
162
|
+
data: Raw bytes to send
|
|
163
|
+
"""
|
|
164
|
+
self.SendAsync(tag, data)
|
|
165
|
+
|
|
166
|
+
def TestSend(self, timeout: int) -> None:
|
|
167
|
+
"""Test if channel can send a dummy message to peer.
|
|
168
|
+
|
|
169
|
+
Uses fixed tag "__test__" for idempotency.
|
|
170
|
+
|
|
171
|
+
Args:
|
|
172
|
+
timeout: Timeout in milliseconds (informational)
|
|
173
|
+
"""
|
|
174
|
+
test_data = b"\x00" # Minimal 1-byte handshake
|
|
175
|
+
self.Send("__test__", test_data)
|
|
176
|
+
|
|
177
|
+
def TestRecv(self) -> None:
|
|
178
|
+
"""Wait for dummy message from peer.
|
|
179
|
+
|
|
180
|
+
Timeout controlled by recv_timeout_ms in link descriptor.
|
|
181
|
+
|
|
182
|
+
Raises:
|
|
183
|
+
Warning if unexpected handshake data received
|
|
184
|
+
"""
|
|
185
|
+
test_data = self.Recv("__test__")
|
|
186
|
+
if test_data != b"\x00":
|
|
187
|
+
logging.warning(
|
|
188
|
+
f"TestRecv: unexpected handshake from {self._peer_rank}, "
|
|
189
|
+
f"expected b'\\x00', got {test_data!r}"
|
|
190
|
+
)
|
|
191
|
+
|
|
192
|
+
def WaitLinkTaskFinish(self) -> None:
|
|
193
|
+
"""Wait for all pending async tasks.
|
|
194
|
+
|
|
195
|
+
No-op for v2 communicators (handled automatically).
|
|
196
|
+
"""
|
|
197
|
+
|
|
198
|
+
def Abort(self) -> None:
|
|
199
|
+
"""Abort communication (cleanup).
|
|
200
|
+
|
|
201
|
+
Currently a no-op. Could be extended for resource cleanup.
|
|
202
|
+
"""
|
|
203
|
+
logging.warning(f"BaseChannel.Abort: {self._local_rank} <-> {self._peer_rank}")
|
|
204
|
+
|
|
205
|
+
def SetThrottleWindowSize(self, size: int) -> None:
|
|
206
|
+
"""Set throttle window size (no-op).
|
|
207
|
+
|
|
208
|
+
Args:
|
|
209
|
+
size: Window size (ignored)
|
|
210
|
+
"""
|
|
211
|
+
|
|
212
|
+
def SetChunkParallelSendSize(self, size: int) -> None:
|
|
213
|
+
"""Set chunk parallel send size (no-op).
|
|
214
|
+
|
|
215
|
+
Args:
|
|
216
|
+
size: Chunk size (ignored)
|
|
217
|
+
"""
|