mplang-nightly 0.1.dev158__py3-none-any.whl → 0.1.dev268__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (191) hide show
  1. mplang/__init__.py +21 -45
  2. mplang/py.typed +13 -0
  3. mplang/v1/__init__.py +157 -0
  4. mplang/v1/_device.py +602 -0
  5. mplang/{analysis → v1/analysis}/__init__.py +1 -1
  6. mplang/{analysis → v1/analysis}/diagram.py +5 -7
  7. mplang/v1/core/__init__.py +157 -0
  8. mplang/{core → v1/core}/cluster.py +30 -14
  9. mplang/{core → v1/core}/comm.py +5 -1
  10. mplang/{core → v1/core}/context_mgr.py +1 -1
  11. mplang/{core/dtype.py → v1/core/dtypes.py} +44 -2
  12. mplang/{core → v1/core}/expr/__init__.py +7 -7
  13. mplang/{core → v1/core}/expr/ast.py +13 -14
  14. mplang/{core → v1/core}/expr/evaluator.py +65 -24
  15. mplang/{core → v1/core}/expr/printer.py +24 -18
  16. mplang/{core → v1/core}/expr/transformer.py +3 -3
  17. mplang/{core → v1/core}/expr/utils.py +2 -2
  18. mplang/{core → v1/core}/expr/visitor.py +1 -1
  19. mplang/{core → v1/core}/expr/walk.py +1 -1
  20. mplang/{core → v1/core}/interp.py +6 -6
  21. mplang/{core → v1/core}/mpir.py +23 -16
  22. mplang/{core → v1/core}/mpobject.py +6 -6
  23. mplang/{core → v1/core}/mptype.py +13 -10
  24. mplang/{core → v1/core}/pfunc.py +4 -4
  25. mplang/{core → v1/core}/primitive.py +106 -201
  26. mplang/{core → v1/core}/table.py +36 -8
  27. mplang/{core → v1/core}/tensor.py +1 -1
  28. mplang/{core → v1/core}/tracer.py +9 -9
  29. mplang/{api.py → v1/host.py} +38 -6
  30. mplang/v1/kernels/__init__.py +41 -0
  31. mplang/{kernels → v1/kernels}/base.py +1 -1
  32. mplang/v1/kernels/basic.py +240 -0
  33. mplang/{kernels → v1/kernels}/context.py +42 -27
  34. mplang/{kernels → v1/kernels}/crypto.py +44 -37
  35. mplang/v1/kernels/fhe.py +858 -0
  36. mplang/{kernels → v1/kernels}/mock_tee.py +12 -13
  37. mplang/{kernels → v1/kernels}/phe.py +263 -57
  38. mplang/{kernels → v1/kernels}/spu.py +137 -48
  39. mplang/{kernels → v1/kernels}/sql_duckdb.py +12 -15
  40. mplang/{kernels → v1/kernels}/stablehlo.py +30 -23
  41. mplang/v1/kernels/value.py +626 -0
  42. mplang/{ops → v1/ops}/__init__.py +5 -16
  43. mplang/{ops → v1/ops}/base.py +2 -5
  44. mplang/{ops/builtin.py → v1/ops/basic.py} +34 -26
  45. mplang/v1/ops/crypto.py +262 -0
  46. mplang/v1/ops/fhe.py +272 -0
  47. mplang/{ops → v1/ops}/jax_cc.py +33 -68
  48. mplang/v1/ops/nnx_cc.py +168 -0
  49. mplang/{ops → v1/ops}/phe.py +16 -4
  50. mplang/{ops → v1/ops}/spu.py +3 -5
  51. mplang/v1/ops/sql_cc.py +303 -0
  52. mplang/{ops → v1/ops}/tee.py +9 -24
  53. mplang/{protos → v1/protos}/v1alpha1/mpir_pb2.pyi +71 -21
  54. mplang/v1/protos/v1alpha1/value_pb2.py +34 -0
  55. mplang/v1/protos/v1alpha1/value_pb2.pyi +169 -0
  56. mplang/{runtime → v1/runtime}/__init__.py +2 -2
  57. mplang/v1/runtime/channel.py +230 -0
  58. mplang/{runtime → v1/runtime}/cli.py +35 -20
  59. mplang/{runtime → v1/runtime}/client.py +19 -8
  60. mplang/{runtime → v1/runtime}/communicator.py +59 -15
  61. mplang/{runtime → v1/runtime}/data_providers.py +80 -19
  62. mplang/{runtime → v1/runtime}/driver.py +30 -12
  63. mplang/v1/runtime/link_comm.py +196 -0
  64. mplang/{runtime → v1/runtime}/server.py +58 -42
  65. mplang/{runtime → v1/runtime}/session.py +57 -71
  66. mplang/{runtime → v1/runtime}/simulation.py +55 -28
  67. mplang/v1/simp/api.py +353 -0
  68. mplang/{simp → v1/simp}/mpi.py +8 -9
  69. mplang/{simp/__init__.py → v1/simp/party.py} +19 -145
  70. mplang/{simp → v1/simp}/random.py +21 -22
  71. mplang/v1/simp/smpc.py +238 -0
  72. mplang/v1/utils/table_utils.py +185 -0
  73. mplang/v2/__init__.py +424 -0
  74. mplang/v2/backends/__init__.py +57 -0
  75. mplang/v2/backends/bfv_impl.py +705 -0
  76. mplang/v2/backends/channel.py +217 -0
  77. mplang/v2/backends/crypto_impl.py +723 -0
  78. mplang/v2/backends/field_impl.py +454 -0
  79. mplang/v2/backends/func_impl.py +107 -0
  80. mplang/v2/backends/phe_impl.py +148 -0
  81. mplang/v2/backends/simp_design.md +136 -0
  82. mplang/v2/backends/simp_driver/__init__.py +41 -0
  83. mplang/v2/backends/simp_driver/http.py +168 -0
  84. mplang/v2/backends/simp_driver/mem.py +280 -0
  85. mplang/v2/backends/simp_driver/ops.py +135 -0
  86. mplang/v2/backends/simp_driver/state.py +60 -0
  87. mplang/v2/backends/simp_driver/values.py +52 -0
  88. mplang/v2/backends/simp_worker/__init__.py +29 -0
  89. mplang/v2/backends/simp_worker/http.py +354 -0
  90. mplang/v2/backends/simp_worker/mem.py +102 -0
  91. mplang/v2/backends/simp_worker/ops.py +167 -0
  92. mplang/v2/backends/simp_worker/state.py +49 -0
  93. mplang/v2/backends/spu_impl.py +275 -0
  94. mplang/v2/backends/spu_state.py +187 -0
  95. mplang/v2/backends/store_impl.py +62 -0
  96. mplang/v2/backends/table_impl.py +838 -0
  97. mplang/v2/backends/tee_impl.py +215 -0
  98. mplang/v2/backends/tensor_impl.py +519 -0
  99. mplang/v2/cli.py +603 -0
  100. mplang/v2/cli_guide.md +122 -0
  101. mplang/v2/dialects/__init__.py +36 -0
  102. mplang/v2/dialects/bfv.py +665 -0
  103. mplang/v2/dialects/crypto.py +689 -0
  104. mplang/v2/dialects/dtypes.py +378 -0
  105. mplang/v2/dialects/field.py +210 -0
  106. mplang/v2/dialects/func.py +135 -0
  107. mplang/v2/dialects/phe.py +723 -0
  108. mplang/v2/dialects/simp.py +944 -0
  109. mplang/v2/dialects/spu.py +349 -0
  110. mplang/v2/dialects/store.py +63 -0
  111. mplang/v2/dialects/table.py +407 -0
  112. mplang/v2/dialects/tee.py +346 -0
  113. mplang/v2/dialects/tensor.py +1175 -0
  114. mplang/v2/edsl/README.md +279 -0
  115. mplang/v2/edsl/__init__.py +99 -0
  116. mplang/v2/edsl/context.py +311 -0
  117. mplang/v2/edsl/graph.py +463 -0
  118. mplang/v2/edsl/jit.py +62 -0
  119. mplang/v2/edsl/object.py +53 -0
  120. mplang/v2/edsl/primitive.py +284 -0
  121. mplang/v2/edsl/printer.py +119 -0
  122. mplang/v2/edsl/registry.py +207 -0
  123. mplang/v2/edsl/serde.py +375 -0
  124. mplang/v2/edsl/tracer.py +614 -0
  125. mplang/v2/edsl/typing.py +816 -0
  126. mplang/v2/kernels/Makefile +30 -0
  127. mplang/v2/kernels/__init__.py +23 -0
  128. mplang/v2/kernels/gf128.cpp +148 -0
  129. mplang/v2/kernels/ldpc.cpp +82 -0
  130. mplang/v2/kernels/okvs.cpp +283 -0
  131. mplang/v2/kernels/okvs_opt.cpp +291 -0
  132. mplang/v2/kernels/py_kernels.py +398 -0
  133. mplang/v2/libs/collective.py +330 -0
  134. mplang/v2/libs/device/__init__.py +51 -0
  135. mplang/v2/libs/device/api.py +813 -0
  136. mplang/v2/libs/device/cluster.py +352 -0
  137. mplang/v2/libs/ml/__init__.py +23 -0
  138. mplang/v2/libs/ml/sgb.py +1861 -0
  139. mplang/v2/libs/mpc/__init__.py +41 -0
  140. mplang/v2/libs/mpc/_utils.py +99 -0
  141. mplang/v2/libs/mpc/analytics/__init__.py +35 -0
  142. mplang/v2/libs/mpc/analytics/aggregation.py +372 -0
  143. mplang/v2/libs/mpc/analytics/groupby.md +99 -0
  144. mplang/v2/libs/mpc/analytics/groupby.py +331 -0
  145. mplang/v2/libs/mpc/analytics/permutation.py +386 -0
  146. mplang/v2/libs/mpc/common/constants.py +39 -0
  147. mplang/v2/libs/mpc/ot/__init__.py +32 -0
  148. mplang/v2/libs/mpc/ot/base.py +222 -0
  149. mplang/v2/libs/mpc/ot/extension.py +477 -0
  150. mplang/v2/libs/mpc/ot/silent.py +217 -0
  151. mplang/v2/libs/mpc/psi/__init__.py +40 -0
  152. mplang/v2/libs/mpc/psi/cuckoo.py +228 -0
  153. mplang/v2/libs/mpc/psi/okvs.py +49 -0
  154. mplang/v2/libs/mpc/psi/okvs_gct.py +79 -0
  155. mplang/v2/libs/mpc/psi/oprf.py +310 -0
  156. mplang/v2/libs/mpc/psi/rr22.py +344 -0
  157. mplang/v2/libs/mpc/psi/unbalanced.py +200 -0
  158. mplang/v2/libs/mpc/vole/__init__.py +31 -0
  159. mplang/v2/libs/mpc/vole/gilboa.py +327 -0
  160. mplang/v2/libs/mpc/vole/ldpc.py +383 -0
  161. mplang/v2/libs/mpc/vole/silver.py +336 -0
  162. mplang/v2/runtime/__init__.py +15 -0
  163. mplang/v2/runtime/dialect_state.py +41 -0
  164. mplang/v2/runtime/interpreter.py +871 -0
  165. mplang/v2/runtime/object_store.py +194 -0
  166. mplang/v2/runtime/value.py +141 -0
  167. {mplang_nightly-0.1.dev158.dist-info → mplang_nightly-0.1.dev268.dist-info}/METADATA +24 -17
  168. mplang_nightly-0.1.dev268.dist-info/RECORD +180 -0
  169. {mplang_nightly-0.1.dev158.dist-info → mplang_nightly-0.1.dev268.dist-info}/WHEEL +1 -1
  170. mplang/core/__init__.py +0 -92
  171. mplang/device.py +0 -340
  172. mplang/kernels/builtin.py +0 -207
  173. mplang/ops/crypto.py +0 -109
  174. mplang/ops/ibis_cc.py +0 -139
  175. mplang/ops/sql.py +0 -61
  176. mplang/protos/v1alpha1/mpir_pb2_grpc.py +0 -3
  177. mplang/runtime/link_comm.py +0 -131
  178. mplang/simp/smpc.py +0 -201
  179. mplang/utils/table_utils.py +0 -73
  180. mplang_nightly-0.1.dev158.dist-info/RECORD +0 -77
  181. /mplang/{core → v1/core}/mask.py +0 -0
  182. /mplang/{protos → v1/protos}/v1alpha1/mpir_pb2.py +0 -0
  183. /mplang/{runtime → v1/runtime}/exceptions.py +0 -0
  184. /mplang/{runtime → v1/runtime}/http_api.md +0 -0
  185. /mplang/{kernels → v1/simp}/__init__.py +0 -0
  186. /mplang/{utils → v1/utils}/__init__.py +0 -0
  187. /mplang/{utils → v1/utils}/crypto.py +0 -0
  188. /mplang/{utils → v1/utils}/func_utils.py +0 -0
  189. /mplang/{utils → v1/utils}/spu_utils.py +0 -0
  190. {mplang_nightly-0.1.dev158.dist-info → mplang_nightly-0.1.dev268.dist-info}/entry_points.txt +0 -0
  191. {mplang_nightly-0.1.dev158.dist-info → mplang_nightly-0.1.dev268.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,352 @@
1
+ # Copyright 2025 Ant Group Co., Ltd.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """
16
+ This module provides the formal data structures and parsing logic for the
17
+ MPLang cluster configuration.
18
+ """
19
+
20
+ from __future__ import annotations
21
+
22
+ from dataclasses import dataclass, field
23
+ from functools import cached_property
24
+ from typing import Any
25
+
26
+
27
+ @dataclass(frozen=True)
28
+ class RuntimeInfo:
29
+ """Per-physical-node runtime configuration.
30
+
31
+ ``op_bindings`` is a per-node override map (logical_op -> kernel_id) merged
32
+ into that node's ``RuntimeContext``. Unknown future / auxiliary fields are
33
+ preserved in ``extra``.
34
+ """
35
+
36
+ version: str
37
+ platform: str
38
+ # Per-node partial override dispatch table (merged over project defaults).
39
+ op_bindings: dict[str, str] = field(default_factory=dict)
40
+
41
+ # A catch-all for any other custom or future properties (must not collide
42
+ # with reserved keys: version, platform, op_bindings).
43
+ extra: dict[str, Any] = field(default_factory=dict)
44
+
45
+ def to_dict(self) -> dict[str, Any]:
46
+ """Convert RuntimeInfo to a dictionary (stable field names)."""
47
+ result = {
48
+ "version": self.version,
49
+ "platform": self.platform,
50
+ "op_bindings": self.op_bindings,
51
+ }
52
+ result.update(self.extra)
53
+ return result
54
+
55
+
56
+ @dataclass(frozen=True)
57
+ class Node:
58
+ """
59
+ Represents a single physical node (PN) in the cluster.
60
+ This is an immutable description of a compute resource.
61
+ """
62
+
63
+ name: str
64
+ rank: int
65
+ endpoint: str
66
+ runtime_info: RuntimeInfo
67
+
68
+ def to_dict(self) -> dict[str, Any]:
69
+ """Convert PhysicalNode to a dictionary."""
70
+ return {
71
+ "name": self.name,
72
+ "endpoint": self.endpoint,
73
+ "runtime_info": self.runtime_info.to_dict(),
74
+ }
75
+
76
+
77
+ @dataclass(frozen=True)
78
+ class Device:
79
+ """
80
+ Represents a logical device (LD), which is a user-facing computational entity.
81
+ It is composed of one or more Physical Nodes.
82
+ """
83
+
84
+ name: str
85
+ kind: str
86
+ members: list[Node]
87
+ config: dict[str, Any] = field(default_factory=dict)
88
+
89
+ @property
90
+ def member_ranks(self) -> tuple[int, ...]:
91
+ """Returns the ranks of the member PNs."""
92
+ return tuple(sorted([node.rank for node in self.members]))
93
+
94
+ def to_dict(self) -> dict[str, Any]:
95
+ """Convert LogicalDevice to a dictionary."""
96
+ return {
97
+ "kind": self.kind,
98
+ "members": [node.name for node in self.members],
99
+ "config": self.config,
100
+ }
101
+
102
+
103
+ @dataclass(frozen=True)
104
+ class ClusterSpec:
105
+ """
106
+ The formal, validated representation of the entire cluster.
107
+ This object is the "first-class citizen" representing the cluster topology.
108
+ """
109
+
110
+ cluster_id: str
111
+ nodes: dict[str, Node]
112
+ devices: dict[str, Device]
113
+
114
+ @property
115
+ def world_size(self) -> int:
116
+ """Total number of physical nodes (parties)."""
117
+ return len(self.nodes)
118
+
119
+ def __post_init__(self) -> None:
120
+ for key, node in self.nodes.items():
121
+ if key != node.name:
122
+ raise ValueError(
123
+ f"Node key '{key}' does not match node.name '{node.name}'"
124
+ )
125
+
126
+ for key, device in self.devices.items():
127
+ if key != device.name:
128
+ raise ValueError(
129
+ f"Device key '{key}' does not match device.name '{device.name}'"
130
+ )
131
+
132
+ # check all device members are valid nodes
133
+ node_names = set(self.nodes.keys())
134
+ for device in self.devices.values():
135
+ for member in device.members:
136
+ if member.name not in node_names:
137
+ raise ValueError(
138
+ f"Device '{device.name}' has member '{member.name}' "
139
+ "which is not defined in nodes"
140
+ )
141
+
142
+ # ensure ppu devices have exactly one member
143
+ for device in self.devices.values():
144
+ if device.kind.lower() == "ppu" and len(device.members) != 1:
145
+ raise ValueError(
146
+ f"PPU device '{device.name}' must have exactly one member"
147
+ )
148
+
149
+ def get_node(self, name: str) -> Node:
150
+ """Get a Physical Node by its unique name."""
151
+ return self.nodes[name]
152
+
153
+ def get_device(self, name: str) -> Device:
154
+ """Get a Logical Device by its unique name."""
155
+ return self.devices[name]
156
+
157
+ def get_devices_by_kind(self, kind: str) -> list[Device]:
158
+ """Get all Logical Devices of a specific kind."""
159
+ lowered = kind.lower()
160
+ return [dev for dev in self.devices.values() if dev.kind.lower() == lowered]
161
+
162
+ def get_node_by_rank(self, rank: int) -> Node:
163
+ """Get a Physical Node by its unique rank."""
164
+ # This might require an internal mapping for efficiency if called often
165
+ for node in self.nodes.values():
166
+ if node.rank == rank:
167
+ return node
168
+ raise KeyError(f"No Physical Node found with rank {rank}")
169
+
170
+ def to_dict(self) -> dict[str, Any]:
171
+ """Convert ClusterSpec to a dictionary."""
172
+ return {
173
+ "cluster_id": self.cluster_id,
174
+ "nodes": [node.to_dict() for node in self.nodes.values()],
175
+ "devices": {
176
+ name: device.to_dict() for name, device in self.devices.items()
177
+ },
178
+ }
179
+
180
+ @cached_property
181
+ def endpoints(self) -> list[str]:
182
+ eps: list[str] = []
183
+ for n in sorted(
184
+ self.nodes.values(),
185
+ key=lambda x: x.rank, # type: ignore[attr-defined]
186
+ ):
187
+ eps.append(n.endpoint)
188
+ return eps
189
+
190
+ @classmethod
191
+ def from_dict(cls, config: dict[str, Any]) -> ClusterSpec:
192
+ """Parses a raw config dictionary and returns a validated ClusterSpec."""
193
+ # 1. Validate top-level keys
194
+ if "nodes" not in config or "devices" not in config:
195
+ raise ValueError(
196
+ "Cluster config must contain 'nodes' and 'devices' sections."
197
+ )
198
+
199
+ # 2. Parse Physical Nodes, using the list index as the rank
200
+ nodes_map: dict[str, Node] = {}
201
+ # Reserved runtime info keys we recognize explicitly.
202
+ known_runtime_fields = {"version", "platform", "op_bindings"}
203
+ for i, node_cfg in enumerate(config["nodes"]):
204
+ if "rank" in node_cfg:
205
+ # Optionally, we can log a warning that the explicit 'rank' is ignored.
206
+ pass
207
+
208
+ runtime_info_cfg = node_cfg.get("runtime_info", {})
209
+ extra_runtime_info = {
210
+ k: v
211
+ for k, v in runtime_info_cfg.items()
212
+ if k not in known_runtime_fields
213
+ }
214
+ runtime_info = RuntimeInfo(
215
+ version=runtime_info_cfg.get("version", "N/A"),
216
+ platform=runtime_info_cfg.get("platform", "N/A"),
217
+ op_bindings=runtime_info_cfg.get("op_bindings", {}) or {},
218
+ extra=extra_runtime_info,
219
+ )
220
+
221
+ node = Node(
222
+ name=node_cfg["name"],
223
+ rank=i, # Implicit rank assignment
224
+ endpoint=node_cfg["endpoint"],
225
+ runtime_info=runtime_info,
226
+ )
227
+
228
+ if node.name in nodes_map:
229
+ raise ValueError(f"Duplicate node name found: {node.name}")
230
+ nodes_map[node.name] = node
231
+
232
+ # 3. Parse Logical Devices
233
+ devices_map: dict[str, Device] = {}
234
+ for dev_name, dev_cfg in config["devices"].items():
235
+ member_nodes = []
236
+ for member_name in dev_cfg["members"]:
237
+ if member_name not in nodes_map:
238
+ raise ValueError(
239
+ f"Node '{member_name}' in device '{dev_name}' not defined in 'nodes' section."
240
+ )
241
+ member_nodes.append(nodes_map[member_name])
242
+
243
+ devices_map[dev_name] = Device(
244
+ name=dev_name,
245
+ kind=dev_cfg["kind"],
246
+ members=member_nodes,
247
+ config=dev_cfg.get("config", {}),
248
+ )
249
+
250
+ # Get cluster_id from config or generate from filename
251
+ cluster_id = config.get("cluster_id", f"cluster_{len(nodes_map)}")
252
+ return cls(cluster_id=cluster_id, nodes=nodes_map, devices=devices_map)
253
+
254
+ @classmethod
255
+ def simple(
256
+ cls,
257
+ world_size: int,
258
+ *,
259
+ endpoints: list[str] | None = None,
260
+ spu_world_size: int | None = None,
261
+ spu_protocol: str = "SEMI2K",
262
+ spu_field: str = "FM128",
263
+ runtime_version: str = "simulated",
264
+ runtime_platform: str = "simulated",
265
+ op_bindings: list[dict[str, str]] | None = None,
266
+ enable_ppu_device: bool = True,
267
+ enable_spu_device: bool = True,
268
+ ) -> ClusterSpec:
269
+ """Convenience constructor used heavily in tests.
270
+
271
+ Parameters
272
+ ----------
273
+ world_size:
274
+ Number of parties (physical nodes).
275
+ endpoints:
276
+ Optional explicit endpoint list of length ``world_size``. Each element may
277
+ include scheme (``http://``) or not; stored verbatim. If not provided we
278
+ synthesize ``localhost:{5000 + i}`` (5000 is a fixed default; pass explicit
279
+ endpoints for control).
280
+ spu_protocol / spu_field:
281
+ SPU device config values.
282
+ runtime_version / runtime_platform:
283
+ Populated into each node's ``RuntimeInfo``.
284
+ op_bindings:
285
+ Optional list of length ``world_size`` supplying per-node op_bindings
286
+ override dicts (defaults to empty dicts).
287
+ enable_ppu_device:
288
+ If True (default), create one ``P{rank}`` PPU device per node.
289
+ enable_spu_device:
290
+ If True (default) create a shared SPU device named ``SP0``.
291
+ """
292
+ base_port = 5000
293
+
294
+ if endpoints is not None and len(endpoints) != world_size:
295
+ raise ValueError(
296
+ "len(endpoints) must equal world_size when provided: "
297
+ f"{len(endpoints)} != {world_size}"
298
+ )
299
+
300
+ if op_bindings is not None and len(op_bindings) != world_size:
301
+ raise ValueError(
302
+ "len(op_bindings) must equal world_size when provided: "
303
+ f"{len(op_bindings)} != {world_size}"
304
+ )
305
+
306
+ if not enable_ppu_device and not enable_spu_device:
307
+ raise ValueError(
308
+ "At least one of enable_ppu_device or enable_spu_device must be True"
309
+ )
310
+
311
+ nodes: dict[str, Node] = {}
312
+ for i in range(world_size):
313
+ ep = endpoints[i] if endpoints is not None else f"localhost:{base_port + i}"
314
+ node_op_bindings = op_bindings[i] if op_bindings is not None else {}
315
+ nodes[f"node{i}"] = Node(
316
+ name=f"node{i}",
317
+ rank=i,
318
+ endpoint=ep,
319
+ runtime_info=RuntimeInfo(
320
+ version=runtime_version,
321
+ platform=runtime_platform,
322
+ op_bindings=node_op_bindings,
323
+ ),
324
+ )
325
+
326
+ devices: dict[str, Device] = {}
327
+ # Optional per-node PPU devices
328
+ if enable_ppu_device:
329
+ for i in range(world_size):
330
+ devices[f"P{i}"] = Device(
331
+ name=f"P{i}",
332
+ kind="ppu",
333
+ members=[nodes[f"node{i}"]],
334
+ )
335
+
336
+ # Shared SPU device
337
+ if enable_spu_device:
338
+ if spu_world_size is None:
339
+ spu_world_size = world_size
340
+ spu_members = [nodes[f"node{i}"] for i in range(spu_world_size)]
341
+
342
+ devices["SP0"] = Device(
343
+ name="SP0",
344
+ kind="SPU",
345
+ members=spu_members,
346
+ config={
347
+ "protocol": spu_protocol,
348
+ "field": spu_field,
349
+ },
350
+ )
351
+
352
+ return cls(cluster_id=f"__sim_{world_size}", nodes=nodes, devices=devices)
@@ -0,0 +1,23 @@
1
+ # Copyright 2025 Ant Group Co., Ltd.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """Machine Learning algorithms for secure multi-party computation."""
16
+
17
+ from mplang.v2.libs.ml.sgb import SecureBoost, Tree, TreeEnsemble
18
+
19
+ __all__ = [
20
+ "SecureBoost",
21
+ "Tree",
22
+ "TreeEnsemble",
23
+ ]