mplang-nightly 0.1.dev268__py3-none-any.whl → 0.1.dev270__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mplang/__init__.py +391 -17
- mplang/{v2/backends → backends}/__init__.py +9 -7
- mplang/{v2/backends → backends}/bfv_impl.py +6 -6
- mplang/{v2/backends → backends}/crypto_impl.py +6 -6
- mplang/{v2/backends → backends}/field_impl.py +5 -5
- mplang/{v2/backends → backends}/func_impl.py +4 -4
- mplang/{v2/backends → backends}/phe_impl.py +3 -3
- mplang/{v2/backends → backends}/simp_design.md +1 -1
- mplang/{v2/backends → backends}/simp_driver/__init__.py +5 -5
- mplang/{v2/backends → backends}/simp_driver/http.py +8 -8
- mplang/{v2/backends → backends}/simp_driver/mem.py +9 -9
- mplang/{v2/backends → backends}/simp_driver/ops.py +4 -4
- mplang/{v2/backends → backends}/simp_driver/state.py +2 -2
- mplang/{v2/backends → backends}/simp_driver/values.py +2 -2
- mplang/{v2/backends → backends}/simp_worker/__init__.py +3 -3
- mplang/{v2/backends → backends}/simp_worker/http.py +10 -10
- mplang/{v2/backends → backends}/simp_worker/mem.py +1 -1
- mplang/{v2/backends → backends}/simp_worker/ops.py +5 -5
- mplang/{v2/backends → backends}/simp_worker/state.py +2 -4
- mplang/{v2/backends → backends}/spu_impl.py +8 -8
- mplang/{v2/backends → backends}/spu_state.py +4 -4
- mplang/{v2/backends → backends}/store_impl.py +3 -3
- mplang/{v2/backends → backends}/table_impl.py +8 -8
- mplang/{v2/backends → backends}/tee_impl.py +6 -6
- mplang/{v2/backends → backends}/tensor_impl.py +6 -6
- mplang/{v2/cli.py → cli.py} +9 -9
- mplang/{v2/cli_guide.md → cli_guide.md} +12 -12
- mplang/{v2/dialects → dialects}/__init__.py +5 -5
- mplang/{v2/dialects → dialects}/bfv.py +6 -6
- mplang/{v2/dialects → dialects}/crypto.py +5 -5
- mplang/{v2/dialects → dialects}/dtypes.py +2 -2
- mplang/{v2/dialects → dialects}/field.py +3 -3
- mplang/{v2/dialects → dialects}/func.py +2 -2
- mplang/{v2/dialects → dialects}/phe.py +6 -6
- mplang/{v2/dialects → dialects}/simp.py +6 -6
- mplang/{v2/dialects → dialects}/spu.py +7 -7
- mplang/{v2/dialects → dialects}/store.py +2 -2
- mplang/{v2/dialects → dialects}/table.py +3 -3
- mplang/{v2/dialects → dialects}/tee.py +6 -6
- mplang/{v2/dialects → dialects}/tensor.py +5 -5
- mplang/{v2/edsl → edsl}/__init__.py +3 -3
- mplang/{v2/edsl → edsl}/context.py +6 -6
- mplang/{v2/edsl → edsl}/graph.py +5 -5
- mplang/{v2/edsl → edsl}/jit.py +2 -2
- mplang/{v2/edsl → edsl}/object.py +1 -1
- mplang/{v2/edsl → edsl}/primitive.py +5 -5
- mplang/{v2/edsl → edsl}/printer.py +1 -1
- mplang/{v2/edsl → edsl}/serde.py +1 -1
- mplang/{v2/edsl → edsl}/tracer.py +7 -7
- mplang/{v2/edsl → edsl}/typing.py +1 -1
- mplang/{v2/kernels → kernels}/ldpc.cpp +13 -13
- mplang/{v2/kernels → kernels}/okvs.cpp +4 -4
- mplang/{v2/kernels → kernels}/okvs_opt.cpp +46 -31
- mplang/{v2/kernels → kernels}/py_kernels.py +1 -1
- mplang/{v2/libs → libs}/collective.py +5 -5
- mplang/{v2/libs → libs}/device/__init__.py +1 -1
- mplang/{v2/libs → libs}/device/api.py +12 -12
- mplang/{v2/libs → libs}/ml/__init__.py +1 -1
- mplang/{v2/libs → libs}/ml/sgb.py +4 -4
- mplang/{v2/libs → libs}/mpc/__init__.py +3 -3
- mplang/{v2/libs → libs}/mpc/_utils.py +2 -2
- mplang/{v2/libs → libs}/mpc/analytics/aggregation.py +1 -1
- mplang/{v2/libs → libs}/mpc/analytics/groupby.py +2 -2
- mplang/{v2/libs → libs}/mpc/analytics/permutation.py +3 -3
- mplang/{v2/libs → libs}/mpc/ot/base.py +3 -3
- mplang/{v2/libs → libs}/mpc/ot/extension.py +2 -2
- mplang/{v2/libs → libs}/mpc/ot/silent.py +4 -4
- mplang/{v2/libs → libs}/mpc/psi/cuckoo.py +3 -3
- mplang/{v2/libs → libs}/mpc/psi/okvs.py +1 -1
- mplang/{v2/libs → libs}/mpc/psi/okvs_gct.py +19 -13
- mplang/{v2/libs → libs}/mpc/psi/oprf.py +3 -3
- mplang/libs/mpc/psi/rr22.py +303 -0
- mplang/{v2/libs → libs}/mpc/psi/unbalanced.py +4 -4
- mplang/{v2/libs → libs}/mpc/vole/gilboa.py +3 -3
- mplang/{v2/libs → libs}/mpc/vole/ldpc.py +2 -2
- mplang/{v2/libs → libs}/mpc/vole/silver.py +6 -6
- mplang/{v2/runtime → runtime}/interpreter.py +11 -11
- mplang/{v2/runtime → runtime}/value.py +2 -2
- mplang/{v1/runtime → utils}/__init__.py +18 -15
- mplang/{v1/utils → utils}/func_utils.py +1 -1
- {mplang_nightly-0.1.dev268.dist-info → mplang_nightly-0.1.dev270.dist-info}/METADATA +2 -2
- mplang_nightly-0.1.dev270.dist-info/RECORD +102 -0
- mplang/v1/__init__.py +0 -157
- mplang/v1/_device.py +0 -602
- mplang/v1/analysis/__init__.py +0 -37
- mplang/v1/analysis/diagram.py +0 -567
- mplang/v1/core/__init__.py +0 -157
- mplang/v1/core/cluster.py +0 -343
- mplang/v1/core/comm.py +0 -281
- mplang/v1/core/context_mgr.py +0 -50
- mplang/v1/core/dtypes.py +0 -335
- mplang/v1/core/expr/__init__.py +0 -80
- mplang/v1/core/expr/ast.py +0 -542
- mplang/v1/core/expr/evaluator.py +0 -581
- mplang/v1/core/expr/printer.py +0 -285
- mplang/v1/core/expr/transformer.py +0 -141
- mplang/v1/core/expr/utils.py +0 -78
- mplang/v1/core/expr/visitor.py +0 -85
- mplang/v1/core/expr/walk.py +0 -387
- mplang/v1/core/interp.py +0 -160
- mplang/v1/core/mask.py +0 -325
- mplang/v1/core/mpir.py +0 -965
- mplang/v1/core/mpobject.py +0 -117
- mplang/v1/core/mptype.py +0 -407
- mplang/v1/core/pfunc.py +0 -130
- mplang/v1/core/primitive.py +0 -877
- mplang/v1/core/table.py +0 -218
- mplang/v1/core/tensor.py +0 -75
- mplang/v1/core/tracer.py +0 -383
- mplang/v1/host.py +0 -130
- mplang/v1/kernels/__init__.py +0 -41
- mplang/v1/kernels/base.py +0 -125
- mplang/v1/kernels/basic.py +0 -240
- mplang/v1/kernels/context.py +0 -369
- mplang/v1/kernels/crypto.py +0 -122
- mplang/v1/kernels/fhe.py +0 -858
- mplang/v1/kernels/mock_tee.py +0 -72
- mplang/v1/kernels/phe.py +0 -1864
- mplang/v1/kernels/spu.py +0 -341
- mplang/v1/kernels/sql_duckdb.py +0 -44
- mplang/v1/kernels/stablehlo.py +0 -90
- mplang/v1/kernels/value.py +0 -626
- mplang/v1/ops/__init__.py +0 -35
- mplang/v1/ops/base.py +0 -424
- mplang/v1/ops/basic.py +0 -294
- mplang/v1/ops/crypto.py +0 -262
- mplang/v1/ops/fhe.py +0 -272
- mplang/v1/ops/jax_cc.py +0 -147
- mplang/v1/ops/nnx_cc.py +0 -168
- mplang/v1/ops/phe.py +0 -216
- mplang/v1/ops/spu.py +0 -151
- mplang/v1/ops/sql_cc.py +0 -303
- mplang/v1/ops/tee.py +0 -36
- mplang/v1/protos/v1alpha1/mpir_pb2.py +0 -63
- mplang/v1/protos/v1alpha1/mpir_pb2.pyi +0 -557
- mplang/v1/protos/v1alpha1/value_pb2.py +0 -34
- mplang/v1/protos/v1alpha1/value_pb2.pyi +0 -169
- mplang/v1/runtime/channel.py +0 -230
- mplang/v1/runtime/cli.py +0 -451
- mplang/v1/runtime/client.py +0 -456
- mplang/v1/runtime/communicator.py +0 -131
- mplang/v1/runtime/data_providers.py +0 -303
- mplang/v1/runtime/driver.py +0 -324
- mplang/v1/runtime/exceptions.py +0 -27
- mplang/v1/runtime/http_api.md +0 -56
- mplang/v1/runtime/link_comm.py +0 -196
- mplang/v1/runtime/server.py +0 -501
- mplang/v1/runtime/session.py +0 -270
- mplang/v1/runtime/simulation.py +0 -324
- mplang/v1/simp/__init__.py +0 -13
- mplang/v1/simp/api.py +0 -353
- mplang/v1/simp/mpi.py +0 -131
- mplang/v1/simp/party.py +0 -225
- mplang/v1/simp/random.py +0 -120
- mplang/v1/simp/smpc.py +0 -238
- mplang/v1/utils/__init__.py +0 -13
- mplang/v1/utils/crypto.py +0 -32
- mplang/v1/utils/spu_utils.py +0 -130
- mplang/v1/utils/table_utils.py +0 -185
- mplang/v2/__init__.py +0 -424
- mplang/v2/libs/mpc/psi/rr22.py +0 -344
- mplang_nightly-0.1.dev268.dist-info/RECORD +0 -180
- /mplang/{v2/backends → backends}/channel.py +0 -0
- /mplang/{v2/edsl → edsl}/README.md +0 -0
- /mplang/{v2/edsl → edsl}/registry.py +0 -0
- /mplang/{v2/kernels → kernels}/Makefile +0 -0
- /mplang/{v2/kernels → kernels}/__init__.py +0 -0
- /mplang/{v2/kernels → kernels}/gf128.cpp +0 -0
- /mplang/{v2/libs → libs}/device/cluster.py +0 -0
- /mplang/{v2/libs → libs}/mpc/analytics/__init__.py +0 -0
- /mplang/{v2/libs → libs}/mpc/analytics/groupby.md +0 -0
- /mplang/{v2/libs → libs}/mpc/common/constants.py +0 -0
- /mplang/{v2/libs → libs}/mpc/ot/__init__.py +0 -0
- /mplang/{v2/libs → libs}/mpc/psi/__init__.py +0 -0
- /mplang/{v2/libs → libs}/mpc/vole/__init__.py +0 -0
- /mplang/{v2/runtime → runtime}/__init__.py +0 -0
- /mplang/{v2/runtime → runtime}/dialect_state.py +0 -0
- /mplang/{v2/runtime → runtime}/object_store.py +0 -0
- {mplang_nightly-0.1.dev268.dist-info → mplang_nightly-0.1.dev270.dist-info}/WHEEL +0 -0
- {mplang_nightly-0.1.dev268.dist-info → mplang_nightly-0.1.dev270.dist-info}/entry_points.txt +0 -0
- {mplang_nightly-0.1.dev268.dist-info → mplang_nightly-0.1.dev270.dist-info}/licenses/LICENSE +0 -0
mplang/v1/core/__init__.py
DELETED
|
@@ -1,157 +0,0 @@
|
|
|
1
|
-
# Copyright 2025 Ant Group Co., Ltd.
|
|
2
|
-
#
|
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
-
# you may not use this file except in compliance with the License.
|
|
5
|
-
# You may obtain a copy of the License at
|
|
6
|
-
#
|
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
-
#
|
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
-
# See the License for the specific language governing permissions and
|
|
13
|
-
# limitations under the License.
|
|
14
|
-
|
|
15
|
-
"""
|
|
16
|
-
Core components for multi-party computation.
|
|
17
|
-
|
|
18
|
-
This package provides the fundamental building blocks for multi-party computation,
|
|
19
|
-
including type systems, tracing mechanisms, and interpreter contexts.
|
|
20
|
-
"""
|
|
21
|
-
|
|
22
|
-
# Core type system
|
|
23
|
-
# Communication interfaces & core symbols
|
|
24
|
-
from mplang.v1.core.cluster import ClusterSpec, Device, Node, RuntimeInfo
|
|
25
|
-
from mplang.v1.core.comm import (
|
|
26
|
-
CollectiveMixin,
|
|
27
|
-
CommunicatorBase,
|
|
28
|
-
ICollective,
|
|
29
|
-
ICommunicator,
|
|
30
|
-
)
|
|
31
|
-
from mplang.v1.core.context_mgr import cur_ctx, set_ctx, with_ctx
|
|
32
|
-
from mplang.v1.core.dtypes import (
|
|
33
|
-
BINARY,
|
|
34
|
-
BOOL,
|
|
35
|
-
COMPLEX64,
|
|
36
|
-
COMPLEX128,
|
|
37
|
-
DATE,
|
|
38
|
-
DECIMAL,
|
|
39
|
-
FLOAT16,
|
|
40
|
-
FLOAT32,
|
|
41
|
-
FLOAT64,
|
|
42
|
-
INT8,
|
|
43
|
-
INT16,
|
|
44
|
-
INT32,
|
|
45
|
-
INT64,
|
|
46
|
-
INTERVAL,
|
|
47
|
-
JSON,
|
|
48
|
-
STRING,
|
|
49
|
-
TIME,
|
|
50
|
-
TIMESTAMP,
|
|
51
|
-
UINT8,
|
|
52
|
-
UINT16,
|
|
53
|
-
UINT32,
|
|
54
|
-
UINT64,
|
|
55
|
-
UUID,
|
|
56
|
-
DType,
|
|
57
|
-
)
|
|
58
|
-
from mplang.v1.core.interp import InterpContext, InterpVar
|
|
59
|
-
from mplang.v1.core.mask import Mask
|
|
60
|
-
from mplang.v1.core.mpir import IrReader, IrWriter
|
|
61
|
-
from mplang.v1.core.mpobject import MPContext, MPObject
|
|
62
|
-
from mplang.v1.core.mptype import MPType, Rank, Shape
|
|
63
|
-
from mplang.v1.core.pfunc import PFunction, get_fn_name
|
|
64
|
-
|
|
65
|
-
# Import primitive-dependent symbols at the end to avoid cycles when frontend ops
|
|
66
|
-
# import from the core facade during package initialization.
|
|
67
|
-
from mplang.v1.core.primitive import (
|
|
68
|
-
builtin_function,
|
|
69
|
-
function,
|
|
70
|
-
pconv,
|
|
71
|
-
peval,
|
|
72
|
-
pmask,
|
|
73
|
-
pshfl,
|
|
74
|
-
pshfl_s,
|
|
75
|
-
psize,
|
|
76
|
-
uniform_cond,
|
|
77
|
-
while_loop,
|
|
78
|
-
)
|
|
79
|
-
from mplang.v1.core.table import TableLike, TableType
|
|
80
|
-
from mplang.v1.core.tensor import ScalarType, TensorLike, TensorType
|
|
81
|
-
from mplang.v1.core.tracer import (
|
|
82
|
-
TraceContext,
|
|
83
|
-
TracedFunction,
|
|
84
|
-
TraceVar,
|
|
85
|
-
VarNamer,
|
|
86
|
-
trace,
|
|
87
|
-
)
|
|
88
|
-
|
|
89
|
-
__all__ = [
|
|
90
|
-
"BINARY",
|
|
91
|
-
"BOOL",
|
|
92
|
-
"COMPLEX64",
|
|
93
|
-
"COMPLEX128",
|
|
94
|
-
"DATE",
|
|
95
|
-
"DECIMAL",
|
|
96
|
-
"FLOAT16",
|
|
97
|
-
"FLOAT32",
|
|
98
|
-
"FLOAT64",
|
|
99
|
-
"INT8",
|
|
100
|
-
"INT16",
|
|
101
|
-
"INT32",
|
|
102
|
-
"INT64",
|
|
103
|
-
"INTERVAL",
|
|
104
|
-
"JSON",
|
|
105
|
-
"STRING",
|
|
106
|
-
"TIME",
|
|
107
|
-
"TIMESTAMP",
|
|
108
|
-
"UINT8",
|
|
109
|
-
"UINT16",
|
|
110
|
-
"UINT32",
|
|
111
|
-
"UINT64",
|
|
112
|
-
"UUID",
|
|
113
|
-
"ClusterSpec",
|
|
114
|
-
"CollectiveMixin",
|
|
115
|
-
"CommunicatorBase",
|
|
116
|
-
"DType",
|
|
117
|
-
"Device",
|
|
118
|
-
"ICollective",
|
|
119
|
-
"ICommunicator",
|
|
120
|
-
"InterpContext",
|
|
121
|
-
"InterpVar",
|
|
122
|
-
"IrReader",
|
|
123
|
-
"IrWriter",
|
|
124
|
-
"MPContext",
|
|
125
|
-
"MPObject",
|
|
126
|
-
"MPType",
|
|
127
|
-
"Mask",
|
|
128
|
-
"Node",
|
|
129
|
-
"PFunction",
|
|
130
|
-
"Rank",
|
|
131
|
-
"RuntimeInfo",
|
|
132
|
-
"ScalarType",
|
|
133
|
-
"Shape",
|
|
134
|
-
"TableLike",
|
|
135
|
-
"TableType",
|
|
136
|
-
"TensorLike",
|
|
137
|
-
"TensorType",
|
|
138
|
-
"TraceContext",
|
|
139
|
-
"TraceVar",
|
|
140
|
-
"TracedFunction",
|
|
141
|
-
"VarNamer",
|
|
142
|
-
"builtin_function",
|
|
143
|
-
"cur_ctx",
|
|
144
|
-
"function",
|
|
145
|
-
"get_fn_name",
|
|
146
|
-
"pconv",
|
|
147
|
-
"peval",
|
|
148
|
-
"pmask",
|
|
149
|
-
"pshfl",
|
|
150
|
-
"pshfl_s",
|
|
151
|
-
"psize",
|
|
152
|
-
"set_ctx",
|
|
153
|
-
"trace",
|
|
154
|
-
"uniform_cond",
|
|
155
|
-
"while_loop",
|
|
156
|
-
"with_ctx",
|
|
157
|
-
]
|
mplang/v1/core/cluster.py
DELETED
|
@@ -1,343 +0,0 @@
|
|
|
1
|
-
# Copyright 2025 Ant Group Co., Ltd.
|
|
2
|
-
#
|
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
-
# you may not use this file except in compliance with the License.
|
|
5
|
-
# You may obtain a copy of the License at
|
|
6
|
-
#
|
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
-
#
|
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
-
# See the License for the specific language governing permissions and
|
|
13
|
-
# limitations under the License.
|
|
14
|
-
|
|
15
|
-
"""
|
|
16
|
-
This module provides the formal data structures and parsing logic for the
|
|
17
|
-
MPLang cluster configuration.
|
|
18
|
-
"""
|
|
19
|
-
|
|
20
|
-
from __future__ import annotations
|
|
21
|
-
|
|
22
|
-
from dataclasses import dataclass, field
|
|
23
|
-
from functools import cached_property
|
|
24
|
-
from typing import Any
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
@dataclass(frozen=True)
|
|
28
|
-
class RuntimeInfo:
|
|
29
|
-
"""Per-physical-node runtime configuration.
|
|
30
|
-
|
|
31
|
-
``op_bindings`` is a per-node override map (logical_op -> kernel_id) merged
|
|
32
|
-
into that node's ``RuntimeContext``. Unknown future / auxiliary fields are
|
|
33
|
-
preserved in ``extra``.
|
|
34
|
-
"""
|
|
35
|
-
|
|
36
|
-
version: str
|
|
37
|
-
platform: str
|
|
38
|
-
# Per-node partial override dispatch table (merged over project defaults).
|
|
39
|
-
op_bindings: dict[str, str] = field(default_factory=dict)
|
|
40
|
-
|
|
41
|
-
# A catch-all for any other custom or future properties (must not collide
|
|
42
|
-
# with reserved keys: version, platform, op_bindings).
|
|
43
|
-
extra: dict[str, Any] = field(default_factory=dict)
|
|
44
|
-
|
|
45
|
-
def to_dict(self) -> dict[str, Any]:
|
|
46
|
-
"""Convert RuntimeInfo to a dictionary (stable field names)."""
|
|
47
|
-
result = {
|
|
48
|
-
"version": self.version,
|
|
49
|
-
"platform": self.platform,
|
|
50
|
-
"op_bindings": self.op_bindings,
|
|
51
|
-
}
|
|
52
|
-
result.update(self.extra)
|
|
53
|
-
return result
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
@dataclass(frozen=True)
|
|
57
|
-
class Node:
|
|
58
|
-
"""
|
|
59
|
-
Represents a single physical node (PN) in the cluster.
|
|
60
|
-
This is an immutable description of a compute resource.
|
|
61
|
-
"""
|
|
62
|
-
|
|
63
|
-
name: str
|
|
64
|
-
rank: int
|
|
65
|
-
endpoint: str
|
|
66
|
-
runtime_info: RuntimeInfo
|
|
67
|
-
|
|
68
|
-
def to_dict(self) -> dict[str, Any]:
|
|
69
|
-
"""Convert PhysicalNode to a dictionary."""
|
|
70
|
-
return {
|
|
71
|
-
"name": self.name,
|
|
72
|
-
"endpoint": self.endpoint,
|
|
73
|
-
"runtime_info": self.runtime_info.to_dict(),
|
|
74
|
-
}
|
|
75
|
-
|
|
76
|
-
|
|
77
|
-
@dataclass(frozen=True)
|
|
78
|
-
class Device:
|
|
79
|
-
"""
|
|
80
|
-
Represents a logical device (LD), which is a user-facing computational entity.
|
|
81
|
-
It is composed of one or more Physical Nodes.
|
|
82
|
-
"""
|
|
83
|
-
|
|
84
|
-
name: str
|
|
85
|
-
kind: str
|
|
86
|
-
members: list[Node]
|
|
87
|
-
config: dict[str, Any] = field(default_factory=dict)
|
|
88
|
-
|
|
89
|
-
@property
|
|
90
|
-
def member_ranks(self) -> list[int]:
|
|
91
|
-
"""Returns the ranks of the member PNs."""
|
|
92
|
-
return sorted([node.rank for node in self.members])
|
|
93
|
-
|
|
94
|
-
def to_dict(self) -> dict[str, Any]:
|
|
95
|
-
"""Convert LogicalDevice to a dictionary."""
|
|
96
|
-
return {
|
|
97
|
-
"kind": self.kind,
|
|
98
|
-
"members": [node.name for node in self.members],
|
|
99
|
-
"config": self.config,
|
|
100
|
-
}
|
|
101
|
-
|
|
102
|
-
|
|
103
|
-
@dataclass(frozen=True)
|
|
104
|
-
class ClusterSpec:
|
|
105
|
-
"""
|
|
106
|
-
The formal, validated representation of the entire cluster.
|
|
107
|
-
This object is the "first-class citizen" representing the cluster topology.
|
|
108
|
-
"""
|
|
109
|
-
|
|
110
|
-
nodes: dict[str, Node]
|
|
111
|
-
devices: dict[str, Device]
|
|
112
|
-
|
|
113
|
-
def __post_init__(self) -> None:
|
|
114
|
-
for key, node in self.nodes.items():
|
|
115
|
-
if key != node.name:
|
|
116
|
-
raise ValueError(
|
|
117
|
-
f"Node key '{key}' does not match node.name '{node.name}'"
|
|
118
|
-
)
|
|
119
|
-
|
|
120
|
-
for key, device in self.devices.items():
|
|
121
|
-
if key != device.name:
|
|
122
|
-
raise ValueError(
|
|
123
|
-
f"Device key '{key}' does not match device.name '{device.name}'"
|
|
124
|
-
)
|
|
125
|
-
|
|
126
|
-
# check all device members are valid nodes
|
|
127
|
-
node_names = set(self.nodes.keys())
|
|
128
|
-
for device in self.devices.values():
|
|
129
|
-
for member in device.members:
|
|
130
|
-
if member.name not in node_names:
|
|
131
|
-
raise ValueError(
|
|
132
|
-
f"Device '{device.name}' has member '{member.name}' "
|
|
133
|
-
"which is not defined in nodes"
|
|
134
|
-
)
|
|
135
|
-
|
|
136
|
-
# ensure ppu devices have exactly one member
|
|
137
|
-
for device in self.devices.values():
|
|
138
|
-
if device.kind.lower() == "ppu" and len(device.members) != 1:
|
|
139
|
-
raise ValueError(
|
|
140
|
-
f"PPU device '{device.name}' must have exactly one member"
|
|
141
|
-
)
|
|
142
|
-
|
|
143
|
-
def get_node(self, name: str) -> Node:
|
|
144
|
-
"""Get a Physical Node by its unique name."""
|
|
145
|
-
return self.nodes[name]
|
|
146
|
-
|
|
147
|
-
def get_device(self, name: str) -> Device:
|
|
148
|
-
"""Get a Logical Device by its unique name."""
|
|
149
|
-
return self.devices[name]
|
|
150
|
-
|
|
151
|
-
def get_devices_by_kind(self, kind: str) -> list[Device]:
|
|
152
|
-
"""Get all Logical Devices of a specific kind."""
|
|
153
|
-
lowered = kind.lower()
|
|
154
|
-
return [dev for dev in self.devices.values() if dev.kind.lower() == lowered]
|
|
155
|
-
|
|
156
|
-
def get_node_by_rank(self, rank: int) -> Node:
|
|
157
|
-
"""Get a Physical Node by its unique rank."""
|
|
158
|
-
# This might require an internal mapping for efficiency if called often
|
|
159
|
-
for node in self.nodes.values():
|
|
160
|
-
if node.rank == rank:
|
|
161
|
-
return node
|
|
162
|
-
raise KeyError(f"No Physical Node found with rank {rank}")
|
|
163
|
-
|
|
164
|
-
def to_dict(self) -> dict[str, Any]:
|
|
165
|
-
"""Convert ClusterSpec to a dictionary."""
|
|
166
|
-
return {
|
|
167
|
-
"nodes": [node.to_dict() for node in self.nodes.values()],
|
|
168
|
-
"devices": {
|
|
169
|
-
name: device.to_dict() for name, device in self.devices.items()
|
|
170
|
-
},
|
|
171
|
-
}
|
|
172
|
-
|
|
173
|
-
@cached_property
|
|
174
|
-
def endpoints(self) -> list[str]:
|
|
175
|
-
eps: list[str] = []
|
|
176
|
-
for n in sorted(
|
|
177
|
-
self.nodes.values(),
|
|
178
|
-
key=lambda x: x.rank, # type: ignore[attr-defined]
|
|
179
|
-
):
|
|
180
|
-
eps.append(n.endpoint)
|
|
181
|
-
return eps
|
|
182
|
-
|
|
183
|
-
@classmethod
|
|
184
|
-
def from_dict(cls, config: dict[str, Any]) -> ClusterSpec:
|
|
185
|
-
"""Parses a raw config dictionary and returns a validated ClusterSpec."""
|
|
186
|
-
# 1. Validate top-level keys
|
|
187
|
-
if "nodes" not in config or "devices" not in config:
|
|
188
|
-
raise ValueError(
|
|
189
|
-
"Cluster config must contain 'nodes' and 'devices' sections."
|
|
190
|
-
)
|
|
191
|
-
|
|
192
|
-
# 2. Parse Physical Nodes, using the list index as the rank
|
|
193
|
-
nodes_map: dict[str, Node] = {}
|
|
194
|
-
# Reserved runtime info keys we recognize explicitly.
|
|
195
|
-
known_runtime_fields = {"version", "platform", "op_bindings"}
|
|
196
|
-
for i, node_cfg in enumerate(config["nodes"]):
|
|
197
|
-
if "rank" in node_cfg:
|
|
198
|
-
# Optionally, we can log a warning that the explicit 'rank' is ignored.
|
|
199
|
-
pass
|
|
200
|
-
|
|
201
|
-
runtime_info_cfg = node_cfg.get("runtime_info", {})
|
|
202
|
-
extra_runtime_info = {
|
|
203
|
-
k: v
|
|
204
|
-
for k, v in runtime_info_cfg.items()
|
|
205
|
-
if k not in known_runtime_fields
|
|
206
|
-
}
|
|
207
|
-
runtime_info = RuntimeInfo(
|
|
208
|
-
version=runtime_info_cfg.get("version", "N/A"),
|
|
209
|
-
platform=runtime_info_cfg.get("platform", "N/A"),
|
|
210
|
-
op_bindings=runtime_info_cfg.get("op_bindings", {}) or {},
|
|
211
|
-
extra=extra_runtime_info,
|
|
212
|
-
)
|
|
213
|
-
|
|
214
|
-
node = Node(
|
|
215
|
-
name=node_cfg["name"],
|
|
216
|
-
rank=i, # Implicit rank assignment
|
|
217
|
-
endpoint=node_cfg["endpoint"],
|
|
218
|
-
runtime_info=runtime_info,
|
|
219
|
-
)
|
|
220
|
-
|
|
221
|
-
if node.name in nodes_map:
|
|
222
|
-
raise ValueError(f"Duplicate node name found: {node.name}")
|
|
223
|
-
nodes_map[node.name] = node
|
|
224
|
-
|
|
225
|
-
# 3. Parse Logical Devices
|
|
226
|
-
devices_map: dict[str, Device] = {}
|
|
227
|
-
for dev_name, dev_cfg in config["devices"].items():
|
|
228
|
-
member_nodes = []
|
|
229
|
-
for member_name in dev_cfg["members"]:
|
|
230
|
-
if member_name not in nodes_map:
|
|
231
|
-
raise ValueError(
|
|
232
|
-
f"Node '{member_name}' in device '{dev_name}' not defined in 'nodes' section."
|
|
233
|
-
)
|
|
234
|
-
member_nodes.append(nodes_map[member_name])
|
|
235
|
-
|
|
236
|
-
devices_map[dev_name] = Device(
|
|
237
|
-
name=dev_name,
|
|
238
|
-
kind=dev_cfg["kind"],
|
|
239
|
-
members=member_nodes,
|
|
240
|
-
config=dev_cfg.get("config", {}),
|
|
241
|
-
)
|
|
242
|
-
|
|
243
|
-
return cls(nodes=nodes_map, devices=devices_map)
|
|
244
|
-
|
|
245
|
-
@classmethod
|
|
246
|
-
def simple(
|
|
247
|
-
cls,
|
|
248
|
-
world_size: int,
|
|
249
|
-
*,
|
|
250
|
-
endpoints: list[str] | None = None,
|
|
251
|
-
spu_world_size: int | None = None,
|
|
252
|
-
spu_protocol: str = "SEMI2K",
|
|
253
|
-
spu_field: str = "FM128",
|
|
254
|
-
runtime_version: str = "simulated",
|
|
255
|
-
runtime_platform: str = "simulated",
|
|
256
|
-
op_bindings: list[dict[str, str]] | None = None,
|
|
257
|
-
enable_ppu_device: bool = True,
|
|
258
|
-
enable_spu_device: bool = True,
|
|
259
|
-
) -> ClusterSpec:
|
|
260
|
-
"""Convenience constructor used heavily in tests.
|
|
261
|
-
|
|
262
|
-
Parameters
|
|
263
|
-
----------
|
|
264
|
-
world_size:
|
|
265
|
-
Number of parties (physical nodes).
|
|
266
|
-
endpoints:
|
|
267
|
-
Optional explicit endpoint list of length ``world_size``. Each element may
|
|
268
|
-
include scheme (``http://``) or not; stored verbatim. If not provided we
|
|
269
|
-
synthesize ``localhost:{5000 + i}`` (5000 is a fixed default; pass explicit
|
|
270
|
-
endpoints for control).
|
|
271
|
-
spu_protocol / spu_field:
|
|
272
|
-
SPU device config values.
|
|
273
|
-
runtime_version / runtime_platform:
|
|
274
|
-
Populated into each node's ``RuntimeInfo``.
|
|
275
|
-
op_bindings:
|
|
276
|
-
Optional list of length ``world_size`` supplying per-node op_bindings
|
|
277
|
-
override dicts (defaults to empty dicts).
|
|
278
|
-
enable_ppu_device:
|
|
279
|
-
If True (default), create one ``P{rank}`` PPU device per node.
|
|
280
|
-
enable_spu_device:
|
|
281
|
-
If True (default) create a shared SPU device named ``SP0``.
|
|
282
|
-
"""
|
|
283
|
-
base_port = 5000
|
|
284
|
-
|
|
285
|
-
if endpoints is not None and len(endpoints) != world_size:
|
|
286
|
-
raise ValueError(
|
|
287
|
-
"len(endpoints) must equal world_size when provided: "
|
|
288
|
-
f"{len(endpoints)} != {world_size}"
|
|
289
|
-
)
|
|
290
|
-
|
|
291
|
-
if op_bindings is not None and len(op_bindings) != world_size:
|
|
292
|
-
raise ValueError(
|
|
293
|
-
"len(op_bindings) must equal world_size when provided: "
|
|
294
|
-
f"{len(op_bindings)} != {world_size}"
|
|
295
|
-
)
|
|
296
|
-
|
|
297
|
-
if not enable_ppu_device and not enable_spu_device:
|
|
298
|
-
raise ValueError(
|
|
299
|
-
"At least one of enable_ppu_device or enable_spu_device must be True"
|
|
300
|
-
)
|
|
301
|
-
|
|
302
|
-
nodes: dict[str, Node] = {}
|
|
303
|
-
for i in range(world_size):
|
|
304
|
-
ep = endpoints[i] if endpoints is not None else f"localhost:{base_port + i}"
|
|
305
|
-
node_op_bindings = op_bindings[i] if op_bindings is not None else {}
|
|
306
|
-
nodes[f"node{i}"] = Node(
|
|
307
|
-
name=f"node{i}",
|
|
308
|
-
rank=i,
|
|
309
|
-
endpoint=ep,
|
|
310
|
-
runtime_info=RuntimeInfo(
|
|
311
|
-
version=runtime_version,
|
|
312
|
-
platform=runtime_platform,
|
|
313
|
-
op_bindings=node_op_bindings,
|
|
314
|
-
),
|
|
315
|
-
)
|
|
316
|
-
|
|
317
|
-
devices: dict[str, Device] = {}
|
|
318
|
-
# Optional per-node PPU devices
|
|
319
|
-
if enable_ppu_device:
|
|
320
|
-
for i in range(world_size):
|
|
321
|
-
devices[f"P{i}"] = Device(
|
|
322
|
-
name=f"P{i}",
|
|
323
|
-
kind="ppu",
|
|
324
|
-
members=[nodes[f"node{i}"]],
|
|
325
|
-
)
|
|
326
|
-
|
|
327
|
-
# Shared SPU device
|
|
328
|
-
if enable_spu_device:
|
|
329
|
-
if spu_world_size is None:
|
|
330
|
-
spu_world_size = world_size
|
|
331
|
-
spu_members = [nodes[f"node{i}"] for i in range(spu_world_size)]
|
|
332
|
-
|
|
333
|
-
devices["SP0"] = Device(
|
|
334
|
-
name="SP0",
|
|
335
|
-
kind="SPU",
|
|
336
|
-
members=spu_members,
|
|
337
|
-
config={
|
|
338
|
-
"protocol": spu_protocol,
|
|
339
|
-
"field": spu_field,
|
|
340
|
-
},
|
|
341
|
-
)
|
|
342
|
-
|
|
343
|
-
return cls(nodes=nodes, devices=devices)
|