mplang-nightly 0.1.dev268__py3-none-any.whl → 0.1.dev270__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (181) hide show
  1. mplang/__init__.py +391 -17
  2. mplang/{v2/backends → backends}/__init__.py +9 -7
  3. mplang/{v2/backends → backends}/bfv_impl.py +6 -6
  4. mplang/{v2/backends → backends}/crypto_impl.py +6 -6
  5. mplang/{v2/backends → backends}/field_impl.py +5 -5
  6. mplang/{v2/backends → backends}/func_impl.py +4 -4
  7. mplang/{v2/backends → backends}/phe_impl.py +3 -3
  8. mplang/{v2/backends → backends}/simp_design.md +1 -1
  9. mplang/{v2/backends → backends}/simp_driver/__init__.py +5 -5
  10. mplang/{v2/backends → backends}/simp_driver/http.py +8 -8
  11. mplang/{v2/backends → backends}/simp_driver/mem.py +9 -9
  12. mplang/{v2/backends → backends}/simp_driver/ops.py +4 -4
  13. mplang/{v2/backends → backends}/simp_driver/state.py +2 -2
  14. mplang/{v2/backends → backends}/simp_driver/values.py +2 -2
  15. mplang/{v2/backends → backends}/simp_worker/__init__.py +3 -3
  16. mplang/{v2/backends → backends}/simp_worker/http.py +10 -10
  17. mplang/{v2/backends → backends}/simp_worker/mem.py +1 -1
  18. mplang/{v2/backends → backends}/simp_worker/ops.py +5 -5
  19. mplang/{v2/backends → backends}/simp_worker/state.py +2 -4
  20. mplang/{v2/backends → backends}/spu_impl.py +8 -8
  21. mplang/{v2/backends → backends}/spu_state.py +4 -4
  22. mplang/{v2/backends → backends}/store_impl.py +3 -3
  23. mplang/{v2/backends → backends}/table_impl.py +8 -8
  24. mplang/{v2/backends → backends}/tee_impl.py +6 -6
  25. mplang/{v2/backends → backends}/tensor_impl.py +6 -6
  26. mplang/{v2/cli.py → cli.py} +9 -9
  27. mplang/{v2/cli_guide.md → cli_guide.md} +12 -12
  28. mplang/{v2/dialects → dialects}/__init__.py +5 -5
  29. mplang/{v2/dialects → dialects}/bfv.py +6 -6
  30. mplang/{v2/dialects → dialects}/crypto.py +5 -5
  31. mplang/{v2/dialects → dialects}/dtypes.py +2 -2
  32. mplang/{v2/dialects → dialects}/field.py +3 -3
  33. mplang/{v2/dialects → dialects}/func.py +2 -2
  34. mplang/{v2/dialects → dialects}/phe.py +6 -6
  35. mplang/{v2/dialects → dialects}/simp.py +6 -6
  36. mplang/{v2/dialects → dialects}/spu.py +7 -7
  37. mplang/{v2/dialects → dialects}/store.py +2 -2
  38. mplang/{v2/dialects → dialects}/table.py +3 -3
  39. mplang/{v2/dialects → dialects}/tee.py +6 -6
  40. mplang/{v2/dialects → dialects}/tensor.py +5 -5
  41. mplang/{v2/edsl → edsl}/__init__.py +3 -3
  42. mplang/{v2/edsl → edsl}/context.py +6 -6
  43. mplang/{v2/edsl → edsl}/graph.py +5 -5
  44. mplang/{v2/edsl → edsl}/jit.py +2 -2
  45. mplang/{v2/edsl → edsl}/object.py +1 -1
  46. mplang/{v2/edsl → edsl}/primitive.py +5 -5
  47. mplang/{v2/edsl → edsl}/printer.py +1 -1
  48. mplang/{v2/edsl → edsl}/serde.py +1 -1
  49. mplang/{v2/edsl → edsl}/tracer.py +7 -7
  50. mplang/{v2/edsl → edsl}/typing.py +1 -1
  51. mplang/{v2/kernels → kernels}/ldpc.cpp +13 -13
  52. mplang/{v2/kernels → kernels}/okvs.cpp +4 -4
  53. mplang/{v2/kernels → kernels}/okvs_opt.cpp +46 -31
  54. mplang/{v2/kernels → kernels}/py_kernels.py +1 -1
  55. mplang/{v2/libs → libs}/collective.py +5 -5
  56. mplang/{v2/libs → libs}/device/__init__.py +1 -1
  57. mplang/{v2/libs → libs}/device/api.py +12 -12
  58. mplang/{v2/libs → libs}/ml/__init__.py +1 -1
  59. mplang/{v2/libs → libs}/ml/sgb.py +4 -4
  60. mplang/{v2/libs → libs}/mpc/__init__.py +3 -3
  61. mplang/{v2/libs → libs}/mpc/_utils.py +2 -2
  62. mplang/{v2/libs → libs}/mpc/analytics/aggregation.py +1 -1
  63. mplang/{v2/libs → libs}/mpc/analytics/groupby.py +2 -2
  64. mplang/{v2/libs → libs}/mpc/analytics/permutation.py +3 -3
  65. mplang/{v2/libs → libs}/mpc/ot/base.py +3 -3
  66. mplang/{v2/libs → libs}/mpc/ot/extension.py +2 -2
  67. mplang/{v2/libs → libs}/mpc/ot/silent.py +4 -4
  68. mplang/{v2/libs → libs}/mpc/psi/cuckoo.py +3 -3
  69. mplang/{v2/libs → libs}/mpc/psi/okvs.py +1 -1
  70. mplang/{v2/libs → libs}/mpc/psi/okvs_gct.py +19 -13
  71. mplang/{v2/libs → libs}/mpc/psi/oprf.py +3 -3
  72. mplang/libs/mpc/psi/rr22.py +303 -0
  73. mplang/{v2/libs → libs}/mpc/psi/unbalanced.py +4 -4
  74. mplang/{v2/libs → libs}/mpc/vole/gilboa.py +3 -3
  75. mplang/{v2/libs → libs}/mpc/vole/ldpc.py +2 -2
  76. mplang/{v2/libs → libs}/mpc/vole/silver.py +6 -6
  77. mplang/{v2/runtime → runtime}/interpreter.py +11 -11
  78. mplang/{v2/runtime → runtime}/value.py +2 -2
  79. mplang/{v1/runtime → utils}/__init__.py +18 -15
  80. mplang/{v1/utils → utils}/func_utils.py +1 -1
  81. {mplang_nightly-0.1.dev268.dist-info → mplang_nightly-0.1.dev270.dist-info}/METADATA +2 -2
  82. mplang_nightly-0.1.dev270.dist-info/RECORD +102 -0
  83. mplang/v1/__init__.py +0 -157
  84. mplang/v1/_device.py +0 -602
  85. mplang/v1/analysis/__init__.py +0 -37
  86. mplang/v1/analysis/diagram.py +0 -567
  87. mplang/v1/core/__init__.py +0 -157
  88. mplang/v1/core/cluster.py +0 -343
  89. mplang/v1/core/comm.py +0 -281
  90. mplang/v1/core/context_mgr.py +0 -50
  91. mplang/v1/core/dtypes.py +0 -335
  92. mplang/v1/core/expr/__init__.py +0 -80
  93. mplang/v1/core/expr/ast.py +0 -542
  94. mplang/v1/core/expr/evaluator.py +0 -581
  95. mplang/v1/core/expr/printer.py +0 -285
  96. mplang/v1/core/expr/transformer.py +0 -141
  97. mplang/v1/core/expr/utils.py +0 -78
  98. mplang/v1/core/expr/visitor.py +0 -85
  99. mplang/v1/core/expr/walk.py +0 -387
  100. mplang/v1/core/interp.py +0 -160
  101. mplang/v1/core/mask.py +0 -325
  102. mplang/v1/core/mpir.py +0 -965
  103. mplang/v1/core/mpobject.py +0 -117
  104. mplang/v1/core/mptype.py +0 -407
  105. mplang/v1/core/pfunc.py +0 -130
  106. mplang/v1/core/primitive.py +0 -877
  107. mplang/v1/core/table.py +0 -218
  108. mplang/v1/core/tensor.py +0 -75
  109. mplang/v1/core/tracer.py +0 -383
  110. mplang/v1/host.py +0 -130
  111. mplang/v1/kernels/__init__.py +0 -41
  112. mplang/v1/kernels/base.py +0 -125
  113. mplang/v1/kernels/basic.py +0 -240
  114. mplang/v1/kernels/context.py +0 -369
  115. mplang/v1/kernels/crypto.py +0 -122
  116. mplang/v1/kernels/fhe.py +0 -858
  117. mplang/v1/kernels/mock_tee.py +0 -72
  118. mplang/v1/kernels/phe.py +0 -1864
  119. mplang/v1/kernels/spu.py +0 -341
  120. mplang/v1/kernels/sql_duckdb.py +0 -44
  121. mplang/v1/kernels/stablehlo.py +0 -90
  122. mplang/v1/kernels/value.py +0 -626
  123. mplang/v1/ops/__init__.py +0 -35
  124. mplang/v1/ops/base.py +0 -424
  125. mplang/v1/ops/basic.py +0 -294
  126. mplang/v1/ops/crypto.py +0 -262
  127. mplang/v1/ops/fhe.py +0 -272
  128. mplang/v1/ops/jax_cc.py +0 -147
  129. mplang/v1/ops/nnx_cc.py +0 -168
  130. mplang/v1/ops/phe.py +0 -216
  131. mplang/v1/ops/spu.py +0 -151
  132. mplang/v1/ops/sql_cc.py +0 -303
  133. mplang/v1/ops/tee.py +0 -36
  134. mplang/v1/protos/v1alpha1/mpir_pb2.py +0 -63
  135. mplang/v1/protos/v1alpha1/mpir_pb2.pyi +0 -557
  136. mplang/v1/protos/v1alpha1/value_pb2.py +0 -34
  137. mplang/v1/protos/v1alpha1/value_pb2.pyi +0 -169
  138. mplang/v1/runtime/channel.py +0 -230
  139. mplang/v1/runtime/cli.py +0 -451
  140. mplang/v1/runtime/client.py +0 -456
  141. mplang/v1/runtime/communicator.py +0 -131
  142. mplang/v1/runtime/data_providers.py +0 -303
  143. mplang/v1/runtime/driver.py +0 -324
  144. mplang/v1/runtime/exceptions.py +0 -27
  145. mplang/v1/runtime/http_api.md +0 -56
  146. mplang/v1/runtime/link_comm.py +0 -196
  147. mplang/v1/runtime/server.py +0 -501
  148. mplang/v1/runtime/session.py +0 -270
  149. mplang/v1/runtime/simulation.py +0 -324
  150. mplang/v1/simp/__init__.py +0 -13
  151. mplang/v1/simp/api.py +0 -353
  152. mplang/v1/simp/mpi.py +0 -131
  153. mplang/v1/simp/party.py +0 -225
  154. mplang/v1/simp/random.py +0 -120
  155. mplang/v1/simp/smpc.py +0 -238
  156. mplang/v1/utils/__init__.py +0 -13
  157. mplang/v1/utils/crypto.py +0 -32
  158. mplang/v1/utils/spu_utils.py +0 -130
  159. mplang/v1/utils/table_utils.py +0 -185
  160. mplang/v2/__init__.py +0 -424
  161. mplang/v2/libs/mpc/psi/rr22.py +0 -344
  162. mplang_nightly-0.1.dev268.dist-info/RECORD +0 -180
  163. /mplang/{v2/backends → backends}/channel.py +0 -0
  164. /mplang/{v2/edsl → edsl}/README.md +0 -0
  165. /mplang/{v2/edsl → edsl}/registry.py +0 -0
  166. /mplang/{v2/kernels → kernels}/Makefile +0 -0
  167. /mplang/{v2/kernels → kernels}/__init__.py +0 -0
  168. /mplang/{v2/kernels → kernels}/gf128.cpp +0 -0
  169. /mplang/{v2/libs → libs}/device/cluster.py +0 -0
  170. /mplang/{v2/libs → libs}/mpc/analytics/__init__.py +0 -0
  171. /mplang/{v2/libs → libs}/mpc/analytics/groupby.md +0 -0
  172. /mplang/{v2/libs → libs}/mpc/common/constants.py +0 -0
  173. /mplang/{v2/libs → libs}/mpc/ot/__init__.py +0 -0
  174. /mplang/{v2/libs → libs}/mpc/psi/__init__.py +0 -0
  175. /mplang/{v2/libs → libs}/mpc/vole/__init__.py +0 -0
  176. /mplang/{v2/runtime → runtime}/__init__.py +0 -0
  177. /mplang/{v2/runtime → runtime}/dialect_state.py +0 -0
  178. /mplang/{v2/runtime → runtime}/object_store.py +0 -0
  179. {mplang_nightly-0.1.dev268.dist-info → mplang_nightly-0.1.dev270.dist-info}/WHEEL +0 -0
  180. {mplang_nightly-0.1.dev268.dist-info → mplang_nightly-0.1.dev270.dist-info}/entry_points.txt +0 -0
  181. {mplang_nightly-0.1.dev268.dist-info → mplang_nightly-0.1.dev270.dist-info}/licenses/LICENSE +0 -0
@@ -1,270 +0,0 @@
1
- # Copyright 2025 Ant Group Co., Ltd.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
-
15
- """Core Session model (pure, no global registries).
16
-
17
- Contents:
18
- * SessionState dataclass
19
- * LinkCommFactory (SPU link reuse cache)
20
- * Session (topology derivation, runtime init, SPU env seeding, local symbol/computation storage)
21
-
22
- Process-wide registries (sessions, global symbols) live in the server layer
23
- (`server.py`) so this module remains portable and easy to unit test.
24
- """
25
-
26
- from __future__ import annotations
27
-
28
- import time
29
- from dataclasses import dataclass, field
30
- from functools import cached_property
31
- from typing import TYPE_CHECKING, Any, cast
32
-
33
- import spu.libspu as libspu
34
-
35
- from mplang.v1.core.cluster import ClusterSpec
36
- from mplang.v1.core.comm import ICommunicator
37
- from mplang.v1.core.expr.ast import Expr
38
- from mplang.v1.core.expr.evaluator import IEvaluator, create_evaluator
39
- from mplang.v1.core.mask import Mask
40
- from mplang.v1.kernels.context import RuntimeContext
41
- from mplang.v1.kernels.spu import PFunction # type: ignore
42
- from mplang.v1.kernels.value import Value
43
- from mplang.v1.runtime.communicator import HttpCommunicator
44
- from mplang.v1.runtime.exceptions import ResourceNotFound
45
- from mplang.v1.runtime.link_comm import LinkCommunicator
46
- from mplang.v1.utils.spu_utils import parse_field, parse_protocol
47
-
48
- if TYPE_CHECKING: # pragma: no cover - import only for type checking
49
- from mplang.v1.core.cluster import ClusterSpec, Node, RuntimeInfo
50
-
51
-
52
- @dataclass
53
- class Symbol:
54
- name: str
55
- mptype: Any
56
- data: Any
57
-
58
-
59
- @dataclass
60
- class Computation:
61
- name: str
62
- expr: Expr
63
-
64
-
65
- @dataclass
66
- class SessionState:
67
- runtime: RuntimeContext | None = None
68
- computations: dict[str, Computation] = field(default_factory=dict)
69
- symbols: dict[str, Symbol] = field(default_factory=dict)
70
- spu_seeded: bool = False
71
- created_ts: float = field(default_factory=time.time)
72
- last_access_ts: float = field(default_factory=time.time)
73
-
74
-
75
- class Session:
76
- """Represents the per-rank execution context.
77
-
78
- Immutable config: name, rank, cluster_spec, communicator.
79
- Derived: node, runtime_info, endpoints, spu_device, spu_mask, protocol/field, is_spu_party.
80
- Mutable: state (runtime object, symbols, computations, seeded flag).
81
-
82
- Note: communicator is assumed to be initialized with cluster spec info (e.g. endpoints).
83
- """
84
-
85
- def __init__(
86
- self,
87
- name: str,
88
- rank: int,
89
- cluster_spec: ClusterSpec,
90
- communicator: ICommunicator,
91
- ):
92
- self.name = name
93
- self.rank = rank
94
- self.cluster_spec = cluster_spec
95
- self.state = SessionState()
96
- self.communicator = communicator
97
-
98
- # --- Derived topology ---
99
- @cached_property
100
- def node(self) -> Node:
101
- return self.cluster_spec.get_node_by_rank(self.rank)
102
-
103
- @property
104
- def runtime_info(self) -> RuntimeInfo:
105
- return self.node.runtime_info
106
-
107
- @property
108
- def endpoints(self) -> list[str]:
109
- return self.cluster_spec.endpoints
110
-
111
- @cached_property
112
- def spu_device(self): # type: ignore
113
- devs = self.cluster_spec.get_devices_by_kind("SPU")
114
- if len(devs) != 1:
115
- raise RuntimeError(
116
- f"Expected exactly one SPU device, got {len(devs)} (session={self.name})"
117
- )
118
- return devs[0]
119
-
120
- @cached_property
121
- def spu_mask(self) -> Mask:
122
- return Mask.from_ranks([m.rank for m in self.spu_device.members])
123
-
124
- @property
125
- def spu_protocol(self) -> str:
126
- return cast(str, self.spu_device.config.get("protocol", "SEMI2K"))
127
-
128
- @property
129
- def spu_field(self) -> str:
130
- return cast(str, self.spu_device.config.get("field", "FM64"))
131
-
132
- @property
133
- def is_spu_party(self) -> bool:
134
- return self.rank in self.spu_mask
135
-
136
- # --- Runtime helpers ---
137
- def ensure_runtime(self) -> RuntimeContext:
138
- if self.state.runtime is None:
139
- self.state.runtime = RuntimeContext(
140
- rank=self.rank,
141
- world_size=len(self.cluster_spec.nodes), # type: ignore[attr-defined]
142
- initial_bindings=(
143
- self.runtime_info.op_bindings if self.runtime_info else {}
144
- ),
145
- )
146
- return self.state.runtime
147
-
148
- def ensure_spu_env(self) -> None:
149
- """Ensure SPU kernel env (config/world[/link]) registered on this runtime.
150
-
151
- Previous logic only seeded SPU parties; non-participating ranks then raised
152
- a hard error when the evaluator encountered SPU ops in the global program,
153
- because the kernel pocket lacked config/world. For now we register the
154
- config/world on ALL parties (idempotent) and only attach a link context for
155
- participating SPU ranks. Non-parties will still error later if they try to
156
- execute a link-dependent SPU kernel (which should be guarded by masks in the
157
- IR), but they will no longer fail early with a misleading
158
- "SPU kernel state not initialized" message.
159
- """
160
- if self.state.spu_seeded:
161
- return
162
-
163
- link_ctx = None
164
-
165
- if self.is_spu_party:
166
- # Use Channels mode to reuse existing HttpCommunicator
167
- # This eliminates the need for separate BRPC ports (SPU_PORT_OFFSET)
168
- from mplang.v1.core.comm import CommunicatorBase
169
-
170
- # Type assertion: ICommunicator is actually CommunicatorBase
171
- comm = cast(CommunicatorBase, self.communicator)
172
- link_ctx = LinkCommunicator(
173
- rank=self.rank,
174
- comm=comm,
175
- spu_mask=self.spu_mask,
176
- )
177
-
178
- spu_config = libspu.RuntimeConfig(
179
- protocol=parse_protocol(self.spu_protocol),
180
- field=parse_field(self.spu_field),
181
- fxp_fraction_bits=18,
182
- )
183
- seed_pfunc = PFunction(
184
- fn_type="spu.seed_env",
185
- ins_info=(),
186
- outs_info=(),
187
- config=spu_config,
188
- world=self.spu_mask.num_parties(),
189
- link=link_ctx,
190
- )
191
- self.ensure_runtime().run_kernel(seed_pfunc, [])
192
- self.state.spu_seeded = True
193
-
194
- # --- Computations & Symbols (instance-local) ---
195
- def add_computation(self, computation: Computation) -> None:
196
- self.state.computations[computation.name] = computation
197
-
198
- def get_computation(self, name: str) -> Computation | None:
199
- return self.state.computations.get(name)
200
-
201
- def add_symbol(self, symbol: Symbol) -> None:
202
- self.state.symbols[symbol.name] = symbol
203
-
204
- def get_symbol(self, name: str) -> Symbol | None:
205
- return self.state.symbols.get(name)
206
-
207
- def list_symbols(self) -> list[str]: # pragma: no cover - trivial
208
- return list(self.state.symbols.keys())
209
-
210
- def delete_symbol(self, name: str) -> bool:
211
- if name in self.state.symbols:
212
- del self.state.symbols[name]
213
- return True
214
- return False
215
-
216
- def list_computations(self) -> list[str]: # pragma: no cover - trivial
217
- return list(self.state.computations.keys())
218
-
219
- def delete_computation(self, name: str) -> bool:
220
- if name in self.state.computations:
221
- del self.state.computations[name]
222
- return True
223
- return False
224
-
225
- # --- Execution ---
226
- def execute(
227
- self, computation: Computation, input_names: list[str], output_names: list[str]
228
- ) -> None:
229
- env: dict[str, Any] = {}
230
- for in_name in input_names:
231
- sym = self.get_symbol(in_name)
232
- if sym is None:
233
- raise ResourceNotFound(
234
- f"Input symbol '{in_name}' not found in session '{self.name}'"
235
- )
236
- env[in_name] = sym.data
237
- rt = self.ensure_runtime()
238
- self.ensure_spu_env()
239
- evaluator: IEvaluator = create_evaluator(
240
- rank=self.rank, env=env, comm=self.communicator, runtime=rt
241
- )
242
- results = evaluator.evaluate(computation.expr)
243
- if results and len(results) != len(output_names):
244
- raise RuntimeError(
245
- f"Expected {len(output_names)} results, got {len(results)}"
246
- )
247
- for name, val in zip(output_names, results, strict=True):
248
- # In pure SIMP model, all nodes should have the same symbol table.
249
- # Non-participating nodes get None values.
250
- if val is not None and not isinstance(val, Value):
251
- raise TypeError(
252
- "Session executions must produce kernel Value outputs; "
253
- f"got {type(val).__name__} for symbol '{name}'"
254
- )
255
- self.add_symbol(Symbol(name=name, mptype={}, data=val))
256
-
257
-
258
- # --- Convenience constructor use HttpCommunicator---
259
- def create_session_from_spec(name: str, rank: int, spec: ClusterSpec) -> Session:
260
- if len(spec.get_devices_by_kind("SPU")) == 0:
261
- raise RuntimeError("No SPU device found in cluster_spec")
262
-
263
- # Create HttpCommunicator for the session
264
- communicator = HttpCommunicator(
265
- session_name=name,
266
- rank=rank,
267
- endpoints=spec.endpoints,
268
- )
269
-
270
- return Session(name=name, rank=rank, cluster_spec=spec, communicator=communicator)
@@ -1,324 +0,0 @@
1
- # Copyright 2025 Ant Group Co., Ltd.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
-
15
- from __future__ import annotations
16
-
17
- import concurrent.futures
18
- import faulthandler
19
- import logging
20
- import sys
21
- import threading
22
- import traceback
23
- from collections.abc import Sequence
24
- from typing import Any, cast
25
-
26
- import spu.libspu as libspu
27
-
28
- from mplang.v1.core import (
29
- ClusterSpec,
30
- CollectiveMixin,
31
- CommunicatorBase,
32
- InterpContext,
33
- InterpVar,
34
- IrReader,
35
- IrWriter,
36
- Mask,
37
- MPObject,
38
- MPType,
39
- PFunction, # for spu.seed_env kernel seeding
40
- TensorLike,
41
- )
42
- from mplang.v1.core.expr.ast import Expr
43
- from mplang.v1.core.expr.evaluator import IEvaluator, create_evaluator
44
- from mplang.v1.kernels.context import RuntimeContext
45
- from mplang.v1.runtime.link_comm import LinkCommunicator
46
- from mplang.v1.utils.spu_utils import parse_field, parse_protocol
47
-
48
-
49
- class ThreadCommunicator(CommunicatorBase, CollectiveMixin):
50
- """Thread-based communicator for in-memory communication between threads"""
51
-
52
- def __init__(self, rank: int, world_size: int):
53
- super().__init__(rank, world_size)
54
- self.peers: list[ThreadCommunicator] = []
55
- logging.debug(
56
- f"ThreadCommunicator initialized with rank={self.rank}, world_size={self.world_size}"
57
- )
58
-
59
- def set_peers(self, peers: list[ThreadCommunicator]) -> None:
60
- assert self.world_size == len(peers)
61
- self.peers = peers
62
-
63
- def send(self, to: int, key: str, data: Any) -> None:
64
- assert 0 <= to < self.world_size
65
- # print(f"send {key}: {self.rank} -> {to_rank}")
66
- self.peers[to].onSent(self.rank, key, data)
67
-
68
-
69
- class SimVar(InterpVar):
70
- """A variable that references a value in an interpreter.
71
-
72
- SimVar represents a value that has been computed and exists
73
- in the interpreter's variable store.
74
- """
75
-
76
- def __init__(self, ctx: Simulator, mptype: MPType, values: list[Any]):
77
- # Initialize the parent InterpVar with a generated name
78
- super().__init__(ctx, mptype)
79
- self._values = values
80
-
81
- @property
82
- def values(self) -> list[Any]:
83
- """Converted values across all ranks for user inspection."""
84
- return [v.to_numpy() if hasattr(v, "to_numpy") else v for v in self._values]
85
-
86
- def __repr__(self) -> str:
87
- return f"SimVar({self.mptype})"
88
-
89
-
90
- class Simulator(InterpContext):
91
- def __init__(
92
- self,
93
- cluster_spec: ClusterSpec,
94
- *,
95
- trace_ranks: list[int] | None = None,
96
- ) -> None:
97
- """Initialize a simulator with the given cluster specification.
98
-
99
- Args:
100
- cluster_spec: The cluster specification defining the simulation environment.
101
- trace_ranks: List of ranks to trace execution for debugging.
102
- Per-node op binding overrides should now be provided via
103
- each node's `runtime_info.op_bindings` in the supplied
104
- `cluster_spec`.
105
- """
106
- super().__init__(cluster_spec)
107
- self._trace_ranks = trace_ranks or []
108
-
109
- spu_devices = cluster_spec.get_devices_by_kind("SPU")
110
- if not spu_devices:
111
- raise ValueError("No SPU device found in the cluster specification")
112
- if len(spu_devices) > 1:
113
- raise ValueError("Multiple SPU devices found in the cluster specification")
114
- spu_device = spu_devices[0]
115
-
116
- # compute spu_mask from spu_device members
117
- spu_mask = Mask.from_ranks([member.rank for member in spu_device.members])
118
-
119
- # Convert protocol and field from config using utility functions
120
- spu_protocol = parse_protocol(spu_device.config["protocol"])
121
- spu_field = parse_field(spu_device.config["field"])
122
-
123
- world_size = self.world_size()
124
-
125
- # Setup communicators
126
- self._comms = [
127
- ThreadCommunicator(rank, world_size) for rank in range(world_size)
128
- ]
129
- for comm in self._comms:
130
- comm.set_peers(self._comms)
131
-
132
- # Prepare link contexts for SPU parties (store for evaluator-time initialization)
133
- # Use Channels mode to reuse ThreadCommunicator instead of separate mem_link
134
- self._spu_link_ctxs: list[LinkCommunicator | None] = [None] * world_size
135
-
136
- # Create LinkCommunicators in parallel to avoid deadlock
137
- # (create_with_channels does handshake via TestSend/TestRecv)
138
- exceptions: dict[int, Exception] = {}
139
-
140
- def create_link(g_rank: int) -> None:
141
- try:
142
- self._spu_link_ctxs[g_rank] = LinkCommunicator(
143
- rank=g_rank,
144
- comm=self._comms[g_rank],
145
- spu_mask=spu_mask,
146
- )
147
- except Exception as e:
148
- exceptions[g_rank] = e
149
-
150
- threads = [
151
- threading.Thread(target=create_link, args=(g_rank,)) for g_rank in spu_mask
152
- ]
153
- for t in threads:
154
- t.start()
155
- for t in threads:
156
- t.join()
157
-
158
- # Check for exceptions during link creation
159
- if exceptions:
160
- first_exc = next(iter(exceptions.values()))
161
- raise RuntimeError(
162
- f"Failed to create SPU link contexts for ranks {list(exceptions.keys())}"
163
- ) from first_exc
164
-
165
- self._spu_runtime_cfg = libspu.RuntimeConfig(
166
- protocol=spu_protocol, field=spu_field
167
- )
168
- self._spu_world = spu_mask.num_parties()
169
- self._spu_mask = spu_mask
170
-
171
- # Persistent per-rank RuntimeContext instances (reused across evaluates).
172
- # We no longer pre-create evaluators since each evaluate has different env bindings.
173
- # Build per-rank runtime contexts.
174
- self._runtimes: list[RuntimeContext] = []
175
- for rank in range(self.world_size()):
176
- node = self.cluster_spec.get_node_by_rank(rank)
177
- rt = RuntimeContext(
178
- rank=rank,
179
- world_size=self.world_size(),
180
- initial_bindings=node.runtime_info.op_bindings,
181
- )
182
- self._runtimes.append(rt)
183
-
184
- @classmethod
185
- def simple(
186
- cls,
187
- world_size: int,
188
- op_bindings: dict[str, str] | None = None,
189
- **kwargs: Any,
190
- ) -> Simulator:
191
- """Create a simple simulator with the given number of parties.
192
-
193
- This is a convenience method that creates a ClusterSpec.simple()
194
- configuration for quick testing and prototyping.
195
-
196
- Args:
197
- world_size: Number of simulated parties.
198
- **kwargs: Additional arguments passed to the Simulator constructor.
199
-
200
- Returns:
201
- A Simulator instance with a simple cluster configuration.
202
- """
203
- cluster_spec = ClusterSpec.simple(world_size)
204
- if op_bindings:
205
- # Apply the same op_bindings to every node's runtime_info for convenience
206
- for node in cluster_spec.nodes.values():
207
- node.runtime_info.op_bindings.update(op_bindings)
208
- return cls(cluster_spec, **kwargs)
209
-
210
- def _do_evaluate(self, expr: Expr, evaluator_engine: IEvaluator) -> Any:
211
- """
212
- Helper function to simulate real-world MPIR serialization/deserialization
213
- process instead of direct expr.accept execution.
214
-
215
- This exposes potential MPIR serialization bugs by forcing expressions
216
- to go through the full serialize->deserialize cycle.
217
- """
218
- writer = IrWriter()
219
- graph_proto = writer.dumps(expr)
220
-
221
- reader = IrReader()
222
- deserialized_expr = reader.loads(graph_proto)
223
-
224
- if deserialized_expr is None:
225
- raise ValueError("Failed to deserialize expression")
226
-
227
- return evaluator_engine.evaluate(deserialized_expr)
228
-
229
- # override
230
- def fetch(self, obj: MPObject) -> list[TensorLike]:
231
- if not isinstance(obj, SimVar):
232
- raise ValueError(f"Expected SimVar, got {type(obj)}")
233
- return [v.to_numpy() if hasattr(v, "to_numpy") else v for v in obj._values]
234
-
235
- # override
236
- def evaluate(self, expr: Expr, bindings: dict[str, MPObject]) -> Sequence[MPObject]:
237
- # sanity check for bindings.
238
- for name, var in bindings.items():
239
- if var.ctx is not self:
240
- raise ValueError(f"Variable {name} not in this context, got {var.ctx}.")
241
-
242
- pts_env = [
243
- {name: cast(SimVar, var)._values[rank] for name, var in bindings.items()}
244
- for rank in range(self.world_size())
245
- ]
246
-
247
- # Build per-rank evaluators with the per-party environment (runtime reused)
248
- pts_evaluators: list[IEvaluator] = []
249
- for rank in range(self.world_size()):
250
- runtime = self._runtimes[rank]
251
- ev = create_evaluator(
252
- rank,
253
- pts_env[rank],
254
- self._comms[rank],
255
- runtime,
256
- None,
257
- )
258
- # Seed SPU once per runtime (idempotent logical requirement)
259
- # Use setdefault to both retrieve and create metadata dict in one step.
260
- spu_meta = runtime.state.setdefault("_spu", {})
261
- if not spu_meta.get("inited", False):
262
- link_ctx = self._spu_link_ctxs[rank]
263
- seed_fn = PFunction(
264
- fn_type="spu.seed_env",
265
- ins_info=(),
266
- outs_info=(),
267
- config=self._spu_runtime_cfg,
268
- world=self._spu_world,
269
- link=link_ctx,
270
- )
271
- ev.runtime.run_kernel(seed_fn, []) # type: ignore[arg-type]
272
- spu_meta["inited"] = True
273
- pts_evaluators.append(ev)
274
-
275
- # Collect evaluation results from all parties
276
- pts_results: list[Any] = []
277
-
278
- with concurrent.futures.ThreadPoolExecutor() as executor:
279
- futures = [
280
- executor.submit(self._do_evaluate, expr, evaluator)
281
- for evaluator in pts_evaluators
282
- ]
283
-
284
- # Collect results with proper exception handling
285
- for i, future in enumerate(futures):
286
- try:
287
- result = future.result(100) # 100 second timeout
288
- pts_results.append(result)
289
- except concurrent.futures.TimeoutError:
290
- faulthandler.dump_traceback(file=sys.stderr, all_threads=True)
291
- raise
292
- except Exception as e:
293
- print(
294
- f"Exception in party {i}: {type(e).__name__}: {e}",
295
- file=sys.stderr,
296
- )
297
- traceback.print_exc(file=sys.stderr)
298
- executor.shutdown(wait=False, cancel_futures=True)
299
- raise
300
-
301
- # Convert results to SimVar objects
302
- # pts_results is a list of party results, where each party result is a list of values
303
- # We need to transpose this to get (n_outputs, n_parties) structure
304
- assert len(pts_results) == self.world_size()
305
-
306
- # Ensure all parties returned the same number of outputs (matrix validation)
307
- if pts_results and not all(
308
- len(row) == len(pts_results[0]) for row in pts_results
309
- ):
310
- raise ValueError("Inconsistent number of outputs across parties")
311
-
312
- # Transpose: (n_parties, n_outputs) -> (n_outputs, n_parties)
313
- output_values = list(zip(*pts_results, strict=False))
314
-
315
- # Get the output types from the expression
316
- output_types = expr.mptypes
317
-
318
- # Create SimVar objects for each output
319
- sim_vars = []
320
- for values, mptype in zip(output_values, output_types, strict=False):
321
- sim_var = SimVar(self, mptype, list(values))
322
- sim_vars.append(sim_var)
323
-
324
- return sim_vars
@@ -1,13 +0,0 @@
1
- # Copyright 2025 Ant Group Co., Ltd.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.