mplang-nightly 0.1.dev254__py3-none-any.whl → 0.1.dev255__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -27,6 +27,7 @@ import spu.api as spu_api
27
27
  import spu.libspu as libspu
28
28
 
29
29
  from mplang.v2.backends.simp_worker import SimpWorker
30
+ from mplang.v2.backends.spu_state import SPUState
30
31
  from mplang.v2.backends.tensor_impl import TensorValue
31
32
  from mplang.v2.dialects import spu
32
33
  from mplang.v2.edsl import serde
@@ -107,84 +108,6 @@ def to_runtime_config(config: spu.SPUConfig) -> libspu.RuntimeConfig:
107
108
  return runtime_config
108
109
 
109
110
 
110
- # Global cache for SPU runtimes per (local_rank, world_size) pair
111
- # Key: (local_rank, spu_world_size, protocol, field, link_mode), Value: (Runtime, Io)
112
- _SPU_RUNTIMES: dict[
113
- tuple[int, int, str, str, str], tuple[spu_api.Runtime, spu_api.Io]
114
- ] = {}
115
-
116
-
117
- def _create_mem_link(local_rank: int, spu_world_size: int) -> libspu.link.Context:
118
- """Create in-memory link for simulation."""
119
- desc = libspu.link.Desc() # type: ignore
120
- desc.recv_timeout_ms = 30 * 1000
121
- for i in range(spu_world_size):
122
- desc.add_party(f"P{i}", f"mem:{i}")
123
- return libspu.link.create_mem(desc, local_rank)
124
-
125
-
126
- def _create_brpc_link(local_rank: int, spu_endpoints: list[str]) -> libspu.link.Context:
127
- """Create BRPC link for distributed execution.
128
-
129
- Args:
130
- local_rank: The local rank within the SPU device (0-indexed).
131
- spu_endpoints: List of BRPC endpoints for all SPU parties.
132
-
133
- Returns:
134
- A libspu.link.Context for BRPC communication.
135
- """
136
- desc = libspu.link.Desc() # type: ignore
137
- desc.recv_timeout_ms = 100 * 1000 # 100 seconds
138
- desc.http_max_payload_size = 32 * 1024 * 1024 # 32MB
139
-
140
- for i, endpoint in enumerate(spu_endpoints):
141
- desc.add_party(f"P{i}", endpoint)
142
-
143
- return libspu.link.create_brpc(desc, local_rank)
144
-
145
-
146
- def _get_spu_ctx(
147
- local_rank: int,
148
- spu_world_size: int,
149
- config: spu.SPUConfig,
150
- spu_endpoints: list[str] | None = None,
151
- ) -> tuple[spu_api.Runtime, spu_api.Io]:
152
- """Get or create SPU runtime and IO for the given local rank within SPU.
153
-
154
- Args:
155
- local_rank: The local rank within the SPU device (0-indexed).
156
- spu_world_size: The number of parties in the SPU device.
157
- config: SPU configuration including protocol settings.
158
- spu_endpoints: Optional list of BRPC endpoints. If None, use mem link.
159
-
160
- Returns:
161
- A tuple of (Runtime, Io) for this party.
162
- """
163
- # Determine link mode
164
- link_mode = "brpc" if spu_endpoints else "mem"
165
-
166
- # Include protocol, field, and link_mode in cache key
167
- cache_key = (local_rank, spu_world_size, config.protocol, config.field, link_mode)
168
- if cache_key in _SPU_RUNTIMES:
169
- return _SPU_RUNTIMES[cache_key]
170
-
171
- # Create Link
172
- if spu_endpoints:
173
- link = _create_brpc_link(local_rank, spu_endpoints)
174
- else:
175
- link = _create_mem_link(local_rank, spu_world_size)
176
-
177
- # Use config from SPUConfig
178
- runtime_config = to_runtime_config(config)
179
-
180
- # Create Runtime and Io
181
- runtime = spu_api.Runtime(link, runtime_config)
182
- io = spu_api.Io(spu_world_size, runtime_config)
183
-
184
- _SPU_RUNTIMES[cache_key] = (runtime, io)
185
- return runtime, io
186
-
187
-
188
111
  @spu.makeshares_p.def_impl
189
112
  def makeshares_impl(
190
113
  interpreter: Interpreter, op: Operation, data: TensorValue
@@ -293,7 +216,15 @@ def exec_impl(interpreter: Interpreter, op: Operation, *args: Any) -> Any:
293
216
  )
294
217
  spu_endpoints.append(spu_endpoints_map[party_rank])
295
218
 
296
- runtime, io = _get_spu_ctx(local_rank, spu_world_size, config, spu_endpoints)
219
+ # Get or create SPUState for caching Runtime/Io
220
+ spu_state = interpreter.get_dialect_state(SPUState.dialect_name)
221
+ if not isinstance(spu_state, SPUState):
222
+ spu_state = SPUState()
223
+ interpreter.set_dialect_state(SPUState.dialect_name, spu_state)
224
+
225
+ runtime, io = spu_state.get_or_create(
226
+ local_rank, spu_world_size, config, spu_endpoints
227
+ )
297
228
 
298
229
  executable_code = op.attrs["executable"]
299
230
  input_names = op.attrs["input_names"]
@@ -0,0 +1,124 @@
1
+ # Copyright 2025 Ant Group Co., Ltd.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """SPU Dialect State.
16
+
17
+ Manages SPU Runtime lifecycle as a dialect state, enabling reuse across
18
+ multiple executions while binding to the Interpreter's lifecycle.
19
+ """
20
+
21
+ from __future__ import annotations
22
+
23
+ from typing import TYPE_CHECKING
24
+
25
+ import spu.api as spu_api
26
+ import spu.libspu as libspu
27
+
28
+ from mplang.v2.runtime.dialect_state import DialectState
29
+
30
+ if TYPE_CHECKING:
31
+ from mplang.v2.dialects import spu
32
+
33
+
34
+ class SPUState(DialectState):
35
+ """SPU Runtime cache as dialect state.
36
+
37
+ Caches SPU Runtime and Io objects per (local_rank, world_size, config, link_mode)
38
+ to enable reuse across multiple SPU kernel executions.
39
+
40
+ This replaces the previous global `_SPU_RUNTIMES` cache with a properly
41
+ lifecycle-managed dialect state.
42
+ """
43
+
44
+ dialect_name: str = "spu"
45
+
46
+ def __init__(self) -> None:
47
+ # Key: (local_rank, world_size, protocol, field, link_mode)
48
+ # Value: (Runtime, Io)
49
+ self._runtimes: dict[
50
+ tuple[int, int, str, str, str], tuple[spu_api.Runtime, spu_api.Io]
51
+ ] = {}
52
+
53
+ def get_or_create(
54
+ self,
55
+ local_rank: int,
56
+ spu_world_size: int,
57
+ config: spu.SPUConfig,
58
+ spu_endpoints: list[str] | None = None,
59
+ ) -> tuple[spu_api.Runtime, spu_api.Io]:
60
+ """Get or create SPU Runtime and Io for the given configuration.
61
+
62
+ Args:
63
+ local_rank: The local rank within the SPU device (0-indexed).
64
+ spu_world_size: The number of parties in the SPU device.
65
+ config: SPU configuration including protocol settings.
66
+ spu_endpoints: Optional list of BRPC endpoints. If None, use mem link.
67
+
68
+ Returns:
69
+ A tuple of (Runtime, Io) for this party.
70
+ """
71
+ from mplang.v2.backends.spu_impl import to_runtime_config
72
+
73
+ link_mode = "brpc" if spu_endpoints else "mem"
74
+ cache_key = (
75
+ local_rank,
76
+ spu_world_size,
77
+ config.protocol,
78
+ config.field,
79
+ link_mode,
80
+ )
81
+
82
+ if cache_key in self._runtimes:
83
+ return self._runtimes[cache_key]
84
+
85
+ # Create Link
86
+ if spu_endpoints:
87
+ link = self._create_brpc_link(local_rank, spu_endpoints)
88
+ else:
89
+ link = self._create_mem_link(local_rank, spu_world_size)
90
+
91
+ # Create Runtime and Io
92
+ runtime_config = to_runtime_config(config)
93
+ runtime = spu_api.Runtime(link, runtime_config)
94
+ io = spu_api.Io(spu_world_size, runtime_config)
95
+
96
+ self._runtimes[cache_key] = (runtime, io)
97
+ return runtime, io
98
+
99
+ def _create_mem_link(
100
+ self, local_rank: int, spu_world_size: int
101
+ ) -> libspu.link.Context:
102
+ """Create in-memory link for simulation."""
103
+ desc = libspu.link.Desc() # type: ignore
104
+ desc.recv_timeout_ms = 30 * 1000
105
+ for i in range(spu_world_size):
106
+ desc.add_party(f"P{i}", f"mem:{i}")
107
+ return libspu.link.create_mem(desc, local_rank)
108
+
109
+ def _create_brpc_link(
110
+ self, local_rank: int, spu_endpoints: list[str]
111
+ ) -> libspu.link.Context:
112
+ """Create BRPC link for distributed execution."""
113
+ desc = libspu.link.Desc() # type: ignore
114
+ desc.recv_timeout_ms = 100 * 1000 # 100 seconds
115
+ desc.http_max_payload_size = 32 * 1024 * 1024 # 32MB
116
+
117
+ for i, endpoint in enumerate(spu_endpoints):
118
+ desc.add_party(f"P{i}", endpoint)
119
+
120
+ return libspu.link.create_brpc(desc, local_rank)
121
+
122
+ def shutdown(self) -> None:
123
+ """Clear all cached runtimes."""
124
+ self._runtimes.clear()
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: mplang-nightly
3
- Version: 0.1.dev254
3
+ Version: 0.1.dev255
4
4
  Summary: Multi-Party Programming Language
5
5
  Author-email: SecretFlow Team <secretflow-contact@service.alipay.com>
6
6
  License: Apache License
@@ -87,7 +87,8 @@ mplang/v2/backends/field_impl.py,sha256=50sKGOlkUiaTj_IAola86uQeoi-fxV0o7G91BdTC
87
87
  mplang/v2/backends/func_impl.py,sha256=R0662cC0gSSfkjuLyevJ_g4bJDJirY76LTFYqEimCkE,3585
88
88
  mplang/v2/backends/phe_impl.py,sha256=r836e_qBHGrHhfnFail5IaUDzvS7bABjdEQmJmAtBVI,4127
89
89
  mplang/v2/backends/simp_design.md,sha256=CXvfxrvV1TmKlFm8IbKTbcHHwLl6AhwlY_cNqMdff_Y,5250
90
- mplang/v2/backends/spu_impl.py,sha256=pMjdD8_wMs1scSoJqsnZnKRrPbkfcCu-U-hxyz-EN_0,11757
90
+ mplang/v2/backends/spu_impl.py,sha256=gKyueQZQXRQhJ_7q3EQ74ItJntzeFdgTnPtU2mJRqF8,9466
91
+ mplang/v2/backends/spu_state.py,sha256=wj876IvNPhKyWISN6WwKBYoaDQFFJ8jemdJUVeH5IfA,4144
91
92
  mplang/v2/backends/store_impl.py,sha256=RyhADTNsnnNnwsatAMr7eeewXkVXtfNWA1oFiLXg8H0,2222
92
93
  mplang/v2/backends/table_impl.py,sha256=c36gyBCWLQbV3g0hkJeTnMXUqT0nxgu74k2sLondTio,8784
93
94
  mplang/v2/backends/tee_impl.py,sha256=5PzzQ6mibd6-Wyvvt_8DD6G-CzA4VAmqGk6H1Z9risI,6986
@@ -169,8 +170,8 @@ mplang/v2/runtime/dialect_state.py,sha256=HxO1i4kSOujS2tQzAF9-WmI3nChSaGgupf2_07
169
170
  mplang/v2/runtime/interpreter.py,sha256=UzrM5oepka6H0YKRZncNXhsuwKVm4pliG5J92fFRZMI,32300
170
171
  mplang/v2/runtime/object_store.py,sha256=yT6jtKG2GUEJVmpq3gnQ8mCMvUFYzgBciC5A-J5KRdk,5998
171
172
  mplang/v2/runtime/value.py,sha256=CMOxElJP78v7pjasPhEpbxWbSgB2KsLbpPmzz0mQX0E,4317
172
- mplang_nightly-0.1.dev254.dist-info/METADATA,sha256=ZcoLTH8VISCxJm5lA7S8eZBNSF5vEYWphlx8Ni9fO38,16768
173
- mplang_nightly-0.1.dev254.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
174
- mplang_nightly-0.1.dev254.dist-info/entry_points.txt,sha256=mG1oJT-GAjQR834a62_QIWb7litzWPPyVnwFqm-rWuY,55
175
- mplang_nightly-0.1.dev254.dist-info/licenses/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
176
- mplang_nightly-0.1.dev254.dist-info/RECORD,,
173
+ mplang_nightly-0.1.dev255.dist-info/METADATA,sha256=k-_Pe_IksZD0UsPf-oPxMv_FlGNEKKfWwz4BjCVqC00,16768
174
+ mplang_nightly-0.1.dev255.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
175
+ mplang_nightly-0.1.dev255.dist-info/entry_points.txt,sha256=mG1oJT-GAjQR834a62_QIWb7litzWPPyVnwFqm-rWuY,55
176
+ mplang_nightly-0.1.dev255.dist-info/licenses/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
177
+ mplang_nightly-0.1.dev255.dist-info/RECORD,,