mplang-nightly 0.1.dev268__py3-none-any.whl → 0.1.dev270__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mplang/__init__.py +391 -17
- mplang/{v2/backends → backends}/__init__.py +9 -7
- mplang/{v2/backends → backends}/bfv_impl.py +6 -6
- mplang/{v2/backends → backends}/crypto_impl.py +6 -6
- mplang/{v2/backends → backends}/field_impl.py +5 -5
- mplang/{v2/backends → backends}/func_impl.py +4 -4
- mplang/{v2/backends → backends}/phe_impl.py +3 -3
- mplang/{v2/backends → backends}/simp_design.md +1 -1
- mplang/{v2/backends → backends}/simp_driver/__init__.py +5 -5
- mplang/{v2/backends → backends}/simp_driver/http.py +8 -8
- mplang/{v2/backends → backends}/simp_driver/mem.py +9 -9
- mplang/{v2/backends → backends}/simp_driver/ops.py +4 -4
- mplang/{v2/backends → backends}/simp_driver/state.py +2 -2
- mplang/{v2/backends → backends}/simp_driver/values.py +2 -2
- mplang/{v2/backends → backends}/simp_worker/__init__.py +3 -3
- mplang/{v2/backends → backends}/simp_worker/http.py +10 -10
- mplang/{v2/backends → backends}/simp_worker/mem.py +1 -1
- mplang/{v2/backends → backends}/simp_worker/ops.py +5 -5
- mplang/{v2/backends → backends}/simp_worker/state.py +2 -4
- mplang/{v2/backends → backends}/spu_impl.py +8 -8
- mplang/{v2/backends → backends}/spu_state.py +4 -4
- mplang/{v2/backends → backends}/store_impl.py +3 -3
- mplang/{v2/backends → backends}/table_impl.py +8 -8
- mplang/{v2/backends → backends}/tee_impl.py +6 -6
- mplang/{v2/backends → backends}/tensor_impl.py +6 -6
- mplang/{v2/cli.py → cli.py} +9 -9
- mplang/{v2/cli_guide.md → cli_guide.md} +12 -12
- mplang/{v2/dialects → dialects}/__init__.py +5 -5
- mplang/{v2/dialects → dialects}/bfv.py +6 -6
- mplang/{v2/dialects → dialects}/crypto.py +5 -5
- mplang/{v2/dialects → dialects}/dtypes.py +2 -2
- mplang/{v2/dialects → dialects}/field.py +3 -3
- mplang/{v2/dialects → dialects}/func.py +2 -2
- mplang/{v2/dialects → dialects}/phe.py +6 -6
- mplang/{v2/dialects → dialects}/simp.py +6 -6
- mplang/{v2/dialects → dialects}/spu.py +7 -7
- mplang/{v2/dialects → dialects}/store.py +2 -2
- mplang/{v2/dialects → dialects}/table.py +3 -3
- mplang/{v2/dialects → dialects}/tee.py +6 -6
- mplang/{v2/dialects → dialects}/tensor.py +5 -5
- mplang/{v2/edsl → edsl}/__init__.py +3 -3
- mplang/{v2/edsl → edsl}/context.py +6 -6
- mplang/{v2/edsl → edsl}/graph.py +5 -5
- mplang/{v2/edsl → edsl}/jit.py +2 -2
- mplang/{v2/edsl → edsl}/object.py +1 -1
- mplang/{v2/edsl → edsl}/primitive.py +5 -5
- mplang/{v2/edsl → edsl}/printer.py +1 -1
- mplang/{v2/edsl → edsl}/serde.py +1 -1
- mplang/{v2/edsl → edsl}/tracer.py +7 -7
- mplang/{v2/edsl → edsl}/typing.py +1 -1
- mplang/{v2/kernels → kernels}/ldpc.cpp +13 -13
- mplang/{v2/kernels → kernels}/okvs.cpp +4 -4
- mplang/{v2/kernels → kernels}/okvs_opt.cpp +46 -31
- mplang/{v2/kernels → kernels}/py_kernels.py +1 -1
- mplang/{v2/libs → libs}/collective.py +5 -5
- mplang/{v2/libs → libs}/device/__init__.py +1 -1
- mplang/{v2/libs → libs}/device/api.py +12 -12
- mplang/{v2/libs → libs}/ml/__init__.py +1 -1
- mplang/{v2/libs → libs}/ml/sgb.py +4 -4
- mplang/{v2/libs → libs}/mpc/__init__.py +3 -3
- mplang/{v2/libs → libs}/mpc/_utils.py +2 -2
- mplang/{v2/libs → libs}/mpc/analytics/aggregation.py +1 -1
- mplang/{v2/libs → libs}/mpc/analytics/groupby.py +2 -2
- mplang/{v2/libs → libs}/mpc/analytics/permutation.py +3 -3
- mplang/{v2/libs → libs}/mpc/ot/base.py +3 -3
- mplang/{v2/libs → libs}/mpc/ot/extension.py +2 -2
- mplang/{v2/libs → libs}/mpc/ot/silent.py +4 -4
- mplang/{v2/libs → libs}/mpc/psi/cuckoo.py +3 -3
- mplang/{v2/libs → libs}/mpc/psi/okvs.py +1 -1
- mplang/{v2/libs → libs}/mpc/psi/okvs_gct.py +19 -13
- mplang/{v2/libs → libs}/mpc/psi/oprf.py +3 -3
- mplang/libs/mpc/psi/rr22.py +303 -0
- mplang/{v2/libs → libs}/mpc/psi/unbalanced.py +4 -4
- mplang/{v2/libs → libs}/mpc/vole/gilboa.py +3 -3
- mplang/{v2/libs → libs}/mpc/vole/ldpc.py +2 -2
- mplang/{v2/libs → libs}/mpc/vole/silver.py +6 -6
- mplang/{v2/runtime → runtime}/interpreter.py +11 -11
- mplang/{v2/runtime → runtime}/value.py +2 -2
- mplang/{v1/runtime → utils}/__init__.py +18 -15
- mplang/{v1/utils → utils}/func_utils.py +1 -1
- {mplang_nightly-0.1.dev268.dist-info → mplang_nightly-0.1.dev270.dist-info}/METADATA +2 -2
- mplang_nightly-0.1.dev270.dist-info/RECORD +102 -0
- mplang/v1/__init__.py +0 -157
- mplang/v1/_device.py +0 -602
- mplang/v1/analysis/__init__.py +0 -37
- mplang/v1/analysis/diagram.py +0 -567
- mplang/v1/core/__init__.py +0 -157
- mplang/v1/core/cluster.py +0 -343
- mplang/v1/core/comm.py +0 -281
- mplang/v1/core/context_mgr.py +0 -50
- mplang/v1/core/dtypes.py +0 -335
- mplang/v1/core/expr/__init__.py +0 -80
- mplang/v1/core/expr/ast.py +0 -542
- mplang/v1/core/expr/evaluator.py +0 -581
- mplang/v1/core/expr/printer.py +0 -285
- mplang/v1/core/expr/transformer.py +0 -141
- mplang/v1/core/expr/utils.py +0 -78
- mplang/v1/core/expr/visitor.py +0 -85
- mplang/v1/core/expr/walk.py +0 -387
- mplang/v1/core/interp.py +0 -160
- mplang/v1/core/mask.py +0 -325
- mplang/v1/core/mpir.py +0 -965
- mplang/v1/core/mpobject.py +0 -117
- mplang/v1/core/mptype.py +0 -407
- mplang/v1/core/pfunc.py +0 -130
- mplang/v1/core/primitive.py +0 -877
- mplang/v1/core/table.py +0 -218
- mplang/v1/core/tensor.py +0 -75
- mplang/v1/core/tracer.py +0 -383
- mplang/v1/host.py +0 -130
- mplang/v1/kernels/__init__.py +0 -41
- mplang/v1/kernels/base.py +0 -125
- mplang/v1/kernels/basic.py +0 -240
- mplang/v1/kernels/context.py +0 -369
- mplang/v1/kernels/crypto.py +0 -122
- mplang/v1/kernels/fhe.py +0 -858
- mplang/v1/kernels/mock_tee.py +0 -72
- mplang/v1/kernels/phe.py +0 -1864
- mplang/v1/kernels/spu.py +0 -341
- mplang/v1/kernels/sql_duckdb.py +0 -44
- mplang/v1/kernels/stablehlo.py +0 -90
- mplang/v1/kernels/value.py +0 -626
- mplang/v1/ops/__init__.py +0 -35
- mplang/v1/ops/base.py +0 -424
- mplang/v1/ops/basic.py +0 -294
- mplang/v1/ops/crypto.py +0 -262
- mplang/v1/ops/fhe.py +0 -272
- mplang/v1/ops/jax_cc.py +0 -147
- mplang/v1/ops/nnx_cc.py +0 -168
- mplang/v1/ops/phe.py +0 -216
- mplang/v1/ops/spu.py +0 -151
- mplang/v1/ops/sql_cc.py +0 -303
- mplang/v1/ops/tee.py +0 -36
- mplang/v1/protos/v1alpha1/mpir_pb2.py +0 -63
- mplang/v1/protos/v1alpha1/mpir_pb2.pyi +0 -557
- mplang/v1/protos/v1alpha1/value_pb2.py +0 -34
- mplang/v1/protos/v1alpha1/value_pb2.pyi +0 -169
- mplang/v1/runtime/channel.py +0 -230
- mplang/v1/runtime/cli.py +0 -451
- mplang/v1/runtime/client.py +0 -456
- mplang/v1/runtime/communicator.py +0 -131
- mplang/v1/runtime/data_providers.py +0 -303
- mplang/v1/runtime/driver.py +0 -324
- mplang/v1/runtime/exceptions.py +0 -27
- mplang/v1/runtime/http_api.md +0 -56
- mplang/v1/runtime/link_comm.py +0 -196
- mplang/v1/runtime/server.py +0 -501
- mplang/v1/runtime/session.py +0 -270
- mplang/v1/runtime/simulation.py +0 -324
- mplang/v1/simp/__init__.py +0 -13
- mplang/v1/simp/api.py +0 -353
- mplang/v1/simp/mpi.py +0 -131
- mplang/v1/simp/party.py +0 -225
- mplang/v1/simp/random.py +0 -120
- mplang/v1/simp/smpc.py +0 -238
- mplang/v1/utils/__init__.py +0 -13
- mplang/v1/utils/crypto.py +0 -32
- mplang/v1/utils/spu_utils.py +0 -130
- mplang/v1/utils/table_utils.py +0 -185
- mplang/v2/__init__.py +0 -424
- mplang/v2/libs/mpc/psi/rr22.py +0 -344
- mplang_nightly-0.1.dev268.dist-info/RECORD +0 -180
- /mplang/{v2/backends → backends}/channel.py +0 -0
- /mplang/{v2/edsl → edsl}/README.md +0 -0
- /mplang/{v2/edsl → edsl}/registry.py +0 -0
- /mplang/{v2/kernels → kernels}/Makefile +0 -0
- /mplang/{v2/kernels → kernels}/__init__.py +0 -0
- /mplang/{v2/kernels → kernels}/gf128.cpp +0 -0
- /mplang/{v2/libs → libs}/device/cluster.py +0 -0
- /mplang/{v2/libs → libs}/mpc/analytics/__init__.py +0 -0
- /mplang/{v2/libs → libs}/mpc/analytics/groupby.md +0 -0
- /mplang/{v2/libs → libs}/mpc/common/constants.py +0 -0
- /mplang/{v2/libs → libs}/mpc/ot/__init__.py +0 -0
- /mplang/{v2/libs → libs}/mpc/psi/__init__.py +0 -0
- /mplang/{v2/libs → libs}/mpc/vole/__init__.py +0 -0
- /mplang/{v2/runtime → runtime}/__init__.py +0 -0
- /mplang/{v2/runtime → runtime}/dialect_state.py +0 -0
- /mplang/{v2/runtime → runtime}/object_store.py +0 -0
- {mplang_nightly-0.1.dev268.dist-info → mplang_nightly-0.1.dev270.dist-info}/WHEEL +0 -0
- {mplang_nightly-0.1.dev268.dist-info → mplang_nightly-0.1.dev270.dist-info}/entry_points.txt +0 -0
- {mplang_nightly-0.1.dev268.dist-info → mplang_nightly-0.1.dev270.dist-info}/licenses/LICENSE +0 -0
mplang/v1/simp/api.py
DELETED
|
@@ -1,353 +0,0 @@
|
|
|
1
|
-
# Copyright 2025 Ant Group Co., Ltd.
|
|
2
|
-
#
|
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
-
# you may not use this file except in compliance with the License.
|
|
5
|
-
# You may obtain a copy of the License at
|
|
6
|
-
#
|
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
-
#
|
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
-
# See the License for the specific language governing permissions and
|
|
13
|
-
# limitations under the License.
|
|
14
|
-
|
|
15
|
-
from __future__ import annotations
|
|
16
|
-
|
|
17
|
-
from collections.abc import Callable
|
|
18
|
-
from typing import Any, cast
|
|
19
|
-
|
|
20
|
-
from mplang.v1.core import (
|
|
21
|
-
Mask,
|
|
22
|
-
MPObject,
|
|
23
|
-
Rank,
|
|
24
|
-
ScalarType,
|
|
25
|
-
Shape,
|
|
26
|
-
TableLike,
|
|
27
|
-
TensorLike,
|
|
28
|
-
builtin_function,
|
|
29
|
-
peval,
|
|
30
|
-
)
|
|
31
|
-
from mplang.v1.ops import basic, jax_cc, nnx_cc, sql_cc
|
|
32
|
-
from mplang.v1.ops.base import FeOperation
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
def run(
|
|
36
|
-
pmask: Mask | None,
|
|
37
|
-
fe_op: FeOperation,
|
|
38
|
-
*args: Any,
|
|
39
|
-
**kwargs: Any,
|
|
40
|
-
) -> Any:
|
|
41
|
-
"""Run an operation in the current context."""
|
|
42
|
-
pfunc, eval_args, out_tree = fe_op(*args, **kwargs)
|
|
43
|
-
results = peval(pfunc, eval_args, pmask)
|
|
44
|
-
return out_tree.unflatten(results)
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
def run_at(rank: Rank, op: Any, *args: Any, **kwargs: Any) -> Any:
|
|
48
|
-
"""Run an operation at a specific rank."""
|
|
49
|
-
return run(Mask.from_ranks(rank), op, *args, **kwargs)
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
@builtin_function
|
|
53
|
-
def prank() -> MPObject:
|
|
54
|
-
"""Multi-party get the rank (party identifier) of each party.
|
|
55
|
-
|
|
56
|
-
This function returns a scalar tensor containing the rank (party identifier)
|
|
57
|
-
for each party in the current party mask. Each party independently produces
|
|
58
|
-
its own rank value, which serves as a unique identifier within the multi-party
|
|
59
|
-
computation context.
|
|
60
|
-
|
|
61
|
-
The rank values range from 0 to world_size-1, where world_size is the total
|
|
62
|
-
number of parties in the computation. Each party's rank is private to that
|
|
63
|
-
party and represents its position in the multi-party protocol.
|
|
64
|
-
|
|
65
|
-
Returns:
|
|
66
|
-
MPObject: A variable representing a scalar tensor with:
|
|
67
|
-
- dtype: UINT64
|
|
68
|
-
- shape: () (scalar)
|
|
69
|
-
|
|
70
|
-
Note:
|
|
71
|
-
Each party in the current party mask independently produces its own rank value.
|
|
72
|
-
"""
|
|
73
|
-
return cast(MPObject, run(None, basic.rank))
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
@builtin_function
|
|
77
|
-
def prand(shape: Shape = ()) -> MPObject:
|
|
78
|
-
"""Multi-party generate a private random (uint64) tensor with the given shape.
|
|
79
|
-
|
|
80
|
-
This function creates a private random tensor where each party independently
|
|
81
|
-
generates its own local random values. Each party's random values are private
|
|
82
|
-
and unknown to other parties. The output tensor contains 64-bit unsigned
|
|
83
|
-
integers, with each party holding its own privately generated values.
|
|
84
|
-
|
|
85
|
-
Args:
|
|
86
|
-
shape: The shape of the random tensor to generate.
|
|
87
|
-
Must be a tuple of positive integers. Defaults to () for scalar.
|
|
88
|
-
|
|
89
|
-
Returns:
|
|
90
|
-
MPObject: A variable representing the generated private random tensor with:
|
|
91
|
-
- dtype: UINT64
|
|
92
|
-
- shape: As specified by the shape parameter
|
|
93
|
-
|
|
94
|
-
Note:
|
|
95
|
-
Each party in the current party mask independently generates its own
|
|
96
|
-
private random values. The randomness is local to each party and is
|
|
97
|
-
not shared or revealed to other parties.
|
|
98
|
-
"""
|
|
99
|
-
return cast(MPObject, run(None, basic.prand, shape))
|
|
100
|
-
|
|
101
|
-
|
|
102
|
-
def constant(data: TensorLike | ScalarType | TableLike) -> MPObject:
|
|
103
|
-
"""Create a constant tensor or table from data.
|
|
104
|
-
|
|
105
|
-
This function creates a constant that can be used in multi-party
|
|
106
|
-
computations. The constant value is embedded directly into the computation
|
|
107
|
-
graph and is available to all parties in the current party mask.
|
|
108
|
-
|
|
109
|
-
Args:
|
|
110
|
-
data: The constant data to embed. Can be:
|
|
111
|
-
- A scalar value (int, float, bool)
|
|
112
|
-
- A numpy array or other tensor-like object
|
|
113
|
-
- A pandas DataFrame or other table-like object
|
|
114
|
-
- Any object that can be converted to tensor
|
|
115
|
-
|
|
116
|
-
Returns:
|
|
117
|
-
MPObject: A variable representing the constant tensor or table with:
|
|
118
|
-
- dtype: Inferred from the input data
|
|
119
|
-
- shape: Inferred from the input data (for tensors)
|
|
120
|
-
- schema: Inferred from the input data (for tables)
|
|
121
|
-
- data: The embedded constant values
|
|
122
|
-
|
|
123
|
-
Note:
|
|
124
|
-
The constant data is embedded at graph construction time and is available
|
|
125
|
-
to all parties during execution. Large constants may impact graph size.
|
|
126
|
-
|
|
127
|
-
For table-like objects (e.g., pandas DataFrame), JSON serialization is used.
|
|
128
|
-
Note that the constant primitive is not designed to carry large tables efficiently -
|
|
129
|
-
consider using dedicated table loading mechanisms for substantial datasets.
|
|
130
|
-
"""
|
|
131
|
-
return cast(MPObject, run(None, basic.constant, data))
|
|
132
|
-
|
|
133
|
-
|
|
134
|
-
@builtin_function
|
|
135
|
-
def debug_print(obj: MPObject, prefix: str = "") -> MPObject:
|
|
136
|
-
"""Print local value of obj on owning parties and pass it through.
|
|
137
|
-
|
|
138
|
-
This function prints the value of an MPObject at runtime on each party that
|
|
139
|
-
owns the value, and returns the same MPObject unchanged. This is useful for
|
|
140
|
-
debugging multi-party computations without affecting the computation flow.
|
|
141
|
-
|
|
142
|
-
Args:
|
|
143
|
-
obj: The MPObject whose value should be printed.
|
|
144
|
-
prefix: Optional text prefix for the printed output. Defaults to "".
|
|
145
|
-
|
|
146
|
-
Returns:
|
|
147
|
-
MPObject: The same MPObject value passed in, unchanged. This allows
|
|
148
|
-
the function to be used in chains like: x = debug_print(x, "x=")
|
|
149
|
-
and prevents dead code elimination (DCE) from removing the print.
|
|
150
|
-
|
|
151
|
-
Note:
|
|
152
|
-
The print operation occurs at runtime on each party that holds the value.
|
|
153
|
-
If obj has a static pmask, only parties in that mask will print.
|
|
154
|
-
If obj has a dynamic pmask, the parties are determined at runtime.
|
|
155
|
-
"""
|
|
156
|
-
pfunc, eval_args, out_tree = basic.debug_print(obj, prefix=prefix)
|
|
157
|
-
results = peval(pfunc, eval_args)
|
|
158
|
-
return cast(MPObject, out_tree.unflatten(results))
|
|
159
|
-
|
|
160
|
-
|
|
161
|
-
def set_mask(arg: MPObject, mask: Mask) -> MPObject:
|
|
162
|
-
"""Set the mask of an MPObject to a new value.
|
|
163
|
-
|
|
164
|
-
This function allows changing the party mask of an existing MPObject variable.
|
|
165
|
-
The behavior depends on whether the input MPObject has a dynamic or static pmask:
|
|
166
|
-
|
|
167
|
-
**Case 1: Dynamic pmask (arg.pmask is None)**
|
|
168
|
-
- The input MPObject has a runtime-determined pmask
|
|
169
|
-
- The return value's pmask will be exactly the specified mask
|
|
170
|
-
- No validation is performed at compile time
|
|
171
|
-
|
|
172
|
-
**Case 2: Static pmask (arg.pmask is not None)**
|
|
173
|
-
- If mask is a subset of arg.pmask: return_var.pmask == arg.pmask (unchanged)
|
|
174
|
-
- If mask is NOT a subset of arg.pmask: raises ValueError at compile time
|
|
175
|
-
|
|
176
|
-
Args:
|
|
177
|
-
arg: The MPObject whose mask needs to be changed.
|
|
178
|
-
mask: The target mask to apply. Must be a valid party mask.
|
|
179
|
-
|
|
180
|
-
Returns:
|
|
181
|
-
MPObject: A new variable with the specified mask behavior:
|
|
182
|
-
- For dynamic inputs: pmask = mask
|
|
183
|
-
- For static inputs (valid subset): pmask = arg.pmask
|
|
184
|
-
|
|
185
|
-
Raises:
|
|
186
|
-
ValueError: When arg has a static pmask and mask is not a subset of arg.pmask.
|
|
187
|
-
This validation occurs at compile time during graph construction.
|
|
188
|
-
|
|
189
|
-
Examples:
|
|
190
|
-
**Example 1: Dynamic pmask - mask assignment**
|
|
191
|
-
P0 P1 P2
|
|
192
|
-
-- -- --
|
|
193
|
-
Input: ? ? ? (pmask=None, runtime-determined)
|
|
194
|
-
mask: [0,2] (target mask)
|
|
195
|
-
-----------------------------------------------------------
|
|
196
|
-
Output: x0 - x2 (pmask=[0,2])
|
|
197
|
-
|
|
198
|
-
**Example 2: Static pmask - valid subset**
|
|
199
|
-
P0 P1 P2
|
|
200
|
-
-- -- --
|
|
201
|
-
Input: x0 x1 x2 (pmask=[0,1,2])
|
|
202
|
-
mask: [0,2] (subset of input pmask)
|
|
203
|
-
-----------------------------------------------------------
|
|
204
|
-
Output: x0 - x2 (pmask=[0,2])
|
|
205
|
-
|
|
206
|
-
**Example 3: Static pmask - invalid subset (compile error)**
|
|
207
|
-
P0 P1 P2
|
|
208
|
-
-- -- --
|
|
209
|
-
Input: x0 - x2 (pmask=[0,2])
|
|
210
|
-
mask: [1,2] (NOT subset of [0,2])
|
|
211
|
-
-----------------------------------------------------------
|
|
212
|
-
Result: ValueError at compile time
|
|
213
|
-
|
|
214
|
-
Note:
|
|
215
|
-
This function is typically used for constraining the execution scope
|
|
216
|
-
of variables or for type casting between different pmask contexts.
|
|
217
|
-
The underlying implementation uses JAX identity function with the
|
|
218
|
-
specified execution mask.
|
|
219
|
-
"""
|
|
220
|
-
pfunc, eval_args, out_tree = basic.identity(arg)
|
|
221
|
-
results = peval(pfunc, eval_args, mask)
|
|
222
|
-
return cast(MPObject, out_tree.unflatten(results))
|
|
223
|
-
|
|
224
|
-
|
|
225
|
-
def run_jax(jax_fn: Callable, *args: Any, **kwargs: Any) -> Any:
|
|
226
|
-
"""Run a JAX function.
|
|
227
|
-
|
|
228
|
-
Args:
|
|
229
|
-
jax_fn: The JAX function to be executed.
|
|
230
|
-
*args: Positional arguments to pass to the JAX function.
|
|
231
|
-
**kwargs: Keyword arguments to pass to the JAX function.
|
|
232
|
-
|
|
233
|
-
Returns:
|
|
234
|
-
The result of evaluating the JAX function through the mplang system.
|
|
235
|
-
|
|
236
|
-
Raises:
|
|
237
|
-
TypeError: If the function compilation or evaluation fails.
|
|
238
|
-
RuntimeError: If the underlying peval execution encounters errors.
|
|
239
|
-
|
|
240
|
-
Notes:
|
|
241
|
-
Argument binding semantics with respect to JAX static arguments:
|
|
242
|
-
|
|
243
|
-
- If an argument (or any leaf within a PyTree argument) is an
|
|
244
|
-
:class:`~mplang.core.mpobject.MPObject`, it is captured as a runtime
|
|
245
|
-
variable (dynamic value) in the traced program and is not treated as a
|
|
246
|
-
JAX static argument.
|
|
247
|
-
- If an argument contains no :class:`MPObject` leaves, it is treated as a
|
|
248
|
-
constant configuration with respect to JAX; effectively it behaves
|
|
249
|
-
like a static argument and may contribute to JAX compilation cache
|
|
250
|
-
keys (similar to ``static_argnums`` semantics). Changing such constant
|
|
251
|
-
arguments can lead to different compiled variants/cached entries.
|
|
252
|
-
|
|
253
|
-
Examples:
|
|
254
|
-
Defining and running a simple JAX function:
|
|
255
|
-
|
|
256
|
-
>>> import jax.numpy as jnp
|
|
257
|
-
>>> def add_matrices(a, b):
|
|
258
|
-
... return jnp.add(a, b)
|
|
259
|
-
>>> result = run_jax(add_matrices, matrix_a, matrix_b)
|
|
260
|
-
|
|
261
|
-
Running a more complex JAX function:
|
|
262
|
-
|
|
263
|
-
>>> def compute_statistics(data):
|
|
264
|
-
... mean = jnp.mean(data)
|
|
265
|
-
... std = jnp.std(data)
|
|
266
|
-
... return {"mean": mean, "std": std}
|
|
267
|
-
>>> stats = run_jax(compute_statistics, dataset)
|
|
268
|
-
"""
|
|
269
|
-
return run(None, jax_cc.run_jax, jax_fn, *args, **kwargs)
|
|
270
|
-
|
|
271
|
-
|
|
272
|
-
def run_jax_at(rank: Rank, jax_fn: Callable, *args: Any, **kwargs: Any) -> Any:
|
|
273
|
-
return run_at(rank, jax_cc.run_jax, jax_fn, *args, **kwargs)
|
|
274
|
-
|
|
275
|
-
|
|
276
|
-
def run_sql(
|
|
277
|
-
query: str, out_type: Any, in_tables: dict[str, MPObject] | None = None
|
|
278
|
-
) -> Any:
|
|
279
|
-
# TODO(jint): add docstring, drop out_type.
|
|
280
|
-
return run(None, sql_cc.run_sql_raw, query, out_type, in_tables)
|
|
281
|
-
|
|
282
|
-
|
|
283
|
-
def run_sql_at(
|
|
284
|
-
rank: Rank, query: str, out_type: Any, in_tables: dict[str, MPObject] | None = None
|
|
285
|
-
) -> Any:
|
|
286
|
-
return run_at(rank, sql_cc.run_sql_raw, query, out_type, in_tables)
|
|
287
|
-
|
|
288
|
-
|
|
289
|
-
def run_nnx(nnx_fn: Callable, *args: Any, **kwargs: Any) -> Any:
|
|
290
|
-
"""Run an NNX function.
|
|
291
|
-
|
|
292
|
-
Args:
|
|
293
|
-
nnx_fn: The NNX function to be executed.
|
|
294
|
-
*args: Positional arguments to pass to the NNX function.
|
|
295
|
-
**kwargs: Keyword arguments to pass to the NNX function.
|
|
296
|
-
|
|
297
|
-
Returns:
|
|
298
|
-
The result of evaluating the NNX function through the mplang system.
|
|
299
|
-
|
|
300
|
-
Raises:
|
|
301
|
-
TypeError: If the function compilation or evaluation fails.
|
|
302
|
-
RuntimeError: If the underlying peval execution encounters errors.
|
|
303
|
-
|
|
304
|
-
Notes:
|
|
305
|
-
Argument binding semantics with respect to NNX static arguments:
|
|
306
|
-
|
|
307
|
-
- If an argument (or any leaf within a PyTree argument) is an
|
|
308
|
-
:class:`~mplana.v1.core.mpobject.MPObject`, it is captured as a runtime
|
|
309
|
-
variable (dynamic value) in the traced program and is not treated as a
|
|
310
|
-
NNX static argument.
|
|
311
|
-
- If an argument contains no :class:`MPObject` leaves, it is treated as a
|
|
312
|
-
constant configuration with respect to NNX; effectively it behaves
|
|
313
|
-
like a static argument and may contribute to NNX compilation cache
|
|
314
|
-
keys (similar to ``static_argnums`` semantics). Changing such constant
|
|
315
|
-
arguments can lead to different compiled variants/cached entries.
|
|
316
|
-
|
|
317
|
-
Examples:
|
|
318
|
-
Defining and running a simple NNX function:
|
|
319
|
-
|
|
320
|
-
>>> from flax import nnx
|
|
321
|
-
>>> import jax.numpy as jnp
|
|
322
|
-
>>> def nnx_linear(inputs, weights, bias):
|
|
323
|
-
... return jnp.dot(inputs, weights) + bias
|
|
324
|
-
>>> result = run_nnx(nnx_linear, inputs, weights, bias)
|
|
325
|
-
|
|
326
|
-
Running an NNX model:
|
|
327
|
-
|
|
328
|
-
>>> class LinearModel(nnx.Module):
|
|
329
|
-
... def __init__(self, features: int, rngs: nnx.Rngs):
|
|
330
|
-
... self.linear = nnx.Linear(features, features, rngs=rngs)
|
|
331
|
-
...
|
|
332
|
-
... def __call__(self, x):
|
|
333
|
-
... return self.linear(x)
|
|
334
|
-
>>> def forward_pass(model, x):
|
|
335
|
-
... return model(x)
|
|
336
|
-
>>> output = run_nnx(forward_pass, model, input_data)
|
|
337
|
-
"""
|
|
338
|
-
return run(None, nnx_cc.run_nnx, nnx_fn, *args, **kwargs)
|
|
339
|
-
|
|
340
|
-
|
|
341
|
-
def run_nnx_at(rank: Rank, nnx_fn: Callable, *args: Any, **kwargs: Any) -> Any:
|
|
342
|
-
"""Run an NNX function at a specific rank.
|
|
343
|
-
|
|
344
|
-
Args:
|
|
345
|
-
rank: The rank where the NNX function should be executed.
|
|
346
|
-
nnx_fn: The NNX function to be executed.
|
|
347
|
-
*args: Positional arguments to pass to the NNX function.
|
|
348
|
-
**kwargs: Keyword arguments to pass to the NNX function.
|
|
349
|
-
|
|
350
|
-
Returns:
|
|
351
|
-
The result of evaluating the NNX function at the specified rank.
|
|
352
|
-
"""
|
|
353
|
-
return run_at(rank, nnx_cc.run_nnx, nnx_fn, *args, **kwargs)
|
mplang/v1/simp/mpi.py
DELETED
|
@@ -1,131 +0,0 @@
|
|
|
1
|
-
# Copyright 2025 Ant Group Co., Ltd.
|
|
2
|
-
#
|
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
-
# you may not use this file except in compliance with the License.
|
|
5
|
-
# You may obtain a copy of the License at
|
|
6
|
-
#
|
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
-
#
|
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
-
# See the License for the specific language governing permissions and
|
|
13
|
-
# limitations under the License.
|
|
14
|
-
|
|
15
|
-
from __future__ import annotations
|
|
16
|
-
|
|
17
|
-
import logging
|
|
18
|
-
|
|
19
|
-
from mplang.v1.core import Mask, MPObject, Rank, function, pconv, pshfl_s
|
|
20
|
-
|
|
21
|
-
|
|
22
|
-
# scatter :: [m a] -> m Rank -> m a
|
|
23
|
-
@function
|
|
24
|
-
def scatter_m(to_mask: Mask, root: Rank, args: list[MPObject]) -> MPObject:
|
|
25
|
-
"""Scatter the object from root to the parties in pmask.
|
|
26
|
-
|
|
27
|
-
Args:
|
|
28
|
-
to_mask: The mask of the parties that will receive the object.
|
|
29
|
-
root: The rank of the root party.
|
|
30
|
-
args: The objects to be scattered, which must hold by root and length of pmask'ed parties.
|
|
31
|
-
"""
|
|
32
|
-
# sanity check, ensure all args are in the to_mask.
|
|
33
|
-
for arg in args:
|
|
34
|
-
if arg.pmask is None:
|
|
35
|
-
logging.warning(f"Scattering dynamic {arg} from static root {root}")
|
|
36
|
-
else:
|
|
37
|
-
if not Mask.from_ranks(root).is_subset(arg.pmask):
|
|
38
|
-
raise ValueError(f"Expect root {root} in {arg.pmask}, got {arg}.")
|
|
39
|
-
|
|
40
|
-
to_ranks = list(Mask(to_mask))
|
|
41
|
-
if len(args) != len(to_ranks):
|
|
42
|
-
raise ValueError(f"Expect {len(to_ranks)} args, got {len(args)}. ")
|
|
43
|
-
|
|
44
|
-
scattered = [
|
|
45
|
-
pshfl_s(arg, Mask.from_ranks(to_rank), [root])
|
|
46
|
-
for to_rank, arg in zip(to_ranks, args, strict=False)
|
|
47
|
-
]
|
|
48
|
-
|
|
49
|
-
result = pconv(scattered)
|
|
50
|
-
assert result.pmask == to_mask, (result.pmask, to_mask)
|
|
51
|
-
return result # type: ignore[no-any-return]
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
# gather :: m a -> m Rank -> [m a]
|
|
55
|
-
@function
|
|
56
|
-
def gather_m(src_mask: Mask, root: Rank, arg: MPObject) -> list[MPObject]:
|
|
57
|
-
"""Gather the object from pmask'ed parties to the root party.
|
|
58
|
-
|
|
59
|
-
Args:
|
|
60
|
-
src_mask: The mask of the parties that will gather the object.
|
|
61
|
-
root: The rank of the root party.
|
|
62
|
-
arg: The object to be gathered. It must be held by all parties specified in `src_mask`.
|
|
63
|
-
|
|
64
|
-
Returns:
|
|
65
|
-
A list of objects, with length equal to the number of parties in pmask.
|
|
66
|
-
"""
|
|
67
|
-
# static pmask check.
|
|
68
|
-
if arg.pmask is None:
|
|
69
|
-
logging.warning(f"Gathering {arg} from {src_mask}, may raise RuntimeError.")
|
|
70
|
-
else:
|
|
71
|
-
if not Mask(src_mask).is_subset(arg.pmask):
|
|
72
|
-
raise ValueError(f"Expect {src_mask} in {arg.pmask}, got {arg}.")
|
|
73
|
-
|
|
74
|
-
result = []
|
|
75
|
-
root_mask = Mask.from_ranks(root)
|
|
76
|
-
for src_rank in Mask(src_mask):
|
|
77
|
-
# Shuffle data from src_rank to root
|
|
78
|
-
gathered_data = pshfl_s(arg, root_mask, [src_rank])
|
|
79
|
-
result.append(gathered_data)
|
|
80
|
-
|
|
81
|
-
assert len(result) == Mask(src_mask).num_parties(), (result, src_mask)
|
|
82
|
-
return result
|
|
83
|
-
|
|
84
|
-
|
|
85
|
-
# bcast :: m a -> m Rank -> m a
|
|
86
|
-
@function
|
|
87
|
-
def bcast_m(pmask: Mask, root: Rank, obj: MPObject) -> MPObject:
|
|
88
|
-
"""Broadcast the object from the root party to the parties in pmask."""
|
|
89
|
-
if obj.pmask is None:
|
|
90
|
-
logging.warning(f"Broadcasting {obj} from {root}, may raise RuntimeError.")
|
|
91
|
-
else:
|
|
92
|
-
if not Mask.from_ranks(root).is_subset(obj.pmask):
|
|
93
|
-
raise ValueError(f"Expect root {root} in obj mask {obj.pmask}.")
|
|
94
|
-
|
|
95
|
-
result = pshfl_s(obj, pmask, [root] * Mask(pmask).num_parties())
|
|
96
|
-
|
|
97
|
-
assert result.pmask == pmask, (result.pmask, pmask)
|
|
98
|
-
return result # type: ignore[no-any-return]
|
|
99
|
-
|
|
100
|
-
|
|
101
|
-
# p2p :: m Rank -> m Rank -> m a -> m a
|
|
102
|
-
@function
|
|
103
|
-
def p2p(frm: Rank, to: Rank, obj: MPObject) -> MPObject:
|
|
104
|
-
"""Point-to-point communication from frm to to."""
|
|
105
|
-
|
|
106
|
-
# sanity check, ensure the object is in the frm mask.
|
|
107
|
-
if obj.pmask is None:
|
|
108
|
-
logging.warning(f"P2P {obj} from {frm} to {to}, may raise RuntimeError.")
|
|
109
|
-
else:
|
|
110
|
-
if not Mask.from_ranks(frm).is_subset(obj.pmask):
|
|
111
|
-
raise ValueError(f"Expect {frm} in {obj.pmask}, got {obj}.")
|
|
112
|
-
|
|
113
|
-
if frm == to:
|
|
114
|
-
return obj
|
|
115
|
-
|
|
116
|
-
return pshfl_s(obj, Mask.from_ranks(to), [frm]) # type: ignore[no-any-return]
|
|
117
|
-
|
|
118
|
-
|
|
119
|
-
# allgather :: m a -> [m a]
|
|
120
|
-
@function
|
|
121
|
-
def allgather_m(pmask: Mask, arg: MPObject) -> list[MPObject]:
|
|
122
|
-
"""Gather the object from all parties in pmask and return a list of objects."""
|
|
123
|
-
|
|
124
|
-
if arg.pmask is None:
|
|
125
|
-
logging.warning(f"Allgathering {arg} from {pmask}, may raise RuntimeError.")
|
|
126
|
-
else:
|
|
127
|
-
if not Mask(pmask).is_subset(arg.pmask):
|
|
128
|
-
raise ValueError(f"Expect {pmask} in {arg.pmask}, got {arg}.")
|
|
129
|
-
|
|
130
|
-
# TODO(jint): implement me.
|
|
131
|
-
raise NotImplementedError("Allgather not implemented")
|