mplang-nightly 0.1.dev158__py3-none-any.whl → 0.1.dev268__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mplang/__init__.py +21 -45
- mplang/py.typed +13 -0
- mplang/v1/__init__.py +157 -0
- mplang/v1/_device.py +602 -0
- mplang/{analysis → v1/analysis}/__init__.py +1 -1
- mplang/{analysis → v1/analysis}/diagram.py +5 -7
- mplang/v1/core/__init__.py +157 -0
- mplang/{core → v1/core}/cluster.py +30 -14
- mplang/{core → v1/core}/comm.py +5 -1
- mplang/{core → v1/core}/context_mgr.py +1 -1
- mplang/{core/dtype.py → v1/core/dtypes.py} +44 -2
- mplang/{core → v1/core}/expr/__init__.py +7 -7
- mplang/{core → v1/core}/expr/ast.py +13 -14
- mplang/{core → v1/core}/expr/evaluator.py +65 -24
- mplang/{core → v1/core}/expr/printer.py +24 -18
- mplang/{core → v1/core}/expr/transformer.py +3 -3
- mplang/{core → v1/core}/expr/utils.py +2 -2
- mplang/{core → v1/core}/expr/visitor.py +1 -1
- mplang/{core → v1/core}/expr/walk.py +1 -1
- mplang/{core → v1/core}/interp.py +6 -6
- mplang/{core → v1/core}/mpir.py +23 -16
- mplang/{core → v1/core}/mpobject.py +6 -6
- mplang/{core → v1/core}/mptype.py +13 -10
- mplang/{core → v1/core}/pfunc.py +4 -4
- mplang/{core → v1/core}/primitive.py +106 -201
- mplang/{core → v1/core}/table.py +36 -8
- mplang/{core → v1/core}/tensor.py +1 -1
- mplang/{core → v1/core}/tracer.py +9 -9
- mplang/{api.py → v1/host.py} +38 -6
- mplang/v1/kernels/__init__.py +41 -0
- mplang/{kernels → v1/kernels}/base.py +1 -1
- mplang/v1/kernels/basic.py +240 -0
- mplang/{kernels → v1/kernels}/context.py +42 -27
- mplang/{kernels → v1/kernels}/crypto.py +44 -37
- mplang/v1/kernels/fhe.py +858 -0
- mplang/{kernels → v1/kernels}/mock_tee.py +12 -13
- mplang/{kernels → v1/kernels}/phe.py +263 -57
- mplang/{kernels → v1/kernels}/spu.py +137 -48
- mplang/{kernels → v1/kernels}/sql_duckdb.py +12 -15
- mplang/{kernels → v1/kernels}/stablehlo.py +30 -23
- mplang/v1/kernels/value.py +626 -0
- mplang/{ops → v1/ops}/__init__.py +5 -16
- mplang/{ops → v1/ops}/base.py +2 -5
- mplang/{ops/builtin.py → v1/ops/basic.py} +34 -26
- mplang/v1/ops/crypto.py +262 -0
- mplang/v1/ops/fhe.py +272 -0
- mplang/{ops → v1/ops}/jax_cc.py +33 -68
- mplang/v1/ops/nnx_cc.py +168 -0
- mplang/{ops → v1/ops}/phe.py +16 -4
- mplang/{ops → v1/ops}/spu.py +3 -5
- mplang/v1/ops/sql_cc.py +303 -0
- mplang/{ops → v1/ops}/tee.py +9 -24
- mplang/{protos → v1/protos}/v1alpha1/mpir_pb2.pyi +71 -21
- mplang/v1/protos/v1alpha1/value_pb2.py +34 -0
- mplang/v1/protos/v1alpha1/value_pb2.pyi +169 -0
- mplang/{runtime → v1/runtime}/__init__.py +2 -2
- mplang/v1/runtime/channel.py +230 -0
- mplang/{runtime → v1/runtime}/cli.py +35 -20
- mplang/{runtime → v1/runtime}/client.py +19 -8
- mplang/{runtime → v1/runtime}/communicator.py +59 -15
- mplang/{runtime → v1/runtime}/data_providers.py +80 -19
- mplang/{runtime → v1/runtime}/driver.py +30 -12
- mplang/v1/runtime/link_comm.py +196 -0
- mplang/{runtime → v1/runtime}/server.py +58 -42
- mplang/{runtime → v1/runtime}/session.py +57 -71
- mplang/{runtime → v1/runtime}/simulation.py +55 -28
- mplang/v1/simp/api.py +353 -0
- mplang/{simp → v1/simp}/mpi.py +8 -9
- mplang/{simp/__init__.py → v1/simp/party.py} +19 -145
- mplang/{simp → v1/simp}/random.py +21 -22
- mplang/v1/simp/smpc.py +238 -0
- mplang/v1/utils/table_utils.py +185 -0
- mplang/v2/__init__.py +424 -0
- mplang/v2/backends/__init__.py +57 -0
- mplang/v2/backends/bfv_impl.py +705 -0
- mplang/v2/backends/channel.py +217 -0
- mplang/v2/backends/crypto_impl.py +723 -0
- mplang/v2/backends/field_impl.py +454 -0
- mplang/v2/backends/func_impl.py +107 -0
- mplang/v2/backends/phe_impl.py +148 -0
- mplang/v2/backends/simp_design.md +136 -0
- mplang/v2/backends/simp_driver/__init__.py +41 -0
- mplang/v2/backends/simp_driver/http.py +168 -0
- mplang/v2/backends/simp_driver/mem.py +280 -0
- mplang/v2/backends/simp_driver/ops.py +135 -0
- mplang/v2/backends/simp_driver/state.py +60 -0
- mplang/v2/backends/simp_driver/values.py +52 -0
- mplang/v2/backends/simp_worker/__init__.py +29 -0
- mplang/v2/backends/simp_worker/http.py +354 -0
- mplang/v2/backends/simp_worker/mem.py +102 -0
- mplang/v2/backends/simp_worker/ops.py +167 -0
- mplang/v2/backends/simp_worker/state.py +49 -0
- mplang/v2/backends/spu_impl.py +275 -0
- mplang/v2/backends/spu_state.py +187 -0
- mplang/v2/backends/store_impl.py +62 -0
- mplang/v2/backends/table_impl.py +838 -0
- mplang/v2/backends/tee_impl.py +215 -0
- mplang/v2/backends/tensor_impl.py +519 -0
- mplang/v2/cli.py +603 -0
- mplang/v2/cli_guide.md +122 -0
- mplang/v2/dialects/__init__.py +36 -0
- mplang/v2/dialects/bfv.py +665 -0
- mplang/v2/dialects/crypto.py +689 -0
- mplang/v2/dialects/dtypes.py +378 -0
- mplang/v2/dialects/field.py +210 -0
- mplang/v2/dialects/func.py +135 -0
- mplang/v2/dialects/phe.py +723 -0
- mplang/v2/dialects/simp.py +944 -0
- mplang/v2/dialects/spu.py +349 -0
- mplang/v2/dialects/store.py +63 -0
- mplang/v2/dialects/table.py +407 -0
- mplang/v2/dialects/tee.py +346 -0
- mplang/v2/dialects/tensor.py +1175 -0
- mplang/v2/edsl/README.md +279 -0
- mplang/v2/edsl/__init__.py +99 -0
- mplang/v2/edsl/context.py +311 -0
- mplang/v2/edsl/graph.py +463 -0
- mplang/v2/edsl/jit.py +62 -0
- mplang/v2/edsl/object.py +53 -0
- mplang/v2/edsl/primitive.py +284 -0
- mplang/v2/edsl/printer.py +119 -0
- mplang/v2/edsl/registry.py +207 -0
- mplang/v2/edsl/serde.py +375 -0
- mplang/v2/edsl/tracer.py +614 -0
- mplang/v2/edsl/typing.py +816 -0
- mplang/v2/kernels/Makefile +30 -0
- mplang/v2/kernels/__init__.py +23 -0
- mplang/v2/kernels/gf128.cpp +148 -0
- mplang/v2/kernels/ldpc.cpp +82 -0
- mplang/v2/kernels/okvs.cpp +283 -0
- mplang/v2/kernels/okvs_opt.cpp +291 -0
- mplang/v2/kernels/py_kernels.py +398 -0
- mplang/v2/libs/collective.py +330 -0
- mplang/v2/libs/device/__init__.py +51 -0
- mplang/v2/libs/device/api.py +813 -0
- mplang/v2/libs/device/cluster.py +352 -0
- mplang/v2/libs/ml/__init__.py +23 -0
- mplang/v2/libs/ml/sgb.py +1861 -0
- mplang/v2/libs/mpc/__init__.py +41 -0
- mplang/v2/libs/mpc/_utils.py +99 -0
- mplang/v2/libs/mpc/analytics/__init__.py +35 -0
- mplang/v2/libs/mpc/analytics/aggregation.py +372 -0
- mplang/v2/libs/mpc/analytics/groupby.md +99 -0
- mplang/v2/libs/mpc/analytics/groupby.py +331 -0
- mplang/v2/libs/mpc/analytics/permutation.py +386 -0
- mplang/v2/libs/mpc/common/constants.py +39 -0
- mplang/v2/libs/mpc/ot/__init__.py +32 -0
- mplang/v2/libs/mpc/ot/base.py +222 -0
- mplang/v2/libs/mpc/ot/extension.py +477 -0
- mplang/v2/libs/mpc/ot/silent.py +217 -0
- mplang/v2/libs/mpc/psi/__init__.py +40 -0
- mplang/v2/libs/mpc/psi/cuckoo.py +228 -0
- mplang/v2/libs/mpc/psi/okvs.py +49 -0
- mplang/v2/libs/mpc/psi/okvs_gct.py +79 -0
- mplang/v2/libs/mpc/psi/oprf.py +310 -0
- mplang/v2/libs/mpc/psi/rr22.py +344 -0
- mplang/v2/libs/mpc/psi/unbalanced.py +200 -0
- mplang/v2/libs/mpc/vole/__init__.py +31 -0
- mplang/v2/libs/mpc/vole/gilboa.py +327 -0
- mplang/v2/libs/mpc/vole/ldpc.py +383 -0
- mplang/v2/libs/mpc/vole/silver.py +336 -0
- mplang/v2/runtime/__init__.py +15 -0
- mplang/v2/runtime/dialect_state.py +41 -0
- mplang/v2/runtime/interpreter.py +871 -0
- mplang/v2/runtime/object_store.py +194 -0
- mplang/v2/runtime/value.py +141 -0
- {mplang_nightly-0.1.dev158.dist-info → mplang_nightly-0.1.dev268.dist-info}/METADATA +24 -17
- mplang_nightly-0.1.dev268.dist-info/RECORD +180 -0
- {mplang_nightly-0.1.dev158.dist-info → mplang_nightly-0.1.dev268.dist-info}/WHEEL +1 -1
- mplang/core/__init__.py +0 -92
- mplang/device.py +0 -340
- mplang/kernels/builtin.py +0 -207
- mplang/ops/crypto.py +0 -109
- mplang/ops/ibis_cc.py +0 -139
- mplang/ops/sql.py +0 -61
- mplang/protos/v1alpha1/mpir_pb2_grpc.py +0 -3
- mplang/runtime/link_comm.py +0 -131
- mplang/simp/smpc.py +0 -201
- mplang/utils/table_utils.py +0 -73
- mplang_nightly-0.1.dev158.dist-info/RECORD +0 -77
- /mplang/{core → v1/core}/mask.py +0 -0
- /mplang/{protos → v1/protos}/v1alpha1/mpir_pb2.py +0 -0
- /mplang/{runtime → v1/runtime}/exceptions.py +0 -0
- /mplang/{runtime → v1/runtime}/http_api.md +0 -0
- /mplang/{kernels → v1/simp}/__init__.py +0 -0
- /mplang/{utils → v1/utils}/__init__.py +0 -0
- /mplang/{utils → v1/utils}/crypto.py +0 -0
- /mplang/{utils → v1/utils}/func_utils.py +0 -0
- /mplang/{utils → v1/utils}/spu_utils.py +0 -0
- {mplang_nightly-0.1.dev158.dist-info → mplang_nightly-0.1.dev268.dist-info}/entry_points.txt +0 -0
- {mplang_nightly-0.1.dev158.dist-info → mplang_nightly-0.1.dev268.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,330 @@
|
|
|
1
|
+
# Copyright 2025 Ant Group Co., Ltd.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
"""Collective communication library for multi-party data redistribution.
|
|
16
|
+
|
|
17
|
+
This module provides high-level collective operations built on top of
|
|
18
|
+
SIMP dialect primitives (shuffle_static, shuffle_dynamic, converge).
|
|
19
|
+
|
|
20
|
+
Design Philosophy:
|
|
21
|
+
- Single-controller perspective: all operations describe data flow from
|
|
22
|
+
the orchestrator's view, not individual party's view
|
|
23
|
+
- MPObject represents distributed values across parties
|
|
24
|
+
- Operations transform the distribution pattern
|
|
25
|
+
|
|
26
|
+
Naming Convention:
|
|
27
|
+
- transfer: point-to-point (1 party → 1 party)
|
|
28
|
+
- replicate: broadcast (1 party → N parties, same value)
|
|
29
|
+
- distribute: scatter (1 party with N values → N parties, one each)
|
|
30
|
+
- collect: gather (N parties → 1 party, stacked)
|
|
31
|
+
|
|
32
|
+
Example:
|
|
33
|
+
>>> from mplang.v2.libs.collective import transfer, replicate, distribute, collect
|
|
34
|
+
>>> from mplang.v2.dialects.simp import constant, converge
|
|
35
|
+
>>>
|
|
36
|
+
>>> # Create data on party 0
|
|
37
|
+
>>> x = constant((0,), 42)
|
|
38
|
+
>>>
|
|
39
|
+
>>> # Transfer to party 1
|
|
40
|
+
>>> y = transfer(x, to=1)
|
|
41
|
+
>>>
|
|
42
|
+
>>> # Replicate to all parties
|
|
43
|
+
>>> z = replicate(x, to=(0, 1, 2))
|
|
44
|
+
"""
|
|
45
|
+
|
|
46
|
+
from __future__ import annotations
|
|
47
|
+
|
|
48
|
+
from typing import TYPE_CHECKING
|
|
49
|
+
|
|
50
|
+
from mplang.v2.dialects.simp import converge, shuffle_static
|
|
51
|
+
from mplang.v2.edsl import Object
|
|
52
|
+
from mplang.v2.edsl.typing import MPType
|
|
53
|
+
|
|
54
|
+
if TYPE_CHECKING:
|
|
55
|
+
pass
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
# =============================================================================
|
|
59
|
+
# Helpers
|
|
60
|
+
# =============================================================================
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
def _get_parties(obj: Object) -> tuple[int, ...] | None:
|
|
64
|
+
"""Extract static parties from object type."""
|
|
65
|
+
if isinstance(obj.type, MPType):
|
|
66
|
+
return obj.type.parties
|
|
67
|
+
return None
|
|
68
|
+
|
|
69
|
+
|
|
70
|
+
def _get_single_party(obj: Object) -> int:
|
|
71
|
+
"""Extract the single party from an object (must have exactly one).
|
|
72
|
+
|
|
73
|
+
Args:
|
|
74
|
+
obj: Object with static parties containing exactly one party
|
|
75
|
+
|
|
76
|
+
Returns:
|
|
77
|
+
The single party rank
|
|
78
|
+
|
|
79
|
+
Raises:
|
|
80
|
+
ValueError: If parties is None (dynamic) or has != 1 party
|
|
81
|
+
"""
|
|
82
|
+
parties = _get_parties(obj)
|
|
83
|
+
if parties is None:
|
|
84
|
+
raise ValueError(
|
|
85
|
+
"Operation requires static parties, got dynamic (parties=None)"
|
|
86
|
+
)
|
|
87
|
+
if len(parties) != 1:
|
|
88
|
+
raise ValueError(
|
|
89
|
+
f"Operation requires single-party source, got parties={parties}"
|
|
90
|
+
)
|
|
91
|
+
return parties[0]
|
|
92
|
+
|
|
93
|
+
|
|
94
|
+
def _require_static_parties(obj: Object, op_name: str) -> tuple[int, ...]:
|
|
95
|
+
"""Require and return static parties from object.
|
|
96
|
+
|
|
97
|
+
Args:
|
|
98
|
+
obj: Object to check
|
|
99
|
+
op_name: Operation name for error message
|
|
100
|
+
|
|
101
|
+
Returns:
|
|
102
|
+
Static parties tuple
|
|
103
|
+
|
|
104
|
+
Raises:
|
|
105
|
+
ValueError: If parties is None (dynamic)
|
|
106
|
+
"""
|
|
107
|
+
parties = _get_parties(obj)
|
|
108
|
+
if parties is None:
|
|
109
|
+
raise ValueError(
|
|
110
|
+
f"{op_name} requires static parties, got dynamic (parties=None)"
|
|
111
|
+
)
|
|
112
|
+
return parties
|
|
113
|
+
|
|
114
|
+
|
|
115
|
+
# =============================================================================
|
|
116
|
+
# Point-to-Point Communication
|
|
117
|
+
# =============================================================================
|
|
118
|
+
|
|
119
|
+
|
|
120
|
+
def transfer(data: Object, *, to: int) -> Object:
|
|
121
|
+
"""Transfer data from one party to another.
|
|
122
|
+
|
|
123
|
+
Single-controller perspective:
|
|
124
|
+
- Input: MPObject held by exactly one party
|
|
125
|
+
- Output: MPObject held by party `to`
|
|
126
|
+
|
|
127
|
+
The source party is automatically inferred from data.type.parties.
|
|
128
|
+
|
|
129
|
+
Args:
|
|
130
|
+
data: Data to transfer (must have static parties with exactly one party)
|
|
131
|
+
to: Target party rank
|
|
132
|
+
|
|
133
|
+
Returns:
|
|
134
|
+
Data held by party `to` (parties=(to,))
|
|
135
|
+
|
|
136
|
+
Raises:
|
|
137
|
+
ValueError: If data has dynamic parties or more than one party
|
|
138
|
+
|
|
139
|
+
Example:
|
|
140
|
+
>>> x = constant((0,), 42) # x held by party 0
|
|
141
|
+
>>> y = transfer(x, to=1) # y held by party 1
|
|
142
|
+
>>> y.type.parties # (1,)
|
|
143
|
+
"""
|
|
144
|
+
frm = _get_single_party(data)
|
|
145
|
+
if frm == to:
|
|
146
|
+
return data
|
|
147
|
+
return shuffle_static(data, routing={to: frm})
|
|
148
|
+
|
|
149
|
+
|
|
150
|
+
# =============================================================================
|
|
151
|
+
# One-to-Many Operations
|
|
152
|
+
# =============================================================================
|
|
153
|
+
|
|
154
|
+
|
|
155
|
+
def replicate(data: Object, *, to: tuple[int, ...]) -> Object:
|
|
156
|
+
"""Replicate data from one party to multiple parties.
|
|
157
|
+
|
|
158
|
+
Single-controller perspective:
|
|
159
|
+
- Input: MPObject held by exactly one party
|
|
160
|
+
- Output: MPObject replicated across all parties in `to`
|
|
161
|
+
|
|
162
|
+
Each target party receives an identical copy of the data.
|
|
163
|
+
|
|
164
|
+
Args:
|
|
165
|
+
data: Data to replicate (must have static parties with exactly one party)
|
|
166
|
+
to: Target party ranks (tuple)
|
|
167
|
+
|
|
168
|
+
Returns:
|
|
169
|
+
Data replicated across all target parties (parties=to)
|
|
170
|
+
|
|
171
|
+
Raises:
|
|
172
|
+
ValueError: If data has dynamic parties or more than one party
|
|
173
|
+
|
|
174
|
+
Example:
|
|
175
|
+
>>> x = constant((0,), 42)
|
|
176
|
+
>>> y = replicate(x, to=(0, 1, 2))
|
|
177
|
+
>>> y.type.parties # (0, 1, 2)
|
|
178
|
+
>>> # All three parties now hold the value 42
|
|
179
|
+
"""
|
|
180
|
+
frm = _get_single_party(data)
|
|
181
|
+
routing = dict.fromkeys(to, frm)
|
|
182
|
+
return shuffle_static(data, routing=routing)
|
|
183
|
+
|
|
184
|
+
|
|
185
|
+
def distribute(values: list[Object], *, frm: int) -> Object:
|
|
186
|
+
"""Distribute a list of values from one party to multiple parties.
|
|
187
|
+
|
|
188
|
+
Single-controller perspective:
|
|
189
|
+
- Input: N MPObjects, all held by party `frm`
|
|
190
|
+
- Output: 1 MPObject distributed across N parties (party i holds values[i])
|
|
191
|
+
|
|
192
|
+
This is the inverse of collect().
|
|
193
|
+
|
|
194
|
+
Args:
|
|
195
|
+
values: List of N objects, all must be held by party `frm`
|
|
196
|
+
frm: Source party rank
|
|
197
|
+
|
|
198
|
+
Returns:
|
|
199
|
+
Single MPObject with parties=(0, 1, ..., N-1)
|
|
200
|
+
Party i holds the value from values[i]
|
|
201
|
+
|
|
202
|
+
Raises:
|
|
203
|
+
ValueError: If values is empty or any value is not held by `frm`
|
|
204
|
+
|
|
205
|
+
Example:
|
|
206
|
+
>>> xs = [constant((0,), i) for i in range(3)] # all held by party 0
|
|
207
|
+
>>> y = distribute(xs, frm=0)
|
|
208
|
+
>>> y.type.parties # (0, 1, 2)
|
|
209
|
+
>>> # Party 0 has 0, party 1 has 1, party 2 has 2
|
|
210
|
+
"""
|
|
211
|
+
if not values:
|
|
212
|
+
raise ValueError("distribute requires at least one value")
|
|
213
|
+
|
|
214
|
+
# Validate all values are held by frm
|
|
215
|
+
for i, v in enumerate(values):
|
|
216
|
+
parties = _get_parties(v)
|
|
217
|
+
if parties is None:
|
|
218
|
+
raise ValueError(
|
|
219
|
+
f"distribute requires static parties, value[{i}] has dynamic parties"
|
|
220
|
+
)
|
|
221
|
+
if parties != (frm,):
|
|
222
|
+
raise ValueError(
|
|
223
|
+
f"distribute requires all values from party {frm}, "
|
|
224
|
+
f"value[{i}] has parties={parties}"
|
|
225
|
+
)
|
|
226
|
+
|
|
227
|
+
pieces = [shuffle_static(v, routing={i: frm}) for i, v in enumerate(values)]
|
|
228
|
+
return converge(*pieces)
|
|
229
|
+
|
|
230
|
+
|
|
231
|
+
# =============================================================================
|
|
232
|
+
# Many-to-One Operations
|
|
233
|
+
# =============================================================================
|
|
234
|
+
|
|
235
|
+
|
|
236
|
+
def collect(data: Object, *, to: int) -> list[Object]:
|
|
237
|
+
"""Collect distributed data to one party.
|
|
238
|
+
|
|
239
|
+
Single-controller perspective:
|
|
240
|
+
- Input: 1 MPObject distributed across N parties
|
|
241
|
+
- Output: N MPObjects, each held by party `to`, preserving source order
|
|
242
|
+
|
|
243
|
+
Note: Returns a list because we preserve the logical separation of values
|
|
244
|
+
from different source parties. Use pcall_static to stack/concat if needed.
|
|
245
|
+
|
|
246
|
+
Args:
|
|
247
|
+
data: Distributed data (must have static parties)
|
|
248
|
+
to: Target party rank
|
|
249
|
+
|
|
250
|
+
Returns:
|
|
251
|
+
List of N objects, all held by party `to`
|
|
252
|
+
result[i] contains the value from source party i
|
|
253
|
+
|
|
254
|
+
Raises:
|
|
255
|
+
ValueError: If data has dynamic parties
|
|
256
|
+
|
|
257
|
+
Example:
|
|
258
|
+
>>> x = converge(x0, x1, x2) # x.parties = (0, 1, 2)
|
|
259
|
+
>>> ys = collect(x, to=0) # List of 3 objects
|
|
260
|
+
>>> ys[0].type.parties # (0,)
|
|
261
|
+
>>> ys[1].type.parties # (0,)
|
|
262
|
+
>>> # ys[0] has x0's value, ys[1] has x1's value, etc.
|
|
263
|
+
"""
|
|
264
|
+
src_parties = _require_static_parties(data, "collect")
|
|
265
|
+
return [shuffle_static(data, routing={to: src}) for src in src_parties]
|
|
266
|
+
|
|
267
|
+
|
|
268
|
+
# =============================================================================
|
|
269
|
+
# Many-to-Many Operations
|
|
270
|
+
# =============================================================================
|
|
271
|
+
|
|
272
|
+
|
|
273
|
+
def allreplicate(data: Object) -> list[Object]:
|
|
274
|
+
"""Replicate each party's data to all parties.
|
|
275
|
+
|
|
276
|
+
Single-controller perspective:
|
|
277
|
+
- Input: 1 MPObject distributed across N parties
|
|
278
|
+
- Output: N MPObjects, each replicated across all N parties
|
|
279
|
+
|
|
280
|
+
result[i] contains party i's original value, replicated to all parties.
|
|
281
|
+
|
|
282
|
+
Args:
|
|
283
|
+
data: Distributed data (must have static parties)
|
|
284
|
+
|
|
285
|
+
Returns:
|
|
286
|
+
List of N objects, each with parties equal to the original parties
|
|
287
|
+
result[i] is the value from source party i, replicated to all parties
|
|
288
|
+
|
|
289
|
+
Raises:
|
|
290
|
+
ValueError: If data has dynamic parties
|
|
291
|
+
|
|
292
|
+
Example:
|
|
293
|
+
>>> x = converge(x0, x1, x2) # x.parties = (0, 1, 2)
|
|
294
|
+
>>> ys = allreplicate(x) # List of 3 objects
|
|
295
|
+
>>> ys[0].type.parties # (0, 1, 2) - contains x0's value
|
|
296
|
+
>>> ys[1].type.parties # (0, 1, 2) - contains x1's value
|
|
297
|
+
"""
|
|
298
|
+
src_parties = _require_static_parties(data, "allreplicate")
|
|
299
|
+
|
|
300
|
+
result = []
|
|
301
|
+
for src in src_parties:
|
|
302
|
+
# Replicate from src to all parties
|
|
303
|
+
routing = dict.fromkeys(src_parties, src)
|
|
304
|
+
result.append(shuffle_static(data, routing=routing))
|
|
305
|
+
return result
|
|
306
|
+
|
|
307
|
+
|
|
308
|
+
def permute(data: Object, *, mapping: dict[int, int]) -> Object:
|
|
309
|
+
"""Permute data according to a party mapping.
|
|
310
|
+
|
|
311
|
+
Single-controller perspective:
|
|
312
|
+
- Input: 1 MPObject distributed across parties
|
|
313
|
+
- Output: 1 MPObject with data permuted according to mapping
|
|
314
|
+
|
|
315
|
+
The mapping specifies: target_party -> source_party.
|
|
316
|
+
This is a thin wrapper around shuffle_static for clarity.
|
|
317
|
+
|
|
318
|
+
Args:
|
|
319
|
+
data: Distributed data
|
|
320
|
+
mapping: Dict mapping target_party -> source_party
|
|
321
|
+
|
|
322
|
+
Returns:
|
|
323
|
+
Permuted data with parties = tuple(sorted(mapping.keys()))
|
|
324
|
+
|
|
325
|
+
Example:
|
|
326
|
+
>>> x = converge(x0, x1) # x.parties = (0, 1)
|
|
327
|
+
>>> y = permute(x, mapping={0: 1, 1: 0}) # swap
|
|
328
|
+
>>> # Party 0 now has x1's value, party 1 has x0's value
|
|
329
|
+
"""
|
|
330
|
+
return shuffle_static(data, routing=mapping)
|
|
@@ -0,0 +1,51 @@
|
|
|
1
|
+
# Copyright 2025 Ant Group Co., Ltd.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
"""Device library for MPLang2.
|
|
16
|
+
|
|
17
|
+
This module provides the high-level device-centric programming interface.
|
|
18
|
+
"""
|
|
19
|
+
|
|
20
|
+
from mplang.v2.dialects.tensor import jax_fn
|
|
21
|
+
|
|
22
|
+
from .api import (
|
|
23
|
+
DeviceContext,
|
|
24
|
+
DeviceError,
|
|
25
|
+
DeviceInferenceError,
|
|
26
|
+
DeviceNotFoundError,
|
|
27
|
+
device,
|
|
28
|
+
fetch,
|
|
29
|
+
get_dev_attr,
|
|
30
|
+
is_device_obj,
|
|
31
|
+
put,
|
|
32
|
+
set_dev_attr,
|
|
33
|
+
)
|
|
34
|
+
from .cluster import ClusterSpec, Device, Node
|
|
35
|
+
|
|
36
|
+
__all__ = [
|
|
37
|
+
"ClusterSpec",
|
|
38
|
+
"Device",
|
|
39
|
+
"DeviceContext",
|
|
40
|
+
"DeviceError",
|
|
41
|
+
"DeviceInferenceError",
|
|
42
|
+
"DeviceNotFoundError",
|
|
43
|
+
"Node",
|
|
44
|
+
"device",
|
|
45
|
+
"fetch",
|
|
46
|
+
"get_dev_attr",
|
|
47
|
+
"is_device_obj",
|
|
48
|
+
"jax_fn",
|
|
49
|
+
"put",
|
|
50
|
+
"set_dev_attr",
|
|
51
|
+
]
|