mplang-nightly 0.1.dev158__py3-none-any.whl → 0.1.dev268__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mplang/__init__.py +21 -45
- mplang/py.typed +13 -0
- mplang/v1/__init__.py +157 -0
- mplang/v1/_device.py +602 -0
- mplang/{analysis → v1/analysis}/__init__.py +1 -1
- mplang/{analysis → v1/analysis}/diagram.py +5 -7
- mplang/v1/core/__init__.py +157 -0
- mplang/{core → v1/core}/cluster.py +30 -14
- mplang/{core → v1/core}/comm.py +5 -1
- mplang/{core → v1/core}/context_mgr.py +1 -1
- mplang/{core/dtype.py → v1/core/dtypes.py} +44 -2
- mplang/{core → v1/core}/expr/__init__.py +7 -7
- mplang/{core → v1/core}/expr/ast.py +13 -14
- mplang/{core → v1/core}/expr/evaluator.py +65 -24
- mplang/{core → v1/core}/expr/printer.py +24 -18
- mplang/{core → v1/core}/expr/transformer.py +3 -3
- mplang/{core → v1/core}/expr/utils.py +2 -2
- mplang/{core → v1/core}/expr/visitor.py +1 -1
- mplang/{core → v1/core}/expr/walk.py +1 -1
- mplang/{core → v1/core}/interp.py +6 -6
- mplang/{core → v1/core}/mpir.py +23 -16
- mplang/{core → v1/core}/mpobject.py +6 -6
- mplang/{core → v1/core}/mptype.py +13 -10
- mplang/{core → v1/core}/pfunc.py +4 -4
- mplang/{core → v1/core}/primitive.py +106 -201
- mplang/{core → v1/core}/table.py +36 -8
- mplang/{core → v1/core}/tensor.py +1 -1
- mplang/{core → v1/core}/tracer.py +9 -9
- mplang/{api.py → v1/host.py} +38 -6
- mplang/v1/kernels/__init__.py +41 -0
- mplang/{kernels → v1/kernels}/base.py +1 -1
- mplang/v1/kernels/basic.py +240 -0
- mplang/{kernels → v1/kernels}/context.py +42 -27
- mplang/{kernels → v1/kernels}/crypto.py +44 -37
- mplang/v1/kernels/fhe.py +858 -0
- mplang/{kernels → v1/kernels}/mock_tee.py +12 -13
- mplang/{kernels → v1/kernels}/phe.py +263 -57
- mplang/{kernels → v1/kernels}/spu.py +137 -48
- mplang/{kernels → v1/kernels}/sql_duckdb.py +12 -15
- mplang/{kernels → v1/kernels}/stablehlo.py +30 -23
- mplang/v1/kernels/value.py +626 -0
- mplang/{ops → v1/ops}/__init__.py +5 -16
- mplang/{ops → v1/ops}/base.py +2 -5
- mplang/{ops/builtin.py → v1/ops/basic.py} +34 -26
- mplang/v1/ops/crypto.py +262 -0
- mplang/v1/ops/fhe.py +272 -0
- mplang/{ops → v1/ops}/jax_cc.py +33 -68
- mplang/v1/ops/nnx_cc.py +168 -0
- mplang/{ops → v1/ops}/phe.py +16 -4
- mplang/{ops → v1/ops}/spu.py +3 -5
- mplang/v1/ops/sql_cc.py +303 -0
- mplang/{ops → v1/ops}/tee.py +9 -24
- mplang/{protos → v1/protos}/v1alpha1/mpir_pb2.pyi +71 -21
- mplang/v1/protos/v1alpha1/value_pb2.py +34 -0
- mplang/v1/protos/v1alpha1/value_pb2.pyi +169 -0
- mplang/{runtime → v1/runtime}/__init__.py +2 -2
- mplang/v1/runtime/channel.py +230 -0
- mplang/{runtime → v1/runtime}/cli.py +35 -20
- mplang/{runtime → v1/runtime}/client.py +19 -8
- mplang/{runtime → v1/runtime}/communicator.py +59 -15
- mplang/{runtime → v1/runtime}/data_providers.py +80 -19
- mplang/{runtime → v1/runtime}/driver.py +30 -12
- mplang/v1/runtime/link_comm.py +196 -0
- mplang/{runtime → v1/runtime}/server.py +58 -42
- mplang/{runtime → v1/runtime}/session.py +57 -71
- mplang/{runtime → v1/runtime}/simulation.py +55 -28
- mplang/v1/simp/api.py +353 -0
- mplang/{simp → v1/simp}/mpi.py +8 -9
- mplang/{simp/__init__.py → v1/simp/party.py} +19 -145
- mplang/{simp → v1/simp}/random.py +21 -22
- mplang/v1/simp/smpc.py +238 -0
- mplang/v1/utils/table_utils.py +185 -0
- mplang/v2/__init__.py +424 -0
- mplang/v2/backends/__init__.py +57 -0
- mplang/v2/backends/bfv_impl.py +705 -0
- mplang/v2/backends/channel.py +217 -0
- mplang/v2/backends/crypto_impl.py +723 -0
- mplang/v2/backends/field_impl.py +454 -0
- mplang/v2/backends/func_impl.py +107 -0
- mplang/v2/backends/phe_impl.py +148 -0
- mplang/v2/backends/simp_design.md +136 -0
- mplang/v2/backends/simp_driver/__init__.py +41 -0
- mplang/v2/backends/simp_driver/http.py +168 -0
- mplang/v2/backends/simp_driver/mem.py +280 -0
- mplang/v2/backends/simp_driver/ops.py +135 -0
- mplang/v2/backends/simp_driver/state.py +60 -0
- mplang/v2/backends/simp_driver/values.py +52 -0
- mplang/v2/backends/simp_worker/__init__.py +29 -0
- mplang/v2/backends/simp_worker/http.py +354 -0
- mplang/v2/backends/simp_worker/mem.py +102 -0
- mplang/v2/backends/simp_worker/ops.py +167 -0
- mplang/v2/backends/simp_worker/state.py +49 -0
- mplang/v2/backends/spu_impl.py +275 -0
- mplang/v2/backends/spu_state.py +187 -0
- mplang/v2/backends/store_impl.py +62 -0
- mplang/v2/backends/table_impl.py +838 -0
- mplang/v2/backends/tee_impl.py +215 -0
- mplang/v2/backends/tensor_impl.py +519 -0
- mplang/v2/cli.py +603 -0
- mplang/v2/cli_guide.md +122 -0
- mplang/v2/dialects/__init__.py +36 -0
- mplang/v2/dialects/bfv.py +665 -0
- mplang/v2/dialects/crypto.py +689 -0
- mplang/v2/dialects/dtypes.py +378 -0
- mplang/v2/dialects/field.py +210 -0
- mplang/v2/dialects/func.py +135 -0
- mplang/v2/dialects/phe.py +723 -0
- mplang/v2/dialects/simp.py +944 -0
- mplang/v2/dialects/spu.py +349 -0
- mplang/v2/dialects/store.py +63 -0
- mplang/v2/dialects/table.py +407 -0
- mplang/v2/dialects/tee.py +346 -0
- mplang/v2/dialects/tensor.py +1175 -0
- mplang/v2/edsl/README.md +279 -0
- mplang/v2/edsl/__init__.py +99 -0
- mplang/v2/edsl/context.py +311 -0
- mplang/v2/edsl/graph.py +463 -0
- mplang/v2/edsl/jit.py +62 -0
- mplang/v2/edsl/object.py +53 -0
- mplang/v2/edsl/primitive.py +284 -0
- mplang/v2/edsl/printer.py +119 -0
- mplang/v2/edsl/registry.py +207 -0
- mplang/v2/edsl/serde.py +375 -0
- mplang/v2/edsl/tracer.py +614 -0
- mplang/v2/edsl/typing.py +816 -0
- mplang/v2/kernels/Makefile +30 -0
- mplang/v2/kernels/__init__.py +23 -0
- mplang/v2/kernels/gf128.cpp +148 -0
- mplang/v2/kernels/ldpc.cpp +82 -0
- mplang/v2/kernels/okvs.cpp +283 -0
- mplang/v2/kernels/okvs_opt.cpp +291 -0
- mplang/v2/kernels/py_kernels.py +398 -0
- mplang/v2/libs/collective.py +330 -0
- mplang/v2/libs/device/__init__.py +51 -0
- mplang/v2/libs/device/api.py +813 -0
- mplang/v2/libs/device/cluster.py +352 -0
- mplang/v2/libs/ml/__init__.py +23 -0
- mplang/v2/libs/ml/sgb.py +1861 -0
- mplang/v2/libs/mpc/__init__.py +41 -0
- mplang/v2/libs/mpc/_utils.py +99 -0
- mplang/v2/libs/mpc/analytics/__init__.py +35 -0
- mplang/v2/libs/mpc/analytics/aggregation.py +372 -0
- mplang/v2/libs/mpc/analytics/groupby.md +99 -0
- mplang/v2/libs/mpc/analytics/groupby.py +331 -0
- mplang/v2/libs/mpc/analytics/permutation.py +386 -0
- mplang/v2/libs/mpc/common/constants.py +39 -0
- mplang/v2/libs/mpc/ot/__init__.py +32 -0
- mplang/v2/libs/mpc/ot/base.py +222 -0
- mplang/v2/libs/mpc/ot/extension.py +477 -0
- mplang/v2/libs/mpc/ot/silent.py +217 -0
- mplang/v2/libs/mpc/psi/__init__.py +40 -0
- mplang/v2/libs/mpc/psi/cuckoo.py +228 -0
- mplang/v2/libs/mpc/psi/okvs.py +49 -0
- mplang/v2/libs/mpc/psi/okvs_gct.py +79 -0
- mplang/v2/libs/mpc/psi/oprf.py +310 -0
- mplang/v2/libs/mpc/psi/rr22.py +344 -0
- mplang/v2/libs/mpc/psi/unbalanced.py +200 -0
- mplang/v2/libs/mpc/vole/__init__.py +31 -0
- mplang/v2/libs/mpc/vole/gilboa.py +327 -0
- mplang/v2/libs/mpc/vole/ldpc.py +383 -0
- mplang/v2/libs/mpc/vole/silver.py +336 -0
- mplang/v2/runtime/__init__.py +15 -0
- mplang/v2/runtime/dialect_state.py +41 -0
- mplang/v2/runtime/interpreter.py +871 -0
- mplang/v2/runtime/object_store.py +194 -0
- mplang/v2/runtime/value.py +141 -0
- {mplang_nightly-0.1.dev158.dist-info → mplang_nightly-0.1.dev268.dist-info}/METADATA +24 -17
- mplang_nightly-0.1.dev268.dist-info/RECORD +180 -0
- {mplang_nightly-0.1.dev158.dist-info → mplang_nightly-0.1.dev268.dist-info}/WHEEL +1 -1
- mplang/core/__init__.py +0 -92
- mplang/device.py +0 -340
- mplang/kernels/builtin.py +0 -207
- mplang/ops/crypto.py +0 -109
- mplang/ops/ibis_cc.py +0 -139
- mplang/ops/sql.py +0 -61
- mplang/protos/v1alpha1/mpir_pb2_grpc.py +0 -3
- mplang/runtime/link_comm.py +0 -131
- mplang/simp/smpc.py +0 -201
- mplang/utils/table_utils.py +0 -73
- mplang_nightly-0.1.dev158.dist-info/RECORD +0 -77
- /mplang/{core → v1/core}/mask.py +0 -0
- /mplang/{protos → v1/protos}/v1alpha1/mpir_pb2.py +0 -0
- /mplang/{runtime → v1/runtime}/exceptions.py +0 -0
- /mplang/{runtime → v1/runtime}/http_api.md +0 -0
- /mplang/{kernels → v1/simp}/__init__.py +0 -0
- /mplang/{utils → v1/utils}/__init__.py +0 -0
- /mplang/{utils → v1/utils}/crypto.py +0 -0
- /mplang/{utils → v1/utils}/func_utils.py +0 -0
- /mplang/{utils → v1/utils}/spu_utils.py +0 -0
- {mplang_nightly-0.1.dev158.dist-info → mplang_nightly-0.1.dev268.dist-info}/entry_points.txt +0 -0
- {mplang_nightly-0.1.dev158.dist-info → mplang_nightly-0.1.dev268.dist-info}/licenses/LICENSE +0 -0
mplang/v1/simp/api.py
ADDED
|
@@ -0,0 +1,353 @@
|
|
|
1
|
+
# Copyright 2025 Ant Group Co., Ltd.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
from __future__ import annotations
|
|
16
|
+
|
|
17
|
+
from collections.abc import Callable
|
|
18
|
+
from typing import Any, cast
|
|
19
|
+
|
|
20
|
+
from mplang.v1.core import (
|
|
21
|
+
Mask,
|
|
22
|
+
MPObject,
|
|
23
|
+
Rank,
|
|
24
|
+
ScalarType,
|
|
25
|
+
Shape,
|
|
26
|
+
TableLike,
|
|
27
|
+
TensorLike,
|
|
28
|
+
builtin_function,
|
|
29
|
+
peval,
|
|
30
|
+
)
|
|
31
|
+
from mplang.v1.ops import basic, jax_cc, nnx_cc, sql_cc
|
|
32
|
+
from mplang.v1.ops.base import FeOperation
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
def run(
|
|
36
|
+
pmask: Mask | None,
|
|
37
|
+
fe_op: FeOperation,
|
|
38
|
+
*args: Any,
|
|
39
|
+
**kwargs: Any,
|
|
40
|
+
) -> Any:
|
|
41
|
+
"""Run an operation in the current context."""
|
|
42
|
+
pfunc, eval_args, out_tree = fe_op(*args, **kwargs)
|
|
43
|
+
results = peval(pfunc, eval_args, pmask)
|
|
44
|
+
return out_tree.unflatten(results)
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
def run_at(rank: Rank, op: Any, *args: Any, **kwargs: Any) -> Any:
|
|
48
|
+
"""Run an operation at a specific rank."""
|
|
49
|
+
return run(Mask.from_ranks(rank), op, *args, **kwargs)
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
@builtin_function
|
|
53
|
+
def prank() -> MPObject:
|
|
54
|
+
"""Multi-party get the rank (party identifier) of each party.
|
|
55
|
+
|
|
56
|
+
This function returns a scalar tensor containing the rank (party identifier)
|
|
57
|
+
for each party in the current party mask. Each party independently produces
|
|
58
|
+
its own rank value, which serves as a unique identifier within the multi-party
|
|
59
|
+
computation context.
|
|
60
|
+
|
|
61
|
+
The rank values range from 0 to world_size-1, where world_size is the total
|
|
62
|
+
number of parties in the computation. Each party's rank is private to that
|
|
63
|
+
party and represents its position in the multi-party protocol.
|
|
64
|
+
|
|
65
|
+
Returns:
|
|
66
|
+
MPObject: A variable representing a scalar tensor with:
|
|
67
|
+
- dtype: UINT64
|
|
68
|
+
- shape: () (scalar)
|
|
69
|
+
|
|
70
|
+
Note:
|
|
71
|
+
Each party in the current party mask independently produces its own rank value.
|
|
72
|
+
"""
|
|
73
|
+
return cast(MPObject, run(None, basic.rank))
|
|
74
|
+
|
|
75
|
+
|
|
76
|
+
@builtin_function
|
|
77
|
+
def prand(shape: Shape = ()) -> MPObject:
|
|
78
|
+
"""Multi-party generate a private random (uint64) tensor with the given shape.
|
|
79
|
+
|
|
80
|
+
This function creates a private random tensor where each party independently
|
|
81
|
+
generates its own local random values. Each party's random values are private
|
|
82
|
+
and unknown to other parties. The output tensor contains 64-bit unsigned
|
|
83
|
+
integers, with each party holding its own privately generated values.
|
|
84
|
+
|
|
85
|
+
Args:
|
|
86
|
+
shape: The shape of the random tensor to generate.
|
|
87
|
+
Must be a tuple of positive integers. Defaults to () for scalar.
|
|
88
|
+
|
|
89
|
+
Returns:
|
|
90
|
+
MPObject: A variable representing the generated private random tensor with:
|
|
91
|
+
- dtype: UINT64
|
|
92
|
+
- shape: As specified by the shape parameter
|
|
93
|
+
|
|
94
|
+
Note:
|
|
95
|
+
Each party in the current party mask independently generates its own
|
|
96
|
+
private random values. The randomness is local to each party and is
|
|
97
|
+
not shared or revealed to other parties.
|
|
98
|
+
"""
|
|
99
|
+
return cast(MPObject, run(None, basic.prand, shape))
|
|
100
|
+
|
|
101
|
+
|
|
102
|
+
def constant(data: TensorLike | ScalarType | TableLike) -> MPObject:
|
|
103
|
+
"""Create a constant tensor or table from data.
|
|
104
|
+
|
|
105
|
+
This function creates a constant that can be used in multi-party
|
|
106
|
+
computations. The constant value is embedded directly into the computation
|
|
107
|
+
graph and is available to all parties in the current party mask.
|
|
108
|
+
|
|
109
|
+
Args:
|
|
110
|
+
data: The constant data to embed. Can be:
|
|
111
|
+
- A scalar value (int, float, bool)
|
|
112
|
+
- A numpy array or other tensor-like object
|
|
113
|
+
- A pandas DataFrame or other table-like object
|
|
114
|
+
- Any object that can be converted to tensor
|
|
115
|
+
|
|
116
|
+
Returns:
|
|
117
|
+
MPObject: A variable representing the constant tensor or table with:
|
|
118
|
+
- dtype: Inferred from the input data
|
|
119
|
+
- shape: Inferred from the input data (for tensors)
|
|
120
|
+
- schema: Inferred from the input data (for tables)
|
|
121
|
+
- data: The embedded constant values
|
|
122
|
+
|
|
123
|
+
Note:
|
|
124
|
+
The constant data is embedded at graph construction time and is available
|
|
125
|
+
to all parties during execution. Large constants may impact graph size.
|
|
126
|
+
|
|
127
|
+
For table-like objects (e.g., pandas DataFrame), JSON serialization is used.
|
|
128
|
+
Note that the constant primitive is not designed to carry large tables efficiently -
|
|
129
|
+
consider using dedicated table loading mechanisms for substantial datasets.
|
|
130
|
+
"""
|
|
131
|
+
return cast(MPObject, run(None, basic.constant, data))
|
|
132
|
+
|
|
133
|
+
|
|
134
|
+
@builtin_function
|
|
135
|
+
def debug_print(obj: MPObject, prefix: str = "") -> MPObject:
|
|
136
|
+
"""Print local value of obj on owning parties and pass it through.
|
|
137
|
+
|
|
138
|
+
This function prints the value of an MPObject at runtime on each party that
|
|
139
|
+
owns the value, and returns the same MPObject unchanged. This is useful for
|
|
140
|
+
debugging multi-party computations without affecting the computation flow.
|
|
141
|
+
|
|
142
|
+
Args:
|
|
143
|
+
obj: The MPObject whose value should be printed.
|
|
144
|
+
prefix: Optional text prefix for the printed output. Defaults to "".
|
|
145
|
+
|
|
146
|
+
Returns:
|
|
147
|
+
MPObject: The same MPObject value passed in, unchanged. This allows
|
|
148
|
+
the function to be used in chains like: x = debug_print(x, "x=")
|
|
149
|
+
and prevents dead code elimination (DCE) from removing the print.
|
|
150
|
+
|
|
151
|
+
Note:
|
|
152
|
+
The print operation occurs at runtime on each party that holds the value.
|
|
153
|
+
If obj has a static pmask, only parties in that mask will print.
|
|
154
|
+
If obj has a dynamic pmask, the parties are determined at runtime.
|
|
155
|
+
"""
|
|
156
|
+
pfunc, eval_args, out_tree = basic.debug_print(obj, prefix=prefix)
|
|
157
|
+
results = peval(pfunc, eval_args)
|
|
158
|
+
return cast(MPObject, out_tree.unflatten(results))
|
|
159
|
+
|
|
160
|
+
|
|
161
|
+
def set_mask(arg: MPObject, mask: Mask) -> MPObject:
|
|
162
|
+
"""Set the mask of an MPObject to a new value.
|
|
163
|
+
|
|
164
|
+
This function allows changing the party mask of an existing MPObject variable.
|
|
165
|
+
The behavior depends on whether the input MPObject has a dynamic or static pmask:
|
|
166
|
+
|
|
167
|
+
**Case 1: Dynamic pmask (arg.pmask is None)**
|
|
168
|
+
- The input MPObject has a runtime-determined pmask
|
|
169
|
+
- The return value's pmask will be exactly the specified mask
|
|
170
|
+
- No validation is performed at compile time
|
|
171
|
+
|
|
172
|
+
**Case 2: Static pmask (arg.pmask is not None)**
|
|
173
|
+
- If mask is a subset of arg.pmask: return_var.pmask == arg.pmask (unchanged)
|
|
174
|
+
- If mask is NOT a subset of arg.pmask: raises ValueError at compile time
|
|
175
|
+
|
|
176
|
+
Args:
|
|
177
|
+
arg: The MPObject whose mask needs to be changed.
|
|
178
|
+
mask: The target mask to apply. Must be a valid party mask.
|
|
179
|
+
|
|
180
|
+
Returns:
|
|
181
|
+
MPObject: A new variable with the specified mask behavior:
|
|
182
|
+
- For dynamic inputs: pmask = mask
|
|
183
|
+
- For static inputs (valid subset): pmask = arg.pmask
|
|
184
|
+
|
|
185
|
+
Raises:
|
|
186
|
+
ValueError: When arg has a static pmask and mask is not a subset of arg.pmask.
|
|
187
|
+
This validation occurs at compile time during graph construction.
|
|
188
|
+
|
|
189
|
+
Examples:
|
|
190
|
+
**Example 1: Dynamic pmask - mask assignment**
|
|
191
|
+
P0 P1 P2
|
|
192
|
+
-- -- --
|
|
193
|
+
Input: ? ? ? (pmask=None, runtime-determined)
|
|
194
|
+
mask: [0,2] (target mask)
|
|
195
|
+
-----------------------------------------------------------
|
|
196
|
+
Output: x0 - x2 (pmask=[0,2])
|
|
197
|
+
|
|
198
|
+
**Example 2: Static pmask - valid subset**
|
|
199
|
+
P0 P1 P2
|
|
200
|
+
-- -- --
|
|
201
|
+
Input: x0 x1 x2 (pmask=[0,1,2])
|
|
202
|
+
mask: [0,2] (subset of input pmask)
|
|
203
|
+
-----------------------------------------------------------
|
|
204
|
+
Output: x0 - x2 (pmask=[0,2])
|
|
205
|
+
|
|
206
|
+
**Example 3: Static pmask - invalid subset (compile error)**
|
|
207
|
+
P0 P1 P2
|
|
208
|
+
-- -- --
|
|
209
|
+
Input: x0 - x2 (pmask=[0,2])
|
|
210
|
+
mask: [1,2] (NOT subset of [0,2])
|
|
211
|
+
-----------------------------------------------------------
|
|
212
|
+
Result: ValueError at compile time
|
|
213
|
+
|
|
214
|
+
Note:
|
|
215
|
+
This function is typically used for constraining the execution scope
|
|
216
|
+
of variables or for type casting between different pmask contexts.
|
|
217
|
+
The underlying implementation uses JAX identity function with the
|
|
218
|
+
specified execution mask.
|
|
219
|
+
"""
|
|
220
|
+
pfunc, eval_args, out_tree = basic.identity(arg)
|
|
221
|
+
results = peval(pfunc, eval_args, mask)
|
|
222
|
+
return cast(MPObject, out_tree.unflatten(results))
|
|
223
|
+
|
|
224
|
+
|
|
225
|
+
def run_jax(jax_fn: Callable, *args: Any, **kwargs: Any) -> Any:
|
|
226
|
+
"""Run a JAX function.
|
|
227
|
+
|
|
228
|
+
Args:
|
|
229
|
+
jax_fn: The JAX function to be executed.
|
|
230
|
+
*args: Positional arguments to pass to the JAX function.
|
|
231
|
+
**kwargs: Keyword arguments to pass to the JAX function.
|
|
232
|
+
|
|
233
|
+
Returns:
|
|
234
|
+
The result of evaluating the JAX function through the mplang system.
|
|
235
|
+
|
|
236
|
+
Raises:
|
|
237
|
+
TypeError: If the function compilation or evaluation fails.
|
|
238
|
+
RuntimeError: If the underlying peval execution encounters errors.
|
|
239
|
+
|
|
240
|
+
Notes:
|
|
241
|
+
Argument binding semantics with respect to JAX static arguments:
|
|
242
|
+
|
|
243
|
+
- If an argument (or any leaf within a PyTree argument) is an
|
|
244
|
+
:class:`~mplang.core.mpobject.MPObject`, it is captured as a runtime
|
|
245
|
+
variable (dynamic value) in the traced program and is not treated as a
|
|
246
|
+
JAX static argument.
|
|
247
|
+
- If an argument contains no :class:`MPObject` leaves, it is treated as a
|
|
248
|
+
constant configuration with respect to JAX; effectively it behaves
|
|
249
|
+
like a static argument and may contribute to JAX compilation cache
|
|
250
|
+
keys (similar to ``static_argnums`` semantics). Changing such constant
|
|
251
|
+
arguments can lead to different compiled variants/cached entries.
|
|
252
|
+
|
|
253
|
+
Examples:
|
|
254
|
+
Defining and running a simple JAX function:
|
|
255
|
+
|
|
256
|
+
>>> import jax.numpy as jnp
|
|
257
|
+
>>> def add_matrices(a, b):
|
|
258
|
+
... return jnp.add(a, b)
|
|
259
|
+
>>> result = run_jax(add_matrices, matrix_a, matrix_b)
|
|
260
|
+
|
|
261
|
+
Running a more complex JAX function:
|
|
262
|
+
|
|
263
|
+
>>> def compute_statistics(data):
|
|
264
|
+
... mean = jnp.mean(data)
|
|
265
|
+
... std = jnp.std(data)
|
|
266
|
+
... return {"mean": mean, "std": std}
|
|
267
|
+
>>> stats = run_jax(compute_statistics, dataset)
|
|
268
|
+
"""
|
|
269
|
+
return run(None, jax_cc.run_jax, jax_fn, *args, **kwargs)
|
|
270
|
+
|
|
271
|
+
|
|
272
|
+
def run_jax_at(rank: Rank, jax_fn: Callable, *args: Any, **kwargs: Any) -> Any:
|
|
273
|
+
return run_at(rank, jax_cc.run_jax, jax_fn, *args, **kwargs)
|
|
274
|
+
|
|
275
|
+
|
|
276
|
+
def run_sql(
|
|
277
|
+
query: str, out_type: Any, in_tables: dict[str, MPObject] | None = None
|
|
278
|
+
) -> Any:
|
|
279
|
+
# TODO(jint): add docstring, drop out_type.
|
|
280
|
+
return run(None, sql_cc.run_sql_raw, query, out_type, in_tables)
|
|
281
|
+
|
|
282
|
+
|
|
283
|
+
def run_sql_at(
|
|
284
|
+
rank: Rank, query: str, out_type: Any, in_tables: dict[str, MPObject] | None = None
|
|
285
|
+
) -> Any:
|
|
286
|
+
return run_at(rank, sql_cc.run_sql_raw, query, out_type, in_tables)
|
|
287
|
+
|
|
288
|
+
|
|
289
|
+
def run_nnx(nnx_fn: Callable, *args: Any, **kwargs: Any) -> Any:
|
|
290
|
+
"""Run an NNX function.
|
|
291
|
+
|
|
292
|
+
Args:
|
|
293
|
+
nnx_fn: The NNX function to be executed.
|
|
294
|
+
*args: Positional arguments to pass to the NNX function.
|
|
295
|
+
**kwargs: Keyword arguments to pass to the NNX function.
|
|
296
|
+
|
|
297
|
+
Returns:
|
|
298
|
+
The result of evaluating the NNX function through the mplang system.
|
|
299
|
+
|
|
300
|
+
Raises:
|
|
301
|
+
TypeError: If the function compilation or evaluation fails.
|
|
302
|
+
RuntimeError: If the underlying peval execution encounters errors.
|
|
303
|
+
|
|
304
|
+
Notes:
|
|
305
|
+
Argument binding semantics with respect to NNX static arguments:
|
|
306
|
+
|
|
307
|
+
- If an argument (or any leaf within a PyTree argument) is an
|
|
308
|
+
:class:`~mplana.v1.core.mpobject.MPObject`, it is captured as a runtime
|
|
309
|
+
variable (dynamic value) in the traced program and is not treated as a
|
|
310
|
+
NNX static argument.
|
|
311
|
+
- If an argument contains no :class:`MPObject` leaves, it is treated as a
|
|
312
|
+
constant configuration with respect to NNX; effectively it behaves
|
|
313
|
+
like a static argument and may contribute to NNX compilation cache
|
|
314
|
+
keys (similar to ``static_argnums`` semantics). Changing such constant
|
|
315
|
+
arguments can lead to different compiled variants/cached entries.
|
|
316
|
+
|
|
317
|
+
Examples:
|
|
318
|
+
Defining and running a simple NNX function:
|
|
319
|
+
|
|
320
|
+
>>> from flax import nnx
|
|
321
|
+
>>> import jax.numpy as jnp
|
|
322
|
+
>>> def nnx_linear(inputs, weights, bias):
|
|
323
|
+
... return jnp.dot(inputs, weights) + bias
|
|
324
|
+
>>> result = run_nnx(nnx_linear, inputs, weights, bias)
|
|
325
|
+
|
|
326
|
+
Running an NNX model:
|
|
327
|
+
|
|
328
|
+
>>> class LinearModel(nnx.Module):
|
|
329
|
+
... def __init__(self, features: int, rngs: nnx.Rngs):
|
|
330
|
+
... self.linear = nnx.Linear(features, features, rngs=rngs)
|
|
331
|
+
...
|
|
332
|
+
... def __call__(self, x):
|
|
333
|
+
... return self.linear(x)
|
|
334
|
+
>>> def forward_pass(model, x):
|
|
335
|
+
... return model(x)
|
|
336
|
+
>>> output = run_nnx(forward_pass, model, input_data)
|
|
337
|
+
"""
|
|
338
|
+
return run(None, nnx_cc.run_nnx, nnx_fn, *args, **kwargs)
|
|
339
|
+
|
|
340
|
+
|
|
341
|
+
def run_nnx_at(rank: Rank, nnx_fn: Callable, *args: Any, **kwargs: Any) -> Any:
|
|
342
|
+
"""Run an NNX function at a specific rank.
|
|
343
|
+
|
|
344
|
+
Args:
|
|
345
|
+
rank: The rank where the NNX function should be executed.
|
|
346
|
+
nnx_fn: The NNX function to be executed.
|
|
347
|
+
*args: Positional arguments to pass to the NNX function.
|
|
348
|
+
**kwargs: Keyword arguments to pass to the NNX function.
|
|
349
|
+
|
|
350
|
+
Returns:
|
|
351
|
+
The result of evaluating the NNX function at the specified rank.
|
|
352
|
+
"""
|
|
353
|
+
return run_at(rank, nnx_cc.run_nnx, nnx_fn, *args, **kwargs)
|
mplang/{simp → v1/simp}/mpi.py
RENAMED
|
@@ -16,8 +16,7 @@ from __future__ import annotations
|
|
|
16
16
|
|
|
17
17
|
import logging
|
|
18
18
|
|
|
19
|
-
|
|
20
|
-
from mplang.core import Mask, MPObject, Rank, function
|
|
19
|
+
from mplang.v1.core import Mask, MPObject, Rank, function, pconv, pshfl_s
|
|
21
20
|
|
|
22
21
|
|
|
23
22
|
# scatter :: [m a] -> m Rank -> m a
|
|
@@ -43,11 +42,11 @@ def scatter_m(to_mask: Mask, root: Rank, args: list[MPObject]) -> MPObject:
|
|
|
43
42
|
raise ValueError(f"Expect {len(to_ranks)} args, got {len(args)}. ")
|
|
44
43
|
|
|
45
44
|
scattered = [
|
|
46
|
-
|
|
45
|
+
pshfl_s(arg, Mask.from_ranks(to_rank), [root])
|
|
47
46
|
for to_rank, arg in zip(to_ranks, args, strict=False)
|
|
48
47
|
]
|
|
49
48
|
|
|
50
|
-
result =
|
|
49
|
+
result = pconv(scattered)
|
|
51
50
|
assert result.pmask == to_mask, (result.pmask, to_mask)
|
|
52
51
|
return result # type: ignore[no-any-return]
|
|
53
52
|
|
|
@@ -58,9 +57,9 @@ def gather_m(src_mask: Mask, root: Rank, arg: MPObject) -> list[MPObject]:
|
|
|
58
57
|
"""Gather the object from pmask'ed parties to the root party.
|
|
59
58
|
|
|
60
59
|
Args:
|
|
61
|
-
|
|
60
|
+
src_mask: The mask of the parties that will gather the object.
|
|
62
61
|
root: The rank of the root party.
|
|
63
|
-
arg: The object to be gathered
|
|
62
|
+
arg: The object to be gathered. It must be held by all parties specified in `src_mask`.
|
|
64
63
|
|
|
65
64
|
Returns:
|
|
66
65
|
A list of objects, with length equal to the number of parties in pmask.
|
|
@@ -76,7 +75,7 @@ def gather_m(src_mask: Mask, root: Rank, arg: MPObject) -> list[MPObject]:
|
|
|
76
75
|
root_mask = Mask.from_ranks(root)
|
|
77
76
|
for src_rank in Mask(src_mask):
|
|
78
77
|
# Shuffle data from src_rank to root
|
|
79
|
-
gathered_data =
|
|
78
|
+
gathered_data = pshfl_s(arg, root_mask, [src_rank])
|
|
80
79
|
result.append(gathered_data)
|
|
81
80
|
|
|
82
81
|
assert len(result) == Mask(src_mask).num_parties(), (result, src_mask)
|
|
@@ -93,7 +92,7 @@ def bcast_m(pmask: Mask, root: Rank, obj: MPObject) -> MPObject:
|
|
|
93
92
|
if not Mask.from_ranks(root).is_subset(obj.pmask):
|
|
94
93
|
raise ValueError(f"Expect root {root} in obj mask {obj.pmask}.")
|
|
95
94
|
|
|
96
|
-
result =
|
|
95
|
+
result = pshfl_s(obj, pmask, [root] * Mask(pmask).num_parties())
|
|
97
96
|
|
|
98
97
|
assert result.pmask == pmask, (result.pmask, pmask)
|
|
99
98
|
return result # type: ignore[no-any-return]
|
|
@@ -114,7 +113,7 @@ def p2p(frm: Rank, to: Rank, obj: MPObject) -> MPObject:
|
|
|
114
113
|
if frm == to:
|
|
115
114
|
return obj
|
|
116
115
|
|
|
117
|
-
return
|
|
116
|
+
return pshfl_s(obj, Mask.from_ranks(to), [frm]) # type: ignore[no-any-return]
|
|
118
117
|
|
|
119
118
|
|
|
120
119
|
# allgather :: m a -> [m a]
|
|
@@ -18,144 +18,13 @@ import importlib
|
|
|
18
18
|
import pathlib
|
|
19
19
|
import pkgutil
|
|
20
20
|
from collections.abc import Callable
|
|
21
|
-
from functools import
|
|
21
|
+
from functools import wraps
|
|
22
22
|
from types import ModuleType
|
|
23
23
|
from typing import Any
|
|
24
24
|
|
|
25
|
-
from mplang.
|
|
26
|
-
from mplang.
|
|
27
|
-
from mplang.
|
|
28
|
-
from mplang.core.primitive import (
|
|
29
|
-
constant,
|
|
30
|
-
pconv,
|
|
31
|
-
peval,
|
|
32
|
-
prand,
|
|
33
|
-
prank,
|
|
34
|
-
pshfl,
|
|
35
|
-
pshfl_s,
|
|
36
|
-
uniform_cond,
|
|
37
|
-
while_loop,
|
|
38
|
-
)
|
|
39
|
-
from mplang.ops import ibis_cc, jax_cc
|
|
40
|
-
from mplang.ops.base import FeOperation
|
|
41
|
-
from mplang.simp.mpi import allgather_m, bcast_m, gather_m, p2p, scatter_m
|
|
42
|
-
from mplang.simp.random import key_split, pperm, prandint, ukey, urandint
|
|
43
|
-
from mplang.simp.smpc import reveal, revealTo, seal, sealFrom, srun
|
|
44
|
-
|
|
45
|
-
# Public exports of the simplified party execution API.
|
|
46
|
-
# NOTE: Replaces previous internal __reexport__ (not a Python convention)
|
|
47
|
-
# to make star-imports explicit and tooling-friendly.
|
|
48
|
-
__all__ = [ # noqa: RUF022
|
|
49
|
-
"MPObject",
|
|
50
|
-
"P",
|
|
51
|
-
"P0",
|
|
52
|
-
"P1",
|
|
53
|
-
"P2",
|
|
54
|
-
"P2P",
|
|
55
|
-
"Party",
|
|
56
|
-
"allgather_m",
|
|
57
|
-
"bcast_m",
|
|
58
|
-
"constant",
|
|
59
|
-
"gather_m",
|
|
60
|
-
"key_split",
|
|
61
|
-
"load_module",
|
|
62
|
-
"p2p",
|
|
63
|
-
"pconv",
|
|
64
|
-
"peval",
|
|
65
|
-
"pperm",
|
|
66
|
-
"prand",
|
|
67
|
-
"prandint",
|
|
68
|
-
"prank",
|
|
69
|
-
"pshfl",
|
|
70
|
-
"pshfl_s",
|
|
71
|
-
"reveal",
|
|
72
|
-
"revealTo",
|
|
73
|
-
"run",
|
|
74
|
-
"runAt",
|
|
75
|
-
"scatter_m",
|
|
76
|
-
"seal",
|
|
77
|
-
"sealFrom",
|
|
78
|
-
"srun",
|
|
79
|
-
"ukey",
|
|
80
|
-
"uniform_cond",
|
|
81
|
-
"urandint",
|
|
82
|
-
"while_loop",
|
|
83
|
-
]
|
|
84
|
-
|
|
85
|
-
|
|
86
|
-
def run_impl(
|
|
87
|
-
pmask: Mask | None,
|
|
88
|
-
func: Callable,
|
|
89
|
-
*args: Any,
|
|
90
|
-
**kwargs: Any,
|
|
91
|
-
) -> Any:
|
|
92
|
-
"""
|
|
93
|
-
Run a function that can be evaluated by the mplang system.
|
|
94
|
-
|
|
95
|
-
This function provides a dispatch mechanism based on the first argument
|
|
96
|
-
to route different function types to appropriate handlers.
|
|
97
|
-
|
|
98
|
-
Args:
|
|
99
|
-
pmask: The party mask of this function, None indicates auto deduce parties.
|
|
100
|
-
func: The function to be dispatched and executed
|
|
101
|
-
*args: Positional arguments to pass to the function
|
|
102
|
-
**kwargs: Keyword arguments to pass to the function
|
|
103
|
-
|
|
104
|
-
Returns:
|
|
105
|
-
The result of evaluating the function through the appropriate handler
|
|
106
|
-
|
|
107
|
-
Raises:
|
|
108
|
-
ValueError: If builtin.write is called without required arguments
|
|
109
|
-
TypeError: If the function compilation or evaluation fails
|
|
110
|
-
RuntimeError: If the underlying peval execution encounters errors
|
|
111
|
-
|
|
112
|
-
Examples:
|
|
113
|
-
Reading data from a file:
|
|
114
|
-
|
|
115
|
-
>>> tensor_info = TensorType(shape=(10, 10), dtype=np.float32)
|
|
116
|
-
>>> attrs = {"format": "binary"}
|
|
117
|
-
>>> result = run_impl(builtin.read, "data/input.bin", tensor_info, attrs)
|
|
118
|
-
|
|
119
|
-
Writing data to a file:
|
|
120
|
-
|
|
121
|
-
>>> run_impl(builtin.write, data, "data/output.bin")
|
|
122
|
-
|
|
123
|
-
Running a JAX function:
|
|
124
|
-
|
|
125
|
-
>>> def matrix_multiply(a, b):
|
|
126
|
-
... return jnp.dot(a, b)
|
|
127
|
-
>>> result = run_impl(matrix_multiply, mat_a, mat_b)
|
|
128
|
-
|
|
129
|
-
Running a custom computation function:
|
|
130
|
-
|
|
131
|
-
>>> def compute_statistics(data):
|
|
132
|
-
... mean = jnp.mean(data)
|
|
133
|
-
... std = jnp.std(data)
|
|
134
|
-
... return {"mean": mean, "std": std}
|
|
135
|
-
>>> stats = run_impl(compute_statistics, dataset)
|
|
136
|
-
"""
|
|
137
|
-
|
|
138
|
-
if isinstance(func, FeOperation):
|
|
139
|
-
pfunc, eval_args, out_tree = func(*args, **kwargs)
|
|
140
|
-
else:
|
|
141
|
-
if ibis_cc.is_ibis_function(func):
|
|
142
|
-
pfunc, eval_args, out_tree = ibis_cc.ibis_compile(func, *args, **kwargs)
|
|
143
|
-
else:
|
|
144
|
-
# unknown python callable, treat it as jax function
|
|
145
|
-
pfunc, eval_args, out_tree = jax_cc.jax_compile(func, *args, **kwargs)
|
|
146
|
-
results = peval(pfunc, eval_args, pmask)
|
|
147
|
-
return out_tree.unflatten(results)
|
|
148
|
-
|
|
149
|
-
|
|
150
|
-
# run :: (a -> a) -> m a -> m a
|
|
151
|
-
def run(pyfn: Callable) -> Callable:
|
|
152
|
-
return partial(run_impl, None, pyfn)
|
|
153
|
-
|
|
154
|
-
|
|
155
|
-
# runAt :: Rank -> (a -> a) -> m a -> m a
|
|
156
|
-
def runAt(rank: Rank, pyfn: Callable) -> Callable:
|
|
157
|
-
pmask = Mask.from_ranks(rank)
|
|
158
|
-
return partial(run_impl, pmask, pyfn)
|
|
25
|
+
from mplang.v1.ops.base import FeOperation
|
|
26
|
+
from mplang.v1.simp.api import run_at, run_jax_at
|
|
27
|
+
from mplang.v1.simp.mpi import p2p
|
|
159
28
|
|
|
160
29
|
|
|
161
30
|
def P2P(src: Party, dst: Party, value: Any) -> Any:
|
|
@@ -229,22 +98,22 @@ class _PartyModuleProxy:
|
|
|
229
98
|
|
|
230
99
|
def __getattr__(self, item: str) -> Callable[..., Any]:
|
|
231
100
|
self._ensure()
|
|
232
|
-
|
|
233
|
-
if not callable(
|
|
101
|
+
op = getattr(self._module, item)
|
|
102
|
+
if not callable(op):
|
|
234
103
|
raise AttributeError(
|
|
235
|
-
f"Attribute '{item}' of party module '{self._name}' is not callable (got {type(
|
|
104
|
+
f"Attribute '{item}' of party module '{self._name}' is not callable (got {type(op).__name__})"
|
|
236
105
|
)
|
|
237
106
|
|
|
238
|
-
@wraps(
|
|
107
|
+
@wraps(op)
|
|
239
108
|
def _wrapped(*args: Any, **kw: Any) -> Any:
|
|
240
109
|
# Inline runAt to reduce an extra partial layer while preserving semantics.
|
|
241
|
-
return
|
|
110
|
+
return run_at(self._party.rank, op, *args, **kw)
|
|
242
111
|
|
|
243
112
|
# Provide a party-qualified name for debugging / logs without losing original metadata.
|
|
244
|
-
base_name = getattr(
|
|
113
|
+
base_name = getattr(op, "__name__", None)
|
|
245
114
|
if base_name is None:
|
|
246
115
|
# Frontend FeOperation or object without __name__; try .name attribute (FeOperation contract) or fallback to repr
|
|
247
|
-
base_name = getattr(
|
|
116
|
+
base_name = getattr(op, "name", None) or type(op).__name__
|
|
248
117
|
try:
|
|
249
118
|
_wrapped.__name__ = f"{base_name}@P{self._party.rank}"
|
|
250
119
|
except Exception: # pragma: no cover - assignment may fail for exotic wrappers
|
|
@@ -264,7 +133,12 @@ class Party:
|
|
|
264
133
|
raise TypeError(
|
|
265
134
|
f"First argument to Party({self.rank}) must be callable, got {fn!r}"
|
|
266
135
|
)
|
|
267
|
-
|
|
136
|
+
# Use run_op_at for FeOperation, run_jax_at for plain callables
|
|
137
|
+
if isinstance(fn, FeOperation):
|
|
138
|
+
return run_at(self.rank, fn, *args, **kwargs)
|
|
139
|
+
else:
|
|
140
|
+
# TODO(jint): implicitly assume non-FeOperation as JAX function is a bit too magical?
|
|
141
|
+
return run_jax_at(self.rank, fn, *args, **kwargs)
|
|
268
142
|
|
|
269
143
|
def __getattr__(self, name: str) -> _PartyModuleProxy:
|
|
270
144
|
if name in _NAMESPACE_REGISTRY:
|
|
@@ -289,7 +163,7 @@ def _load_prelude_modules() -> None:
|
|
|
289
163
|
unwieldy we can switch to an allowlist.
|
|
290
164
|
"""
|
|
291
165
|
try:
|
|
292
|
-
import mplang.ops as _fe # type: ignore
|
|
166
|
+
import mplang.v1.ops as _fe # type: ignore
|
|
293
167
|
except (ImportError, ModuleNotFoundError): # pragma: no cover
|
|
294
168
|
# Frontend package not present (minimal install); safe to skip.
|
|
295
169
|
return
|
|
@@ -299,7 +173,7 @@ def _load_prelude_modules() -> None:
|
|
|
299
173
|
if m.name.startswith("_"):
|
|
300
174
|
continue
|
|
301
175
|
if m.name not in _NAMESPACE_REGISTRY:
|
|
302
|
-
_NAMESPACE_REGISTRY[m.name] = f"mplang.ops.{m.name}"
|
|
176
|
+
_NAMESPACE_REGISTRY[m.name] = f"mplang.v1.ops.{m.name}"
|
|
303
177
|
|
|
304
178
|
|
|
305
179
|
def load_module(module: str, alias: str | None = None) -> None:
|