mplang-nightly 0.1.dev285__py3-none-any.whl → 0.1.dev287__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mplang/backends/simp_worker/collective_algorithms.py +228 -0
- mplang/backends/simp_worker/collectives.py +275 -0
- mplang/backends/simp_worker/ops.py +5 -2
- mplang/edsl/typing.py +25 -9
- {mplang_nightly-0.1.dev285.dist-info → mplang_nightly-0.1.dev287.dist-info}/METADATA +1 -1
- {mplang_nightly-0.1.dev285.dist-info → mplang_nightly-0.1.dev287.dist-info}/RECORD +9 -7
- {mplang_nightly-0.1.dev285.dist-info → mplang_nightly-0.1.dev287.dist-info}/WHEEL +0 -0
- {mplang_nightly-0.1.dev285.dist-info → mplang_nightly-0.1.dev287.dist-info}/entry_points.txt +0 -0
- {mplang_nightly-0.1.dev285.dist-info → mplang_nightly-0.1.dev287.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,228 @@
|
|
|
1
|
+
# Copyright 2026 Ant Group Co., Ltd.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
"""Collective communication algorithms (communicator-only).
|
|
16
|
+
|
|
17
|
+
This module contains *pure* collective algorithms implemented only in terms of
|
|
18
|
+
(a) a communicator and (b) an explicit participant set.
|
|
19
|
+
|
|
20
|
+
It intentionally does NOT depend on:
|
|
21
|
+
- Interpreter execution IDs / graph keys
|
|
22
|
+
- SimpWorker current_parties
|
|
23
|
+
- Operation objects
|
|
24
|
+
|
|
25
|
+
Callers are expected to provide a collision-free key prefix for each collective
|
|
26
|
+
instance.
|
|
27
|
+
"""
|
|
28
|
+
|
|
29
|
+
from __future__ import annotations
|
|
30
|
+
|
|
31
|
+
import operator
|
|
32
|
+
from collections.abc import Callable, Sequence
|
|
33
|
+
from typing import Any, Protocol
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
class Communicator(Protocol):
|
|
37
|
+
"""Minimal communicator interface required by the algorithms."""
|
|
38
|
+
|
|
39
|
+
rank: int
|
|
40
|
+
world_size: int
|
|
41
|
+
|
|
42
|
+
def send(
|
|
43
|
+
self, to: int, key: str, data: Any, *, is_raw_bytes: bool = False
|
|
44
|
+
) -> None: ...
|
|
45
|
+
|
|
46
|
+
def recv(self, frm: int, key: str) -> Any: ...
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
def normalize_participants(
|
|
50
|
+
comm: Communicator, participants: Sequence[int]
|
|
51
|
+
) -> tuple[int, ...]:
|
|
52
|
+
ps = tuple(sorted({int(r) for r in participants}))
|
|
53
|
+
if not ps:
|
|
54
|
+
raise ValueError("participants must be non-empty")
|
|
55
|
+
if any(r < 0 or r >= comm.world_size for r in ps):
|
|
56
|
+
raise ValueError(
|
|
57
|
+
f"participants out of range: {ps}, world_size={comm.world_size}"
|
|
58
|
+
)
|
|
59
|
+
if comm.rank not in ps:
|
|
60
|
+
raise ValueError(f"rank {comm.rank} is not in participants {ps}")
|
|
61
|
+
return ps
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
def barrier(
|
|
65
|
+
comm: Communicator, *, participants: Sequence[int], key_prefix: str
|
|
66
|
+
) -> None:
|
|
67
|
+
"""Barrier using root gather + root release."""
|
|
68
|
+
|
|
69
|
+
ps = normalize_participants(comm, participants)
|
|
70
|
+
root = ps[0]
|
|
71
|
+
|
|
72
|
+
arrive_key = f"{key_prefix}_arrive"
|
|
73
|
+
release_key = f"{key_prefix}_release"
|
|
74
|
+
|
|
75
|
+
if comm.rank != root:
|
|
76
|
+
comm.send(root, arrive_key, True)
|
|
77
|
+
comm.recv(root, release_key)
|
|
78
|
+
return
|
|
79
|
+
|
|
80
|
+
for r in ps:
|
|
81
|
+
if r == root:
|
|
82
|
+
continue
|
|
83
|
+
_ = comm.recv(r, arrive_key)
|
|
84
|
+
|
|
85
|
+
for r in ps:
|
|
86
|
+
if r == root:
|
|
87
|
+
continue
|
|
88
|
+
comm.send(r, release_key, True)
|
|
89
|
+
|
|
90
|
+
|
|
91
|
+
def broadcast(
|
|
92
|
+
comm: Communicator,
|
|
93
|
+
value: Any,
|
|
94
|
+
*,
|
|
95
|
+
root: int,
|
|
96
|
+
participants: Sequence[int],
|
|
97
|
+
key_prefix: str,
|
|
98
|
+
) -> Any:
|
|
99
|
+
"""Broadcast a value from root to all participants."""
|
|
100
|
+
|
|
101
|
+
ps = normalize_participants(comm, participants)
|
|
102
|
+
if root not in ps:
|
|
103
|
+
raise ValueError(f"root {root} must be in participants {ps}")
|
|
104
|
+
|
|
105
|
+
bcast_key = f"{key_prefix}_bcast"
|
|
106
|
+
|
|
107
|
+
if comm.rank == root:
|
|
108
|
+
for r in ps:
|
|
109
|
+
if r == root:
|
|
110
|
+
continue
|
|
111
|
+
comm.send(r, bcast_key, value)
|
|
112
|
+
return value
|
|
113
|
+
|
|
114
|
+
return comm.recv(root, bcast_key)
|
|
115
|
+
|
|
116
|
+
|
|
117
|
+
def allgather(
|
|
118
|
+
comm: Communicator, value: Any, *, participants: Sequence[int], key_prefix: str
|
|
119
|
+
) -> list[Any]:
|
|
120
|
+
"""Allgather implemented as gather-to-root then root broadcast."""
|
|
121
|
+
|
|
122
|
+
ps = normalize_participants(comm, participants)
|
|
123
|
+
root = ps[0]
|
|
124
|
+
|
|
125
|
+
gather_key = f"{key_prefix}_gather"
|
|
126
|
+
bcast_key = f"{key_prefix}_bcast"
|
|
127
|
+
|
|
128
|
+
if comm.rank != root:
|
|
129
|
+
comm.send(root, gather_key, value)
|
|
130
|
+
gathered = comm.recv(root, bcast_key)
|
|
131
|
+
if not isinstance(gathered, list):
|
|
132
|
+
raise TypeError(f"expected list from root broadcast, got {type(gathered)}")
|
|
133
|
+
return gathered
|
|
134
|
+
|
|
135
|
+
values_by_rank: dict[int, Any] = {root: value}
|
|
136
|
+
for r in ps:
|
|
137
|
+
if r == root:
|
|
138
|
+
continue
|
|
139
|
+
values_by_rank[r] = comm.recv(r, gather_key)
|
|
140
|
+
|
|
141
|
+
gathered = [values_by_rank[r] for r in ps]
|
|
142
|
+
|
|
143
|
+
for r in ps:
|
|
144
|
+
if r == root:
|
|
145
|
+
continue
|
|
146
|
+
comm.send(r, bcast_key, gathered)
|
|
147
|
+
|
|
148
|
+
return gathered
|
|
149
|
+
|
|
150
|
+
|
|
151
|
+
def allreduce_bool_and(
|
|
152
|
+
comm: Communicator,
|
|
153
|
+
value: bool,
|
|
154
|
+
*,
|
|
155
|
+
participants: Sequence[int],
|
|
156
|
+
key_prefix: str,
|
|
157
|
+
) -> bool:
|
|
158
|
+
return _allreduce_bool(
|
|
159
|
+
comm,
|
|
160
|
+
value,
|
|
161
|
+
participants=participants,
|
|
162
|
+
key_prefix=key_prefix,
|
|
163
|
+
combine=operator.and_,
|
|
164
|
+
)
|
|
165
|
+
|
|
166
|
+
|
|
167
|
+
def allreduce_bool_or(
|
|
168
|
+
comm: Communicator,
|
|
169
|
+
value: bool,
|
|
170
|
+
*,
|
|
171
|
+
participants: Sequence[int],
|
|
172
|
+
key_prefix: str,
|
|
173
|
+
) -> bool:
|
|
174
|
+
return _allreduce_bool(
|
|
175
|
+
comm,
|
|
176
|
+
value,
|
|
177
|
+
participants=participants,
|
|
178
|
+
key_prefix=key_prefix,
|
|
179
|
+
combine=operator.or_,
|
|
180
|
+
)
|
|
181
|
+
|
|
182
|
+
|
|
183
|
+
def allreduce_bool_xor(
|
|
184
|
+
comm: Communicator,
|
|
185
|
+
value: bool,
|
|
186
|
+
*,
|
|
187
|
+
participants: Sequence[int],
|
|
188
|
+
key_prefix: str,
|
|
189
|
+
) -> bool:
|
|
190
|
+
return _allreduce_bool(
|
|
191
|
+
comm,
|
|
192
|
+
value,
|
|
193
|
+
participants=participants,
|
|
194
|
+
key_prefix=key_prefix,
|
|
195
|
+
combine=operator.xor,
|
|
196
|
+
)
|
|
197
|
+
|
|
198
|
+
|
|
199
|
+
def _allreduce_bool(
|
|
200
|
+
comm: Communicator,
|
|
201
|
+
value: bool,
|
|
202
|
+
*,
|
|
203
|
+
participants: Sequence[int],
|
|
204
|
+
key_prefix: str,
|
|
205
|
+
combine: Callable[[bool, bool], bool],
|
|
206
|
+
) -> bool:
|
|
207
|
+
ps = normalize_participants(comm, participants)
|
|
208
|
+
root = ps[0]
|
|
209
|
+
|
|
210
|
+
gather_key = f"{key_prefix}_gather"
|
|
211
|
+
bcast_key = f"{key_prefix}_bcast"
|
|
212
|
+
|
|
213
|
+
if comm.rank != root:
|
|
214
|
+
comm.send(root, gather_key, bool(value))
|
|
215
|
+
return bool(comm.recv(root, bcast_key))
|
|
216
|
+
|
|
217
|
+
acc = bool(value)
|
|
218
|
+
for r in ps:
|
|
219
|
+
if r == root:
|
|
220
|
+
continue
|
|
221
|
+
acc = combine(acc, bool(comm.recv(r, gather_key)))
|
|
222
|
+
|
|
223
|
+
for r in ps:
|
|
224
|
+
if r == root:
|
|
225
|
+
continue
|
|
226
|
+
comm.send(r, bcast_key, acc)
|
|
227
|
+
|
|
228
|
+
return acc
|
|
@@ -0,0 +1,275 @@
|
|
|
1
|
+
# Copyright 2026 Ant Group Co., Ltd.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
"""Simp worker-side collectives (wrapper layer).
|
|
16
|
+
|
|
17
|
+
This module is the *context-aware wrapper* on top of
|
|
18
|
+
`mplang.backends.simp_worker.collective_algorithms`.
|
|
19
|
+
|
|
20
|
+
Responsibilities here:
|
|
21
|
+
- Resolve "participants" from (explicit arg / op.attrs["parties"] /
|
|
22
|
+
worker.current_parties / world).
|
|
23
|
+
- Build collision-free `key_prefix` using interpreter execution IDs.
|
|
24
|
+
|
|
25
|
+
The underlying algorithms only depend on the communicator interface.
|
|
26
|
+
"""
|
|
27
|
+
|
|
28
|
+
from __future__ import annotations
|
|
29
|
+
|
|
30
|
+
from collections.abc import Sequence
|
|
31
|
+
from typing import Any, Protocol
|
|
32
|
+
|
|
33
|
+
from mplang.backends.simp_worker import collective_algorithms as algo
|
|
34
|
+
from mplang.edsl.graph import Operation
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
class _ExecContext(Protocol):
|
|
38
|
+
def current_op_exec_id(self) -> int: ...
|
|
39
|
+
|
|
40
|
+
def current_graph_exec_key(self) -> str: ...
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
class _Worker(Protocol):
|
|
44
|
+
rank: int
|
|
45
|
+
world_size: int
|
|
46
|
+
communicator: algo.Communicator
|
|
47
|
+
current_parties: tuple[int, ...] | None
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
def resolve_participants(
|
|
51
|
+
worker: _Worker,
|
|
52
|
+
*,
|
|
53
|
+
op: Operation | None = None,
|
|
54
|
+
participants: Sequence[int] | None = None,
|
|
55
|
+
) -> Sequence[int]:
|
|
56
|
+
"""Resolve participant ranks.
|
|
57
|
+
|
|
58
|
+
Priority:
|
|
59
|
+
1) explicit participants argument
|
|
60
|
+
2) op.attrs["parties"] if present
|
|
61
|
+
3) worker.current_parties if set (pcall_static dynamic scope)
|
|
62
|
+
4) all ranks [0, world_size)
|
|
63
|
+
|
|
64
|
+
Note:
|
|
65
|
+
Normalization/validation (sorting, emptiness, range checks, rank
|
|
66
|
+
inclusion) is intentionally delegated to the lower-level algorithms in
|
|
67
|
+
`collective_algorithms`.
|
|
68
|
+
"""
|
|
69
|
+
|
|
70
|
+
if participants is not None:
|
|
71
|
+
return participants
|
|
72
|
+
|
|
73
|
+
if op is not None:
|
|
74
|
+
parties = op.attrs.get("parties")
|
|
75
|
+
if parties is not None:
|
|
76
|
+
if not isinstance(parties, Sequence):
|
|
77
|
+
raise TypeError(
|
|
78
|
+
"op.attrs['parties'] must be a sequence of rank integers"
|
|
79
|
+
)
|
|
80
|
+
return tuple(int(r) for r in parties)
|
|
81
|
+
|
|
82
|
+
if worker.current_parties is not None:
|
|
83
|
+
return worker.current_parties
|
|
84
|
+
|
|
85
|
+
return tuple(range(worker.world_size))
|
|
86
|
+
|
|
87
|
+
|
|
88
|
+
def _collective_prefix(
|
|
89
|
+
interpreter: _ExecContext, *, op: Operation | None, name: str
|
|
90
|
+
) -> str:
|
|
91
|
+
exec_id = interpreter.current_op_exec_id()
|
|
92
|
+
graph_key = interpreter.current_graph_exec_key()
|
|
93
|
+
op_name = op.name if op is not None else "_"
|
|
94
|
+
return f"coll_{graph_key}_{op_name}_{exec_id}_{name}"
|
|
95
|
+
|
|
96
|
+
|
|
97
|
+
def barrier(
|
|
98
|
+
interpreter: _ExecContext,
|
|
99
|
+
worker: _Worker,
|
|
100
|
+
*,
|
|
101
|
+
op: Operation | None = None,
|
|
102
|
+
participants: Sequence[int] | None = None,
|
|
103
|
+
name: str = "barrier",
|
|
104
|
+
) -> None:
|
|
105
|
+
ps = resolve_participants(worker, op=op, participants=participants)
|
|
106
|
+
prefix = _collective_prefix(interpreter, op=op, name=name)
|
|
107
|
+
algo.barrier(worker.communicator, participants=ps, key_prefix=prefix)
|
|
108
|
+
|
|
109
|
+
|
|
110
|
+
def broadcast(
|
|
111
|
+
interpreter: _ExecContext,
|
|
112
|
+
worker: _Worker,
|
|
113
|
+
value: Any,
|
|
114
|
+
*,
|
|
115
|
+
root: int,
|
|
116
|
+
op: Operation | None = None,
|
|
117
|
+
participants: Sequence[int] | None = None,
|
|
118
|
+
name: str = "broadcast",
|
|
119
|
+
) -> Any:
|
|
120
|
+
ps = resolve_participants(worker, op=op, participants=participants)
|
|
121
|
+
prefix = _collective_prefix(interpreter, op=op, name=name)
|
|
122
|
+
return algo.broadcast(
|
|
123
|
+
worker.communicator,
|
|
124
|
+
value,
|
|
125
|
+
root=int(root),
|
|
126
|
+
participants=ps,
|
|
127
|
+
key_prefix=prefix,
|
|
128
|
+
)
|
|
129
|
+
|
|
130
|
+
|
|
131
|
+
def allgather_obj(
|
|
132
|
+
interpreter: _ExecContext,
|
|
133
|
+
worker: _Worker,
|
|
134
|
+
value: Any,
|
|
135
|
+
*,
|
|
136
|
+
op: Operation | None = None,
|
|
137
|
+
participants: Sequence[int] | None = None,
|
|
138
|
+
name: str = "allgather_obj",
|
|
139
|
+
) -> list[Any]:
|
|
140
|
+
ps = resolve_participants(worker, op=op, participants=participants)
|
|
141
|
+
prefix = _collective_prefix(interpreter, op=op, name=name)
|
|
142
|
+
return algo.allgather(
|
|
143
|
+
worker.communicator, value, participants=ps, key_prefix=prefix
|
|
144
|
+
)
|
|
145
|
+
|
|
146
|
+
|
|
147
|
+
def allgather_bool(
|
|
148
|
+
interpreter: _ExecContext,
|
|
149
|
+
worker: _Worker,
|
|
150
|
+
value: bool,
|
|
151
|
+
*,
|
|
152
|
+
op: Operation | None = None,
|
|
153
|
+
participants: Sequence[int] | None = None,
|
|
154
|
+
name: str = "allgather_bool",
|
|
155
|
+
) -> list[bool]:
|
|
156
|
+
gathered = allgather_obj(
|
|
157
|
+
interpreter,
|
|
158
|
+
worker,
|
|
159
|
+
bool(value),
|
|
160
|
+
op=op,
|
|
161
|
+
participants=participants,
|
|
162
|
+
name=name,
|
|
163
|
+
)
|
|
164
|
+
return [bool(v) for v in gathered]
|
|
165
|
+
|
|
166
|
+
|
|
167
|
+
def allreduce_bool_and(
|
|
168
|
+
interpreter: _ExecContext,
|
|
169
|
+
worker: _Worker,
|
|
170
|
+
value: bool,
|
|
171
|
+
*,
|
|
172
|
+
op: Operation | None = None,
|
|
173
|
+
participants: Sequence[int] | None = None,
|
|
174
|
+
name: str = "allreduce_bool_and",
|
|
175
|
+
) -> bool:
|
|
176
|
+
ps = resolve_participants(worker, op=op, participants=participants)
|
|
177
|
+
prefix = _collective_prefix(interpreter, op=op, name=name)
|
|
178
|
+
return algo.allreduce_bool_and(
|
|
179
|
+
worker.communicator,
|
|
180
|
+
bool(value),
|
|
181
|
+
participants=ps,
|
|
182
|
+
key_prefix=prefix,
|
|
183
|
+
)
|
|
184
|
+
|
|
185
|
+
|
|
186
|
+
def allreduce_bool_or(
|
|
187
|
+
interpreter: _ExecContext,
|
|
188
|
+
worker: _Worker,
|
|
189
|
+
value: bool,
|
|
190
|
+
*,
|
|
191
|
+
op: Operation | None = None,
|
|
192
|
+
participants: Sequence[int] | None = None,
|
|
193
|
+
name: str = "allreduce_bool_or",
|
|
194
|
+
) -> bool:
|
|
195
|
+
ps = resolve_participants(worker, op=op, participants=participants)
|
|
196
|
+
prefix = _collective_prefix(interpreter, op=op, name=name)
|
|
197
|
+
return algo.allreduce_bool_or(
|
|
198
|
+
worker.communicator,
|
|
199
|
+
bool(value),
|
|
200
|
+
participants=ps,
|
|
201
|
+
key_prefix=prefix,
|
|
202
|
+
)
|
|
203
|
+
|
|
204
|
+
|
|
205
|
+
def allreduce_bool_xor(
|
|
206
|
+
interpreter: _ExecContext,
|
|
207
|
+
worker: _Worker,
|
|
208
|
+
value: bool,
|
|
209
|
+
*,
|
|
210
|
+
op: Operation | None = None,
|
|
211
|
+
participants: Sequence[int] | None = None,
|
|
212
|
+
name: str = "allreduce_bool_xor",
|
|
213
|
+
) -> bool:
|
|
214
|
+
ps = resolve_participants(worker, op=op, participants=participants)
|
|
215
|
+
prefix = _collective_prefix(interpreter, op=op, name=name)
|
|
216
|
+
return algo.allreduce_bool_xor(
|
|
217
|
+
worker.communicator,
|
|
218
|
+
bool(value),
|
|
219
|
+
participants=ps,
|
|
220
|
+
key_prefix=prefix,
|
|
221
|
+
)
|
|
222
|
+
|
|
223
|
+
|
|
224
|
+
def verify_uniform_predicate(
|
|
225
|
+
interpreter: _ExecContext,
|
|
226
|
+
worker: _Worker,
|
|
227
|
+
pred: bool,
|
|
228
|
+
*,
|
|
229
|
+
op: Operation | None = None,
|
|
230
|
+
participants: Sequence[int] | None = None,
|
|
231
|
+
name: str = "uniform_predicate",
|
|
232
|
+
) -> bool:
|
|
233
|
+
"""Verify that `pred` is uniform across participants.
|
|
234
|
+
|
|
235
|
+
Uses AND/OR all-reduce to detect mismatch. If mismatch is detected, runs an
|
|
236
|
+
allgather to provide a helpful error message. All participants execute the
|
|
237
|
+
same comm steps to avoid deadlocks.
|
|
238
|
+
"""
|
|
239
|
+
|
|
240
|
+
ps = resolve_participants(worker, op=op, participants=participants)
|
|
241
|
+
|
|
242
|
+
all_and = allreduce_bool_and(
|
|
243
|
+
interpreter,
|
|
244
|
+
worker,
|
|
245
|
+
bool(pred),
|
|
246
|
+
op=op,
|
|
247
|
+
participants=ps,
|
|
248
|
+
name=f"{name}_and",
|
|
249
|
+
)
|
|
250
|
+
all_or = allreduce_bool_or(
|
|
251
|
+
interpreter,
|
|
252
|
+
worker,
|
|
253
|
+
bool(pred),
|
|
254
|
+
op=op,
|
|
255
|
+
participants=ps,
|
|
256
|
+
name=f"{name}_or",
|
|
257
|
+
)
|
|
258
|
+
|
|
259
|
+
if all_and != all_or:
|
|
260
|
+
gathered = allgather_bool(
|
|
261
|
+
interpreter,
|
|
262
|
+
worker,
|
|
263
|
+
bool(pred),
|
|
264
|
+
op=op,
|
|
265
|
+
participants=ps,
|
|
266
|
+
name=f"{name}_gather",
|
|
267
|
+
)
|
|
268
|
+
ps_norm = algo.normalize_participants(worker.communicator, ps)
|
|
269
|
+
dist = dict(zip(ps_norm, gathered, strict=True))
|
|
270
|
+
raise RuntimeError(
|
|
271
|
+
"simp.uniform_cond predicate is not uniform across participants: "
|
|
272
|
+
f"participants={ps_norm}, values={dist}"
|
|
273
|
+
)
|
|
274
|
+
|
|
275
|
+
return bool(pred)
|
|
@@ -22,6 +22,7 @@ from __future__ import annotations
|
|
|
22
22
|
|
|
23
23
|
from typing import Any
|
|
24
24
|
|
|
25
|
+
from mplang.backends.simp_worker.collectives import verify_uniform_predicate
|
|
25
26
|
from mplang.dialects import simp
|
|
26
27
|
from mplang.edsl.graph import Operation
|
|
27
28
|
from mplang.runtime.interpreter import Interpreter
|
|
@@ -117,12 +118,14 @@ def _uniform_cond_worker_impl(
|
|
|
117
118
|
"""Worker implementation of simp.uniform_cond."""
|
|
118
119
|
from mplang.backends.tensor_impl import TensorValue
|
|
119
120
|
|
|
120
|
-
|
|
121
|
-
pass # TODO: Implement AllReduce verification
|
|
121
|
+
worker = _ensure_worker_context(interpreter, "uniform_cond_impl")
|
|
122
122
|
|
|
123
123
|
if isinstance(pred, TensorValue):
|
|
124
124
|
pred = bool(pred.unwrap())
|
|
125
125
|
|
|
126
|
+
if op.attrs.get("verify_uniform", True):
|
|
127
|
+
pred = verify_uniform_predicate(interpreter, worker, bool(pred), op=op)
|
|
128
|
+
|
|
126
129
|
if pred:
|
|
127
130
|
result = interpreter.evaluate_graph(op.regions[0], list(args))
|
|
128
131
|
else:
|
mplang/edsl/typing.py
CHANGED
|
@@ -82,11 +82,18 @@ underlying HE libraries.
|
|
|
82
82
|
- **Core Type**: `Tensor[Scalar, ...]`
|
|
83
83
|
- **API Standard**: Follows NumPy/JAX conventions. All layout and arithmetic operations are valid.
|
|
84
84
|
|
|
85
|
+
- **World 2: The Element-wise HE World**
|
|
85
86
|
- **Core Type**: `Tensor[EncryptedScalar, ...]` (e.g., `Tensor[phe.CiphertextType, ...]`)
|
|
86
87
|
- **API Standard**: Follows TenSEAL-like (Tensor-level) conventions. Layout operations
|
|
87
88
|
(`transpose`, `reshape`) are valid as they merely shuffle independent ciphertext objects.
|
|
88
89
|
Arithmetic operations are overloaded for element-wise HE computation.
|
|
89
90
|
|
|
91
|
+
- **World 3: The SIMD HE World**
|
|
92
|
+
- **Core Type**: `Tensor[Vector[...], ...]` where the inner `Vector` holds SIMD-encrypted slots
|
|
93
|
+
- **API Standard**: SIMD-HE specific (e.g., BFV/CKKS). Only specific operations are supported
|
|
94
|
+
due to the batched nature of SIMD encryption. Layout operations must account for slot packing,
|
|
95
|
+
and arithmetic operates on batched ciphertexts with slot-wise semantics.
|
|
96
|
+
|
|
90
97
|
===========================
|
|
91
98
|
Principle 3: Contracts via Protocols
|
|
92
99
|
===========================
|
|
@@ -137,7 +144,15 @@ class BaseType:
|
|
|
137
144
|
"""Base class for all MPLang types."""
|
|
138
145
|
|
|
139
146
|
def __repr__(self) -> str:
|
|
140
|
-
|
|
147
|
+
# Prevent infinite recursion: only call __str__ if it's overridden
|
|
148
|
+
if type(self).__str__ is not BaseType.__str__:
|
|
149
|
+
return str(self)
|
|
150
|
+
# Fallback to default object repr if __str__ not implemented
|
|
151
|
+
return object.__repr__(self)
|
|
152
|
+
|
|
153
|
+
def __str__(self) -> str:
|
|
154
|
+
# Default implementation for subclasses that don't override
|
|
155
|
+
return f"{self.__class__.__name__}()"
|
|
141
156
|
|
|
142
157
|
|
|
143
158
|
# ==============================================================================
|
|
@@ -175,11 +190,11 @@ class ScalarType(BaseType):
|
|
|
175
190
|
|
|
176
191
|
@serde.register_class
|
|
177
192
|
class IntegerType(ScalarType):
|
|
178
|
-
"""Represents a
|
|
193
|
+
"""Represents a fixed-width integer type with configurable bitwidth.
|
|
179
194
|
|
|
180
|
-
This is a standard integer type with
|
|
181
|
-
arbitrary-precision arithmetic.
|
|
182
|
-
the range of
|
|
195
|
+
This is a standard integer type with parameterized bit width, used for
|
|
196
|
+
arbitrary-precision arithmetic. By configuring larger bitwidths (e.g., 128, 256),
|
|
197
|
+
this type can represent integers that exceed the range of standard types like i64.
|
|
183
198
|
|
|
184
199
|
Examples:
|
|
185
200
|
>>> i128 = IntegerType(bitwidth=128, signed=True) # i128
|
|
@@ -229,15 +244,16 @@ class IntegerType(ScalarType):
|
|
|
229
244
|
|
|
230
245
|
@serde.register_class
|
|
231
246
|
class FloatType(ScalarType):
|
|
232
|
-
"""Represents a floating-point type.
|
|
247
|
+
"""Represents a fixed-width floating-point type with configurable bitwidth.
|
|
233
248
|
|
|
234
|
-
This supports standard IEEE 754 floating-point types with
|
|
235
|
-
precision
|
|
249
|
+
This supports standard IEEE 754 floating-point types with parameterized
|
|
250
|
+
bit width for different precision requirements.
|
|
236
251
|
|
|
237
252
|
Examples:
|
|
238
253
|
>>> f16 = FloatType(bitwidth=16) # half precision
|
|
239
254
|
>>> f32 = FloatType(bitwidth=32) # single precision
|
|
240
255
|
>>> f64 = FloatType(bitwidth=64) # double precision
|
|
256
|
+
>>> f128 = FloatType(bitwidth=128) # quadruple precision
|
|
241
257
|
"""
|
|
242
258
|
|
|
243
259
|
def __init__(self, *, bitwidth: int = 32):
|
|
@@ -245,7 +261,7 @@ class FloatType(ScalarType):
|
|
|
245
261
|
|
|
246
262
|
Args:
|
|
247
263
|
bitwidth: Number of bits for the float representation.
|
|
248
|
-
Standard values: 16 (half), 32 (single), 64 (double).
|
|
264
|
+
Standard values: 16 (half), 32 (single), 64 (double), 128 (quadruple).
|
|
249
265
|
"""
|
|
250
266
|
if bitwidth not in (16, 32, 64, 128):
|
|
251
267
|
raise ValueError(f"bitwidth must be 16, 32, 64, or 128, got {bitwidth}")
|
|
@@ -23,9 +23,11 @@ mplang/backends/simp_driver/ops.py,sha256=WYObWDRCsiXH0UBWZX5vD5W98ZPkd88U_qBV8S
|
|
|
23
23
|
mplang/backends/simp_driver/state.py,sha256=dNmYMFN2D2BBdgs6C0YLaHrfaBRMgs05UNxMWw6tZIs,1713
|
|
24
24
|
mplang/backends/simp_driver/values.py,sha256=Lz1utNSIzH-dCzZAEjU6JRcxPsfKGfUJrYl6gIuMOGw,1509
|
|
25
25
|
mplang/backends/simp_worker/__init__.py,sha256=gdrSY1-MDkupCoJ8xwwH7em7fgVWv3J4gBJ45uHdzgg,961
|
|
26
|
+
mplang/backends/simp_worker/collective_algorithms.py,sha256=tjqQrXRpV71g7uKBGlITLM29VjsnYfBr7MBfx1uKzvM,5707
|
|
27
|
+
mplang/backends/simp_worker/collectives.py,sha256=Jzo-eU4QXM8zXfbUdb0xvghh6LWDSfUsxglSYi3oGY8,7688
|
|
26
28
|
mplang/backends/simp_worker/http.py,sha256=UiqXKcH5DZG_eTXcW7qDSjIsBWtA37tkqBEowJJ0hQE,13186
|
|
27
29
|
mplang/backends/simp_worker/mem.py,sha256=hLJftzSMQIw64vjBdrSrFReFzWS6rlzf-0Q7SWkfhak,3605
|
|
28
|
-
mplang/backends/simp_worker/ops.py,sha256=
|
|
30
|
+
mplang/backends/simp_worker/ops.py,sha256=qMfXH9lUc9nq9paWa-53XgDZQcL29nDF6Qt9XwKouQ0,5834
|
|
29
31
|
mplang/backends/simp_worker/state.py,sha256=nIu0ybvdYqRqp0TkoSneUF2u31evDHucCRduVBaDals,1445
|
|
30
32
|
mplang/dialects/__init__.py,sha256=CYMmkeQVU0Znr9n3_5clZKb16u7acJ5jl5Zjbx4Tn1U,1478
|
|
31
33
|
mplang/dialects/bfv.py,sha256=m5YfobFCBqn0lg2zBM9RNs2AC7i4PUQH2qXjHLHwSy4,22332
|
|
@@ -52,7 +54,7 @@ mplang/edsl/program.py,sha256=_JdEU2-nb79VlFLcgMJf4JS30TARBeUIzno0y0SFVsg,4467
|
|
|
52
54
|
mplang/edsl/registry.py,sha256=hudXZPUrUUueEwgksDKN0cnE3iiXucuTaDdDK8uSPmk,6822
|
|
53
55
|
mplang/edsl/serde.py,sha256=8K94laE8ObeGuBoF6m7g3A-xEe98EvqQ_6ZPPspddAY,11641
|
|
54
56
|
mplang/edsl/tracer.py,sha256=WQFNL2ZgXSLjxD4JA7cXIDUKIQXe3aZ94qer57IKPXc,23128
|
|
55
|
-
mplang/edsl/typing.py,sha256=
|
|
57
|
+
mplang/edsl/typing.py,sha256=23DjgISpJO7ofG1qZeWFJ8hl6Au_OLw4qlgWSOuujh4,30352
|
|
56
58
|
mplang/kernels/Makefile,sha256=5PoPpajcb_8ByPGNHzVytmovXUwkjJs_K8MbXX9qDYs,1033
|
|
57
59
|
mplang/kernels/__init__.py,sha256=J_rDl9lAXd7QL3Nt_P3YX6j9yge7ssguSaHuafPZNKE,876
|
|
58
60
|
mplang/kernels/gf128.cpp,sha256=WIvCr3MijzwJxMi1Wnfhm8aWT8oL0fia6FeyTmFJtPQ,5975
|
|
@@ -99,8 +101,8 @@ mplang/tool/program.py,sha256=W3H8bpPirnoJ4ZrmyPYuMCPadJis20o__n_1MKqCsWU,11058
|
|
|
99
101
|
mplang/utils/__init__.py,sha256=Hwrwti2nfPxWUXV8DN6T1QaqXH_Jsd27k8UMSdBGUns,1073
|
|
100
102
|
mplang/utils/func_utils.py,sha256=aZ-X43w8JKJgiF-IUMS0G7QqrNeoTM5ZPzRNd-tKxpw,5180
|
|
101
103
|
mplang/utils/logging.py,sha256=9dMhwprVbx1WMGJrgoQbWmV50vyYuLU4NSPnetcl1Go,7237
|
|
102
|
-
mplang_nightly-0.1.
|
|
103
|
-
mplang_nightly-0.1.
|
|
104
|
-
mplang_nightly-0.1.
|
|
105
|
-
mplang_nightly-0.1.
|
|
106
|
-
mplang_nightly-0.1.
|
|
104
|
+
mplang_nightly-0.1.dev287.dist-info/METADATA,sha256=pWxy2dwWffgYSrLRrMGLLkScZjVroOEgBI8YDa0YfGQ,16783
|
|
105
|
+
mplang_nightly-0.1.dev287.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
|
|
106
|
+
mplang_nightly-0.1.dev287.dist-info/entry_points.txt,sha256=mG1oJT-GAjQR834a62_QIWb7litzWPPyVnwFqm-rWuY,55
|
|
107
|
+
mplang_nightly-0.1.dev287.dist-info/licenses/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
|
|
108
|
+
mplang_nightly-0.1.dev287.dist-info/RECORD,,
|
|
File without changes
|
{mplang_nightly-0.1.dev285.dist-info → mplang_nightly-0.1.dev287.dist-info}/entry_points.txt
RENAMED
|
File without changes
|
{mplang_nightly-0.1.dev285.dist-info → mplang_nightly-0.1.dev287.dist-info}/licenses/LICENSE
RENAMED
|
File without changes
|