mplang-nightly 0.1.dev269__py3-none-any.whl → 0.1.dev271__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (180) hide show
  1. mplang/__init__.py +391 -17
  2. mplang/{v2/backends → backends}/__init__.py +9 -7
  3. mplang/{v2/backends → backends}/bfv_impl.py +6 -6
  4. mplang/{v2/backends → backends}/crypto_impl.py +6 -6
  5. mplang/{v2/backends → backends}/field_impl.py +5 -5
  6. mplang/{v2/backends → backends}/func_impl.py +4 -4
  7. mplang/{v2/backends → backends}/phe_impl.py +3 -3
  8. mplang/{v2/backends → backends}/simp_design.md +1 -1
  9. mplang/{v2/backends → backends}/simp_driver/__init__.py +5 -5
  10. mplang/{v2/backends → backends}/simp_driver/http.py +8 -8
  11. mplang/{v2/backends → backends}/simp_driver/mem.py +9 -9
  12. mplang/{v2/backends → backends}/simp_driver/ops.py +4 -4
  13. mplang/{v2/backends → backends}/simp_driver/state.py +2 -2
  14. mplang/{v2/backends → backends}/simp_driver/values.py +2 -2
  15. mplang/{v2/backends → backends}/simp_worker/__init__.py +3 -3
  16. mplang/{v2/backends → backends}/simp_worker/http.py +10 -10
  17. mplang/{v2/backends → backends}/simp_worker/mem.py +1 -1
  18. mplang/{v2/backends → backends}/simp_worker/ops.py +5 -5
  19. mplang/{v2/backends → backends}/simp_worker/state.py +2 -4
  20. mplang/{v2/backends → backends}/spu_impl.py +8 -8
  21. mplang/{v2/backends → backends}/spu_state.py +4 -4
  22. mplang/{v2/backends → backends}/store_impl.py +3 -3
  23. mplang/{v2/backends → backends}/table_impl.py +8 -8
  24. mplang/{v2/backends → backends}/tee_impl.py +6 -6
  25. mplang/{v2/backends → backends}/tensor_impl.py +6 -6
  26. mplang/{v2/cli.py → cli.py} +9 -9
  27. mplang/{v2/cli_guide.md → cli_guide.md} +12 -12
  28. mplang/{v2/dialects → dialects}/__init__.py +5 -5
  29. mplang/{v2/dialects → dialects}/bfv.py +6 -6
  30. mplang/{v2/dialects → dialects}/crypto.py +5 -5
  31. mplang/{v2/dialects → dialects}/dtypes.py +2 -2
  32. mplang/{v2/dialects → dialects}/field.py +3 -3
  33. mplang/{v2/dialects → dialects}/func.py +2 -2
  34. mplang/{v2/dialects → dialects}/phe.py +6 -6
  35. mplang/{v2/dialects → dialects}/simp.py +6 -6
  36. mplang/{v2/dialects → dialects}/spu.py +7 -7
  37. mplang/{v2/dialects → dialects}/store.py +2 -2
  38. mplang/{v2/dialects → dialects}/table.py +3 -3
  39. mplang/{v2/dialects → dialects}/tee.py +6 -6
  40. mplang/{v2/dialects → dialects}/tensor.py +5 -5
  41. mplang/{v2/edsl → edsl}/__init__.py +3 -3
  42. mplang/{v2/edsl → edsl}/context.py +6 -6
  43. mplang/{v2/edsl → edsl}/graph.py +5 -5
  44. mplang/{v2/edsl → edsl}/jit.py +2 -2
  45. mplang/{v2/edsl → edsl}/object.py +1 -1
  46. mplang/{v2/edsl → edsl}/primitive.py +5 -5
  47. mplang/{v2/edsl → edsl}/printer.py +1 -1
  48. mplang/{v2/edsl → edsl}/serde.py +1 -1
  49. mplang/{v2/edsl → edsl}/tracer.py +7 -7
  50. mplang/{v2/edsl → edsl}/typing.py +1 -1
  51. mplang/{v2/kernels → kernels}/ldpc.cpp +13 -13
  52. mplang/{v2/kernels → kernels}/okvs.cpp +4 -4
  53. mplang/{v2/kernels → kernels}/okvs_opt.cpp +31 -31
  54. mplang/{v2/kernels → kernels}/py_kernels.py +1 -1
  55. mplang/{v2/libs → libs}/collective.py +5 -5
  56. mplang/{v2/libs → libs}/device/__init__.py +1 -1
  57. mplang/{v2/libs → libs}/device/api.py +12 -12
  58. mplang/{v2/libs → libs}/ml/__init__.py +1 -1
  59. mplang/{v2/libs → libs}/ml/sgb.py +4 -4
  60. mplang/{v2/libs → libs}/mpc/__init__.py +3 -3
  61. mplang/{v2/libs → libs}/mpc/_utils.py +2 -2
  62. mplang/{v2/libs → libs}/mpc/analytics/aggregation.py +1 -1
  63. mplang/{v2/libs → libs}/mpc/analytics/groupby.py +2 -2
  64. mplang/{v2/libs → libs}/mpc/analytics/permutation.py +3 -3
  65. mplang/{v2/libs → libs}/mpc/ot/base.py +3 -3
  66. mplang/{v2/libs → libs}/mpc/ot/extension.py +2 -2
  67. mplang/{v2/libs → libs}/mpc/ot/silent.py +4 -4
  68. mplang/{v2/libs → libs}/mpc/psi/cuckoo.py +3 -3
  69. mplang/{v2/libs → libs}/mpc/psi/okvs.py +1 -1
  70. mplang/{v2/libs → libs}/mpc/psi/okvs_gct.py +3 -3
  71. mplang/{v2/libs → libs}/mpc/psi/oprf.py +3 -3
  72. mplang/{v2/libs → libs}/mpc/psi/rr22.py +7 -7
  73. mplang/{v2/libs → libs}/mpc/psi/unbalanced.py +4 -4
  74. mplang/{v2/libs → libs}/mpc/vole/gilboa.py +3 -3
  75. mplang/{v2/libs → libs}/mpc/vole/ldpc.py +2 -2
  76. mplang/{v2/libs → libs}/mpc/vole/silver.py +6 -6
  77. mplang/{v2/runtime → runtime}/interpreter.py +11 -11
  78. mplang/{v2/runtime → runtime}/value.py +2 -2
  79. mplang/{v1/runtime → utils}/__init__.py +18 -15
  80. mplang/{v1/utils → utils}/func_utils.py +1 -1
  81. {mplang_nightly-0.1.dev269.dist-info → mplang_nightly-0.1.dev271.dist-info}/METADATA +2 -2
  82. mplang_nightly-0.1.dev271.dist-info/RECORD +102 -0
  83. mplang/v1/__init__.py +0 -157
  84. mplang/v1/_device.py +0 -602
  85. mplang/v1/analysis/__init__.py +0 -37
  86. mplang/v1/analysis/diagram.py +0 -567
  87. mplang/v1/core/__init__.py +0 -157
  88. mplang/v1/core/cluster.py +0 -343
  89. mplang/v1/core/comm.py +0 -281
  90. mplang/v1/core/context_mgr.py +0 -50
  91. mplang/v1/core/dtypes.py +0 -335
  92. mplang/v1/core/expr/__init__.py +0 -80
  93. mplang/v1/core/expr/ast.py +0 -542
  94. mplang/v1/core/expr/evaluator.py +0 -581
  95. mplang/v1/core/expr/printer.py +0 -285
  96. mplang/v1/core/expr/transformer.py +0 -141
  97. mplang/v1/core/expr/utils.py +0 -78
  98. mplang/v1/core/expr/visitor.py +0 -85
  99. mplang/v1/core/expr/walk.py +0 -387
  100. mplang/v1/core/interp.py +0 -160
  101. mplang/v1/core/mask.py +0 -325
  102. mplang/v1/core/mpir.py +0 -965
  103. mplang/v1/core/mpobject.py +0 -117
  104. mplang/v1/core/mptype.py +0 -407
  105. mplang/v1/core/pfunc.py +0 -130
  106. mplang/v1/core/primitive.py +0 -877
  107. mplang/v1/core/table.py +0 -218
  108. mplang/v1/core/tensor.py +0 -75
  109. mplang/v1/core/tracer.py +0 -383
  110. mplang/v1/host.py +0 -130
  111. mplang/v1/kernels/__init__.py +0 -41
  112. mplang/v1/kernels/base.py +0 -125
  113. mplang/v1/kernels/basic.py +0 -240
  114. mplang/v1/kernels/context.py +0 -369
  115. mplang/v1/kernels/crypto.py +0 -122
  116. mplang/v1/kernels/fhe.py +0 -858
  117. mplang/v1/kernels/mock_tee.py +0 -72
  118. mplang/v1/kernels/phe.py +0 -1864
  119. mplang/v1/kernels/spu.py +0 -341
  120. mplang/v1/kernels/sql_duckdb.py +0 -44
  121. mplang/v1/kernels/stablehlo.py +0 -90
  122. mplang/v1/kernels/value.py +0 -626
  123. mplang/v1/ops/__init__.py +0 -35
  124. mplang/v1/ops/base.py +0 -424
  125. mplang/v1/ops/basic.py +0 -294
  126. mplang/v1/ops/crypto.py +0 -262
  127. mplang/v1/ops/fhe.py +0 -272
  128. mplang/v1/ops/jax_cc.py +0 -147
  129. mplang/v1/ops/nnx_cc.py +0 -168
  130. mplang/v1/ops/phe.py +0 -216
  131. mplang/v1/ops/spu.py +0 -151
  132. mplang/v1/ops/sql_cc.py +0 -303
  133. mplang/v1/ops/tee.py +0 -36
  134. mplang/v1/protos/v1alpha1/mpir_pb2.py +0 -63
  135. mplang/v1/protos/v1alpha1/mpir_pb2.pyi +0 -557
  136. mplang/v1/protos/v1alpha1/value_pb2.py +0 -34
  137. mplang/v1/protos/v1alpha1/value_pb2.pyi +0 -169
  138. mplang/v1/runtime/channel.py +0 -230
  139. mplang/v1/runtime/cli.py +0 -451
  140. mplang/v1/runtime/client.py +0 -456
  141. mplang/v1/runtime/communicator.py +0 -131
  142. mplang/v1/runtime/data_providers.py +0 -303
  143. mplang/v1/runtime/driver.py +0 -324
  144. mplang/v1/runtime/exceptions.py +0 -27
  145. mplang/v1/runtime/http_api.md +0 -56
  146. mplang/v1/runtime/link_comm.py +0 -196
  147. mplang/v1/runtime/server.py +0 -501
  148. mplang/v1/runtime/session.py +0 -270
  149. mplang/v1/runtime/simulation.py +0 -324
  150. mplang/v1/simp/__init__.py +0 -13
  151. mplang/v1/simp/api.py +0 -353
  152. mplang/v1/simp/mpi.py +0 -131
  153. mplang/v1/simp/party.py +0 -225
  154. mplang/v1/simp/random.py +0 -120
  155. mplang/v1/simp/smpc.py +0 -238
  156. mplang/v1/utils/__init__.py +0 -13
  157. mplang/v1/utils/crypto.py +0 -32
  158. mplang/v1/utils/spu_utils.py +0 -130
  159. mplang/v1/utils/table_utils.py +0 -185
  160. mplang/v2/__init__.py +0 -424
  161. mplang_nightly-0.1.dev269.dist-info/RECORD +0 -180
  162. /mplang/{v2/backends → backends}/channel.py +0 -0
  163. /mplang/{v2/edsl → edsl}/README.md +0 -0
  164. /mplang/{v2/edsl → edsl}/registry.py +0 -0
  165. /mplang/{v2/kernels → kernels}/Makefile +0 -0
  166. /mplang/{v2/kernels → kernels}/__init__.py +0 -0
  167. /mplang/{v2/kernels → kernels}/gf128.cpp +0 -0
  168. /mplang/{v2/libs → libs}/device/cluster.py +0 -0
  169. /mplang/{v2/libs → libs}/mpc/analytics/__init__.py +0 -0
  170. /mplang/{v2/libs → libs}/mpc/analytics/groupby.md +0 -0
  171. /mplang/{v2/libs → libs}/mpc/common/constants.py +0 -0
  172. /mplang/{v2/libs → libs}/mpc/ot/__init__.py +0 -0
  173. /mplang/{v2/libs → libs}/mpc/psi/__init__.py +0 -0
  174. /mplang/{v2/libs → libs}/mpc/vole/__init__.py +0 -0
  175. /mplang/{v2/runtime → runtime}/__init__.py +0 -0
  176. /mplang/{v2/runtime → runtime}/dialect_state.py +0 -0
  177. /mplang/{v2/runtime → runtime}/object_store.py +0 -0
  178. {mplang_nightly-0.1.dev269.dist-info → mplang_nightly-0.1.dev271.dist-info}/WHEEL +0 -0
  179. {mplang_nightly-0.1.dev269.dist-info → mplang_nightly-0.1.dev271.dist-info}/entry_points.txt +0 -0
  180. {mplang_nightly-0.1.dev269.dist-info → mplang_nightly-0.1.dev271.dist-info}/licenses/LICENSE +0 -0
mplang/v1/simp/api.py DELETED
@@ -1,353 +0,0 @@
1
- # Copyright 2025 Ant Group Co., Ltd.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
-
15
- from __future__ import annotations
16
-
17
- from collections.abc import Callable
18
- from typing import Any, cast
19
-
20
- from mplang.v1.core import (
21
- Mask,
22
- MPObject,
23
- Rank,
24
- ScalarType,
25
- Shape,
26
- TableLike,
27
- TensorLike,
28
- builtin_function,
29
- peval,
30
- )
31
- from mplang.v1.ops import basic, jax_cc, nnx_cc, sql_cc
32
- from mplang.v1.ops.base import FeOperation
33
-
34
-
35
- def run(
36
- pmask: Mask | None,
37
- fe_op: FeOperation,
38
- *args: Any,
39
- **kwargs: Any,
40
- ) -> Any:
41
- """Run an operation in the current context."""
42
- pfunc, eval_args, out_tree = fe_op(*args, **kwargs)
43
- results = peval(pfunc, eval_args, pmask)
44
- return out_tree.unflatten(results)
45
-
46
-
47
- def run_at(rank: Rank, op: Any, *args: Any, **kwargs: Any) -> Any:
48
- """Run an operation at a specific rank."""
49
- return run(Mask.from_ranks(rank), op, *args, **kwargs)
50
-
51
-
52
- @builtin_function
53
- def prank() -> MPObject:
54
- """Multi-party get the rank (party identifier) of each party.
55
-
56
- This function returns a scalar tensor containing the rank (party identifier)
57
- for each party in the current party mask. Each party independently produces
58
- its own rank value, which serves as a unique identifier within the multi-party
59
- computation context.
60
-
61
- The rank values range from 0 to world_size-1, where world_size is the total
62
- number of parties in the computation. Each party's rank is private to that
63
- party and represents its position in the multi-party protocol.
64
-
65
- Returns:
66
- MPObject: A variable representing a scalar tensor with:
67
- - dtype: UINT64
68
- - shape: () (scalar)
69
-
70
- Note:
71
- Each party in the current party mask independently produces its own rank value.
72
- """
73
- return cast(MPObject, run(None, basic.rank))
74
-
75
-
76
- @builtin_function
77
- def prand(shape: Shape = ()) -> MPObject:
78
- """Multi-party generate a private random (uint64) tensor with the given shape.
79
-
80
- This function creates a private random tensor where each party independently
81
- generates its own local random values. Each party's random values are private
82
- and unknown to other parties. The output tensor contains 64-bit unsigned
83
- integers, with each party holding its own privately generated values.
84
-
85
- Args:
86
- shape: The shape of the random tensor to generate.
87
- Must be a tuple of positive integers. Defaults to () for scalar.
88
-
89
- Returns:
90
- MPObject: A variable representing the generated private random tensor with:
91
- - dtype: UINT64
92
- - shape: As specified by the shape parameter
93
-
94
- Note:
95
- Each party in the current party mask independently generates its own
96
- private random values. The randomness is local to each party and is
97
- not shared or revealed to other parties.
98
- """
99
- return cast(MPObject, run(None, basic.prand, shape))
100
-
101
-
102
- def constant(data: TensorLike | ScalarType | TableLike) -> MPObject:
103
- """Create a constant tensor or table from data.
104
-
105
- This function creates a constant that can be used in multi-party
106
- computations. The constant value is embedded directly into the computation
107
- graph and is available to all parties in the current party mask.
108
-
109
- Args:
110
- data: The constant data to embed. Can be:
111
- - A scalar value (int, float, bool)
112
- - A numpy array or other tensor-like object
113
- - A pandas DataFrame or other table-like object
114
- - Any object that can be converted to tensor
115
-
116
- Returns:
117
- MPObject: A variable representing the constant tensor or table with:
118
- - dtype: Inferred from the input data
119
- - shape: Inferred from the input data (for tensors)
120
- - schema: Inferred from the input data (for tables)
121
- - data: The embedded constant values
122
-
123
- Note:
124
- The constant data is embedded at graph construction time and is available
125
- to all parties during execution. Large constants may impact graph size.
126
-
127
- For table-like objects (e.g., pandas DataFrame), JSON serialization is used.
128
- Note that the constant primitive is not designed to carry large tables efficiently -
129
- consider using dedicated table loading mechanisms for substantial datasets.
130
- """
131
- return cast(MPObject, run(None, basic.constant, data))
132
-
133
-
134
- @builtin_function
135
- def debug_print(obj: MPObject, prefix: str = "") -> MPObject:
136
- """Print local value of obj on owning parties and pass it through.
137
-
138
- This function prints the value of an MPObject at runtime on each party that
139
- owns the value, and returns the same MPObject unchanged. This is useful for
140
- debugging multi-party computations without affecting the computation flow.
141
-
142
- Args:
143
- obj: The MPObject whose value should be printed.
144
- prefix: Optional text prefix for the printed output. Defaults to "".
145
-
146
- Returns:
147
- MPObject: The same MPObject value passed in, unchanged. This allows
148
- the function to be used in chains like: x = debug_print(x, "x=")
149
- and prevents dead code elimination (DCE) from removing the print.
150
-
151
- Note:
152
- The print operation occurs at runtime on each party that holds the value.
153
- If obj has a static pmask, only parties in that mask will print.
154
- If obj has a dynamic pmask, the parties are determined at runtime.
155
- """
156
- pfunc, eval_args, out_tree = basic.debug_print(obj, prefix=prefix)
157
- results = peval(pfunc, eval_args)
158
- return cast(MPObject, out_tree.unflatten(results))
159
-
160
-
161
- def set_mask(arg: MPObject, mask: Mask) -> MPObject:
162
- """Set the mask of an MPObject to a new value.
163
-
164
- This function allows changing the party mask of an existing MPObject variable.
165
- The behavior depends on whether the input MPObject has a dynamic or static pmask:
166
-
167
- **Case 1: Dynamic pmask (arg.pmask is None)**
168
- - The input MPObject has a runtime-determined pmask
169
- - The return value's pmask will be exactly the specified mask
170
- - No validation is performed at compile time
171
-
172
- **Case 2: Static pmask (arg.pmask is not None)**
173
- - If mask is a subset of arg.pmask: return_var.pmask == arg.pmask (unchanged)
174
- - If mask is NOT a subset of arg.pmask: raises ValueError at compile time
175
-
176
- Args:
177
- arg: The MPObject whose mask needs to be changed.
178
- mask: The target mask to apply. Must be a valid party mask.
179
-
180
- Returns:
181
- MPObject: A new variable with the specified mask behavior:
182
- - For dynamic inputs: pmask = mask
183
- - For static inputs (valid subset): pmask = arg.pmask
184
-
185
- Raises:
186
- ValueError: When arg has a static pmask and mask is not a subset of arg.pmask.
187
- This validation occurs at compile time during graph construction.
188
-
189
- Examples:
190
- **Example 1: Dynamic pmask - mask assignment**
191
- P0 P1 P2
192
- -- -- --
193
- Input: ? ? ? (pmask=None, runtime-determined)
194
- mask: [0,2] (target mask)
195
- -----------------------------------------------------------
196
- Output: x0 - x2 (pmask=[0,2])
197
-
198
- **Example 2: Static pmask - valid subset**
199
- P0 P1 P2
200
- -- -- --
201
- Input: x0 x1 x2 (pmask=[0,1,2])
202
- mask: [0,2] (subset of input pmask)
203
- -----------------------------------------------------------
204
- Output: x0 - x2 (pmask=[0,2])
205
-
206
- **Example 3: Static pmask - invalid subset (compile error)**
207
- P0 P1 P2
208
- -- -- --
209
- Input: x0 - x2 (pmask=[0,2])
210
- mask: [1,2] (NOT subset of [0,2])
211
- -----------------------------------------------------------
212
- Result: ValueError at compile time
213
-
214
- Note:
215
- This function is typically used for constraining the execution scope
216
- of variables or for type casting between different pmask contexts.
217
- The underlying implementation uses JAX identity function with the
218
- specified execution mask.
219
- """
220
- pfunc, eval_args, out_tree = basic.identity(arg)
221
- results = peval(pfunc, eval_args, mask)
222
- return cast(MPObject, out_tree.unflatten(results))
223
-
224
-
225
- def run_jax(jax_fn: Callable, *args: Any, **kwargs: Any) -> Any:
226
- """Run a JAX function.
227
-
228
- Args:
229
- jax_fn: The JAX function to be executed.
230
- *args: Positional arguments to pass to the JAX function.
231
- **kwargs: Keyword arguments to pass to the JAX function.
232
-
233
- Returns:
234
- The result of evaluating the JAX function through the mplang system.
235
-
236
- Raises:
237
- TypeError: If the function compilation or evaluation fails.
238
- RuntimeError: If the underlying peval execution encounters errors.
239
-
240
- Notes:
241
- Argument binding semantics with respect to JAX static arguments:
242
-
243
- - If an argument (or any leaf within a PyTree argument) is an
244
- :class:`~mplang.core.mpobject.MPObject`, it is captured as a runtime
245
- variable (dynamic value) in the traced program and is not treated as a
246
- JAX static argument.
247
- - If an argument contains no :class:`MPObject` leaves, it is treated as a
248
- constant configuration with respect to JAX; effectively it behaves
249
- like a static argument and may contribute to JAX compilation cache
250
- keys (similar to ``static_argnums`` semantics). Changing such constant
251
- arguments can lead to different compiled variants/cached entries.
252
-
253
- Examples:
254
- Defining and running a simple JAX function:
255
-
256
- >>> import jax.numpy as jnp
257
- >>> def add_matrices(a, b):
258
- ... return jnp.add(a, b)
259
- >>> result = run_jax(add_matrices, matrix_a, matrix_b)
260
-
261
- Running a more complex JAX function:
262
-
263
- >>> def compute_statistics(data):
264
- ... mean = jnp.mean(data)
265
- ... std = jnp.std(data)
266
- ... return {"mean": mean, "std": std}
267
- >>> stats = run_jax(compute_statistics, dataset)
268
- """
269
- return run(None, jax_cc.run_jax, jax_fn, *args, **kwargs)
270
-
271
-
272
- def run_jax_at(rank: Rank, jax_fn: Callable, *args: Any, **kwargs: Any) -> Any:
273
- return run_at(rank, jax_cc.run_jax, jax_fn, *args, **kwargs)
274
-
275
-
276
- def run_sql(
277
- query: str, out_type: Any, in_tables: dict[str, MPObject] | None = None
278
- ) -> Any:
279
- # TODO(jint): add docstring, drop out_type.
280
- return run(None, sql_cc.run_sql_raw, query, out_type, in_tables)
281
-
282
-
283
- def run_sql_at(
284
- rank: Rank, query: str, out_type: Any, in_tables: dict[str, MPObject] | None = None
285
- ) -> Any:
286
- return run_at(rank, sql_cc.run_sql_raw, query, out_type, in_tables)
287
-
288
-
289
- def run_nnx(nnx_fn: Callable, *args: Any, **kwargs: Any) -> Any:
290
- """Run an NNX function.
291
-
292
- Args:
293
- nnx_fn: The NNX function to be executed.
294
- *args: Positional arguments to pass to the NNX function.
295
- **kwargs: Keyword arguments to pass to the NNX function.
296
-
297
- Returns:
298
- The result of evaluating the NNX function through the mplang system.
299
-
300
- Raises:
301
- TypeError: If the function compilation or evaluation fails.
302
- RuntimeError: If the underlying peval execution encounters errors.
303
-
304
- Notes:
305
- Argument binding semantics with respect to NNX static arguments:
306
-
307
- - If an argument (or any leaf within a PyTree argument) is an
308
- :class:`~mplana.v1.core.mpobject.MPObject`, it is captured as a runtime
309
- variable (dynamic value) in the traced program and is not treated as a
310
- NNX static argument.
311
- - If an argument contains no :class:`MPObject` leaves, it is treated as a
312
- constant configuration with respect to NNX; effectively it behaves
313
- like a static argument and may contribute to NNX compilation cache
314
- keys (similar to ``static_argnums`` semantics). Changing such constant
315
- arguments can lead to different compiled variants/cached entries.
316
-
317
- Examples:
318
- Defining and running a simple NNX function:
319
-
320
- >>> from flax import nnx
321
- >>> import jax.numpy as jnp
322
- >>> def nnx_linear(inputs, weights, bias):
323
- ... return jnp.dot(inputs, weights) + bias
324
- >>> result = run_nnx(nnx_linear, inputs, weights, bias)
325
-
326
- Running an NNX model:
327
-
328
- >>> class LinearModel(nnx.Module):
329
- ... def __init__(self, features: int, rngs: nnx.Rngs):
330
- ... self.linear = nnx.Linear(features, features, rngs=rngs)
331
- ...
332
- ... def __call__(self, x):
333
- ... return self.linear(x)
334
- >>> def forward_pass(model, x):
335
- ... return model(x)
336
- >>> output = run_nnx(forward_pass, model, input_data)
337
- """
338
- return run(None, nnx_cc.run_nnx, nnx_fn, *args, **kwargs)
339
-
340
-
341
- def run_nnx_at(rank: Rank, nnx_fn: Callable, *args: Any, **kwargs: Any) -> Any:
342
- """Run an NNX function at a specific rank.
343
-
344
- Args:
345
- rank: The rank where the NNX function should be executed.
346
- nnx_fn: The NNX function to be executed.
347
- *args: Positional arguments to pass to the NNX function.
348
- **kwargs: Keyword arguments to pass to the NNX function.
349
-
350
- Returns:
351
- The result of evaluating the NNX function at the specified rank.
352
- """
353
- return run_at(rank, nnx_cc.run_nnx, nnx_fn, *args, **kwargs)
mplang/v1/simp/mpi.py DELETED
@@ -1,131 +0,0 @@
1
- # Copyright 2025 Ant Group Co., Ltd.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
-
15
- from __future__ import annotations
16
-
17
- import logging
18
-
19
- from mplang.v1.core import Mask, MPObject, Rank, function, pconv, pshfl_s
20
-
21
-
22
- # scatter :: [m a] -> m Rank -> m a
23
- @function
24
- def scatter_m(to_mask: Mask, root: Rank, args: list[MPObject]) -> MPObject:
25
- """Scatter the object from root to the parties in pmask.
26
-
27
- Args:
28
- to_mask: The mask of the parties that will receive the object.
29
- root: The rank of the root party.
30
- args: The objects to be scattered, which must hold by root and length of pmask'ed parties.
31
- """
32
- # sanity check, ensure all args are in the to_mask.
33
- for arg in args:
34
- if arg.pmask is None:
35
- logging.warning(f"Scattering dynamic {arg} from static root {root}")
36
- else:
37
- if not Mask.from_ranks(root).is_subset(arg.pmask):
38
- raise ValueError(f"Expect root {root} in {arg.pmask}, got {arg}.")
39
-
40
- to_ranks = list(Mask(to_mask))
41
- if len(args) != len(to_ranks):
42
- raise ValueError(f"Expect {len(to_ranks)} args, got {len(args)}. ")
43
-
44
- scattered = [
45
- pshfl_s(arg, Mask.from_ranks(to_rank), [root])
46
- for to_rank, arg in zip(to_ranks, args, strict=False)
47
- ]
48
-
49
- result = pconv(scattered)
50
- assert result.pmask == to_mask, (result.pmask, to_mask)
51
- return result # type: ignore[no-any-return]
52
-
53
-
54
- # gather :: m a -> m Rank -> [m a]
55
- @function
56
- def gather_m(src_mask: Mask, root: Rank, arg: MPObject) -> list[MPObject]:
57
- """Gather the object from pmask'ed parties to the root party.
58
-
59
- Args:
60
- src_mask: The mask of the parties that will gather the object.
61
- root: The rank of the root party.
62
- arg: The object to be gathered. It must be held by all parties specified in `src_mask`.
63
-
64
- Returns:
65
- A list of objects, with length equal to the number of parties in pmask.
66
- """
67
- # static pmask check.
68
- if arg.pmask is None:
69
- logging.warning(f"Gathering {arg} from {src_mask}, may raise RuntimeError.")
70
- else:
71
- if not Mask(src_mask).is_subset(arg.pmask):
72
- raise ValueError(f"Expect {src_mask} in {arg.pmask}, got {arg}.")
73
-
74
- result = []
75
- root_mask = Mask.from_ranks(root)
76
- for src_rank in Mask(src_mask):
77
- # Shuffle data from src_rank to root
78
- gathered_data = pshfl_s(arg, root_mask, [src_rank])
79
- result.append(gathered_data)
80
-
81
- assert len(result) == Mask(src_mask).num_parties(), (result, src_mask)
82
- return result
83
-
84
-
85
- # bcast :: m a -> m Rank -> m a
86
- @function
87
- def bcast_m(pmask: Mask, root: Rank, obj: MPObject) -> MPObject:
88
- """Broadcast the object from the root party to the parties in pmask."""
89
- if obj.pmask is None:
90
- logging.warning(f"Broadcasting {obj} from {root}, may raise RuntimeError.")
91
- else:
92
- if not Mask.from_ranks(root).is_subset(obj.pmask):
93
- raise ValueError(f"Expect root {root} in obj mask {obj.pmask}.")
94
-
95
- result = pshfl_s(obj, pmask, [root] * Mask(pmask).num_parties())
96
-
97
- assert result.pmask == pmask, (result.pmask, pmask)
98
- return result # type: ignore[no-any-return]
99
-
100
-
101
- # p2p :: m Rank -> m Rank -> m a -> m a
102
- @function
103
- def p2p(frm: Rank, to: Rank, obj: MPObject) -> MPObject:
104
- """Point-to-point communication from frm to to."""
105
-
106
- # sanity check, ensure the object is in the frm mask.
107
- if obj.pmask is None:
108
- logging.warning(f"P2P {obj} from {frm} to {to}, may raise RuntimeError.")
109
- else:
110
- if not Mask.from_ranks(frm).is_subset(obj.pmask):
111
- raise ValueError(f"Expect {frm} in {obj.pmask}, got {obj}.")
112
-
113
- if frm == to:
114
- return obj
115
-
116
- return pshfl_s(obj, Mask.from_ranks(to), [frm]) # type: ignore[no-any-return]
117
-
118
-
119
- # allgather :: m a -> [m a]
120
- @function
121
- def allgather_m(pmask: Mask, arg: MPObject) -> list[MPObject]:
122
- """Gather the object from all parties in pmask and return a list of objects."""
123
-
124
- if arg.pmask is None:
125
- logging.warning(f"Allgathering {arg} from {pmask}, may raise RuntimeError.")
126
- else:
127
- if not Mask(pmask).is_subset(arg.pmask):
128
- raise ValueError(f"Expect {pmask} in {arg.pmask}, got {arg}.")
129
-
130
- # TODO(jint): implement me.
131
- raise NotImplementedError("Allgather not implemented")