mplang-nightly 0.1.dev158__py3-none-any.whl → 0.1.dev268__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (191) hide show
  1. mplang/__init__.py +21 -45
  2. mplang/py.typed +13 -0
  3. mplang/v1/__init__.py +157 -0
  4. mplang/v1/_device.py +602 -0
  5. mplang/{analysis → v1/analysis}/__init__.py +1 -1
  6. mplang/{analysis → v1/analysis}/diagram.py +5 -7
  7. mplang/v1/core/__init__.py +157 -0
  8. mplang/{core → v1/core}/cluster.py +30 -14
  9. mplang/{core → v1/core}/comm.py +5 -1
  10. mplang/{core → v1/core}/context_mgr.py +1 -1
  11. mplang/{core/dtype.py → v1/core/dtypes.py} +44 -2
  12. mplang/{core → v1/core}/expr/__init__.py +7 -7
  13. mplang/{core → v1/core}/expr/ast.py +13 -14
  14. mplang/{core → v1/core}/expr/evaluator.py +65 -24
  15. mplang/{core → v1/core}/expr/printer.py +24 -18
  16. mplang/{core → v1/core}/expr/transformer.py +3 -3
  17. mplang/{core → v1/core}/expr/utils.py +2 -2
  18. mplang/{core → v1/core}/expr/visitor.py +1 -1
  19. mplang/{core → v1/core}/expr/walk.py +1 -1
  20. mplang/{core → v1/core}/interp.py +6 -6
  21. mplang/{core → v1/core}/mpir.py +23 -16
  22. mplang/{core → v1/core}/mpobject.py +6 -6
  23. mplang/{core → v1/core}/mptype.py +13 -10
  24. mplang/{core → v1/core}/pfunc.py +4 -4
  25. mplang/{core → v1/core}/primitive.py +106 -201
  26. mplang/{core → v1/core}/table.py +36 -8
  27. mplang/{core → v1/core}/tensor.py +1 -1
  28. mplang/{core → v1/core}/tracer.py +9 -9
  29. mplang/{api.py → v1/host.py} +38 -6
  30. mplang/v1/kernels/__init__.py +41 -0
  31. mplang/{kernels → v1/kernels}/base.py +1 -1
  32. mplang/v1/kernels/basic.py +240 -0
  33. mplang/{kernels → v1/kernels}/context.py +42 -27
  34. mplang/{kernels → v1/kernels}/crypto.py +44 -37
  35. mplang/v1/kernels/fhe.py +858 -0
  36. mplang/{kernels → v1/kernels}/mock_tee.py +12 -13
  37. mplang/{kernels → v1/kernels}/phe.py +263 -57
  38. mplang/{kernels → v1/kernels}/spu.py +137 -48
  39. mplang/{kernels → v1/kernels}/sql_duckdb.py +12 -15
  40. mplang/{kernels → v1/kernels}/stablehlo.py +30 -23
  41. mplang/v1/kernels/value.py +626 -0
  42. mplang/{ops → v1/ops}/__init__.py +5 -16
  43. mplang/{ops → v1/ops}/base.py +2 -5
  44. mplang/{ops/builtin.py → v1/ops/basic.py} +34 -26
  45. mplang/v1/ops/crypto.py +262 -0
  46. mplang/v1/ops/fhe.py +272 -0
  47. mplang/{ops → v1/ops}/jax_cc.py +33 -68
  48. mplang/v1/ops/nnx_cc.py +168 -0
  49. mplang/{ops → v1/ops}/phe.py +16 -4
  50. mplang/{ops → v1/ops}/spu.py +3 -5
  51. mplang/v1/ops/sql_cc.py +303 -0
  52. mplang/{ops → v1/ops}/tee.py +9 -24
  53. mplang/{protos → v1/protos}/v1alpha1/mpir_pb2.pyi +71 -21
  54. mplang/v1/protos/v1alpha1/value_pb2.py +34 -0
  55. mplang/v1/protos/v1alpha1/value_pb2.pyi +169 -0
  56. mplang/{runtime → v1/runtime}/__init__.py +2 -2
  57. mplang/v1/runtime/channel.py +230 -0
  58. mplang/{runtime → v1/runtime}/cli.py +35 -20
  59. mplang/{runtime → v1/runtime}/client.py +19 -8
  60. mplang/{runtime → v1/runtime}/communicator.py +59 -15
  61. mplang/{runtime → v1/runtime}/data_providers.py +80 -19
  62. mplang/{runtime → v1/runtime}/driver.py +30 -12
  63. mplang/v1/runtime/link_comm.py +196 -0
  64. mplang/{runtime → v1/runtime}/server.py +58 -42
  65. mplang/{runtime → v1/runtime}/session.py +57 -71
  66. mplang/{runtime → v1/runtime}/simulation.py +55 -28
  67. mplang/v1/simp/api.py +353 -0
  68. mplang/{simp → v1/simp}/mpi.py +8 -9
  69. mplang/{simp/__init__.py → v1/simp/party.py} +19 -145
  70. mplang/{simp → v1/simp}/random.py +21 -22
  71. mplang/v1/simp/smpc.py +238 -0
  72. mplang/v1/utils/table_utils.py +185 -0
  73. mplang/v2/__init__.py +424 -0
  74. mplang/v2/backends/__init__.py +57 -0
  75. mplang/v2/backends/bfv_impl.py +705 -0
  76. mplang/v2/backends/channel.py +217 -0
  77. mplang/v2/backends/crypto_impl.py +723 -0
  78. mplang/v2/backends/field_impl.py +454 -0
  79. mplang/v2/backends/func_impl.py +107 -0
  80. mplang/v2/backends/phe_impl.py +148 -0
  81. mplang/v2/backends/simp_design.md +136 -0
  82. mplang/v2/backends/simp_driver/__init__.py +41 -0
  83. mplang/v2/backends/simp_driver/http.py +168 -0
  84. mplang/v2/backends/simp_driver/mem.py +280 -0
  85. mplang/v2/backends/simp_driver/ops.py +135 -0
  86. mplang/v2/backends/simp_driver/state.py +60 -0
  87. mplang/v2/backends/simp_driver/values.py +52 -0
  88. mplang/v2/backends/simp_worker/__init__.py +29 -0
  89. mplang/v2/backends/simp_worker/http.py +354 -0
  90. mplang/v2/backends/simp_worker/mem.py +102 -0
  91. mplang/v2/backends/simp_worker/ops.py +167 -0
  92. mplang/v2/backends/simp_worker/state.py +49 -0
  93. mplang/v2/backends/spu_impl.py +275 -0
  94. mplang/v2/backends/spu_state.py +187 -0
  95. mplang/v2/backends/store_impl.py +62 -0
  96. mplang/v2/backends/table_impl.py +838 -0
  97. mplang/v2/backends/tee_impl.py +215 -0
  98. mplang/v2/backends/tensor_impl.py +519 -0
  99. mplang/v2/cli.py +603 -0
  100. mplang/v2/cli_guide.md +122 -0
  101. mplang/v2/dialects/__init__.py +36 -0
  102. mplang/v2/dialects/bfv.py +665 -0
  103. mplang/v2/dialects/crypto.py +689 -0
  104. mplang/v2/dialects/dtypes.py +378 -0
  105. mplang/v2/dialects/field.py +210 -0
  106. mplang/v2/dialects/func.py +135 -0
  107. mplang/v2/dialects/phe.py +723 -0
  108. mplang/v2/dialects/simp.py +944 -0
  109. mplang/v2/dialects/spu.py +349 -0
  110. mplang/v2/dialects/store.py +63 -0
  111. mplang/v2/dialects/table.py +407 -0
  112. mplang/v2/dialects/tee.py +346 -0
  113. mplang/v2/dialects/tensor.py +1175 -0
  114. mplang/v2/edsl/README.md +279 -0
  115. mplang/v2/edsl/__init__.py +99 -0
  116. mplang/v2/edsl/context.py +311 -0
  117. mplang/v2/edsl/graph.py +463 -0
  118. mplang/v2/edsl/jit.py +62 -0
  119. mplang/v2/edsl/object.py +53 -0
  120. mplang/v2/edsl/primitive.py +284 -0
  121. mplang/v2/edsl/printer.py +119 -0
  122. mplang/v2/edsl/registry.py +207 -0
  123. mplang/v2/edsl/serde.py +375 -0
  124. mplang/v2/edsl/tracer.py +614 -0
  125. mplang/v2/edsl/typing.py +816 -0
  126. mplang/v2/kernels/Makefile +30 -0
  127. mplang/v2/kernels/__init__.py +23 -0
  128. mplang/v2/kernels/gf128.cpp +148 -0
  129. mplang/v2/kernels/ldpc.cpp +82 -0
  130. mplang/v2/kernels/okvs.cpp +283 -0
  131. mplang/v2/kernels/okvs_opt.cpp +291 -0
  132. mplang/v2/kernels/py_kernels.py +398 -0
  133. mplang/v2/libs/collective.py +330 -0
  134. mplang/v2/libs/device/__init__.py +51 -0
  135. mplang/v2/libs/device/api.py +813 -0
  136. mplang/v2/libs/device/cluster.py +352 -0
  137. mplang/v2/libs/ml/__init__.py +23 -0
  138. mplang/v2/libs/ml/sgb.py +1861 -0
  139. mplang/v2/libs/mpc/__init__.py +41 -0
  140. mplang/v2/libs/mpc/_utils.py +99 -0
  141. mplang/v2/libs/mpc/analytics/__init__.py +35 -0
  142. mplang/v2/libs/mpc/analytics/aggregation.py +372 -0
  143. mplang/v2/libs/mpc/analytics/groupby.md +99 -0
  144. mplang/v2/libs/mpc/analytics/groupby.py +331 -0
  145. mplang/v2/libs/mpc/analytics/permutation.py +386 -0
  146. mplang/v2/libs/mpc/common/constants.py +39 -0
  147. mplang/v2/libs/mpc/ot/__init__.py +32 -0
  148. mplang/v2/libs/mpc/ot/base.py +222 -0
  149. mplang/v2/libs/mpc/ot/extension.py +477 -0
  150. mplang/v2/libs/mpc/ot/silent.py +217 -0
  151. mplang/v2/libs/mpc/psi/__init__.py +40 -0
  152. mplang/v2/libs/mpc/psi/cuckoo.py +228 -0
  153. mplang/v2/libs/mpc/psi/okvs.py +49 -0
  154. mplang/v2/libs/mpc/psi/okvs_gct.py +79 -0
  155. mplang/v2/libs/mpc/psi/oprf.py +310 -0
  156. mplang/v2/libs/mpc/psi/rr22.py +344 -0
  157. mplang/v2/libs/mpc/psi/unbalanced.py +200 -0
  158. mplang/v2/libs/mpc/vole/__init__.py +31 -0
  159. mplang/v2/libs/mpc/vole/gilboa.py +327 -0
  160. mplang/v2/libs/mpc/vole/ldpc.py +383 -0
  161. mplang/v2/libs/mpc/vole/silver.py +336 -0
  162. mplang/v2/runtime/__init__.py +15 -0
  163. mplang/v2/runtime/dialect_state.py +41 -0
  164. mplang/v2/runtime/interpreter.py +871 -0
  165. mplang/v2/runtime/object_store.py +194 -0
  166. mplang/v2/runtime/value.py +141 -0
  167. {mplang_nightly-0.1.dev158.dist-info → mplang_nightly-0.1.dev268.dist-info}/METADATA +24 -17
  168. mplang_nightly-0.1.dev268.dist-info/RECORD +180 -0
  169. {mplang_nightly-0.1.dev158.dist-info → mplang_nightly-0.1.dev268.dist-info}/WHEEL +1 -1
  170. mplang/core/__init__.py +0 -92
  171. mplang/device.py +0 -340
  172. mplang/kernels/builtin.py +0 -207
  173. mplang/ops/crypto.py +0 -109
  174. mplang/ops/ibis_cc.py +0 -139
  175. mplang/ops/sql.py +0 -61
  176. mplang/protos/v1alpha1/mpir_pb2_grpc.py +0 -3
  177. mplang/runtime/link_comm.py +0 -131
  178. mplang/simp/smpc.py +0 -201
  179. mplang/utils/table_utils.py +0 -73
  180. mplang_nightly-0.1.dev158.dist-info/RECORD +0 -77
  181. /mplang/{core → v1/core}/mask.py +0 -0
  182. /mplang/{protos → v1/protos}/v1alpha1/mpir_pb2.py +0 -0
  183. /mplang/{runtime → v1/runtime}/exceptions.py +0 -0
  184. /mplang/{runtime → v1/runtime}/http_api.md +0 -0
  185. /mplang/{kernels → v1/simp}/__init__.py +0 -0
  186. /mplang/{utils → v1/utils}/__init__.py +0 -0
  187. /mplang/{utils → v1/utils}/crypto.py +0 -0
  188. /mplang/{utils → v1/utils}/func_utils.py +0 -0
  189. /mplang/{utils → v1/utils}/spu_utils.py +0 -0
  190. {mplang_nightly-0.1.dev158.dist-info → mplang_nightly-0.1.dev268.dist-info}/entry_points.txt +0 -0
  191. {mplang_nightly-0.1.dev158.dist-info → mplang_nightly-0.1.dev268.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,330 @@
1
+ # Copyright 2025 Ant Group Co., Ltd.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """Collective communication library for multi-party data redistribution.
16
+
17
+ This module provides high-level collective operations built on top of
18
+ SIMP dialect primitives (shuffle_static, shuffle_dynamic, converge).
19
+
20
+ Design Philosophy:
21
+ - Single-controller perspective: all operations describe data flow from
22
+ the orchestrator's view, not individual party's view
23
+ - MPObject represents distributed values across parties
24
+ - Operations transform the distribution pattern
25
+
26
+ Naming Convention:
27
+ - transfer: point-to-point (1 party → 1 party)
28
+ - replicate: broadcast (1 party → N parties, same value)
29
+ - distribute: scatter (1 party with N values → N parties, one each)
30
+ - collect: gather (N parties → 1 party, stacked)
31
+
32
+ Example:
33
+ >>> from mplang.v2.libs.collective import transfer, replicate, distribute, collect
34
+ >>> from mplang.v2.dialects.simp import constant, converge
35
+ >>>
36
+ >>> # Create data on party 0
37
+ >>> x = constant((0,), 42)
38
+ >>>
39
+ >>> # Transfer to party 1
40
+ >>> y = transfer(x, to=1)
41
+ >>>
42
+ >>> # Replicate to all parties
43
+ >>> z = replicate(x, to=(0, 1, 2))
44
+ """
45
+
46
+ from __future__ import annotations
47
+
48
+ from typing import TYPE_CHECKING
49
+
50
+ from mplang.v2.dialects.simp import converge, shuffle_static
51
+ from mplang.v2.edsl import Object
52
+ from mplang.v2.edsl.typing import MPType
53
+
54
+ if TYPE_CHECKING:
55
+ pass
56
+
57
+
58
+ # =============================================================================
59
+ # Helpers
60
+ # =============================================================================
61
+
62
+
63
+ def _get_parties(obj: Object) -> tuple[int, ...] | None:
64
+ """Extract static parties from object type."""
65
+ if isinstance(obj.type, MPType):
66
+ return obj.type.parties
67
+ return None
68
+
69
+
70
+ def _get_single_party(obj: Object) -> int:
71
+ """Extract the single party from an object (must have exactly one).
72
+
73
+ Args:
74
+ obj: Object with static parties containing exactly one party
75
+
76
+ Returns:
77
+ The single party rank
78
+
79
+ Raises:
80
+ ValueError: If parties is None (dynamic) or has != 1 party
81
+ """
82
+ parties = _get_parties(obj)
83
+ if parties is None:
84
+ raise ValueError(
85
+ "Operation requires static parties, got dynamic (parties=None)"
86
+ )
87
+ if len(parties) != 1:
88
+ raise ValueError(
89
+ f"Operation requires single-party source, got parties={parties}"
90
+ )
91
+ return parties[0]
92
+
93
+
94
+ def _require_static_parties(obj: Object, op_name: str) -> tuple[int, ...]:
95
+ """Require and return static parties from object.
96
+
97
+ Args:
98
+ obj: Object to check
99
+ op_name: Operation name for error message
100
+
101
+ Returns:
102
+ Static parties tuple
103
+
104
+ Raises:
105
+ ValueError: If parties is None (dynamic)
106
+ """
107
+ parties = _get_parties(obj)
108
+ if parties is None:
109
+ raise ValueError(
110
+ f"{op_name} requires static parties, got dynamic (parties=None)"
111
+ )
112
+ return parties
113
+
114
+
115
+ # =============================================================================
116
+ # Point-to-Point Communication
117
+ # =============================================================================
118
+
119
+
120
+ def transfer(data: Object, *, to: int) -> Object:
121
+ """Transfer data from one party to another.
122
+
123
+ Single-controller perspective:
124
+ - Input: MPObject held by exactly one party
125
+ - Output: MPObject held by party `to`
126
+
127
+ The source party is automatically inferred from data.type.parties.
128
+
129
+ Args:
130
+ data: Data to transfer (must have static parties with exactly one party)
131
+ to: Target party rank
132
+
133
+ Returns:
134
+ Data held by party `to` (parties=(to,))
135
+
136
+ Raises:
137
+ ValueError: If data has dynamic parties or more than one party
138
+
139
+ Example:
140
+ >>> x = constant((0,), 42) # x held by party 0
141
+ >>> y = transfer(x, to=1) # y held by party 1
142
+ >>> y.type.parties # (1,)
143
+ """
144
+ frm = _get_single_party(data)
145
+ if frm == to:
146
+ return data
147
+ return shuffle_static(data, routing={to: frm})
148
+
149
+
150
+ # =============================================================================
151
+ # One-to-Many Operations
152
+ # =============================================================================
153
+
154
+
155
+ def replicate(data: Object, *, to: tuple[int, ...]) -> Object:
156
+ """Replicate data from one party to multiple parties.
157
+
158
+ Single-controller perspective:
159
+ - Input: MPObject held by exactly one party
160
+ - Output: MPObject replicated across all parties in `to`
161
+
162
+ Each target party receives an identical copy of the data.
163
+
164
+ Args:
165
+ data: Data to replicate (must have static parties with exactly one party)
166
+ to: Target party ranks (tuple)
167
+
168
+ Returns:
169
+ Data replicated across all target parties (parties=to)
170
+
171
+ Raises:
172
+ ValueError: If data has dynamic parties or more than one party
173
+
174
+ Example:
175
+ >>> x = constant((0,), 42)
176
+ >>> y = replicate(x, to=(0, 1, 2))
177
+ >>> y.type.parties # (0, 1, 2)
178
+ >>> # All three parties now hold the value 42
179
+ """
180
+ frm = _get_single_party(data)
181
+ routing = dict.fromkeys(to, frm)
182
+ return shuffle_static(data, routing=routing)
183
+
184
+
185
+ def distribute(values: list[Object], *, frm: int) -> Object:
186
+ """Distribute a list of values from one party to multiple parties.
187
+
188
+ Single-controller perspective:
189
+ - Input: N MPObjects, all held by party `frm`
190
+ - Output: 1 MPObject distributed across N parties (party i holds values[i])
191
+
192
+ This is the inverse of collect().
193
+
194
+ Args:
195
+ values: List of N objects, all must be held by party `frm`
196
+ frm: Source party rank
197
+
198
+ Returns:
199
+ Single MPObject with parties=(0, 1, ..., N-1)
200
+ Party i holds the value from values[i]
201
+
202
+ Raises:
203
+ ValueError: If values is empty or any value is not held by `frm`
204
+
205
+ Example:
206
+ >>> xs = [constant((0,), i) for i in range(3)] # all held by party 0
207
+ >>> y = distribute(xs, frm=0)
208
+ >>> y.type.parties # (0, 1, 2)
209
+ >>> # Party 0 has 0, party 1 has 1, party 2 has 2
210
+ """
211
+ if not values:
212
+ raise ValueError("distribute requires at least one value")
213
+
214
+ # Validate all values are held by frm
215
+ for i, v in enumerate(values):
216
+ parties = _get_parties(v)
217
+ if parties is None:
218
+ raise ValueError(
219
+ f"distribute requires static parties, value[{i}] has dynamic parties"
220
+ )
221
+ if parties != (frm,):
222
+ raise ValueError(
223
+ f"distribute requires all values from party {frm}, "
224
+ f"value[{i}] has parties={parties}"
225
+ )
226
+
227
+ pieces = [shuffle_static(v, routing={i: frm}) for i, v in enumerate(values)]
228
+ return converge(*pieces)
229
+
230
+
231
+ # =============================================================================
232
+ # Many-to-One Operations
233
+ # =============================================================================
234
+
235
+
236
+ def collect(data: Object, *, to: int) -> list[Object]:
237
+ """Collect distributed data to one party.
238
+
239
+ Single-controller perspective:
240
+ - Input: 1 MPObject distributed across N parties
241
+ - Output: N MPObjects, each held by party `to`, preserving source order
242
+
243
+ Note: Returns a list because we preserve the logical separation of values
244
+ from different source parties. Use pcall_static to stack/concat if needed.
245
+
246
+ Args:
247
+ data: Distributed data (must have static parties)
248
+ to: Target party rank
249
+
250
+ Returns:
251
+ List of N objects, all held by party `to`
252
+ result[i] contains the value from source party i
253
+
254
+ Raises:
255
+ ValueError: If data has dynamic parties
256
+
257
+ Example:
258
+ >>> x = converge(x0, x1, x2) # x.parties = (0, 1, 2)
259
+ >>> ys = collect(x, to=0) # List of 3 objects
260
+ >>> ys[0].type.parties # (0,)
261
+ >>> ys[1].type.parties # (0,)
262
+ >>> # ys[0] has x0's value, ys[1] has x1's value, etc.
263
+ """
264
+ src_parties = _require_static_parties(data, "collect")
265
+ return [shuffle_static(data, routing={to: src}) for src in src_parties]
266
+
267
+
268
+ # =============================================================================
269
+ # Many-to-Many Operations
270
+ # =============================================================================
271
+
272
+
273
+ def allreplicate(data: Object) -> list[Object]:
274
+ """Replicate each party's data to all parties.
275
+
276
+ Single-controller perspective:
277
+ - Input: 1 MPObject distributed across N parties
278
+ - Output: N MPObjects, each replicated across all N parties
279
+
280
+ result[i] contains party i's original value, replicated to all parties.
281
+
282
+ Args:
283
+ data: Distributed data (must have static parties)
284
+
285
+ Returns:
286
+ List of N objects, each with parties equal to the original parties
287
+ result[i] is the value from source party i, replicated to all parties
288
+
289
+ Raises:
290
+ ValueError: If data has dynamic parties
291
+
292
+ Example:
293
+ >>> x = converge(x0, x1, x2) # x.parties = (0, 1, 2)
294
+ >>> ys = allreplicate(x) # List of 3 objects
295
+ >>> ys[0].type.parties # (0, 1, 2) - contains x0's value
296
+ >>> ys[1].type.parties # (0, 1, 2) - contains x1's value
297
+ """
298
+ src_parties = _require_static_parties(data, "allreplicate")
299
+
300
+ result = []
301
+ for src in src_parties:
302
+ # Replicate from src to all parties
303
+ routing = dict.fromkeys(src_parties, src)
304
+ result.append(shuffle_static(data, routing=routing))
305
+ return result
306
+
307
+
308
+ def permute(data: Object, *, mapping: dict[int, int]) -> Object:
309
+ """Permute data according to a party mapping.
310
+
311
+ Single-controller perspective:
312
+ - Input: 1 MPObject distributed across parties
313
+ - Output: 1 MPObject with data permuted according to mapping
314
+
315
+ The mapping specifies: target_party -> source_party.
316
+ This is a thin wrapper around shuffle_static for clarity.
317
+
318
+ Args:
319
+ data: Distributed data
320
+ mapping: Dict mapping target_party -> source_party
321
+
322
+ Returns:
323
+ Permuted data with parties = tuple(sorted(mapping.keys()))
324
+
325
+ Example:
326
+ >>> x = converge(x0, x1) # x.parties = (0, 1)
327
+ >>> y = permute(x, mapping={0: 1, 1: 0}) # swap
328
+ >>> # Party 0 now has x1's value, party 1 has x0's value
329
+ """
330
+ return shuffle_static(data, routing=mapping)
@@ -0,0 +1,51 @@
1
+ # Copyright 2025 Ant Group Co., Ltd.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """Device library for MPLang2.
16
+
17
+ This module provides the high-level device-centric programming interface.
18
+ """
19
+
20
+ from mplang.v2.dialects.tensor import jax_fn
21
+
22
+ from .api import (
23
+ DeviceContext,
24
+ DeviceError,
25
+ DeviceInferenceError,
26
+ DeviceNotFoundError,
27
+ device,
28
+ fetch,
29
+ get_dev_attr,
30
+ is_device_obj,
31
+ put,
32
+ set_dev_attr,
33
+ )
34
+ from .cluster import ClusterSpec, Device, Node
35
+
36
+ __all__ = [
37
+ "ClusterSpec",
38
+ "Device",
39
+ "DeviceContext",
40
+ "DeviceError",
41
+ "DeviceInferenceError",
42
+ "DeviceNotFoundError",
43
+ "Node",
44
+ "device",
45
+ "fetch",
46
+ "get_dev_attr",
47
+ "is_device_obj",
48
+ "jax_fn",
49
+ "put",
50
+ "set_dev_attr",
51
+ ]