mplang-nightly 0.1.dev269__py3-none-any.whl → 0.1.dev271__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (180) hide show
  1. mplang/__init__.py +391 -17
  2. mplang/{v2/backends → backends}/__init__.py +9 -7
  3. mplang/{v2/backends → backends}/bfv_impl.py +6 -6
  4. mplang/{v2/backends → backends}/crypto_impl.py +6 -6
  5. mplang/{v2/backends → backends}/field_impl.py +5 -5
  6. mplang/{v2/backends → backends}/func_impl.py +4 -4
  7. mplang/{v2/backends → backends}/phe_impl.py +3 -3
  8. mplang/{v2/backends → backends}/simp_design.md +1 -1
  9. mplang/{v2/backends → backends}/simp_driver/__init__.py +5 -5
  10. mplang/{v2/backends → backends}/simp_driver/http.py +8 -8
  11. mplang/{v2/backends → backends}/simp_driver/mem.py +9 -9
  12. mplang/{v2/backends → backends}/simp_driver/ops.py +4 -4
  13. mplang/{v2/backends → backends}/simp_driver/state.py +2 -2
  14. mplang/{v2/backends → backends}/simp_driver/values.py +2 -2
  15. mplang/{v2/backends → backends}/simp_worker/__init__.py +3 -3
  16. mplang/{v2/backends → backends}/simp_worker/http.py +10 -10
  17. mplang/{v2/backends → backends}/simp_worker/mem.py +1 -1
  18. mplang/{v2/backends → backends}/simp_worker/ops.py +5 -5
  19. mplang/{v2/backends → backends}/simp_worker/state.py +2 -4
  20. mplang/{v2/backends → backends}/spu_impl.py +8 -8
  21. mplang/{v2/backends → backends}/spu_state.py +4 -4
  22. mplang/{v2/backends → backends}/store_impl.py +3 -3
  23. mplang/{v2/backends → backends}/table_impl.py +8 -8
  24. mplang/{v2/backends → backends}/tee_impl.py +6 -6
  25. mplang/{v2/backends → backends}/tensor_impl.py +6 -6
  26. mplang/{v2/cli.py → cli.py} +9 -9
  27. mplang/{v2/cli_guide.md → cli_guide.md} +12 -12
  28. mplang/{v2/dialects → dialects}/__init__.py +5 -5
  29. mplang/{v2/dialects → dialects}/bfv.py +6 -6
  30. mplang/{v2/dialects → dialects}/crypto.py +5 -5
  31. mplang/{v2/dialects → dialects}/dtypes.py +2 -2
  32. mplang/{v2/dialects → dialects}/field.py +3 -3
  33. mplang/{v2/dialects → dialects}/func.py +2 -2
  34. mplang/{v2/dialects → dialects}/phe.py +6 -6
  35. mplang/{v2/dialects → dialects}/simp.py +6 -6
  36. mplang/{v2/dialects → dialects}/spu.py +7 -7
  37. mplang/{v2/dialects → dialects}/store.py +2 -2
  38. mplang/{v2/dialects → dialects}/table.py +3 -3
  39. mplang/{v2/dialects → dialects}/tee.py +6 -6
  40. mplang/{v2/dialects → dialects}/tensor.py +5 -5
  41. mplang/{v2/edsl → edsl}/__init__.py +3 -3
  42. mplang/{v2/edsl → edsl}/context.py +6 -6
  43. mplang/{v2/edsl → edsl}/graph.py +5 -5
  44. mplang/{v2/edsl → edsl}/jit.py +2 -2
  45. mplang/{v2/edsl → edsl}/object.py +1 -1
  46. mplang/{v2/edsl → edsl}/primitive.py +5 -5
  47. mplang/{v2/edsl → edsl}/printer.py +1 -1
  48. mplang/{v2/edsl → edsl}/serde.py +1 -1
  49. mplang/{v2/edsl → edsl}/tracer.py +7 -7
  50. mplang/{v2/edsl → edsl}/typing.py +1 -1
  51. mplang/{v2/kernels → kernels}/ldpc.cpp +13 -13
  52. mplang/{v2/kernels → kernels}/okvs.cpp +4 -4
  53. mplang/{v2/kernels → kernels}/okvs_opt.cpp +31 -31
  54. mplang/{v2/kernels → kernels}/py_kernels.py +1 -1
  55. mplang/{v2/libs → libs}/collective.py +5 -5
  56. mplang/{v2/libs → libs}/device/__init__.py +1 -1
  57. mplang/{v2/libs → libs}/device/api.py +12 -12
  58. mplang/{v2/libs → libs}/ml/__init__.py +1 -1
  59. mplang/{v2/libs → libs}/ml/sgb.py +4 -4
  60. mplang/{v2/libs → libs}/mpc/__init__.py +3 -3
  61. mplang/{v2/libs → libs}/mpc/_utils.py +2 -2
  62. mplang/{v2/libs → libs}/mpc/analytics/aggregation.py +1 -1
  63. mplang/{v2/libs → libs}/mpc/analytics/groupby.py +2 -2
  64. mplang/{v2/libs → libs}/mpc/analytics/permutation.py +3 -3
  65. mplang/{v2/libs → libs}/mpc/ot/base.py +3 -3
  66. mplang/{v2/libs → libs}/mpc/ot/extension.py +2 -2
  67. mplang/{v2/libs → libs}/mpc/ot/silent.py +4 -4
  68. mplang/{v2/libs → libs}/mpc/psi/cuckoo.py +3 -3
  69. mplang/{v2/libs → libs}/mpc/psi/okvs.py +1 -1
  70. mplang/{v2/libs → libs}/mpc/psi/okvs_gct.py +3 -3
  71. mplang/{v2/libs → libs}/mpc/psi/oprf.py +3 -3
  72. mplang/{v2/libs → libs}/mpc/psi/rr22.py +7 -7
  73. mplang/{v2/libs → libs}/mpc/psi/unbalanced.py +4 -4
  74. mplang/{v2/libs → libs}/mpc/vole/gilboa.py +3 -3
  75. mplang/{v2/libs → libs}/mpc/vole/ldpc.py +2 -2
  76. mplang/{v2/libs → libs}/mpc/vole/silver.py +6 -6
  77. mplang/{v2/runtime → runtime}/interpreter.py +11 -11
  78. mplang/{v2/runtime → runtime}/value.py +2 -2
  79. mplang/{v1/runtime → utils}/__init__.py +18 -15
  80. mplang/{v1/utils → utils}/func_utils.py +1 -1
  81. {mplang_nightly-0.1.dev269.dist-info → mplang_nightly-0.1.dev271.dist-info}/METADATA +2 -2
  82. mplang_nightly-0.1.dev271.dist-info/RECORD +102 -0
  83. mplang/v1/__init__.py +0 -157
  84. mplang/v1/_device.py +0 -602
  85. mplang/v1/analysis/__init__.py +0 -37
  86. mplang/v1/analysis/diagram.py +0 -567
  87. mplang/v1/core/__init__.py +0 -157
  88. mplang/v1/core/cluster.py +0 -343
  89. mplang/v1/core/comm.py +0 -281
  90. mplang/v1/core/context_mgr.py +0 -50
  91. mplang/v1/core/dtypes.py +0 -335
  92. mplang/v1/core/expr/__init__.py +0 -80
  93. mplang/v1/core/expr/ast.py +0 -542
  94. mplang/v1/core/expr/evaluator.py +0 -581
  95. mplang/v1/core/expr/printer.py +0 -285
  96. mplang/v1/core/expr/transformer.py +0 -141
  97. mplang/v1/core/expr/utils.py +0 -78
  98. mplang/v1/core/expr/visitor.py +0 -85
  99. mplang/v1/core/expr/walk.py +0 -387
  100. mplang/v1/core/interp.py +0 -160
  101. mplang/v1/core/mask.py +0 -325
  102. mplang/v1/core/mpir.py +0 -965
  103. mplang/v1/core/mpobject.py +0 -117
  104. mplang/v1/core/mptype.py +0 -407
  105. mplang/v1/core/pfunc.py +0 -130
  106. mplang/v1/core/primitive.py +0 -877
  107. mplang/v1/core/table.py +0 -218
  108. mplang/v1/core/tensor.py +0 -75
  109. mplang/v1/core/tracer.py +0 -383
  110. mplang/v1/host.py +0 -130
  111. mplang/v1/kernels/__init__.py +0 -41
  112. mplang/v1/kernels/base.py +0 -125
  113. mplang/v1/kernels/basic.py +0 -240
  114. mplang/v1/kernels/context.py +0 -369
  115. mplang/v1/kernels/crypto.py +0 -122
  116. mplang/v1/kernels/fhe.py +0 -858
  117. mplang/v1/kernels/mock_tee.py +0 -72
  118. mplang/v1/kernels/phe.py +0 -1864
  119. mplang/v1/kernels/spu.py +0 -341
  120. mplang/v1/kernels/sql_duckdb.py +0 -44
  121. mplang/v1/kernels/stablehlo.py +0 -90
  122. mplang/v1/kernels/value.py +0 -626
  123. mplang/v1/ops/__init__.py +0 -35
  124. mplang/v1/ops/base.py +0 -424
  125. mplang/v1/ops/basic.py +0 -294
  126. mplang/v1/ops/crypto.py +0 -262
  127. mplang/v1/ops/fhe.py +0 -272
  128. mplang/v1/ops/jax_cc.py +0 -147
  129. mplang/v1/ops/nnx_cc.py +0 -168
  130. mplang/v1/ops/phe.py +0 -216
  131. mplang/v1/ops/spu.py +0 -151
  132. mplang/v1/ops/sql_cc.py +0 -303
  133. mplang/v1/ops/tee.py +0 -36
  134. mplang/v1/protos/v1alpha1/mpir_pb2.py +0 -63
  135. mplang/v1/protos/v1alpha1/mpir_pb2.pyi +0 -557
  136. mplang/v1/protos/v1alpha1/value_pb2.py +0 -34
  137. mplang/v1/protos/v1alpha1/value_pb2.pyi +0 -169
  138. mplang/v1/runtime/channel.py +0 -230
  139. mplang/v1/runtime/cli.py +0 -451
  140. mplang/v1/runtime/client.py +0 -456
  141. mplang/v1/runtime/communicator.py +0 -131
  142. mplang/v1/runtime/data_providers.py +0 -303
  143. mplang/v1/runtime/driver.py +0 -324
  144. mplang/v1/runtime/exceptions.py +0 -27
  145. mplang/v1/runtime/http_api.md +0 -56
  146. mplang/v1/runtime/link_comm.py +0 -196
  147. mplang/v1/runtime/server.py +0 -501
  148. mplang/v1/runtime/session.py +0 -270
  149. mplang/v1/runtime/simulation.py +0 -324
  150. mplang/v1/simp/__init__.py +0 -13
  151. mplang/v1/simp/api.py +0 -353
  152. mplang/v1/simp/mpi.py +0 -131
  153. mplang/v1/simp/party.py +0 -225
  154. mplang/v1/simp/random.py +0 -120
  155. mplang/v1/simp/smpc.py +0 -238
  156. mplang/v1/utils/__init__.py +0 -13
  157. mplang/v1/utils/crypto.py +0 -32
  158. mplang/v1/utils/spu_utils.py +0 -130
  159. mplang/v1/utils/table_utils.py +0 -185
  160. mplang/v2/__init__.py +0 -424
  161. mplang_nightly-0.1.dev269.dist-info/RECORD +0 -180
  162. /mplang/{v2/backends → backends}/channel.py +0 -0
  163. /mplang/{v2/edsl → edsl}/README.md +0 -0
  164. /mplang/{v2/edsl → edsl}/registry.py +0 -0
  165. /mplang/{v2/kernels → kernels}/Makefile +0 -0
  166. /mplang/{v2/kernels → kernels}/__init__.py +0 -0
  167. /mplang/{v2/kernels → kernels}/gf128.cpp +0 -0
  168. /mplang/{v2/libs → libs}/device/cluster.py +0 -0
  169. /mplang/{v2/libs → libs}/mpc/analytics/__init__.py +0 -0
  170. /mplang/{v2/libs → libs}/mpc/analytics/groupby.md +0 -0
  171. /mplang/{v2/libs → libs}/mpc/common/constants.py +0 -0
  172. /mplang/{v2/libs → libs}/mpc/ot/__init__.py +0 -0
  173. /mplang/{v2/libs → libs}/mpc/psi/__init__.py +0 -0
  174. /mplang/{v2/libs → libs}/mpc/vole/__init__.py +0 -0
  175. /mplang/{v2/runtime → runtime}/__init__.py +0 -0
  176. /mplang/{v2/runtime → runtime}/dialect_state.py +0 -0
  177. /mplang/{v2/runtime → runtime}/object_store.py +0 -0
  178. {mplang_nightly-0.1.dev269.dist-info → mplang_nightly-0.1.dev271.dist-info}/WHEEL +0 -0
  179. {mplang_nightly-0.1.dev269.dist-info → mplang_nightly-0.1.dev271.dist-info}/entry_points.txt +0 -0
  180. {mplang_nightly-0.1.dev269.dist-info → mplang_nightly-0.1.dev271.dist-info}/licenses/LICENSE +0 -0
@@ -1,157 +0,0 @@
1
- # Copyright 2025 Ant Group Co., Ltd.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
-
15
- """
16
- Core components for multi-party computation.
17
-
18
- This package provides the fundamental building blocks for multi-party computation,
19
- including type systems, tracing mechanisms, and interpreter contexts.
20
- """
21
-
22
- # Core type system
23
- # Communication interfaces & core symbols
24
- from mplang.v1.core.cluster import ClusterSpec, Device, Node, RuntimeInfo
25
- from mplang.v1.core.comm import (
26
- CollectiveMixin,
27
- CommunicatorBase,
28
- ICollective,
29
- ICommunicator,
30
- )
31
- from mplang.v1.core.context_mgr import cur_ctx, set_ctx, with_ctx
32
- from mplang.v1.core.dtypes import (
33
- BINARY,
34
- BOOL,
35
- COMPLEX64,
36
- COMPLEX128,
37
- DATE,
38
- DECIMAL,
39
- FLOAT16,
40
- FLOAT32,
41
- FLOAT64,
42
- INT8,
43
- INT16,
44
- INT32,
45
- INT64,
46
- INTERVAL,
47
- JSON,
48
- STRING,
49
- TIME,
50
- TIMESTAMP,
51
- UINT8,
52
- UINT16,
53
- UINT32,
54
- UINT64,
55
- UUID,
56
- DType,
57
- )
58
- from mplang.v1.core.interp import InterpContext, InterpVar
59
- from mplang.v1.core.mask import Mask
60
- from mplang.v1.core.mpir import IrReader, IrWriter
61
- from mplang.v1.core.mpobject import MPContext, MPObject
62
- from mplang.v1.core.mptype import MPType, Rank, Shape
63
- from mplang.v1.core.pfunc import PFunction, get_fn_name
64
-
65
- # Import primitive-dependent symbols at the end to avoid cycles when frontend ops
66
- # import from the core facade during package initialization.
67
- from mplang.v1.core.primitive import (
68
- builtin_function,
69
- function,
70
- pconv,
71
- peval,
72
- pmask,
73
- pshfl,
74
- pshfl_s,
75
- psize,
76
- uniform_cond,
77
- while_loop,
78
- )
79
- from mplang.v1.core.table import TableLike, TableType
80
- from mplang.v1.core.tensor import ScalarType, TensorLike, TensorType
81
- from mplang.v1.core.tracer import (
82
- TraceContext,
83
- TracedFunction,
84
- TraceVar,
85
- VarNamer,
86
- trace,
87
- )
88
-
89
- __all__ = [
90
- "BINARY",
91
- "BOOL",
92
- "COMPLEX64",
93
- "COMPLEX128",
94
- "DATE",
95
- "DECIMAL",
96
- "FLOAT16",
97
- "FLOAT32",
98
- "FLOAT64",
99
- "INT8",
100
- "INT16",
101
- "INT32",
102
- "INT64",
103
- "INTERVAL",
104
- "JSON",
105
- "STRING",
106
- "TIME",
107
- "TIMESTAMP",
108
- "UINT8",
109
- "UINT16",
110
- "UINT32",
111
- "UINT64",
112
- "UUID",
113
- "ClusterSpec",
114
- "CollectiveMixin",
115
- "CommunicatorBase",
116
- "DType",
117
- "Device",
118
- "ICollective",
119
- "ICommunicator",
120
- "InterpContext",
121
- "InterpVar",
122
- "IrReader",
123
- "IrWriter",
124
- "MPContext",
125
- "MPObject",
126
- "MPType",
127
- "Mask",
128
- "Node",
129
- "PFunction",
130
- "Rank",
131
- "RuntimeInfo",
132
- "ScalarType",
133
- "Shape",
134
- "TableLike",
135
- "TableType",
136
- "TensorLike",
137
- "TensorType",
138
- "TraceContext",
139
- "TraceVar",
140
- "TracedFunction",
141
- "VarNamer",
142
- "builtin_function",
143
- "cur_ctx",
144
- "function",
145
- "get_fn_name",
146
- "pconv",
147
- "peval",
148
- "pmask",
149
- "pshfl",
150
- "pshfl_s",
151
- "psize",
152
- "set_ctx",
153
- "trace",
154
- "uniform_cond",
155
- "while_loop",
156
- "with_ctx",
157
- ]
mplang/v1/core/cluster.py DELETED
@@ -1,343 +0,0 @@
1
- # Copyright 2025 Ant Group Co., Ltd.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
-
15
- """
16
- This module provides the formal data structures and parsing logic for the
17
- MPLang cluster configuration.
18
- """
19
-
20
- from __future__ import annotations
21
-
22
- from dataclasses import dataclass, field
23
- from functools import cached_property
24
- from typing import Any
25
-
26
-
27
- @dataclass(frozen=True)
28
- class RuntimeInfo:
29
- """Per-physical-node runtime configuration.
30
-
31
- ``op_bindings`` is a per-node override map (logical_op -> kernel_id) merged
32
- into that node's ``RuntimeContext``. Unknown future / auxiliary fields are
33
- preserved in ``extra``.
34
- """
35
-
36
- version: str
37
- platform: str
38
- # Per-node partial override dispatch table (merged over project defaults).
39
- op_bindings: dict[str, str] = field(default_factory=dict)
40
-
41
- # A catch-all for any other custom or future properties (must not collide
42
- # with reserved keys: version, platform, op_bindings).
43
- extra: dict[str, Any] = field(default_factory=dict)
44
-
45
- def to_dict(self) -> dict[str, Any]:
46
- """Convert RuntimeInfo to a dictionary (stable field names)."""
47
- result = {
48
- "version": self.version,
49
- "platform": self.platform,
50
- "op_bindings": self.op_bindings,
51
- }
52
- result.update(self.extra)
53
- return result
54
-
55
-
56
- @dataclass(frozen=True)
57
- class Node:
58
- """
59
- Represents a single physical node (PN) in the cluster.
60
- This is an immutable description of a compute resource.
61
- """
62
-
63
- name: str
64
- rank: int
65
- endpoint: str
66
- runtime_info: RuntimeInfo
67
-
68
- def to_dict(self) -> dict[str, Any]:
69
- """Convert PhysicalNode to a dictionary."""
70
- return {
71
- "name": self.name,
72
- "endpoint": self.endpoint,
73
- "runtime_info": self.runtime_info.to_dict(),
74
- }
75
-
76
-
77
- @dataclass(frozen=True)
78
- class Device:
79
- """
80
- Represents a logical device (LD), which is a user-facing computational entity.
81
- It is composed of one or more Physical Nodes.
82
- """
83
-
84
- name: str
85
- kind: str
86
- members: list[Node]
87
- config: dict[str, Any] = field(default_factory=dict)
88
-
89
- @property
90
- def member_ranks(self) -> list[int]:
91
- """Returns the ranks of the member PNs."""
92
- return sorted([node.rank for node in self.members])
93
-
94
- def to_dict(self) -> dict[str, Any]:
95
- """Convert LogicalDevice to a dictionary."""
96
- return {
97
- "kind": self.kind,
98
- "members": [node.name for node in self.members],
99
- "config": self.config,
100
- }
101
-
102
-
103
- @dataclass(frozen=True)
104
- class ClusterSpec:
105
- """
106
- The formal, validated representation of the entire cluster.
107
- This object is the "first-class citizen" representing the cluster topology.
108
- """
109
-
110
- nodes: dict[str, Node]
111
- devices: dict[str, Device]
112
-
113
- def __post_init__(self) -> None:
114
- for key, node in self.nodes.items():
115
- if key != node.name:
116
- raise ValueError(
117
- f"Node key '{key}' does not match node.name '{node.name}'"
118
- )
119
-
120
- for key, device in self.devices.items():
121
- if key != device.name:
122
- raise ValueError(
123
- f"Device key '{key}' does not match device.name '{device.name}'"
124
- )
125
-
126
- # check all device members are valid nodes
127
- node_names = set(self.nodes.keys())
128
- for device in self.devices.values():
129
- for member in device.members:
130
- if member.name not in node_names:
131
- raise ValueError(
132
- f"Device '{device.name}' has member '{member.name}' "
133
- "which is not defined in nodes"
134
- )
135
-
136
- # ensure ppu devices have exactly one member
137
- for device in self.devices.values():
138
- if device.kind.lower() == "ppu" and len(device.members) != 1:
139
- raise ValueError(
140
- f"PPU device '{device.name}' must have exactly one member"
141
- )
142
-
143
- def get_node(self, name: str) -> Node:
144
- """Get a Physical Node by its unique name."""
145
- return self.nodes[name]
146
-
147
- def get_device(self, name: str) -> Device:
148
- """Get a Logical Device by its unique name."""
149
- return self.devices[name]
150
-
151
- def get_devices_by_kind(self, kind: str) -> list[Device]:
152
- """Get all Logical Devices of a specific kind."""
153
- lowered = kind.lower()
154
- return [dev for dev in self.devices.values() if dev.kind.lower() == lowered]
155
-
156
- def get_node_by_rank(self, rank: int) -> Node:
157
- """Get a Physical Node by its unique rank."""
158
- # This might require an internal mapping for efficiency if called often
159
- for node in self.nodes.values():
160
- if node.rank == rank:
161
- return node
162
- raise KeyError(f"No Physical Node found with rank {rank}")
163
-
164
- def to_dict(self) -> dict[str, Any]:
165
- """Convert ClusterSpec to a dictionary."""
166
- return {
167
- "nodes": [node.to_dict() for node in self.nodes.values()],
168
- "devices": {
169
- name: device.to_dict() for name, device in self.devices.items()
170
- },
171
- }
172
-
173
- @cached_property
174
- def endpoints(self) -> list[str]:
175
- eps: list[str] = []
176
- for n in sorted(
177
- self.nodes.values(),
178
- key=lambda x: x.rank, # type: ignore[attr-defined]
179
- ):
180
- eps.append(n.endpoint)
181
- return eps
182
-
183
- @classmethod
184
- def from_dict(cls, config: dict[str, Any]) -> ClusterSpec:
185
- """Parses a raw config dictionary and returns a validated ClusterSpec."""
186
- # 1. Validate top-level keys
187
- if "nodes" not in config or "devices" not in config:
188
- raise ValueError(
189
- "Cluster config must contain 'nodes' and 'devices' sections."
190
- )
191
-
192
- # 2. Parse Physical Nodes, using the list index as the rank
193
- nodes_map: dict[str, Node] = {}
194
- # Reserved runtime info keys we recognize explicitly.
195
- known_runtime_fields = {"version", "platform", "op_bindings"}
196
- for i, node_cfg in enumerate(config["nodes"]):
197
- if "rank" in node_cfg:
198
- # Optionally, we can log a warning that the explicit 'rank' is ignored.
199
- pass
200
-
201
- runtime_info_cfg = node_cfg.get("runtime_info", {})
202
- extra_runtime_info = {
203
- k: v
204
- for k, v in runtime_info_cfg.items()
205
- if k not in known_runtime_fields
206
- }
207
- runtime_info = RuntimeInfo(
208
- version=runtime_info_cfg.get("version", "N/A"),
209
- platform=runtime_info_cfg.get("platform", "N/A"),
210
- op_bindings=runtime_info_cfg.get("op_bindings", {}) or {},
211
- extra=extra_runtime_info,
212
- )
213
-
214
- node = Node(
215
- name=node_cfg["name"],
216
- rank=i, # Implicit rank assignment
217
- endpoint=node_cfg["endpoint"],
218
- runtime_info=runtime_info,
219
- )
220
-
221
- if node.name in nodes_map:
222
- raise ValueError(f"Duplicate node name found: {node.name}")
223
- nodes_map[node.name] = node
224
-
225
- # 3. Parse Logical Devices
226
- devices_map: dict[str, Device] = {}
227
- for dev_name, dev_cfg in config["devices"].items():
228
- member_nodes = []
229
- for member_name in dev_cfg["members"]:
230
- if member_name not in nodes_map:
231
- raise ValueError(
232
- f"Node '{member_name}' in device '{dev_name}' not defined in 'nodes' section."
233
- )
234
- member_nodes.append(nodes_map[member_name])
235
-
236
- devices_map[dev_name] = Device(
237
- name=dev_name,
238
- kind=dev_cfg["kind"],
239
- members=member_nodes,
240
- config=dev_cfg.get("config", {}),
241
- )
242
-
243
- return cls(nodes=nodes_map, devices=devices_map)
244
-
245
- @classmethod
246
- def simple(
247
- cls,
248
- world_size: int,
249
- *,
250
- endpoints: list[str] | None = None,
251
- spu_world_size: int | None = None,
252
- spu_protocol: str = "SEMI2K",
253
- spu_field: str = "FM128",
254
- runtime_version: str = "simulated",
255
- runtime_platform: str = "simulated",
256
- op_bindings: list[dict[str, str]] | None = None,
257
- enable_ppu_device: bool = True,
258
- enable_spu_device: bool = True,
259
- ) -> ClusterSpec:
260
- """Convenience constructor used heavily in tests.
261
-
262
- Parameters
263
- ----------
264
- world_size:
265
- Number of parties (physical nodes).
266
- endpoints:
267
- Optional explicit endpoint list of length ``world_size``. Each element may
268
- include scheme (``http://``) or not; stored verbatim. If not provided we
269
- synthesize ``localhost:{5000 + i}`` (5000 is a fixed default; pass explicit
270
- endpoints for control).
271
- spu_protocol / spu_field:
272
- SPU device config values.
273
- runtime_version / runtime_platform:
274
- Populated into each node's ``RuntimeInfo``.
275
- op_bindings:
276
- Optional list of length ``world_size`` supplying per-node op_bindings
277
- override dicts (defaults to empty dicts).
278
- enable_ppu_device:
279
- If True (default), create one ``P{rank}`` PPU device per node.
280
- enable_spu_device:
281
- If True (default) create a shared SPU device named ``SP0``.
282
- """
283
- base_port = 5000
284
-
285
- if endpoints is not None and len(endpoints) != world_size:
286
- raise ValueError(
287
- "len(endpoints) must equal world_size when provided: "
288
- f"{len(endpoints)} != {world_size}"
289
- )
290
-
291
- if op_bindings is not None and len(op_bindings) != world_size:
292
- raise ValueError(
293
- "len(op_bindings) must equal world_size when provided: "
294
- f"{len(op_bindings)} != {world_size}"
295
- )
296
-
297
- if not enable_ppu_device and not enable_spu_device:
298
- raise ValueError(
299
- "At least one of enable_ppu_device or enable_spu_device must be True"
300
- )
301
-
302
- nodes: dict[str, Node] = {}
303
- for i in range(world_size):
304
- ep = endpoints[i] if endpoints is not None else f"localhost:{base_port + i}"
305
- node_op_bindings = op_bindings[i] if op_bindings is not None else {}
306
- nodes[f"node{i}"] = Node(
307
- name=f"node{i}",
308
- rank=i,
309
- endpoint=ep,
310
- runtime_info=RuntimeInfo(
311
- version=runtime_version,
312
- platform=runtime_platform,
313
- op_bindings=node_op_bindings,
314
- ),
315
- )
316
-
317
- devices: dict[str, Device] = {}
318
- # Optional per-node PPU devices
319
- if enable_ppu_device:
320
- for i in range(world_size):
321
- devices[f"P{i}"] = Device(
322
- name=f"P{i}",
323
- kind="ppu",
324
- members=[nodes[f"node{i}"]],
325
- )
326
-
327
- # Shared SPU device
328
- if enable_spu_device:
329
- if spu_world_size is None:
330
- spu_world_size = world_size
331
- spu_members = [nodes[f"node{i}"] for i in range(spu_world_size)]
332
-
333
- devices["SP0"] = Device(
334
- name="SP0",
335
- kind="SPU",
336
- members=spu_members,
337
- config={
338
- "protocol": spu_protocol,
339
- "field": spu_field,
340
- },
341
- )
342
-
343
- return cls(nodes=nodes, devices=devices)