mplang-nightly 0.1.dev158__py3-none-any.whl → 0.1.dev268__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (191) hide show
  1. mplang/__init__.py +21 -45
  2. mplang/py.typed +13 -0
  3. mplang/v1/__init__.py +157 -0
  4. mplang/v1/_device.py +602 -0
  5. mplang/{analysis → v1/analysis}/__init__.py +1 -1
  6. mplang/{analysis → v1/analysis}/diagram.py +5 -7
  7. mplang/v1/core/__init__.py +157 -0
  8. mplang/{core → v1/core}/cluster.py +30 -14
  9. mplang/{core → v1/core}/comm.py +5 -1
  10. mplang/{core → v1/core}/context_mgr.py +1 -1
  11. mplang/{core/dtype.py → v1/core/dtypes.py} +44 -2
  12. mplang/{core → v1/core}/expr/__init__.py +7 -7
  13. mplang/{core → v1/core}/expr/ast.py +13 -14
  14. mplang/{core → v1/core}/expr/evaluator.py +65 -24
  15. mplang/{core → v1/core}/expr/printer.py +24 -18
  16. mplang/{core → v1/core}/expr/transformer.py +3 -3
  17. mplang/{core → v1/core}/expr/utils.py +2 -2
  18. mplang/{core → v1/core}/expr/visitor.py +1 -1
  19. mplang/{core → v1/core}/expr/walk.py +1 -1
  20. mplang/{core → v1/core}/interp.py +6 -6
  21. mplang/{core → v1/core}/mpir.py +23 -16
  22. mplang/{core → v1/core}/mpobject.py +6 -6
  23. mplang/{core → v1/core}/mptype.py +13 -10
  24. mplang/{core → v1/core}/pfunc.py +4 -4
  25. mplang/{core → v1/core}/primitive.py +106 -201
  26. mplang/{core → v1/core}/table.py +36 -8
  27. mplang/{core → v1/core}/tensor.py +1 -1
  28. mplang/{core → v1/core}/tracer.py +9 -9
  29. mplang/{api.py → v1/host.py} +38 -6
  30. mplang/v1/kernels/__init__.py +41 -0
  31. mplang/{kernels → v1/kernels}/base.py +1 -1
  32. mplang/v1/kernels/basic.py +240 -0
  33. mplang/{kernels → v1/kernels}/context.py +42 -27
  34. mplang/{kernels → v1/kernels}/crypto.py +44 -37
  35. mplang/v1/kernels/fhe.py +858 -0
  36. mplang/{kernels → v1/kernels}/mock_tee.py +12 -13
  37. mplang/{kernels → v1/kernels}/phe.py +263 -57
  38. mplang/{kernels → v1/kernels}/spu.py +137 -48
  39. mplang/{kernels → v1/kernels}/sql_duckdb.py +12 -15
  40. mplang/{kernels → v1/kernels}/stablehlo.py +30 -23
  41. mplang/v1/kernels/value.py +626 -0
  42. mplang/{ops → v1/ops}/__init__.py +5 -16
  43. mplang/{ops → v1/ops}/base.py +2 -5
  44. mplang/{ops/builtin.py → v1/ops/basic.py} +34 -26
  45. mplang/v1/ops/crypto.py +262 -0
  46. mplang/v1/ops/fhe.py +272 -0
  47. mplang/{ops → v1/ops}/jax_cc.py +33 -68
  48. mplang/v1/ops/nnx_cc.py +168 -0
  49. mplang/{ops → v1/ops}/phe.py +16 -4
  50. mplang/{ops → v1/ops}/spu.py +3 -5
  51. mplang/v1/ops/sql_cc.py +303 -0
  52. mplang/{ops → v1/ops}/tee.py +9 -24
  53. mplang/{protos → v1/protos}/v1alpha1/mpir_pb2.pyi +71 -21
  54. mplang/v1/protos/v1alpha1/value_pb2.py +34 -0
  55. mplang/v1/protos/v1alpha1/value_pb2.pyi +169 -0
  56. mplang/{runtime → v1/runtime}/__init__.py +2 -2
  57. mplang/v1/runtime/channel.py +230 -0
  58. mplang/{runtime → v1/runtime}/cli.py +35 -20
  59. mplang/{runtime → v1/runtime}/client.py +19 -8
  60. mplang/{runtime → v1/runtime}/communicator.py +59 -15
  61. mplang/{runtime → v1/runtime}/data_providers.py +80 -19
  62. mplang/{runtime → v1/runtime}/driver.py +30 -12
  63. mplang/v1/runtime/link_comm.py +196 -0
  64. mplang/{runtime → v1/runtime}/server.py +58 -42
  65. mplang/{runtime → v1/runtime}/session.py +57 -71
  66. mplang/{runtime → v1/runtime}/simulation.py +55 -28
  67. mplang/v1/simp/api.py +353 -0
  68. mplang/{simp → v1/simp}/mpi.py +8 -9
  69. mplang/{simp/__init__.py → v1/simp/party.py} +19 -145
  70. mplang/{simp → v1/simp}/random.py +21 -22
  71. mplang/v1/simp/smpc.py +238 -0
  72. mplang/v1/utils/table_utils.py +185 -0
  73. mplang/v2/__init__.py +424 -0
  74. mplang/v2/backends/__init__.py +57 -0
  75. mplang/v2/backends/bfv_impl.py +705 -0
  76. mplang/v2/backends/channel.py +217 -0
  77. mplang/v2/backends/crypto_impl.py +723 -0
  78. mplang/v2/backends/field_impl.py +454 -0
  79. mplang/v2/backends/func_impl.py +107 -0
  80. mplang/v2/backends/phe_impl.py +148 -0
  81. mplang/v2/backends/simp_design.md +136 -0
  82. mplang/v2/backends/simp_driver/__init__.py +41 -0
  83. mplang/v2/backends/simp_driver/http.py +168 -0
  84. mplang/v2/backends/simp_driver/mem.py +280 -0
  85. mplang/v2/backends/simp_driver/ops.py +135 -0
  86. mplang/v2/backends/simp_driver/state.py +60 -0
  87. mplang/v2/backends/simp_driver/values.py +52 -0
  88. mplang/v2/backends/simp_worker/__init__.py +29 -0
  89. mplang/v2/backends/simp_worker/http.py +354 -0
  90. mplang/v2/backends/simp_worker/mem.py +102 -0
  91. mplang/v2/backends/simp_worker/ops.py +167 -0
  92. mplang/v2/backends/simp_worker/state.py +49 -0
  93. mplang/v2/backends/spu_impl.py +275 -0
  94. mplang/v2/backends/spu_state.py +187 -0
  95. mplang/v2/backends/store_impl.py +62 -0
  96. mplang/v2/backends/table_impl.py +838 -0
  97. mplang/v2/backends/tee_impl.py +215 -0
  98. mplang/v2/backends/tensor_impl.py +519 -0
  99. mplang/v2/cli.py +603 -0
  100. mplang/v2/cli_guide.md +122 -0
  101. mplang/v2/dialects/__init__.py +36 -0
  102. mplang/v2/dialects/bfv.py +665 -0
  103. mplang/v2/dialects/crypto.py +689 -0
  104. mplang/v2/dialects/dtypes.py +378 -0
  105. mplang/v2/dialects/field.py +210 -0
  106. mplang/v2/dialects/func.py +135 -0
  107. mplang/v2/dialects/phe.py +723 -0
  108. mplang/v2/dialects/simp.py +944 -0
  109. mplang/v2/dialects/spu.py +349 -0
  110. mplang/v2/dialects/store.py +63 -0
  111. mplang/v2/dialects/table.py +407 -0
  112. mplang/v2/dialects/tee.py +346 -0
  113. mplang/v2/dialects/tensor.py +1175 -0
  114. mplang/v2/edsl/README.md +279 -0
  115. mplang/v2/edsl/__init__.py +99 -0
  116. mplang/v2/edsl/context.py +311 -0
  117. mplang/v2/edsl/graph.py +463 -0
  118. mplang/v2/edsl/jit.py +62 -0
  119. mplang/v2/edsl/object.py +53 -0
  120. mplang/v2/edsl/primitive.py +284 -0
  121. mplang/v2/edsl/printer.py +119 -0
  122. mplang/v2/edsl/registry.py +207 -0
  123. mplang/v2/edsl/serde.py +375 -0
  124. mplang/v2/edsl/tracer.py +614 -0
  125. mplang/v2/edsl/typing.py +816 -0
  126. mplang/v2/kernels/Makefile +30 -0
  127. mplang/v2/kernels/__init__.py +23 -0
  128. mplang/v2/kernels/gf128.cpp +148 -0
  129. mplang/v2/kernels/ldpc.cpp +82 -0
  130. mplang/v2/kernels/okvs.cpp +283 -0
  131. mplang/v2/kernels/okvs_opt.cpp +291 -0
  132. mplang/v2/kernels/py_kernels.py +398 -0
  133. mplang/v2/libs/collective.py +330 -0
  134. mplang/v2/libs/device/__init__.py +51 -0
  135. mplang/v2/libs/device/api.py +813 -0
  136. mplang/v2/libs/device/cluster.py +352 -0
  137. mplang/v2/libs/ml/__init__.py +23 -0
  138. mplang/v2/libs/ml/sgb.py +1861 -0
  139. mplang/v2/libs/mpc/__init__.py +41 -0
  140. mplang/v2/libs/mpc/_utils.py +99 -0
  141. mplang/v2/libs/mpc/analytics/__init__.py +35 -0
  142. mplang/v2/libs/mpc/analytics/aggregation.py +372 -0
  143. mplang/v2/libs/mpc/analytics/groupby.md +99 -0
  144. mplang/v2/libs/mpc/analytics/groupby.py +331 -0
  145. mplang/v2/libs/mpc/analytics/permutation.py +386 -0
  146. mplang/v2/libs/mpc/common/constants.py +39 -0
  147. mplang/v2/libs/mpc/ot/__init__.py +32 -0
  148. mplang/v2/libs/mpc/ot/base.py +222 -0
  149. mplang/v2/libs/mpc/ot/extension.py +477 -0
  150. mplang/v2/libs/mpc/ot/silent.py +217 -0
  151. mplang/v2/libs/mpc/psi/__init__.py +40 -0
  152. mplang/v2/libs/mpc/psi/cuckoo.py +228 -0
  153. mplang/v2/libs/mpc/psi/okvs.py +49 -0
  154. mplang/v2/libs/mpc/psi/okvs_gct.py +79 -0
  155. mplang/v2/libs/mpc/psi/oprf.py +310 -0
  156. mplang/v2/libs/mpc/psi/rr22.py +344 -0
  157. mplang/v2/libs/mpc/psi/unbalanced.py +200 -0
  158. mplang/v2/libs/mpc/vole/__init__.py +31 -0
  159. mplang/v2/libs/mpc/vole/gilboa.py +327 -0
  160. mplang/v2/libs/mpc/vole/ldpc.py +383 -0
  161. mplang/v2/libs/mpc/vole/silver.py +336 -0
  162. mplang/v2/runtime/__init__.py +15 -0
  163. mplang/v2/runtime/dialect_state.py +41 -0
  164. mplang/v2/runtime/interpreter.py +871 -0
  165. mplang/v2/runtime/object_store.py +194 -0
  166. mplang/v2/runtime/value.py +141 -0
  167. {mplang_nightly-0.1.dev158.dist-info → mplang_nightly-0.1.dev268.dist-info}/METADATA +24 -17
  168. mplang_nightly-0.1.dev268.dist-info/RECORD +180 -0
  169. {mplang_nightly-0.1.dev158.dist-info → mplang_nightly-0.1.dev268.dist-info}/WHEEL +1 -1
  170. mplang/core/__init__.py +0 -92
  171. mplang/device.py +0 -340
  172. mplang/kernels/builtin.py +0 -207
  173. mplang/ops/crypto.py +0 -109
  174. mplang/ops/ibis_cc.py +0 -139
  175. mplang/ops/sql.py +0 -61
  176. mplang/protos/v1alpha1/mpir_pb2_grpc.py +0 -3
  177. mplang/runtime/link_comm.py +0 -131
  178. mplang/simp/smpc.py +0 -201
  179. mplang/utils/table_utils.py +0 -73
  180. mplang_nightly-0.1.dev158.dist-info/RECORD +0 -77
  181. /mplang/{core → v1/core}/mask.py +0 -0
  182. /mplang/{protos → v1/protos}/v1alpha1/mpir_pb2.py +0 -0
  183. /mplang/{runtime → v1/runtime}/exceptions.py +0 -0
  184. /mplang/{runtime → v1/runtime}/http_api.md +0 -0
  185. /mplang/{kernels → v1/simp}/__init__.py +0 -0
  186. /mplang/{utils → v1/utils}/__init__.py +0 -0
  187. /mplang/{utils → v1/utils}/crypto.py +0 -0
  188. /mplang/{utils → v1/utils}/func_utils.py +0 -0
  189. /mplang/{utils → v1/utils}/spu_utils.py +0 -0
  190. {mplang_nightly-0.1.dev158.dist-info → mplang_nightly-0.1.dev268.dist-info}/entry_points.txt +0 -0
  191. {mplang_nightly-0.1.dev158.dist-info → mplang_nightly-0.1.dev268.dist-info}/licenses/LICENSE +0 -0
mplang/v1/simp/api.py ADDED
@@ -0,0 +1,353 @@
1
+ # Copyright 2025 Ant Group Co., Ltd.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from __future__ import annotations
16
+
17
+ from collections.abc import Callable
18
+ from typing import Any, cast
19
+
20
+ from mplang.v1.core import (
21
+ Mask,
22
+ MPObject,
23
+ Rank,
24
+ ScalarType,
25
+ Shape,
26
+ TableLike,
27
+ TensorLike,
28
+ builtin_function,
29
+ peval,
30
+ )
31
+ from mplang.v1.ops import basic, jax_cc, nnx_cc, sql_cc
32
+ from mplang.v1.ops.base import FeOperation
33
+
34
+
35
+ def run(
36
+ pmask: Mask | None,
37
+ fe_op: FeOperation,
38
+ *args: Any,
39
+ **kwargs: Any,
40
+ ) -> Any:
41
+ """Run an operation in the current context."""
42
+ pfunc, eval_args, out_tree = fe_op(*args, **kwargs)
43
+ results = peval(pfunc, eval_args, pmask)
44
+ return out_tree.unflatten(results)
45
+
46
+
47
+ def run_at(rank: Rank, op: Any, *args: Any, **kwargs: Any) -> Any:
48
+ """Run an operation at a specific rank."""
49
+ return run(Mask.from_ranks(rank), op, *args, **kwargs)
50
+
51
+
52
+ @builtin_function
53
+ def prank() -> MPObject:
54
+ """Multi-party get the rank (party identifier) of each party.
55
+
56
+ This function returns a scalar tensor containing the rank (party identifier)
57
+ for each party in the current party mask. Each party independently produces
58
+ its own rank value, which serves as a unique identifier within the multi-party
59
+ computation context.
60
+
61
+ The rank values range from 0 to world_size-1, where world_size is the total
62
+ number of parties in the computation. Each party's rank is private to that
63
+ party and represents its position in the multi-party protocol.
64
+
65
+ Returns:
66
+ MPObject: A variable representing a scalar tensor with:
67
+ - dtype: UINT64
68
+ - shape: () (scalar)
69
+
70
+ Note:
71
+ Each party in the current party mask independently produces its own rank value.
72
+ """
73
+ return cast(MPObject, run(None, basic.rank))
74
+
75
+
76
+ @builtin_function
77
+ def prand(shape: Shape = ()) -> MPObject:
78
+ """Multi-party generate a private random (uint64) tensor with the given shape.
79
+
80
+ This function creates a private random tensor where each party independently
81
+ generates its own local random values. Each party's random values are private
82
+ and unknown to other parties. The output tensor contains 64-bit unsigned
83
+ integers, with each party holding its own privately generated values.
84
+
85
+ Args:
86
+ shape: The shape of the random tensor to generate.
87
+ Must be a tuple of positive integers. Defaults to () for scalar.
88
+
89
+ Returns:
90
+ MPObject: A variable representing the generated private random tensor with:
91
+ - dtype: UINT64
92
+ - shape: As specified by the shape parameter
93
+
94
+ Note:
95
+ Each party in the current party mask independently generates its own
96
+ private random values. The randomness is local to each party and is
97
+ not shared or revealed to other parties.
98
+ """
99
+ return cast(MPObject, run(None, basic.prand, shape))
100
+
101
+
102
+ def constant(data: TensorLike | ScalarType | TableLike) -> MPObject:
103
+ """Create a constant tensor or table from data.
104
+
105
+ This function creates a constant that can be used in multi-party
106
+ computations. The constant value is embedded directly into the computation
107
+ graph and is available to all parties in the current party mask.
108
+
109
+ Args:
110
+ data: The constant data to embed. Can be:
111
+ - A scalar value (int, float, bool)
112
+ - A numpy array or other tensor-like object
113
+ - A pandas DataFrame or other table-like object
114
+ - Any object that can be converted to tensor
115
+
116
+ Returns:
117
+ MPObject: A variable representing the constant tensor or table with:
118
+ - dtype: Inferred from the input data
119
+ - shape: Inferred from the input data (for tensors)
120
+ - schema: Inferred from the input data (for tables)
121
+ - data: The embedded constant values
122
+
123
+ Note:
124
+ The constant data is embedded at graph construction time and is available
125
+ to all parties during execution. Large constants may impact graph size.
126
+
127
+ For table-like objects (e.g., pandas DataFrame), JSON serialization is used.
128
+ Note that the constant primitive is not designed to carry large tables efficiently -
129
+ consider using dedicated table loading mechanisms for substantial datasets.
130
+ """
131
+ return cast(MPObject, run(None, basic.constant, data))
132
+
133
+
134
+ @builtin_function
135
+ def debug_print(obj: MPObject, prefix: str = "") -> MPObject:
136
+ """Print local value of obj on owning parties and pass it through.
137
+
138
+ This function prints the value of an MPObject at runtime on each party that
139
+ owns the value, and returns the same MPObject unchanged. This is useful for
140
+ debugging multi-party computations without affecting the computation flow.
141
+
142
+ Args:
143
+ obj: The MPObject whose value should be printed.
144
+ prefix: Optional text prefix for the printed output. Defaults to "".
145
+
146
+ Returns:
147
+ MPObject: The same MPObject value passed in, unchanged. This allows
148
+ the function to be used in chains like: x = debug_print(x, "x=")
149
+ and prevents dead code elimination (DCE) from removing the print.
150
+
151
+ Note:
152
+ The print operation occurs at runtime on each party that holds the value.
153
+ If obj has a static pmask, only parties in that mask will print.
154
+ If obj has a dynamic pmask, the parties are determined at runtime.
155
+ """
156
+ pfunc, eval_args, out_tree = basic.debug_print(obj, prefix=prefix)
157
+ results = peval(pfunc, eval_args)
158
+ return cast(MPObject, out_tree.unflatten(results))
159
+
160
+
161
+ def set_mask(arg: MPObject, mask: Mask) -> MPObject:
162
+ """Set the mask of an MPObject to a new value.
163
+
164
+ This function allows changing the party mask of an existing MPObject variable.
165
+ The behavior depends on whether the input MPObject has a dynamic or static pmask:
166
+
167
+ **Case 1: Dynamic pmask (arg.pmask is None)**
168
+ - The input MPObject has a runtime-determined pmask
169
+ - The return value's pmask will be exactly the specified mask
170
+ - No validation is performed at compile time
171
+
172
+ **Case 2: Static pmask (arg.pmask is not None)**
173
+ - If mask is a subset of arg.pmask: return_var.pmask == arg.pmask (unchanged)
174
+ - If mask is NOT a subset of arg.pmask: raises ValueError at compile time
175
+
176
+ Args:
177
+ arg: The MPObject whose mask needs to be changed.
178
+ mask: The target mask to apply. Must be a valid party mask.
179
+
180
+ Returns:
181
+ MPObject: A new variable with the specified mask behavior:
182
+ - For dynamic inputs: pmask = mask
183
+ - For static inputs (valid subset): pmask = arg.pmask
184
+
185
+ Raises:
186
+ ValueError: When arg has a static pmask and mask is not a subset of arg.pmask.
187
+ This validation occurs at compile time during graph construction.
188
+
189
+ Examples:
190
+ **Example 1: Dynamic pmask - mask assignment**
191
+ P0 P1 P2
192
+ -- -- --
193
+ Input: ? ? ? (pmask=None, runtime-determined)
194
+ mask: [0,2] (target mask)
195
+ -----------------------------------------------------------
196
+ Output: x0 - x2 (pmask=[0,2])
197
+
198
+ **Example 2: Static pmask - valid subset**
199
+ P0 P1 P2
200
+ -- -- --
201
+ Input: x0 x1 x2 (pmask=[0,1,2])
202
+ mask: [0,2] (subset of input pmask)
203
+ -----------------------------------------------------------
204
+ Output: x0 - x2 (pmask=[0,2])
205
+
206
+ **Example 3: Static pmask - invalid subset (compile error)**
207
+ P0 P1 P2
208
+ -- -- --
209
+ Input: x0 - x2 (pmask=[0,2])
210
+ mask: [1,2] (NOT subset of [0,2])
211
+ -----------------------------------------------------------
212
+ Result: ValueError at compile time
213
+
214
+ Note:
215
+ This function is typically used for constraining the execution scope
216
+ of variables or for type casting between different pmask contexts.
217
+ The underlying implementation uses JAX identity function with the
218
+ specified execution mask.
219
+ """
220
+ pfunc, eval_args, out_tree = basic.identity(arg)
221
+ results = peval(pfunc, eval_args, mask)
222
+ return cast(MPObject, out_tree.unflatten(results))
223
+
224
+
225
+ def run_jax(jax_fn: Callable, *args: Any, **kwargs: Any) -> Any:
226
+ """Run a JAX function.
227
+
228
+ Args:
229
+ jax_fn: The JAX function to be executed.
230
+ *args: Positional arguments to pass to the JAX function.
231
+ **kwargs: Keyword arguments to pass to the JAX function.
232
+
233
+ Returns:
234
+ The result of evaluating the JAX function through the mplang system.
235
+
236
+ Raises:
237
+ TypeError: If the function compilation or evaluation fails.
238
+ RuntimeError: If the underlying peval execution encounters errors.
239
+
240
+ Notes:
241
+ Argument binding semantics with respect to JAX static arguments:
242
+
243
+ - If an argument (or any leaf within a PyTree argument) is an
244
+ :class:`~mplang.core.mpobject.MPObject`, it is captured as a runtime
245
+ variable (dynamic value) in the traced program and is not treated as a
246
+ JAX static argument.
247
+ - If an argument contains no :class:`MPObject` leaves, it is treated as a
248
+ constant configuration with respect to JAX; effectively it behaves
249
+ like a static argument and may contribute to JAX compilation cache
250
+ keys (similar to ``static_argnums`` semantics). Changing such constant
251
+ arguments can lead to different compiled variants/cached entries.
252
+
253
+ Examples:
254
+ Defining and running a simple JAX function:
255
+
256
+ >>> import jax.numpy as jnp
257
+ >>> def add_matrices(a, b):
258
+ ... return jnp.add(a, b)
259
+ >>> result = run_jax(add_matrices, matrix_a, matrix_b)
260
+
261
+ Running a more complex JAX function:
262
+
263
+ >>> def compute_statistics(data):
264
+ ... mean = jnp.mean(data)
265
+ ... std = jnp.std(data)
266
+ ... return {"mean": mean, "std": std}
267
+ >>> stats = run_jax(compute_statistics, dataset)
268
+ """
269
+ return run(None, jax_cc.run_jax, jax_fn, *args, **kwargs)
270
+
271
+
272
+ def run_jax_at(rank: Rank, jax_fn: Callable, *args: Any, **kwargs: Any) -> Any:
273
+ return run_at(rank, jax_cc.run_jax, jax_fn, *args, **kwargs)
274
+
275
+
276
+ def run_sql(
277
+ query: str, out_type: Any, in_tables: dict[str, MPObject] | None = None
278
+ ) -> Any:
279
+ # TODO(jint): add docstring, drop out_type.
280
+ return run(None, sql_cc.run_sql_raw, query, out_type, in_tables)
281
+
282
+
283
+ def run_sql_at(
284
+ rank: Rank, query: str, out_type: Any, in_tables: dict[str, MPObject] | None = None
285
+ ) -> Any:
286
+ return run_at(rank, sql_cc.run_sql_raw, query, out_type, in_tables)
287
+
288
+
289
+ def run_nnx(nnx_fn: Callable, *args: Any, **kwargs: Any) -> Any:
290
+ """Run an NNX function.
291
+
292
+ Args:
293
+ nnx_fn: The NNX function to be executed.
294
+ *args: Positional arguments to pass to the NNX function.
295
+ **kwargs: Keyword arguments to pass to the NNX function.
296
+
297
+ Returns:
298
+ The result of evaluating the NNX function through the mplang system.
299
+
300
+ Raises:
301
+ TypeError: If the function compilation or evaluation fails.
302
+ RuntimeError: If the underlying peval execution encounters errors.
303
+
304
+ Notes:
305
+ Argument binding semantics with respect to NNX static arguments:
306
+
307
+ - If an argument (or any leaf within a PyTree argument) is an
308
+ :class:`~mplana.v1.core.mpobject.MPObject`, it is captured as a runtime
309
+ variable (dynamic value) in the traced program and is not treated as a
310
+ NNX static argument.
311
+ - If an argument contains no :class:`MPObject` leaves, it is treated as a
312
+ constant configuration with respect to NNX; effectively it behaves
313
+ like a static argument and may contribute to NNX compilation cache
314
+ keys (similar to ``static_argnums`` semantics). Changing such constant
315
+ arguments can lead to different compiled variants/cached entries.
316
+
317
+ Examples:
318
+ Defining and running a simple NNX function:
319
+
320
+ >>> from flax import nnx
321
+ >>> import jax.numpy as jnp
322
+ >>> def nnx_linear(inputs, weights, bias):
323
+ ... return jnp.dot(inputs, weights) + bias
324
+ >>> result = run_nnx(nnx_linear, inputs, weights, bias)
325
+
326
+ Running an NNX model:
327
+
328
+ >>> class LinearModel(nnx.Module):
329
+ ... def __init__(self, features: int, rngs: nnx.Rngs):
330
+ ... self.linear = nnx.Linear(features, features, rngs=rngs)
331
+ ...
332
+ ... def __call__(self, x):
333
+ ... return self.linear(x)
334
+ >>> def forward_pass(model, x):
335
+ ... return model(x)
336
+ >>> output = run_nnx(forward_pass, model, input_data)
337
+ """
338
+ return run(None, nnx_cc.run_nnx, nnx_fn, *args, **kwargs)
339
+
340
+
341
+ def run_nnx_at(rank: Rank, nnx_fn: Callable, *args: Any, **kwargs: Any) -> Any:
342
+ """Run an NNX function at a specific rank.
343
+
344
+ Args:
345
+ rank: The rank where the NNX function should be executed.
346
+ nnx_fn: The NNX function to be executed.
347
+ *args: Positional arguments to pass to the NNX function.
348
+ **kwargs: Keyword arguments to pass to the NNX function.
349
+
350
+ Returns:
351
+ The result of evaluating the NNX function at the specified rank.
352
+ """
353
+ return run_at(rank, nnx_cc.run_nnx, nnx_fn, *args, **kwargs)
@@ -16,8 +16,7 @@ from __future__ import annotations
16
16
 
17
17
  import logging
18
18
 
19
- import mplang.core.primitive as prim
20
- from mplang.core import Mask, MPObject, Rank, function
19
+ from mplang.v1.core import Mask, MPObject, Rank, function, pconv, pshfl_s
21
20
 
22
21
 
23
22
  # scatter :: [m a] -> m Rank -> m a
@@ -43,11 +42,11 @@ def scatter_m(to_mask: Mask, root: Rank, args: list[MPObject]) -> MPObject:
43
42
  raise ValueError(f"Expect {len(to_ranks)} args, got {len(args)}. ")
44
43
 
45
44
  scattered = [
46
- prim.pshfl_s(arg, Mask.from_ranks(to_rank), [root])
45
+ pshfl_s(arg, Mask.from_ranks(to_rank), [root])
47
46
  for to_rank, arg in zip(to_ranks, args, strict=False)
48
47
  ]
49
48
 
50
- result = prim.pconv(scattered)
49
+ result = pconv(scattered)
51
50
  assert result.pmask == to_mask, (result.pmask, to_mask)
52
51
  return result # type: ignore[no-any-return]
53
52
 
@@ -58,9 +57,9 @@ def gather_m(src_mask: Mask, root: Rank, arg: MPObject) -> list[MPObject]:
58
57
  """Gather the object from pmask'ed parties to the root party.
59
58
 
60
59
  Args:
61
- src_pmask: The mask of the parties that will gather the object.
60
+ src_mask: The mask of the parties that will gather the object.
62
61
  root: The rank of the root party.
63
- arg: The object to be gathered, which must be the subset of pmask.
62
+ arg: The object to be gathered. It must be held by all parties specified in `src_mask`.
64
63
 
65
64
  Returns:
66
65
  A list of objects, with length equal to the number of parties in pmask.
@@ -76,7 +75,7 @@ def gather_m(src_mask: Mask, root: Rank, arg: MPObject) -> list[MPObject]:
76
75
  root_mask = Mask.from_ranks(root)
77
76
  for src_rank in Mask(src_mask):
78
77
  # Shuffle data from src_rank to root
79
- gathered_data = prim.pshfl_s(arg, root_mask, [src_rank])
78
+ gathered_data = pshfl_s(arg, root_mask, [src_rank])
80
79
  result.append(gathered_data)
81
80
 
82
81
  assert len(result) == Mask(src_mask).num_parties(), (result, src_mask)
@@ -93,7 +92,7 @@ def bcast_m(pmask: Mask, root: Rank, obj: MPObject) -> MPObject:
93
92
  if not Mask.from_ranks(root).is_subset(obj.pmask):
94
93
  raise ValueError(f"Expect root {root} in obj mask {obj.pmask}.")
95
94
 
96
- result = prim.pshfl_s(obj, pmask, [root] * Mask(pmask).num_parties())
95
+ result = pshfl_s(obj, pmask, [root] * Mask(pmask).num_parties())
97
96
 
98
97
  assert result.pmask == pmask, (result.pmask, pmask)
99
98
  return result # type: ignore[no-any-return]
@@ -114,7 +113,7 @@ def p2p(frm: Rank, to: Rank, obj: MPObject) -> MPObject:
114
113
  if frm == to:
115
114
  return obj
116
115
 
117
- return prim.pshfl_s(obj, Mask.from_ranks(to), [frm]) # type: ignore[no-any-return]
116
+ return pshfl_s(obj, Mask.from_ranks(to), [frm]) # type: ignore[no-any-return]
118
117
 
119
118
 
120
119
  # allgather :: m a -> [m a]
@@ -18,144 +18,13 @@ import importlib
18
18
  import pathlib
19
19
  import pkgutil
20
20
  from collections.abc import Callable
21
- from functools import partial, wraps
21
+ from functools import wraps
22
22
  from types import ModuleType
23
23
  from typing import Any
24
24
 
25
- from mplang.core.mask import Mask
26
- from mplang.core.mpobject import MPObject
27
- from mplang.core.mptype import Rank
28
- from mplang.core.primitive import (
29
- constant,
30
- pconv,
31
- peval,
32
- prand,
33
- prank,
34
- pshfl,
35
- pshfl_s,
36
- uniform_cond,
37
- while_loop,
38
- )
39
- from mplang.ops import ibis_cc, jax_cc
40
- from mplang.ops.base import FeOperation
41
- from mplang.simp.mpi import allgather_m, bcast_m, gather_m, p2p, scatter_m
42
- from mplang.simp.random import key_split, pperm, prandint, ukey, urandint
43
- from mplang.simp.smpc import reveal, revealTo, seal, sealFrom, srun
44
-
45
- # Public exports of the simplified party execution API.
46
- # NOTE: Replaces previous internal __reexport__ (not a Python convention)
47
- # to make star-imports explicit and tooling-friendly.
48
- __all__ = [ # noqa: RUF022
49
- "MPObject",
50
- "P",
51
- "P0",
52
- "P1",
53
- "P2",
54
- "P2P",
55
- "Party",
56
- "allgather_m",
57
- "bcast_m",
58
- "constant",
59
- "gather_m",
60
- "key_split",
61
- "load_module",
62
- "p2p",
63
- "pconv",
64
- "peval",
65
- "pperm",
66
- "prand",
67
- "prandint",
68
- "prank",
69
- "pshfl",
70
- "pshfl_s",
71
- "reveal",
72
- "revealTo",
73
- "run",
74
- "runAt",
75
- "scatter_m",
76
- "seal",
77
- "sealFrom",
78
- "srun",
79
- "ukey",
80
- "uniform_cond",
81
- "urandint",
82
- "while_loop",
83
- ]
84
-
85
-
86
- def run_impl(
87
- pmask: Mask | None,
88
- func: Callable,
89
- *args: Any,
90
- **kwargs: Any,
91
- ) -> Any:
92
- """
93
- Run a function that can be evaluated by the mplang system.
94
-
95
- This function provides a dispatch mechanism based on the first argument
96
- to route different function types to appropriate handlers.
97
-
98
- Args:
99
- pmask: The party mask of this function, None indicates auto deduce parties.
100
- func: The function to be dispatched and executed
101
- *args: Positional arguments to pass to the function
102
- **kwargs: Keyword arguments to pass to the function
103
-
104
- Returns:
105
- The result of evaluating the function through the appropriate handler
106
-
107
- Raises:
108
- ValueError: If builtin.write is called without required arguments
109
- TypeError: If the function compilation or evaluation fails
110
- RuntimeError: If the underlying peval execution encounters errors
111
-
112
- Examples:
113
- Reading data from a file:
114
-
115
- >>> tensor_info = TensorType(shape=(10, 10), dtype=np.float32)
116
- >>> attrs = {"format": "binary"}
117
- >>> result = run_impl(builtin.read, "data/input.bin", tensor_info, attrs)
118
-
119
- Writing data to a file:
120
-
121
- >>> run_impl(builtin.write, data, "data/output.bin")
122
-
123
- Running a JAX function:
124
-
125
- >>> def matrix_multiply(a, b):
126
- ... return jnp.dot(a, b)
127
- >>> result = run_impl(matrix_multiply, mat_a, mat_b)
128
-
129
- Running a custom computation function:
130
-
131
- >>> def compute_statistics(data):
132
- ... mean = jnp.mean(data)
133
- ... std = jnp.std(data)
134
- ... return {"mean": mean, "std": std}
135
- >>> stats = run_impl(compute_statistics, dataset)
136
- """
137
-
138
- if isinstance(func, FeOperation):
139
- pfunc, eval_args, out_tree = func(*args, **kwargs)
140
- else:
141
- if ibis_cc.is_ibis_function(func):
142
- pfunc, eval_args, out_tree = ibis_cc.ibis_compile(func, *args, **kwargs)
143
- else:
144
- # unknown python callable, treat it as jax function
145
- pfunc, eval_args, out_tree = jax_cc.jax_compile(func, *args, **kwargs)
146
- results = peval(pfunc, eval_args, pmask)
147
- return out_tree.unflatten(results)
148
-
149
-
150
- # run :: (a -> a) -> m a -> m a
151
- def run(pyfn: Callable) -> Callable:
152
- return partial(run_impl, None, pyfn)
153
-
154
-
155
- # runAt :: Rank -> (a -> a) -> m a -> m a
156
- def runAt(rank: Rank, pyfn: Callable) -> Callable:
157
- pmask = Mask.from_ranks(rank)
158
- return partial(run_impl, pmask, pyfn)
25
+ from mplang.v1.ops.base import FeOperation
26
+ from mplang.v1.simp.api import run_at, run_jax_at
27
+ from mplang.v1.simp.mpi import p2p
159
28
 
160
29
 
161
30
  def P2P(src: Party, dst: Party, value: Any) -> Any:
@@ -229,22 +98,22 @@ class _PartyModuleProxy:
229
98
 
230
99
  def __getattr__(self, item: str) -> Callable[..., Any]:
231
100
  self._ensure()
232
- target = getattr(self._module, item)
233
- if not callable(target):
101
+ op = getattr(self._module, item)
102
+ if not callable(op):
234
103
  raise AttributeError(
235
- f"Attribute '{item}' of party module '{self._name}' is not callable (got {type(target).__name__})"
104
+ f"Attribute '{item}' of party module '{self._name}' is not callable (got {type(op).__name__})"
236
105
  )
237
106
 
238
- @wraps(target)
107
+ @wraps(op)
239
108
  def _wrapped(*args: Any, **kw: Any) -> Any:
240
109
  # Inline runAt to reduce an extra partial layer while preserving semantics.
241
- return run_impl(Mask.from_ranks(self._party.rank), target, *args, **kw)
110
+ return run_at(self._party.rank, op, *args, **kw)
242
111
 
243
112
  # Provide a party-qualified name for debugging / logs without losing original metadata.
244
- base_name = getattr(target, "__name__", None)
113
+ base_name = getattr(op, "__name__", None)
245
114
  if base_name is None:
246
115
  # Frontend FeOperation or object without __name__; try .name attribute (FeOperation contract) or fallback to repr
247
- base_name = getattr(target, "name", None) or type(target).__name__
116
+ base_name = getattr(op, "name", None) or type(op).__name__
248
117
  try:
249
118
  _wrapped.__name__ = f"{base_name}@P{self._party.rank}"
250
119
  except Exception: # pragma: no cover - assignment may fail for exotic wrappers
@@ -264,7 +133,12 @@ class Party:
264
133
  raise TypeError(
265
134
  f"First argument to Party({self.rank}) must be callable, got {fn!r}"
266
135
  )
267
- return runAt(self.rank, fn)(*args, **kwargs)
136
+ # Use run_op_at for FeOperation, run_jax_at for plain callables
137
+ if isinstance(fn, FeOperation):
138
+ return run_at(self.rank, fn, *args, **kwargs)
139
+ else:
140
+ # TODO(jint): implicitly assume non-FeOperation as JAX function is a bit too magical?
141
+ return run_jax_at(self.rank, fn, *args, **kwargs)
268
142
 
269
143
  def __getattr__(self, name: str) -> _PartyModuleProxy:
270
144
  if name in _NAMESPACE_REGISTRY:
@@ -289,7 +163,7 @@ def _load_prelude_modules() -> None:
289
163
  unwieldy we can switch to an allowlist.
290
164
  """
291
165
  try:
292
- import mplang.ops as _fe # type: ignore
166
+ import mplang.v1.ops as _fe # type: ignore
293
167
  except (ImportError, ModuleNotFoundError): # pragma: no cover
294
168
  # Frontend package not present (minimal install); safe to skip.
295
169
  return
@@ -299,7 +173,7 @@ def _load_prelude_modules() -> None:
299
173
  if m.name.startswith("_"):
300
174
  continue
301
175
  if m.name not in _NAMESPACE_REGISTRY:
302
- _NAMESPACE_REGISTRY[m.name] = f"mplang.ops.{m.name}"
176
+ _NAMESPACE_REGISTRY[m.name] = f"mplang.v1.ops.{m.name}"
303
177
 
304
178
 
305
179
  def load_module(module: str, alias: str | None = None) -> None: