mplang-nightly 0.1.dev266__py3-none-any.whl → 0.1.dev267__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mplang/v1/kernels/phe.py +8 -4
- mplang/v1/runtime/channel.py +230 -0
- mplang/v1/runtime/communicator.py +37 -13
- mplang/v1/runtime/link_comm.py +135 -17
- mplang/v1/runtime/server.py +10 -1
- mplang/v1/runtime/session.py +11 -38
- mplang/v1/runtime/simulation.py +30 -8
- mplang/v2/backends/channel.py +217 -0
- mplang/v2/backends/simp_worker/http.py +44 -13
- mplang/v2/backends/simp_worker/mem.py +9 -6
- mplang/v2/backends/spu_impl.py +18 -5
- mplang/v2/backends/spu_state.py +66 -3
- mplang/v2/libs/ml/sgb.py +20 -32
- {mplang_nightly-0.1.dev266.dist-info → mplang_nightly-0.1.dev267.dist-info}/METADATA +2 -2
- {mplang_nightly-0.1.dev266.dist-info → mplang_nightly-0.1.dev267.dist-info}/RECORD +18 -16
- {mplang_nightly-0.1.dev266.dist-info → mplang_nightly-0.1.dev267.dist-info}/WHEEL +0 -0
- {mplang_nightly-0.1.dev266.dist-info → mplang_nightly-0.1.dev267.dist-info}/entry_points.txt +0 -0
- {mplang_nightly-0.1.dev266.dist-info → mplang_nightly-0.1.dev267.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,217 @@
|
|
|
1
|
+
# Copyright 2026 Ant Group Co., Ltd.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
"""SPU IChannel implementation for MPLang v2.
|
|
16
|
+
|
|
17
|
+
Bridges v2's simp_worker communicators (ThreadCommunicator/HttpCommunicator)
|
|
18
|
+
to libspu's IChannel interface, enabling SPU to reuse existing communication
|
|
19
|
+
infrastructure instead of creating separate BRPC connections.
|
|
20
|
+
"""
|
|
21
|
+
|
|
22
|
+
from __future__ import annotations
|
|
23
|
+
|
|
24
|
+
import logging
|
|
25
|
+
from typing import Protocol
|
|
26
|
+
|
|
27
|
+
import spu.libspu as libspu
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
class CommunicatorProtocol(Protocol):
|
|
31
|
+
"""Protocol for v2 communicators (duck typing).
|
|
32
|
+
|
|
33
|
+
Both ThreadCommunicator and HttpCommunicator implement this interface.
|
|
34
|
+
"""
|
|
35
|
+
|
|
36
|
+
def send(self, to: int, key: str, data: bytes) -> None: ...
|
|
37
|
+
def recv(self, frm: int, key: str) -> bytes: ...
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
class BaseChannel(libspu.link.IChannel):
|
|
41
|
+
"""Bridge v2 communicator to SPU IChannel interface.
|
|
42
|
+
|
|
43
|
+
Supports both ThreadCommunicator and HttpCommunicator via duck typing.
|
|
44
|
+
Each BaseChannel represents a channel to ONE peer rank.
|
|
45
|
+
|
|
46
|
+
Communication Protocol:
|
|
47
|
+
- SPU calls send(tag, bytes) -> comm.send(peer, "spu:tag", bytes)
|
|
48
|
+
- SPU calls recv(tag) -> bytes <- comm.recv(peer, "spu:tag")
|
|
49
|
+
|
|
50
|
+
Tag Namespace:
|
|
51
|
+
All tags are prefixed with "spu:" to avoid collision with other
|
|
52
|
+
traffic on the same communicator.
|
|
53
|
+
"""
|
|
54
|
+
|
|
55
|
+
def __init__(
|
|
56
|
+
self,
|
|
57
|
+
comm: CommunicatorProtocol,
|
|
58
|
+
local_rank: int,
|
|
59
|
+
peer_rank: int,
|
|
60
|
+
tag_prefix: str = "spu",
|
|
61
|
+
):
|
|
62
|
+
"""Initialize channel to a specific peer.
|
|
63
|
+
|
|
64
|
+
Args:
|
|
65
|
+
comm: v2 communicator (any object implementing send/recv)
|
|
66
|
+
local_rank: Global rank of this party
|
|
67
|
+
peer_rank: Global rank of the peer party
|
|
68
|
+
tag_prefix: Prefix for all tags (default: "spu")
|
|
69
|
+
"""
|
|
70
|
+
super().__init__()
|
|
71
|
+
self._comm = comm
|
|
72
|
+
self._local_rank = local_rank
|
|
73
|
+
self._peer_rank = peer_rank
|
|
74
|
+
self._tag_prefix = tag_prefix
|
|
75
|
+
|
|
76
|
+
logging.debug(
|
|
77
|
+
f"BaseChannel initialized: local_rank={local_rank}, "
|
|
78
|
+
f"peer_rank={peer_rank}, tag_prefix={tag_prefix}"
|
|
79
|
+
)
|
|
80
|
+
|
|
81
|
+
def _make_key(self, tag: str) -> str:
|
|
82
|
+
"""Create unique key for communicator.
|
|
83
|
+
|
|
84
|
+
Args:
|
|
85
|
+
tag: SPU-provided tag (e.g., "send_0")
|
|
86
|
+
|
|
87
|
+
Returns:
|
|
88
|
+
Prefixed key (e.g., "spu:send_0")
|
|
89
|
+
"""
|
|
90
|
+
return f"{self._tag_prefix}:{tag}"
|
|
91
|
+
|
|
92
|
+
def Send(self, tag: str, data: bytes) -> None:
|
|
93
|
+
"""Send bytes to peer.
|
|
94
|
+
|
|
95
|
+
Args:
|
|
96
|
+
tag: Message tag for matching send/recv pairs
|
|
97
|
+
data: Raw bytes to send
|
|
98
|
+
"""
|
|
99
|
+
key = self._make_key(tag)
|
|
100
|
+
logging.debug(
|
|
101
|
+
f"BaseChannel.Send: {self._local_rank} -> {self._peer_rank}, "
|
|
102
|
+
f"tag={tag}, key={key}, size={len(data)}"
|
|
103
|
+
)
|
|
104
|
+
|
|
105
|
+
# Send raw bytes directly
|
|
106
|
+
# v2 communicators accept Any, bytes is valid
|
|
107
|
+
self._comm.send(self._peer_rank, key, data)
|
|
108
|
+
|
|
109
|
+
def Recv(self, tag: str) -> bytes:
|
|
110
|
+
"""Receive bytes from peer (blocking).
|
|
111
|
+
|
|
112
|
+
Args:
|
|
113
|
+
tag: Message tag for matching send/recv pairs
|
|
114
|
+
|
|
115
|
+
Returns:
|
|
116
|
+
Raw bytes received
|
|
117
|
+
|
|
118
|
+
Raises:
|
|
119
|
+
TypeError: If received data is not bytes
|
|
120
|
+
"""
|
|
121
|
+
key = self._make_key(tag)
|
|
122
|
+
logging.debug(
|
|
123
|
+
f"BaseChannel.Recv: {self._local_rank} <- {self._peer_rank}, "
|
|
124
|
+
f"tag={tag}, key={key}"
|
|
125
|
+
)
|
|
126
|
+
|
|
127
|
+
# Receive data (should be bytes)
|
|
128
|
+
data = self._comm.recv(self._peer_rank, key)
|
|
129
|
+
|
|
130
|
+
# Validate data type
|
|
131
|
+
if not isinstance(data, bytes):
|
|
132
|
+
raise TypeError(
|
|
133
|
+
f"Expected bytes from communicator, got {type(data).__name__}. "
|
|
134
|
+
f"Communicator must support raw bytes transmission for SPU channels."
|
|
135
|
+
)
|
|
136
|
+
|
|
137
|
+
logging.debug(
|
|
138
|
+
f"BaseChannel.Recv complete: {self._local_rank} <- {self._peer_rank}, "
|
|
139
|
+
f"tag={tag}, size={len(data)}"
|
|
140
|
+
)
|
|
141
|
+
return data
|
|
142
|
+
|
|
143
|
+
def SendAsync(self, tag: str, data: bytes) -> None:
|
|
144
|
+
"""Async send.
|
|
145
|
+
|
|
146
|
+
For HttpCommunicator, underlying HTTP client is non-blocking.
|
|
147
|
+
For ThreadCommunicator, send is instant (memory transfer).
|
|
148
|
+
|
|
149
|
+
Args:
|
|
150
|
+
tag: Message tag
|
|
151
|
+
data: Raw bytes to send
|
|
152
|
+
"""
|
|
153
|
+
self.Send(tag, data)
|
|
154
|
+
|
|
155
|
+
def SendAsyncThrottled(self, tag: str, data: bytes) -> None:
|
|
156
|
+
"""Throttled async send.
|
|
157
|
+
|
|
158
|
+
Currently maps to regular SendAsync.
|
|
159
|
+
|
|
160
|
+
Args:
|
|
161
|
+
tag: Message tag
|
|
162
|
+
data: Raw bytes to send
|
|
163
|
+
"""
|
|
164
|
+
self.SendAsync(tag, data)
|
|
165
|
+
|
|
166
|
+
def TestSend(self, timeout: int) -> None:
|
|
167
|
+
"""Test if channel can send a dummy message to peer.
|
|
168
|
+
|
|
169
|
+
Uses fixed tag "__test__" for idempotency.
|
|
170
|
+
|
|
171
|
+
Args:
|
|
172
|
+
timeout: Timeout in milliseconds (informational)
|
|
173
|
+
"""
|
|
174
|
+
test_data = b"\x00" # Minimal 1-byte handshake
|
|
175
|
+
self.Send("__test__", test_data)
|
|
176
|
+
|
|
177
|
+
def TestRecv(self) -> None:
|
|
178
|
+
"""Wait for dummy message from peer.
|
|
179
|
+
|
|
180
|
+
Timeout controlled by recv_timeout_ms in link descriptor.
|
|
181
|
+
|
|
182
|
+
Raises:
|
|
183
|
+
Warning if unexpected handshake data received
|
|
184
|
+
"""
|
|
185
|
+
test_data = self.Recv("__test__")
|
|
186
|
+
if test_data != b"\x00":
|
|
187
|
+
logging.warning(
|
|
188
|
+
f"TestRecv: unexpected handshake from {self._peer_rank}, "
|
|
189
|
+
f"expected b'\\x00', got {test_data!r}"
|
|
190
|
+
)
|
|
191
|
+
|
|
192
|
+
def WaitLinkTaskFinish(self) -> None:
|
|
193
|
+
"""Wait for all pending async tasks.
|
|
194
|
+
|
|
195
|
+
No-op for v2 communicators (handled automatically).
|
|
196
|
+
"""
|
|
197
|
+
|
|
198
|
+
def Abort(self) -> None:
|
|
199
|
+
"""Abort communication (cleanup).
|
|
200
|
+
|
|
201
|
+
Currently a no-op. Could be extended for resource cleanup.
|
|
202
|
+
"""
|
|
203
|
+
logging.warning(f"BaseChannel.Abort: {self._local_rank} <-> {self._peer_rank}")
|
|
204
|
+
|
|
205
|
+
def SetThrottleWindowSize(self, size: int) -> None:
|
|
206
|
+
"""Set throttle window size (no-op).
|
|
207
|
+
|
|
208
|
+
Args:
|
|
209
|
+
size: Window size (ignored)
|
|
210
|
+
"""
|
|
211
|
+
|
|
212
|
+
def SetChunkParallelSendSize(self, size: int) -> None:
|
|
213
|
+
"""Set chunk parallel send size (no-op).
|
|
214
|
+
|
|
215
|
+
Args:
|
|
216
|
+
size: Chunk size (ignored)
|
|
217
|
+
"""
|
|
@@ -83,7 +83,7 @@ class HttpCommunicator:
|
|
|
83
83
|
self.world_size = world_size
|
|
84
84
|
self.endpoints = endpoints
|
|
85
85
|
self.tracer = tracer
|
|
86
|
-
self._mailbox: dict[str, Any] = {}
|
|
86
|
+
self._mailbox: dict[tuple[int, str], Any] = {}
|
|
87
87
|
self._cond = threading.Condition()
|
|
88
88
|
self._send_executor = concurrent.futures.ThreadPoolExecutor(
|
|
89
89
|
max_workers=world_size, thread_name_prefix=f"comm_send_{rank}"
|
|
@@ -100,8 +100,19 @@ class HttpCommunicator:
|
|
|
100
100
|
"""Perform the HTTP send."""
|
|
101
101
|
url = f"{self.endpoints[to]}/comm/{key}"
|
|
102
102
|
logger.debug(f"Rank {self.rank} sending to {to} key={key}")
|
|
103
|
-
|
|
104
|
-
|
|
103
|
+
|
|
104
|
+
# Detect SPU channel (tag prefix "spu:") and handle bytes
|
|
105
|
+
if key.startswith("spu:") and isinstance(data, bytes):
|
|
106
|
+
# Send raw bytes for SPU channels
|
|
107
|
+
import base64
|
|
108
|
+
|
|
109
|
+
payload = base64.b64encode(data).decode("ascii")
|
|
110
|
+
is_raw_bytes = True
|
|
111
|
+
else:
|
|
112
|
+
# Use secure JSON serialization
|
|
113
|
+
payload = serde.dumps_b64(data)
|
|
114
|
+
is_raw_bytes = False
|
|
115
|
+
|
|
105
116
|
size_bytes = len(payload)
|
|
106
117
|
|
|
107
118
|
# Log to profiler
|
|
@@ -116,7 +127,14 @@ class HttpCommunicator:
|
|
|
116
127
|
|
|
117
128
|
try:
|
|
118
129
|
t0 = time.time()
|
|
119
|
-
resp = self.client.put(
|
|
130
|
+
resp = self.client.put(
|
|
131
|
+
url,
|
|
132
|
+
json={
|
|
133
|
+
"data": payload,
|
|
134
|
+
"from_rank": self.rank,
|
|
135
|
+
"is_raw_bytes": is_raw_bytes,
|
|
136
|
+
},
|
|
137
|
+
)
|
|
120
138
|
resp.raise_for_status()
|
|
121
139
|
duration = time.time() - t0
|
|
122
140
|
if self.tracer:
|
|
@@ -134,17 +152,21 @@ class HttpCommunicator:
|
|
|
134
152
|
def recv(self, frm: int, key: str) -> Any:
|
|
135
153
|
"""Receive data from another rank (blocking)."""
|
|
136
154
|
logger.debug(f"Rank {self.rank} waiting recv from {frm} key={key}")
|
|
155
|
+
mailbox_key = (frm, key)
|
|
137
156
|
with self._cond:
|
|
138
|
-
while
|
|
157
|
+
while mailbox_key not in self._mailbox:
|
|
139
158
|
self._cond.wait(timeout=1.0)
|
|
140
|
-
return self._mailbox.pop(
|
|
159
|
+
return self._mailbox.pop(mailbox_key)
|
|
141
160
|
|
|
142
|
-
def on_receive(self, key: str, data: Any) -> None:
|
|
161
|
+
def on_receive(self, from_rank: int, key: str, data: Any) -> None:
|
|
143
162
|
"""Called when data is received from the HTTP endpoint."""
|
|
163
|
+
mailbox_key = (from_rank, key)
|
|
144
164
|
with self._cond:
|
|
145
|
-
if
|
|
146
|
-
|
|
147
|
-
|
|
165
|
+
if mailbox_key in self._mailbox:
|
|
166
|
+
raise RuntimeError(
|
|
167
|
+
f"Mailbox overflow: key {mailbox_key} already exists"
|
|
168
|
+
)
|
|
169
|
+
self._mailbox[mailbox_key] = data
|
|
148
170
|
self._cond.notify_all()
|
|
149
171
|
|
|
150
172
|
def wait_pending_sends(self) -> None:
|
|
@@ -176,6 +198,7 @@ class CommRequest(BaseModel):
|
|
|
176
198
|
|
|
177
199
|
data: str
|
|
178
200
|
from_rank: int
|
|
201
|
+
is_raw_bytes: bool = False # NEW: indicates raw bytes (not serde)
|
|
179
202
|
|
|
180
203
|
|
|
181
204
|
class FetchRequest(BaseModel):
|
|
@@ -279,9 +302,17 @@ def create_worker_app(
|
|
|
279
302
|
"""Receive communication data from another worker."""
|
|
280
303
|
logger.debug(f"Worker {rank} received comm key={key} from {req.from_rank}")
|
|
281
304
|
try:
|
|
282
|
-
#
|
|
283
|
-
|
|
284
|
-
|
|
305
|
+
# Handle raw bytes (SPU channels) vs serde data
|
|
306
|
+
if req.is_raw_bytes:
|
|
307
|
+
# Decode base64 to raw bytes
|
|
308
|
+
import base64
|
|
309
|
+
|
|
310
|
+
data = base64.b64decode(req.data)
|
|
311
|
+
else:
|
|
312
|
+
# Use secure JSON deserialization
|
|
313
|
+
data = serde.loads_b64(req.data)
|
|
314
|
+
|
|
315
|
+
comm.on_receive(req.from_rank, key, data)
|
|
285
316
|
return {"status": "ok"}
|
|
286
317
|
except Exception as e:
|
|
287
318
|
logger.error(f"Worker {rank} comm failed: {e}")
|
|
@@ -35,7 +35,8 @@ class ThreadCommunicator:
|
|
|
35
35
|
self.world_size = world_size
|
|
36
36
|
self.use_serde = use_serde
|
|
37
37
|
self.peers: list[ThreadCommunicator] = []
|
|
38
|
-
|
|
38
|
+
# Mailbox keyed by (from_rank, tag): each key has exactly one message
|
|
39
|
+
self._mailbox: dict[tuple[int, str], Any] = {}
|
|
39
40
|
self._cond = threading.Condition()
|
|
40
41
|
self._sent_events: dict[str, threading.Event] = {}
|
|
41
42
|
self._shutdown = False
|
|
@@ -58,20 +59,22 @@ class ThreadCommunicator:
|
|
|
58
59
|
self.peers[to]._on_receive(self.rank, key, data)
|
|
59
60
|
|
|
60
61
|
def recv(self, frm: int, key: str) -> Any:
|
|
62
|
+
mailbox_key = (frm, key)
|
|
61
63
|
with self._cond:
|
|
62
|
-
while
|
|
64
|
+
while mailbox_key not in self._mailbox and not self._shutdown:
|
|
63
65
|
self._cond.wait()
|
|
64
66
|
if self._shutdown:
|
|
65
67
|
raise RuntimeError("Communicator shut down")
|
|
66
|
-
return self._mailbox.pop(
|
|
68
|
+
return self._mailbox.pop(mailbox_key)
|
|
67
69
|
|
|
68
70
|
def _on_receive(self, frm: int, key: str, data: Any) -> None:
|
|
71
|
+
mailbox_key = (frm, key)
|
|
69
72
|
with self._cond:
|
|
70
|
-
if
|
|
73
|
+
if mailbox_key in self._mailbox:
|
|
71
74
|
raise RuntimeError(
|
|
72
|
-
f"Mailbox overflow
|
|
75
|
+
f"Mailbox overflow: key {mailbox_key} already exists"
|
|
73
76
|
)
|
|
74
|
-
self._mailbox[
|
|
77
|
+
self._mailbox[mailbox_key] = data
|
|
75
78
|
self._cond.notify_all()
|
|
76
79
|
|
|
77
80
|
|
mplang/v2/backends/spu_impl.py
CHANGED
|
@@ -26,7 +26,6 @@ import numpy as np
|
|
|
26
26
|
import spu.api as spu_api
|
|
27
27
|
import spu.libspu as libspu
|
|
28
28
|
|
|
29
|
-
from mplang.v2.backends.simp_worker import SimpWorker
|
|
30
29
|
from mplang.v2.backends.spu_state import SPUState
|
|
31
30
|
from mplang.v2.backends.tensor_impl import TensorValue
|
|
32
31
|
from mplang.v2.dialects import spu
|
|
@@ -161,6 +160,8 @@ def exec_impl(interpreter: Interpreter, op: Operation, *args: Any) -> Any:
|
|
|
161
160
|
The SPU config must contain parties info to correctly map global rank
|
|
162
161
|
to local SPU rank and determine SPU world size.
|
|
163
162
|
"""
|
|
163
|
+
from mplang.v2.backends.simp_worker.state import SimpWorker
|
|
164
|
+
|
|
164
165
|
# Get SPU config from attrs (passed through from run_jax)
|
|
165
166
|
config: spu.SPUConfig = op.attrs["config"]
|
|
166
167
|
|
|
@@ -193,9 +194,8 @@ def exec_impl(interpreter: Interpreter, op: Operation, *args: Any) -> Any:
|
|
|
193
194
|
interpreter, "spu_endpoints", None
|
|
194
195
|
)
|
|
195
196
|
if spu_endpoints_map is None:
|
|
196
|
-
context
|
|
197
|
-
|
|
198
|
-
spu_endpoints_map = getattr(context, "spu_endpoints", None)
|
|
197
|
+
# Try getting from SimpWorker context (context is already SimpWorker)
|
|
198
|
+
spu_endpoints_map = getattr(context, "spu_endpoints", None)
|
|
199
199
|
|
|
200
200
|
# Build ordered list of endpoints for SPU parties
|
|
201
201
|
spu_endpoints: list[str] | None = None
|
|
@@ -209,6 +209,14 @@ def exec_impl(interpreter: Interpreter, op: Operation, *args: Any) -> Any:
|
|
|
209
209
|
)
|
|
210
210
|
spu_endpoints.append(spu_endpoints_map[party_rank])
|
|
211
211
|
|
|
212
|
+
# Get communicator for Channels mode (reuse existing communication)
|
|
213
|
+
# If no BRPC endpoints configured, use Channels mode
|
|
214
|
+
communicator = None
|
|
215
|
+
if spu_endpoints is None:
|
|
216
|
+
# Use worker's communicator for channel reuse
|
|
217
|
+
# (SimpWorker already imported at function start)
|
|
218
|
+
communicator = context.communicator
|
|
219
|
+
|
|
212
220
|
# Get or create SPUState for caching Runtime/Io
|
|
213
221
|
spu_state = interpreter.get_dialect_state(SPUState.dialect_name)
|
|
214
222
|
if not isinstance(spu_state, SPUState):
|
|
@@ -216,7 +224,12 @@ def exec_impl(interpreter: Interpreter, op: Operation, *args: Any) -> Any:
|
|
|
216
224
|
interpreter.set_dialect_state(SPUState.dialect_name, spu_state)
|
|
217
225
|
|
|
218
226
|
runtime, io = spu_state.get_or_create(
|
|
219
|
-
local_rank,
|
|
227
|
+
local_rank,
|
|
228
|
+
spu_world_size,
|
|
229
|
+
config,
|
|
230
|
+
spu_endpoints,
|
|
231
|
+
communicator=communicator,
|
|
232
|
+
parties=list(parties),
|
|
220
233
|
)
|
|
221
234
|
|
|
222
235
|
executable_code = op.attrs["executable"]
|
mplang/v2/backends/spu_state.py
CHANGED
|
@@ -20,7 +20,7 @@ multiple executions while binding to the Interpreter's lifecycle.
|
|
|
20
20
|
|
|
21
21
|
from __future__ import annotations
|
|
22
22
|
|
|
23
|
-
from typing import TYPE_CHECKING
|
|
23
|
+
from typing import TYPE_CHECKING, Any
|
|
24
24
|
|
|
25
25
|
import spu.api as spu_api
|
|
26
26
|
import spu.libspu as libspu
|
|
@@ -56,6 +56,8 @@ class SPUState(DialectState):
|
|
|
56
56
|
spu_world_size: int,
|
|
57
57
|
config: spu.SPUConfig,
|
|
58
58
|
spu_endpoints: list[str] | None = None,
|
|
59
|
+
communicator: object | None = None,
|
|
60
|
+
parties: list[int] | None = None,
|
|
59
61
|
) -> tuple[spu_api.Runtime, spu_api.Io]:
|
|
60
62
|
"""Get or create SPU Runtime and Io for the given configuration.
|
|
61
63
|
|
|
@@ -64,13 +66,24 @@ class SPUState(DialectState):
|
|
|
64
66
|
spu_world_size: The number of parties in the SPU device.
|
|
65
67
|
config: SPU configuration including protocol settings.
|
|
66
68
|
spu_endpoints: Optional list of BRPC endpoints. If None, use mem link.
|
|
69
|
+
communicator: Optional v2 communicator (ThreadCommunicator/HttpCommunicator).
|
|
70
|
+
If provided, use Channels mode to reuse existing communication.
|
|
71
|
+
parties: Optional list of global ranks for SPU parties.
|
|
72
|
+
Required when communicator is provided.
|
|
67
73
|
|
|
68
74
|
Returns:
|
|
69
75
|
A tuple of (Runtime, Io) for this party.
|
|
70
76
|
"""
|
|
71
77
|
from mplang.v2.backends.spu_impl import to_runtime_config
|
|
72
78
|
|
|
73
|
-
|
|
79
|
+
# Determine link mode
|
|
80
|
+
if communicator is not None:
|
|
81
|
+
link_mode = "channels"
|
|
82
|
+
elif spu_endpoints:
|
|
83
|
+
link_mode = "brpc"
|
|
84
|
+
else:
|
|
85
|
+
link_mode = "mem"
|
|
86
|
+
|
|
74
87
|
cache_key = (
|
|
75
88
|
local_rank,
|
|
76
89
|
spu_world_size,
|
|
@@ -83,7 +96,13 @@ class SPUState(DialectState):
|
|
|
83
96
|
return self._runtimes[cache_key]
|
|
84
97
|
|
|
85
98
|
# Create Link
|
|
86
|
-
if
|
|
99
|
+
if communicator is not None:
|
|
100
|
+
if parties is None:
|
|
101
|
+
raise ValueError("parties required when using communicator")
|
|
102
|
+
link = self._create_channels_link(
|
|
103
|
+
local_rank, spu_world_size, communicator, parties
|
|
104
|
+
)
|
|
105
|
+
elif spu_endpoints:
|
|
87
106
|
link = self._create_brpc_link(local_rank, spu_endpoints)
|
|
88
107
|
else:
|
|
89
108
|
link = self._create_mem_link(local_rank, spu_world_size)
|
|
@@ -106,6 +125,50 @@ class SPUState(DialectState):
|
|
|
106
125
|
desc.add_party(f"P{i}", f"mem:{i}")
|
|
107
126
|
return libspu.link.create_mem(desc, local_rank)
|
|
108
127
|
|
|
128
|
+
def _create_channels_link(
|
|
129
|
+
self,
|
|
130
|
+
local_rank: int,
|
|
131
|
+
spu_world_size: int,
|
|
132
|
+
communicator: Any,
|
|
133
|
+
parties: list[int],
|
|
134
|
+
) -> libspu.link.Context:
|
|
135
|
+
"""Create link using custom channels (reuse v2 communicator).
|
|
136
|
+
|
|
137
|
+
Args:
|
|
138
|
+
local_rank: SPU local rank (0-indexed, already converted from global)
|
|
139
|
+
spu_world_size: Number of SPU parties
|
|
140
|
+
communicator: v2 communicator (ThreadCommunicator/HttpCommunicator)
|
|
141
|
+
parties: List of global ranks for SPU parties (ordered by local rank)
|
|
142
|
+
|
|
143
|
+
Returns:
|
|
144
|
+
libspu link context using BaseChannel adapters
|
|
145
|
+
"""
|
|
146
|
+
from mplang.v2.backends.channel import BaseChannel
|
|
147
|
+
|
|
148
|
+
# Get this worker's global rank
|
|
149
|
+
global_rank = parties[local_rank]
|
|
150
|
+
|
|
151
|
+
# Create channels list (world_size elements, self = None)
|
|
152
|
+
channels = []
|
|
153
|
+
for idx, peer_global_rank in enumerate(parties):
|
|
154
|
+
if idx == local_rank:
|
|
155
|
+
# Self channel must be None
|
|
156
|
+
channel = None
|
|
157
|
+
else:
|
|
158
|
+
# Create channel to peer
|
|
159
|
+
channel = BaseChannel(communicator, global_rank, peer_global_rank)
|
|
160
|
+
channels.append(channel)
|
|
161
|
+
|
|
162
|
+
# Create link descriptor
|
|
163
|
+
desc = libspu.link.Desc() # type: ignore
|
|
164
|
+
desc.recv_timeout_ms = 100 * 1000 # 100 seconds
|
|
165
|
+
|
|
166
|
+
# Add party info (required for world_size inference)
|
|
167
|
+
for idx in range(spu_world_size):
|
|
168
|
+
desc.add_party(f"P{idx}", f"dummy_{parties[idx]}")
|
|
169
|
+
|
|
170
|
+
return libspu.link.create_with_channels(desc, local_rank, channels)
|
|
171
|
+
|
|
109
172
|
def _create_brpc_link(
|
|
110
173
|
self, local_rank: int, spu_endpoints: list[str]
|
|
111
174
|
) -> libspu.link.Context:
|
mplang/v2/libs/ml/sgb.py
CHANGED
|
@@ -1097,11 +1097,9 @@ def _update_tree_state(
|
|
|
1097
1097
|
|
|
1098
1098
|
all_feats[party_idx] = simp.pcall_static(
|
|
1099
1099
|
(party_rank,),
|
|
1100
|
-
lambda pf=all_feats[party_idx],
|
|
1101
|
-
|
|
1102
|
-
|
|
1103
|
-
op=owned_party_party,
|
|
1104
|
-
il=is_leaf_party: tensor.run_jax(update_party_feats, pf, bf, ci, op, il),
|
|
1100
|
+
lambda pf=all_feats[party_idx], bf=all_feats_level[party_idx], ci=cur_indices_party, op=owned_party_party, il=is_leaf_party: (
|
|
1101
|
+
tensor.run_jax(update_party_feats, pf, bf, ci, op, il)
|
|
1102
|
+
),
|
|
1105
1103
|
)
|
|
1106
1104
|
|
|
1107
1105
|
def update_party_thresholds(
|
|
@@ -1123,21 +1121,17 @@ def _update_tree_state(
|
|
|
1123
1121
|
|
|
1124
1122
|
all_thresholds[party_idx] = simp.pcall_static(
|
|
1125
1123
|
(party_rank,),
|
|
1126
|
-
lambda pt=all_thresholds[party_idx],
|
|
1127
|
-
|
|
1128
|
-
|
|
1129
|
-
|
|
1130
|
-
|
|
1131
|
-
|
|
1132
|
-
|
|
1133
|
-
|
|
1134
|
-
|
|
1135
|
-
|
|
1136
|
-
|
|
1137
|
-
bt_idx,
|
|
1138
|
-
ci,
|
|
1139
|
-
op,
|
|
1140
|
-
il,
|
|
1124
|
+
lambda pt=all_thresholds[party_idx], b=all_bins[party_idx], bf=all_feats_level[party_idx], bt_idx=all_threshs_level[party_idx], ci=cur_indices_party, op=owned_party_party, il=is_leaf_party: (
|
|
1125
|
+
tensor.run_jax(
|
|
1126
|
+
update_party_thresholds,
|
|
1127
|
+
pt,
|
|
1128
|
+
b,
|
|
1129
|
+
bf,
|
|
1130
|
+
bt_idx,
|
|
1131
|
+
ci,
|
|
1132
|
+
op,
|
|
1133
|
+
il,
|
|
1134
|
+
)
|
|
1141
1135
|
),
|
|
1142
1136
|
)
|
|
1143
1137
|
|
|
@@ -1152,13 +1146,8 @@ def _update_tree_state(
|
|
|
1152
1146
|
|
|
1153
1147
|
tmp_bt = simp.pcall_static(
|
|
1154
1148
|
(party_rank,),
|
|
1155
|
-
lambda bi=all_bin_indices[party_idx],
|
|
1156
|
-
|
|
1157
|
-
bt_idx=all_threshs_level[party_idx],
|
|
1158
|
-
bt_arr=bt_party,
|
|
1159
|
-
bt_lv=bt_level_party,
|
|
1160
|
-
il=is_leaf_party: tensor.run_jax(
|
|
1161
|
-
update_bt, bt_arr, bt_lv, il, bi, bf, bt_idx
|
|
1149
|
+
lambda bi=all_bin_indices[party_idx], bf=all_feats_level[party_idx], bt_idx=all_threshs_level[party_idx], bt_arr=bt_party, bt_lv=bt_level_party, il=is_leaf_party: (
|
|
1150
|
+
tensor.run_jax(update_bt, bt_arr, bt_lv, il, bi, bf, bt_idx)
|
|
1162
1151
|
),
|
|
1163
1152
|
)
|
|
1164
1153
|
|
|
@@ -1498,11 +1487,10 @@ def predict_tree(
|
|
|
1498
1487
|
for i, rank in enumerate(all_ranks):
|
|
1499
1488
|
mask = simp.pcall_static(
|
|
1500
1489
|
(rank,),
|
|
1501
|
-
lambda d=all_datas[i],
|
|
1502
|
-
|
|
1503
|
-
|
|
1504
|
-
|
|
1505
|
-
d, f, t, tree.is_leaf, tree.owned_party_id, idx, n_nodes
|
|
1490
|
+
lambda d=all_datas[i], f=tree.feature[i], t=tree.threshold[i], idx=i: (
|
|
1491
|
+
predict_tree_single_party(
|
|
1492
|
+
d, f, t, tree.is_leaf, tree.owned_party_id, idx, n_nodes
|
|
1493
|
+
)
|
|
1506
1494
|
),
|
|
1507
1495
|
)
|
|
1508
1496
|
# Transfer to AP
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: mplang-nightly
|
|
3
|
-
Version: 0.1.
|
|
3
|
+
Version: 0.1.dev267
|
|
4
4
|
Summary: Multi-Party Programming Language
|
|
5
5
|
Author-email: SecretFlow Team <secretflow-contact@service.alipay.com>
|
|
6
6
|
License: Apache License
|
|
@@ -219,7 +219,7 @@ Requires-Dist: pandas>=2.0.0
|
|
|
219
219
|
Requires-Dist: protobuf<6.0,>=5.0
|
|
220
220
|
Requires-Dist: pyarrow>=14.0.0
|
|
221
221
|
Requires-Dist: pyyaml>=6.0
|
|
222
|
-
Requires-Dist: spu>=0.10.0.
|
|
222
|
+
Requires-Dist: spu>=0.10.0.dev20251211
|
|
223
223
|
Requires-Dist: sqlglot>=23.0.0
|
|
224
224
|
Requires-Dist: tenseal==0.3.16
|
|
225
225
|
Requires-Dist: typing-extensions
|