mplang-nightly 0.1.dev266__py3-none-any.whl → 0.1.dev267__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,217 @@
1
+ # Copyright 2026 Ant Group Co., Ltd.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """SPU IChannel implementation for MPLang v2.
16
+
17
+ Bridges v2's simp_worker communicators (ThreadCommunicator/HttpCommunicator)
18
+ to libspu's IChannel interface, enabling SPU to reuse existing communication
19
+ infrastructure instead of creating separate BRPC connections.
20
+ """
21
+
22
+ from __future__ import annotations
23
+
24
+ import logging
25
+ from typing import Protocol
26
+
27
+ import spu.libspu as libspu
28
+
29
+
30
+ class CommunicatorProtocol(Protocol):
31
+ """Protocol for v2 communicators (duck typing).
32
+
33
+ Both ThreadCommunicator and HttpCommunicator implement this interface.
34
+ """
35
+
36
+ def send(self, to: int, key: str, data: bytes) -> None: ...
37
+ def recv(self, frm: int, key: str) -> bytes: ...
38
+
39
+
40
+ class BaseChannel(libspu.link.IChannel):
41
+ """Bridge v2 communicator to SPU IChannel interface.
42
+
43
+ Supports both ThreadCommunicator and HttpCommunicator via duck typing.
44
+ Each BaseChannel represents a channel to ONE peer rank.
45
+
46
+ Communication Protocol:
47
+ - SPU calls send(tag, bytes) -> comm.send(peer, "spu:tag", bytes)
48
+ - SPU calls recv(tag) -> bytes <- comm.recv(peer, "spu:tag")
49
+
50
+ Tag Namespace:
51
+ All tags are prefixed with "spu:" to avoid collision with other
52
+ traffic on the same communicator.
53
+ """
54
+
55
+ def __init__(
56
+ self,
57
+ comm: CommunicatorProtocol,
58
+ local_rank: int,
59
+ peer_rank: int,
60
+ tag_prefix: str = "spu",
61
+ ):
62
+ """Initialize channel to a specific peer.
63
+
64
+ Args:
65
+ comm: v2 communicator (any object implementing send/recv)
66
+ local_rank: Global rank of this party
67
+ peer_rank: Global rank of the peer party
68
+ tag_prefix: Prefix for all tags (default: "spu")
69
+ """
70
+ super().__init__()
71
+ self._comm = comm
72
+ self._local_rank = local_rank
73
+ self._peer_rank = peer_rank
74
+ self._tag_prefix = tag_prefix
75
+
76
+ logging.debug(
77
+ f"BaseChannel initialized: local_rank={local_rank}, "
78
+ f"peer_rank={peer_rank}, tag_prefix={tag_prefix}"
79
+ )
80
+
81
+ def _make_key(self, tag: str) -> str:
82
+ """Create unique key for communicator.
83
+
84
+ Args:
85
+ tag: SPU-provided tag (e.g., "send_0")
86
+
87
+ Returns:
88
+ Prefixed key (e.g., "spu:send_0")
89
+ """
90
+ return f"{self._tag_prefix}:{tag}"
91
+
92
+ def Send(self, tag: str, data: bytes) -> None:
93
+ """Send bytes to peer.
94
+
95
+ Args:
96
+ tag: Message tag for matching send/recv pairs
97
+ data: Raw bytes to send
98
+ """
99
+ key = self._make_key(tag)
100
+ logging.debug(
101
+ f"BaseChannel.Send: {self._local_rank} -> {self._peer_rank}, "
102
+ f"tag={tag}, key={key}, size={len(data)}"
103
+ )
104
+
105
+ # Send raw bytes directly
106
+ # v2 communicators accept Any, bytes is valid
107
+ self._comm.send(self._peer_rank, key, data)
108
+
109
+ def Recv(self, tag: str) -> bytes:
110
+ """Receive bytes from peer (blocking).
111
+
112
+ Args:
113
+ tag: Message tag for matching send/recv pairs
114
+
115
+ Returns:
116
+ Raw bytes received
117
+
118
+ Raises:
119
+ TypeError: If received data is not bytes
120
+ """
121
+ key = self._make_key(tag)
122
+ logging.debug(
123
+ f"BaseChannel.Recv: {self._local_rank} <- {self._peer_rank}, "
124
+ f"tag={tag}, key={key}"
125
+ )
126
+
127
+ # Receive data (should be bytes)
128
+ data = self._comm.recv(self._peer_rank, key)
129
+
130
+ # Validate data type
131
+ if not isinstance(data, bytes):
132
+ raise TypeError(
133
+ f"Expected bytes from communicator, got {type(data).__name__}. "
134
+ f"Communicator must support raw bytes transmission for SPU channels."
135
+ )
136
+
137
+ logging.debug(
138
+ f"BaseChannel.Recv complete: {self._local_rank} <- {self._peer_rank}, "
139
+ f"tag={tag}, size={len(data)}"
140
+ )
141
+ return data
142
+
143
+ def SendAsync(self, tag: str, data: bytes) -> None:
144
+ """Async send.
145
+
146
+ For HttpCommunicator, underlying HTTP client is non-blocking.
147
+ For ThreadCommunicator, send is instant (memory transfer).
148
+
149
+ Args:
150
+ tag: Message tag
151
+ data: Raw bytes to send
152
+ """
153
+ self.Send(tag, data)
154
+
155
+ def SendAsyncThrottled(self, tag: str, data: bytes) -> None:
156
+ """Throttled async send.
157
+
158
+ Currently maps to regular SendAsync.
159
+
160
+ Args:
161
+ tag: Message tag
162
+ data: Raw bytes to send
163
+ """
164
+ self.SendAsync(tag, data)
165
+
166
+ def TestSend(self, timeout: int) -> None:
167
+ """Test if channel can send a dummy message to peer.
168
+
169
+ Uses fixed tag "__test__" for idempotency.
170
+
171
+ Args:
172
+ timeout: Timeout in milliseconds (informational)
173
+ """
174
+ test_data = b"\x00" # Minimal 1-byte handshake
175
+ self.Send("__test__", test_data)
176
+
177
+ def TestRecv(self) -> None:
178
+ """Wait for dummy message from peer.
179
+
180
+ Timeout controlled by recv_timeout_ms in link descriptor.
181
+
182
+ Raises:
183
+ Warning if unexpected handshake data received
184
+ """
185
+ test_data = self.Recv("__test__")
186
+ if test_data != b"\x00":
187
+ logging.warning(
188
+ f"TestRecv: unexpected handshake from {self._peer_rank}, "
189
+ f"expected b'\\x00', got {test_data!r}"
190
+ )
191
+
192
+ def WaitLinkTaskFinish(self) -> None:
193
+ """Wait for all pending async tasks.
194
+
195
+ No-op for v2 communicators (handled automatically).
196
+ """
197
+
198
+ def Abort(self) -> None:
199
+ """Abort communication (cleanup).
200
+
201
+ Currently a no-op. Could be extended for resource cleanup.
202
+ """
203
+ logging.warning(f"BaseChannel.Abort: {self._local_rank} <-> {self._peer_rank}")
204
+
205
+ def SetThrottleWindowSize(self, size: int) -> None:
206
+ """Set throttle window size (no-op).
207
+
208
+ Args:
209
+ size: Window size (ignored)
210
+ """
211
+
212
+ def SetChunkParallelSendSize(self, size: int) -> None:
213
+ """Set chunk parallel send size (no-op).
214
+
215
+ Args:
216
+ size: Chunk size (ignored)
217
+ """
@@ -83,7 +83,7 @@ class HttpCommunicator:
83
83
  self.world_size = world_size
84
84
  self.endpoints = endpoints
85
85
  self.tracer = tracer
86
- self._mailbox: dict[str, Any] = {}
86
+ self._mailbox: dict[tuple[int, str], Any] = {}
87
87
  self._cond = threading.Condition()
88
88
  self._send_executor = concurrent.futures.ThreadPoolExecutor(
89
89
  max_workers=world_size, thread_name_prefix=f"comm_send_{rank}"
@@ -100,8 +100,19 @@ class HttpCommunicator:
100
100
  """Perform the HTTP send."""
101
101
  url = f"{self.endpoints[to]}/comm/{key}"
102
102
  logger.debug(f"Rank {self.rank} sending to {to} key={key}")
103
- # Use secure JSON serialization
104
- payload = serde.dumps_b64(data)
103
+
104
+ # Detect SPU channel (tag prefix "spu:") and handle bytes
105
+ if key.startswith("spu:") and isinstance(data, bytes):
106
+ # Send raw bytes for SPU channels
107
+ import base64
108
+
109
+ payload = base64.b64encode(data).decode("ascii")
110
+ is_raw_bytes = True
111
+ else:
112
+ # Use secure JSON serialization
113
+ payload = serde.dumps_b64(data)
114
+ is_raw_bytes = False
115
+
105
116
  size_bytes = len(payload)
106
117
 
107
118
  # Log to profiler
@@ -116,7 +127,14 @@ class HttpCommunicator:
116
127
 
117
128
  try:
118
129
  t0 = time.time()
119
- resp = self.client.put(url, json={"data": payload, "from_rank": self.rank})
130
+ resp = self.client.put(
131
+ url,
132
+ json={
133
+ "data": payload,
134
+ "from_rank": self.rank,
135
+ "is_raw_bytes": is_raw_bytes,
136
+ },
137
+ )
120
138
  resp.raise_for_status()
121
139
  duration = time.time() - t0
122
140
  if self.tracer:
@@ -134,17 +152,21 @@ class HttpCommunicator:
134
152
  def recv(self, frm: int, key: str) -> Any:
135
153
  """Receive data from another rank (blocking)."""
136
154
  logger.debug(f"Rank {self.rank} waiting recv from {frm} key={key}")
155
+ mailbox_key = (frm, key)
137
156
  with self._cond:
138
- while key not in self._mailbox:
157
+ while mailbox_key not in self._mailbox:
139
158
  self._cond.wait(timeout=1.0)
140
- return self._mailbox.pop(key)
159
+ return self._mailbox.pop(mailbox_key)
141
160
 
142
- def on_receive(self, key: str, data: Any) -> None:
161
+ def on_receive(self, from_rank: int, key: str, data: Any) -> None:
143
162
  """Called when data is received from the HTTP endpoint."""
163
+ mailbox_key = (from_rank, key)
144
164
  with self._cond:
145
- if key in self._mailbox:
146
- logger.warning(f"Rank {self.rank} overwriting key={key}")
147
- self._mailbox[key] = data
165
+ if mailbox_key in self._mailbox:
166
+ raise RuntimeError(
167
+ f"Mailbox overflow: key {mailbox_key} already exists"
168
+ )
169
+ self._mailbox[mailbox_key] = data
148
170
  self._cond.notify_all()
149
171
 
150
172
  def wait_pending_sends(self) -> None:
@@ -176,6 +198,7 @@ class CommRequest(BaseModel):
176
198
 
177
199
  data: str
178
200
  from_rank: int
201
+ is_raw_bytes: bool = False # NEW: indicates raw bytes (not serde)
179
202
 
180
203
 
181
204
  class FetchRequest(BaseModel):
@@ -279,9 +302,17 @@ def create_worker_app(
279
302
  """Receive communication data from another worker."""
280
303
  logger.debug(f"Worker {rank} received comm key={key} from {req.from_rank}")
281
304
  try:
282
- # Use secure JSON deserialization
283
- data = serde.loads_b64(req.data)
284
- comm.on_receive(key, data)
305
+ # Handle raw bytes (SPU channels) vs serde data
306
+ if req.is_raw_bytes:
307
+ # Decode base64 to raw bytes
308
+ import base64
309
+
310
+ data = base64.b64decode(req.data)
311
+ else:
312
+ # Use secure JSON deserialization
313
+ data = serde.loads_b64(req.data)
314
+
315
+ comm.on_receive(req.from_rank, key, data)
285
316
  return {"status": "ok"}
286
317
  except Exception as e:
287
318
  logger.error(f"Worker {rank} comm failed: {e}")
@@ -35,7 +35,8 @@ class ThreadCommunicator:
35
35
  self.world_size = world_size
36
36
  self.use_serde = use_serde
37
37
  self.peers: list[ThreadCommunicator] = []
38
- self._mailbox: dict[str, Any] = {}
38
+ # Mailbox keyed by (from_rank, tag): each key has exactly one message
39
+ self._mailbox: dict[tuple[int, str], Any] = {}
39
40
  self._cond = threading.Condition()
40
41
  self._sent_events: dict[str, threading.Event] = {}
41
42
  self._shutdown = False
@@ -58,20 +59,22 @@ class ThreadCommunicator:
58
59
  self.peers[to]._on_receive(self.rank, key, data)
59
60
 
60
61
  def recv(self, frm: int, key: str) -> Any:
62
+ mailbox_key = (frm, key)
61
63
  with self._cond:
62
- while key not in self._mailbox and not self._shutdown:
64
+ while mailbox_key not in self._mailbox and not self._shutdown:
63
65
  self._cond.wait()
64
66
  if self._shutdown:
65
67
  raise RuntimeError("Communicator shut down")
66
- return self._mailbox.pop(key)
68
+ return self._mailbox.pop(mailbox_key)
67
69
 
68
70
  def _on_receive(self, frm: int, key: str, data: Any) -> None:
71
+ mailbox_key = (frm, key)
69
72
  with self._cond:
70
- if key in self._mailbox:
73
+ if mailbox_key in self._mailbox:
71
74
  raise RuntimeError(
72
- f"Mailbox overflow for key {key} at rank {self.rank}"
75
+ f"Mailbox overflow: key {mailbox_key} already exists"
73
76
  )
74
- self._mailbox[key] = data
77
+ self._mailbox[mailbox_key] = data
75
78
  self._cond.notify_all()
76
79
 
77
80
 
@@ -26,7 +26,6 @@ import numpy as np
26
26
  import spu.api as spu_api
27
27
  import spu.libspu as libspu
28
28
 
29
- from mplang.v2.backends.simp_worker import SimpWorker
30
29
  from mplang.v2.backends.spu_state import SPUState
31
30
  from mplang.v2.backends.tensor_impl import TensorValue
32
31
  from mplang.v2.dialects import spu
@@ -161,6 +160,8 @@ def exec_impl(interpreter: Interpreter, op: Operation, *args: Any) -> Any:
161
160
  The SPU config must contain parties info to correctly map global rank
162
161
  to local SPU rank and determine SPU world size.
163
162
  """
163
+ from mplang.v2.backends.simp_worker.state import SimpWorker
164
+
164
165
  # Get SPU config from attrs (passed through from run_jax)
165
166
  config: spu.SPUConfig = op.attrs["config"]
166
167
 
@@ -193,9 +194,8 @@ def exec_impl(interpreter: Interpreter, op: Operation, *args: Any) -> Any:
193
194
  interpreter, "spu_endpoints", None
194
195
  )
195
196
  if spu_endpoints_map is None:
196
- context = interpreter.get_dialect_state("simp")
197
- if context is not None:
198
- spu_endpoints_map = getattr(context, "spu_endpoints", None)
197
+ # Try getting from SimpWorker context (context is already SimpWorker)
198
+ spu_endpoints_map = getattr(context, "spu_endpoints", None)
199
199
 
200
200
  # Build ordered list of endpoints for SPU parties
201
201
  spu_endpoints: list[str] | None = None
@@ -209,6 +209,14 @@ def exec_impl(interpreter: Interpreter, op: Operation, *args: Any) -> Any:
209
209
  )
210
210
  spu_endpoints.append(spu_endpoints_map[party_rank])
211
211
 
212
+ # Get communicator for Channels mode (reuse existing communication)
213
+ # If no BRPC endpoints configured, use Channels mode
214
+ communicator = None
215
+ if spu_endpoints is None:
216
+ # Use worker's communicator for channel reuse
217
+ # (SimpWorker already imported at function start)
218
+ communicator = context.communicator
219
+
212
220
  # Get or create SPUState for caching Runtime/Io
213
221
  spu_state = interpreter.get_dialect_state(SPUState.dialect_name)
214
222
  if not isinstance(spu_state, SPUState):
@@ -216,7 +224,12 @@ def exec_impl(interpreter: Interpreter, op: Operation, *args: Any) -> Any:
216
224
  interpreter.set_dialect_state(SPUState.dialect_name, spu_state)
217
225
 
218
226
  runtime, io = spu_state.get_or_create(
219
- local_rank, spu_world_size, config, spu_endpoints
227
+ local_rank,
228
+ spu_world_size,
229
+ config,
230
+ spu_endpoints,
231
+ communicator=communicator,
232
+ parties=list(parties),
220
233
  )
221
234
 
222
235
  executable_code = op.attrs["executable"]
@@ -20,7 +20,7 @@ multiple executions while binding to the Interpreter's lifecycle.
20
20
 
21
21
  from __future__ import annotations
22
22
 
23
- from typing import TYPE_CHECKING
23
+ from typing import TYPE_CHECKING, Any
24
24
 
25
25
  import spu.api as spu_api
26
26
  import spu.libspu as libspu
@@ -56,6 +56,8 @@ class SPUState(DialectState):
56
56
  spu_world_size: int,
57
57
  config: spu.SPUConfig,
58
58
  spu_endpoints: list[str] | None = None,
59
+ communicator: object | None = None,
60
+ parties: list[int] | None = None,
59
61
  ) -> tuple[spu_api.Runtime, spu_api.Io]:
60
62
  """Get or create SPU Runtime and Io for the given configuration.
61
63
 
@@ -64,13 +66,24 @@ class SPUState(DialectState):
64
66
  spu_world_size: The number of parties in the SPU device.
65
67
  config: SPU configuration including protocol settings.
66
68
  spu_endpoints: Optional list of BRPC endpoints. If None, use mem link.
69
+ communicator: Optional v2 communicator (ThreadCommunicator/HttpCommunicator).
70
+ If provided, use Channels mode to reuse existing communication.
71
+ parties: Optional list of global ranks for SPU parties.
72
+ Required when communicator is provided.
67
73
 
68
74
  Returns:
69
75
  A tuple of (Runtime, Io) for this party.
70
76
  """
71
77
  from mplang.v2.backends.spu_impl import to_runtime_config
72
78
 
73
- link_mode = "brpc" if spu_endpoints else "mem"
79
+ # Determine link mode
80
+ if communicator is not None:
81
+ link_mode = "channels"
82
+ elif spu_endpoints:
83
+ link_mode = "brpc"
84
+ else:
85
+ link_mode = "mem"
86
+
74
87
  cache_key = (
75
88
  local_rank,
76
89
  spu_world_size,
@@ -83,7 +96,13 @@ class SPUState(DialectState):
83
96
  return self._runtimes[cache_key]
84
97
 
85
98
  # Create Link
86
- if spu_endpoints:
99
+ if communicator is not None:
100
+ if parties is None:
101
+ raise ValueError("parties required when using communicator")
102
+ link = self._create_channels_link(
103
+ local_rank, spu_world_size, communicator, parties
104
+ )
105
+ elif spu_endpoints:
87
106
  link = self._create_brpc_link(local_rank, spu_endpoints)
88
107
  else:
89
108
  link = self._create_mem_link(local_rank, spu_world_size)
@@ -106,6 +125,50 @@ class SPUState(DialectState):
106
125
  desc.add_party(f"P{i}", f"mem:{i}")
107
126
  return libspu.link.create_mem(desc, local_rank)
108
127
 
128
+ def _create_channels_link(
129
+ self,
130
+ local_rank: int,
131
+ spu_world_size: int,
132
+ communicator: Any,
133
+ parties: list[int],
134
+ ) -> libspu.link.Context:
135
+ """Create link using custom channels (reuse v2 communicator).
136
+
137
+ Args:
138
+ local_rank: SPU local rank (0-indexed, already converted from global)
139
+ spu_world_size: Number of SPU parties
140
+ communicator: v2 communicator (ThreadCommunicator/HttpCommunicator)
141
+ parties: List of global ranks for SPU parties (ordered by local rank)
142
+
143
+ Returns:
144
+ libspu link context using BaseChannel adapters
145
+ """
146
+ from mplang.v2.backends.channel import BaseChannel
147
+
148
+ # Get this worker's global rank
149
+ global_rank = parties[local_rank]
150
+
151
+ # Create channels list (world_size elements, self = None)
152
+ channels = []
153
+ for idx, peer_global_rank in enumerate(parties):
154
+ if idx == local_rank:
155
+ # Self channel must be None
156
+ channel = None
157
+ else:
158
+ # Create channel to peer
159
+ channel = BaseChannel(communicator, global_rank, peer_global_rank)
160
+ channels.append(channel)
161
+
162
+ # Create link descriptor
163
+ desc = libspu.link.Desc() # type: ignore
164
+ desc.recv_timeout_ms = 100 * 1000 # 100 seconds
165
+
166
+ # Add party info (required for world_size inference)
167
+ for idx in range(spu_world_size):
168
+ desc.add_party(f"P{idx}", f"dummy_{parties[idx]}")
169
+
170
+ return libspu.link.create_with_channels(desc, local_rank, channels)
171
+
109
172
  def _create_brpc_link(
110
173
  self, local_rank: int, spu_endpoints: list[str]
111
174
  ) -> libspu.link.Context:
mplang/v2/libs/ml/sgb.py CHANGED
@@ -1097,11 +1097,9 @@ def _update_tree_state(
1097
1097
 
1098
1098
  all_feats[party_idx] = simp.pcall_static(
1099
1099
  (party_rank,),
1100
- lambda pf=all_feats[party_idx],
1101
- bf=all_feats_level[party_idx],
1102
- ci=cur_indices_party,
1103
- op=owned_party_party,
1104
- il=is_leaf_party: tensor.run_jax(update_party_feats, pf, bf, ci, op, il),
1100
+ lambda pf=all_feats[party_idx], bf=all_feats_level[party_idx], ci=cur_indices_party, op=owned_party_party, il=is_leaf_party: (
1101
+ tensor.run_jax(update_party_feats, pf, bf, ci, op, il)
1102
+ ),
1105
1103
  )
1106
1104
 
1107
1105
  def update_party_thresholds(
@@ -1123,21 +1121,17 @@ def _update_tree_state(
1123
1121
 
1124
1122
  all_thresholds[party_idx] = simp.pcall_static(
1125
1123
  (party_rank,),
1126
- lambda pt=all_thresholds[party_idx],
1127
- b=all_bins[party_idx],
1128
- bf=all_feats_level[party_idx],
1129
- bt_idx=all_threshs_level[party_idx],
1130
- ci=cur_indices_party,
1131
- op=owned_party_party,
1132
- il=is_leaf_party: tensor.run_jax(
1133
- update_party_thresholds,
1134
- pt,
1135
- b,
1136
- bf,
1137
- bt_idx,
1138
- ci,
1139
- op,
1140
- il,
1124
+ lambda pt=all_thresholds[party_idx], b=all_bins[party_idx], bf=all_feats_level[party_idx], bt_idx=all_threshs_level[party_idx], ci=cur_indices_party, op=owned_party_party, il=is_leaf_party: (
1125
+ tensor.run_jax(
1126
+ update_party_thresholds,
1127
+ pt,
1128
+ b,
1129
+ bf,
1130
+ bt_idx,
1131
+ ci,
1132
+ op,
1133
+ il,
1134
+ )
1141
1135
  ),
1142
1136
  )
1143
1137
 
@@ -1152,13 +1146,8 @@ def _update_tree_state(
1152
1146
 
1153
1147
  tmp_bt = simp.pcall_static(
1154
1148
  (party_rank,),
1155
- lambda bi=all_bin_indices[party_idx],
1156
- bf=all_feats_level[party_idx],
1157
- bt_idx=all_threshs_level[party_idx],
1158
- bt_arr=bt_party,
1159
- bt_lv=bt_level_party,
1160
- il=is_leaf_party: tensor.run_jax(
1161
- update_bt, bt_arr, bt_lv, il, bi, bf, bt_idx
1149
+ lambda bi=all_bin_indices[party_idx], bf=all_feats_level[party_idx], bt_idx=all_threshs_level[party_idx], bt_arr=bt_party, bt_lv=bt_level_party, il=is_leaf_party: (
1150
+ tensor.run_jax(update_bt, bt_arr, bt_lv, il, bi, bf, bt_idx)
1162
1151
  ),
1163
1152
  )
1164
1153
 
@@ -1498,11 +1487,10 @@ def predict_tree(
1498
1487
  for i, rank in enumerate(all_ranks):
1499
1488
  mask = simp.pcall_static(
1500
1489
  (rank,),
1501
- lambda d=all_datas[i],
1502
- f=tree.feature[i],
1503
- t=tree.threshold[i],
1504
- idx=i: predict_tree_single_party(
1505
- d, f, t, tree.is_leaf, tree.owned_party_id, idx, n_nodes
1490
+ lambda d=all_datas[i], f=tree.feature[i], t=tree.threshold[i], idx=i: (
1491
+ predict_tree_single_party(
1492
+ d, f, t, tree.is_leaf, tree.owned_party_id, idx, n_nodes
1493
+ )
1506
1494
  ),
1507
1495
  )
1508
1496
  # Transfer to AP
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: mplang-nightly
3
- Version: 0.1.dev266
3
+ Version: 0.1.dev267
4
4
  Summary: Multi-Party Programming Language
5
5
  Author-email: SecretFlow Team <secretflow-contact@service.alipay.com>
6
6
  License: Apache License
@@ -219,7 +219,7 @@ Requires-Dist: pandas>=2.0.0
219
219
  Requires-Dist: protobuf<6.0,>=5.0
220
220
  Requires-Dist: pyarrow>=14.0.0
221
221
  Requires-Dist: pyyaml>=6.0
222
- Requires-Dist: spu>=0.10.0.dev20251208
222
+ Requires-Dist: spu>=0.10.0.dev20251211
223
223
  Requires-Dist: sqlglot>=23.0.0
224
224
  Requires-Dist: tenseal==0.3.16
225
225
  Requires-Dist: typing-extensions