mplang-nightly 0.1.dev158__py3-none-any.whl → 0.1.dev268__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mplang/__init__.py +21 -45
- mplang/py.typed +13 -0
- mplang/v1/__init__.py +157 -0
- mplang/v1/_device.py +602 -0
- mplang/{analysis → v1/analysis}/__init__.py +1 -1
- mplang/{analysis → v1/analysis}/diagram.py +5 -7
- mplang/v1/core/__init__.py +157 -0
- mplang/{core → v1/core}/cluster.py +30 -14
- mplang/{core → v1/core}/comm.py +5 -1
- mplang/{core → v1/core}/context_mgr.py +1 -1
- mplang/{core/dtype.py → v1/core/dtypes.py} +44 -2
- mplang/{core → v1/core}/expr/__init__.py +7 -7
- mplang/{core → v1/core}/expr/ast.py +13 -14
- mplang/{core → v1/core}/expr/evaluator.py +65 -24
- mplang/{core → v1/core}/expr/printer.py +24 -18
- mplang/{core → v1/core}/expr/transformer.py +3 -3
- mplang/{core → v1/core}/expr/utils.py +2 -2
- mplang/{core → v1/core}/expr/visitor.py +1 -1
- mplang/{core → v1/core}/expr/walk.py +1 -1
- mplang/{core → v1/core}/interp.py +6 -6
- mplang/{core → v1/core}/mpir.py +23 -16
- mplang/{core → v1/core}/mpobject.py +6 -6
- mplang/{core → v1/core}/mptype.py +13 -10
- mplang/{core → v1/core}/pfunc.py +4 -4
- mplang/{core → v1/core}/primitive.py +106 -201
- mplang/{core → v1/core}/table.py +36 -8
- mplang/{core → v1/core}/tensor.py +1 -1
- mplang/{core → v1/core}/tracer.py +9 -9
- mplang/{api.py → v1/host.py} +38 -6
- mplang/v1/kernels/__init__.py +41 -0
- mplang/{kernels → v1/kernels}/base.py +1 -1
- mplang/v1/kernels/basic.py +240 -0
- mplang/{kernels → v1/kernels}/context.py +42 -27
- mplang/{kernels → v1/kernels}/crypto.py +44 -37
- mplang/v1/kernels/fhe.py +858 -0
- mplang/{kernels → v1/kernels}/mock_tee.py +12 -13
- mplang/{kernels → v1/kernels}/phe.py +263 -57
- mplang/{kernels → v1/kernels}/spu.py +137 -48
- mplang/{kernels → v1/kernels}/sql_duckdb.py +12 -15
- mplang/{kernels → v1/kernels}/stablehlo.py +30 -23
- mplang/v1/kernels/value.py +626 -0
- mplang/{ops → v1/ops}/__init__.py +5 -16
- mplang/{ops → v1/ops}/base.py +2 -5
- mplang/{ops/builtin.py → v1/ops/basic.py} +34 -26
- mplang/v1/ops/crypto.py +262 -0
- mplang/v1/ops/fhe.py +272 -0
- mplang/{ops → v1/ops}/jax_cc.py +33 -68
- mplang/v1/ops/nnx_cc.py +168 -0
- mplang/{ops → v1/ops}/phe.py +16 -4
- mplang/{ops → v1/ops}/spu.py +3 -5
- mplang/v1/ops/sql_cc.py +303 -0
- mplang/{ops → v1/ops}/tee.py +9 -24
- mplang/{protos → v1/protos}/v1alpha1/mpir_pb2.pyi +71 -21
- mplang/v1/protos/v1alpha1/value_pb2.py +34 -0
- mplang/v1/protos/v1alpha1/value_pb2.pyi +169 -0
- mplang/{runtime → v1/runtime}/__init__.py +2 -2
- mplang/v1/runtime/channel.py +230 -0
- mplang/{runtime → v1/runtime}/cli.py +35 -20
- mplang/{runtime → v1/runtime}/client.py +19 -8
- mplang/{runtime → v1/runtime}/communicator.py +59 -15
- mplang/{runtime → v1/runtime}/data_providers.py +80 -19
- mplang/{runtime → v1/runtime}/driver.py +30 -12
- mplang/v1/runtime/link_comm.py +196 -0
- mplang/{runtime → v1/runtime}/server.py +58 -42
- mplang/{runtime → v1/runtime}/session.py +57 -71
- mplang/{runtime → v1/runtime}/simulation.py +55 -28
- mplang/v1/simp/api.py +353 -0
- mplang/{simp → v1/simp}/mpi.py +8 -9
- mplang/{simp/__init__.py → v1/simp/party.py} +19 -145
- mplang/{simp → v1/simp}/random.py +21 -22
- mplang/v1/simp/smpc.py +238 -0
- mplang/v1/utils/table_utils.py +185 -0
- mplang/v2/__init__.py +424 -0
- mplang/v2/backends/__init__.py +57 -0
- mplang/v2/backends/bfv_impl.py +705 -0
- mplang/v2/backends/channel.py +217 -0
- mplang/v2/backends/crypto_impl.py +723 -0
- mplang/v2/backends/field_impl.py +454 -0
- mplang/v2/backends/func_impl.py +107 -0
- mplang/v2/backends/phe_impl.py +148 -0
- mplang/v2/backends/simp_design.md +136 -0
- mplang/v2/backends/simp_driver/__init__.py +41 -0
- mplang/v2/backends/simp_driver/http.py +168 -0
- mplang/v2/backends/simp_driver/mem.py +280 -0
- mplang/v2/backends/simp_driver/ops.py +135 -0
- mplang/v2/backends/simp_driver/state.py +60 -0
- mplang/v2/backends/simp_driver/values.py +52 -0
- mplang/v2/backends/simp_worker/__init__.py +29 -0
- mplang/v2/backends/simp_worker/http.py +354 -0
- mplang/v2/backends/simp_worker/mem.py +102 -0
- mplang/v2/backends/simp_worker/ops.py +167 -0
- mplang/v2/backends/simp_worker/state.py +49 -0
- mplang/v2/backends/spu_impl.py +275 -0
- mplang/v2/backends/spu_state.py +187 -0
- mplang/v2/backends/store_impl.py +62 -0
- mplang/v2/backends/table_impl.py +838 -0
- mplang/v2/backends/tee_impl.py +215 -0
- mplang/v2/backends/tensor_impl.py +519 -0
- mplang/v2/cli.py +603 -0
- mplang/v2/cli_guide.md +122 -0
- mplang/v2/dialects/__init__.py +36 -0
- mplang/v2/dialects/bfv.py +665 -0
- mplang/v2/dialects/crypto.py +689 -0
- mplang/v2/dialects/dtypes.py +378 -0
- mplang/v2/dialects/field.py +210 -0
- mplang/v2/dialects/func.py +135 -0
- mplang/v2/dialects/phe.py +723 -0
- mplang/v2/dialects/simp.py +944 -0
- mplang/v2/dialects/spu.py +349 -0
- mplang/v2/dialects/store.py +63 -0
- mplang/v2/dialects/table.py +407 -0
- mplang/v2/dialects/tee.py +346 -0
- mplang/v2/dialects/tensor.py +1175 -0
- mplang/v2/edsl/README.md +279 -0
- mplang/v2/edsl/__init__.py +99 -0
- mplang/v2/edsl/context.py +311 -0
- mplang/v2/edsl/graph.py +463 -0
- mplang/v2/edsl/jit.py +62 -0
- mplang/v2/edsl/object.py +53 -0
- mplang/v2/edsl/primitive.py +284 -0
- mplang/v2/edsl/printer.py +119 -0
- mplang/v2/edsl/registry.py +207 -0
- mplang/v2/edsl/serde.py +375 -0
- mplang/v2/edsl/tracer.py +614 -0
- mplang/v2/edsl/typing.py +816 -0
- mplang/v2/kernels/Makefile +30 -0
- mplang/v2/kernels/__init__.py +23 -0
- mplang/v2/kernels/gf128.cpp +148 -0
- mplang/v2/kernels/ldpc.cpp +82 -0
- mplang/v2/kernels/okvs.cpp +283 -0
- mplang/v2/kernels/okvs_opt.cpp +291 -0
- mplang/v2/kernels/py_kernels.py +398 -0
- mplang/v2/libs/collective.py +330 -0
- mplang/v2/libs/device/__init__.py +51 -0
- mplang/v2/libs/device/api.py +813 -0
- mplang/v2/libs/device/cluster.py +352 -0
- mplang/v2/libs/ml/__init__.py +23 -0
- mplang/v2/libs/ml/sgb.py +1861 -0
- mplang/v2/libs/mpc/__init__.py +41 -0
- mplang/v2/libs/mpc/_utils.py +99 -0
- mplang/v2/libs/mpc/analytics/__init__.py +35 -0
- mplang/v2/libs/mpc/analytics/aggregation.py +372 -0
- mplang/v2/libs/mpc/analytics/groupby.md +99 -0
- mplang/v2/libs/mpc/analytics/groupby.py +331 -0
- mplang/v2/libs/mpc/analytics/permutation.py +386 -0
- mplang/v2/libs/mpc/common/constants.py +39 -0
- mplang/v2/libs/mpc/ot/__init__.py +32 -0
- mplang/v2/libs/mpc/ot/base.py +222 -0
- mplang/v2/libs/mpc/ot/extension.py +477 -0
- mplang/v2/libs/mpc/ot/silent.py +217 -0
- mplang/v2/libs/mpc/psi/__init__.py +40 -0
- mplang/v2/libs/mpc/psi/cuckoo.py +228 -0
- mplang/v2/libs/mpc/psi/okvs.py +49 -0
- mplang/v2/libs/mpc/psi/okvs_gct.py +79 -0
- mplang/v2/libs/mpc/psi/oprf.py +310 -0
- mplang/v2/libs/mpc/psi/rr22.py +344 -0
- mplang/v2/libs/mpc/psi/unbalanced.py +200 -0
- mplang/v2/libs/mpc/vole/__init__.py +31 -0
- mplang/v2/libs/mpc/vole/gilboa.py +327 -0
- mplang/v2/libs/mpc/vole/ldpc.py +383 -0
- mplang/v2/libs/mpc/vole/silver.py +336 -0
- mplang/v2/runtime/__init__.py +15 -0
- mplang/v2/runtime/dialect_state.py +41 -0
- mplang/v2/runtime/interpreter.py +871 -0
- mplang/v2/runtime/object_store.py +194 -0
- mplang/v2/runtime/value.py +141 -0
- {mplang_nightly-0.1.dev158.dist-info → mplang_nightly-0.1.dev268.dist-info}/METADATA +24 -17
- mplang_nightly-0.1.dev268.dist-info/RECORD +180 -0
- {mplang_nightly-0.1.dev158.dist-info → mplang_nightly-0.1.dev268.dist-info}/WHEEL +1 -1
- mplang/core/__init__.py +0 -92
- mplang/device.py +0 -340
- mplang/kernels/builtin.py +0 -207
- mplang/ops/crypto.py +0 -109
- mplang/ops/ibis_cc.py +0 -139
- mplang/ops/sql.py +0 -61
- mplang/protos/v1alpha1/mpir_pb2_grpc.py +0 -3
- mplang/runtime/link_comm.py +0 -131
- mplang/simp/smpc.py +0 -201
- mplang/utils/table_utils.py +0 -73
- mplang_nightly-0.1.dev158.dist-info/RECORD +0 -77
- /mplang/{core → v1/core}/mask.py +0 -0
- /mplang/{protos → v1/protos}/v1alpha1/mpir_pb2.py +0 -0
- /mplang/{runtime → v1/runtime}/exceptions.py +0 -0
- /mplang/{runtime → v1/runtime}/http_api.md +0 -0
- /mplang/{kernels → v1/simp}/__init__.py +0 -0
- /mplang/{utils → v1/utils}/__init__.py +0 -0
- /mplang/{utils → v1/utils}/crypto.py +0 -0
- /mplang/{utils → v1/utils}/func_utils.py +0 -0
- /mplang/{utils → v1/utils}/spu_utils.py +0 -0
- {mplang_nightly-0.1.dev158.dist-info → mplang_nightly-0.1.dev268.dist-info}/entry_points.txt +0 -0
- {mplang_nightly-0.1.dev158.dist-info → mplang_nightly-0.1.dev268.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,349 @@
|
|
|
1
|
+
# Copyright 2025 Ant Group Co., Ltd.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
"""SPU (Secure Processing Unit) dialect for the EDSL.
|
|
16
|
+
|
|
17
|
+
This dialect implements an "Encrypted Virtual Machine" model where the SPU
|
|
18
|
+
is treated as a logical device composed of multiple parties. It leverages
|
|
19
|
+
the `simp` dialect for data movement (encryption/decryption) and execution.
|
|
20
|
+
|
|
21
|
+
Concepts:
|
|
22
|
+
- SPUDevice: Represents a set of parties forming the SPU.
|
|
23
|
+
- make_shares: Generates secret shares on the source party.
|
|
24
|
+
- reconstruct: Reconstructs secret from shares on the target party.
|
|
25
|
+
- run_jax: Executes JAX computations on the SPU.
|
|
26
|
+
|
|
27
|
+
Example:
|
|
28
|
+
```python
|
|
29
|
+
import jax.numpy as jnp
|
|
30
|
+
from mplang.v2.dialects import spu, tensor, simp
|
|
31
|
+
import mplang.v2.edsl.typing as elt
|
|
32
|
+
|
|
33
|
+
# 0. Setup
|
|
34
|
+
spu_device = spu.SPUDevice(parties=(0, 1, 2))
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
# 1. Define computation
|
|
38
|
+
def secure_add(x, y):
|
|
39
|
+
return x + y
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
# 2. Encrypt (Public -> SPU)
|
|
43
|
+
# Assume x, y are on party 0
|
|
44
|
+
# Generate shares locally
|
|
45
|
+
x_shares = spu.make_shares(x, count=3)
|
|
46
|
+
y_shares = spu.make_shares(y, count=3)
|
|
47
|
+
|
|
48
|
+
# Distribute shares to SPU parties
|
|
49
|
+
x_dist = []
|
|
50
|
+
y_dist = []
|
|
51
|
+
for i, target in enumerate(spu_device.parties):
|
|
52
|
+
x_dist.append(simp.shuffle_static(x_shares[i], {target: 0}))
|
|
53
|
+
y_dist.append(simp.shuffle_static(y_shares[i], {target: 0}))
|
|
54
|
+
|
|
55
|
+
# Converge to logical SPU variables
|
|
56
|
+
x_enc = simp.converge(*x_dist)
|
|
57
|
+
y_enc = simp.converge(*y_dist)
|
|
58
|
+
|
|
59
|
+
# 3. Execute (SPU -> SPU)
|
|
60
|
+
z_enc = spu.run_jax(secure_add, spu_device.parties, x_enc, y_enc)
|
|
61
|
+
|
|
62
|
+
# 4. Decrypt (SPU -> Public)
|
|
63
|
+
# Gather shares to party 0
|
|
64
|
+
z_shares = []
|
|
65
|
+
for source in spu_device.parties:
|
|
66
|
+
# Extract share from logical variable
|
|
67
|
+
share = simp.pcall_static((source,), lambda x: x, z_enc)
|
|
68
|
+
# Move to target
|
|
69
|
+
z_shares.append(simp.shuffle_static(share, {0: source}))
|
|
70
|
+
|
|
71
|
+
# Reconstruct
|
|
72
|
+
z = spu.reconstruct(tuple(z_shares))
|
|
73
|
+
```
|
|
74
|
+
"""
|
|
75
|
+
|
|
76
|
+
from __future__ import annotations
|
|
77
|
+
|
|
78
|
+
from collections.abc import Callable
|
|
79
|
+
from dataclasses import dataclass
|
|
80
|
+
from typing import Any, ClassVar, Literal, cast
|
|
81
|
+
|
|
82
|
+
import spu.utils.frontend as spu_fe
|
|
83
|
+
from jax import ShapeDtypeStruct
|
|
84
|
+
from jax.tree_util import tree_flatten, tree_unflatten
|
|
85
|
+
|
|
86
|
+
import mplang.v2.edsl as el
|
|
87
|
+
import mplang.v2.edsl.typing as elt
|
|
88
|
+
from mplang.v1.utils.func_utils import normalize_fn
|
|
89
|
+
from mplang.v2.dialects import dtypes
|
|
90
|
+
from mplang.v2.edsl import serde
|
|
91
|
+
|
|
92
|
+
# ==============================================================================
|
|
93
|
+
# --- Configuration
|
|
94
|
+
# ==============================================================================
|
|
95
|
+
|
|
96
|
+
|
|
97
|
+
@serde.register_class
|
|
98
|
+
@dataclass(frozen=True)
|
|
99
|
+
class SPUConfig:
|
|
100
|
+
"""SPU configuration (subset of libspu.RuntimeConfig).
|
|
101
|
+
|
|
102
|
+
Attributes:
|
|
103
|
+
protocol: SPU protocol (e.g., "SEMI2K", "ABY3").
|
|
104
|
+
field: SPU field type (e.g., "FM64", "FM128").
|
|
105
|
+
fxp_fraction_bits: Fixed-point fraction bits.
|
|
106
|
+
"""
|
|
107
|
+
|
|
108
|
+
protocol: str = "SEMI2K"
|
|
109
|
+
field: str = "FM128"
|
|
110
|
+
fxp_fraction_bits: int = 18
|
|
111
|
+
|
|
112
|
+
@classmethod
|
|
113
|
+
def from_dict(cls, d: dict[str, Any]) -> SPUConfig:
|
|
114
|
+
return cls(
|
|
115
|
+
protocol=d.get("protocol", "SEMI2K"),
|
|
116
|
+
field=d.get("field", "FM128"),
|
|
117
|
+
fxp_fraction_bits=d.get("fxp_fraction_bits", 18),
|
|
118
|
+
)
|
|
119
|
+
|
|
120
|
+
# --- Serde methods ---
|
|
121
|
+
_serde_kind: ClassVar[str] = "spu.SPUConfig"
|
|
122
|
+
|
|
123
|
+
def to_json(self) -> dict[str, Any]:
|
|
124
|
+
return {
|
|
125
|
+
"protocol": self.protocol,
|
|
126
|
+
"field": self.field,
|
|
127
|
+
"fxp_fraction_bits": self.fxp_fraction_bits,
|
|
128
|
+
}
|
|
129
|
+
|
|
130
|
+
@classmethod
|
|
131
|
+
def from_json(cls, data: dict[str, Any]) -> SPUConfig:
|
|
132
|
+
return cls(
|
|
133
|
+
protocol=data["protocol"],
|
|
134
|
+
field=data["field"],
|
|
135
|
+
fxp_fraction_bits=data["fxp_fraction_bits"],
|
|
136
|
+
)
|
|
137
|
+
|
|
138
|
+
|
|
139
|
+
# ==============================================================================
|
|
140
|
+
# --- Primitives (Local Operations)
|
|
141
|
+
# ==============================================================================
|
|
142
|
+
|
|
143
|
+
# These primitives operate locally on a single party.
|
|
144
|
+
# They are used inside simp.pcall to construct the distributed protocols.
|
|
145
|
+
|
|
146
|
+
makeshares_p = el.Primitive[tuple[el.Object, ...]]("spu.makeshares")
|
|
147
|
+
reconstruct_p = el.Primitive[el.Object]("spu.reconstruct")
|
|
148
|
+
exec_p = el.Primitive[Any]("spu.exec")
|
|
149
|
+
|
|
150
|
+
|
|
151
|
+
@makeshares_p.def_abstract_eval
|
|
152
|
+
def _makeshares_ae(
|
|
153
|
+
data: elt.TensorType, *, count: int, config: SPUConfig
|
|
154
|
+
) -> tuple[elt.SSType, ...]:
|
|
155
|
+
"""Split a tensor into `count` secret shares."""
|
|
156
|
+
if not isinstance(data, elt.TensorType):
|
|
157
|
+
raise TypeError(f"makeshares expects TensorType, got {data}")
|
|
158
|
+
# Shares have same shape/dtype as data (simplified additive sharing)
|
|
159
|
+
# Return SS-typed shares directly
|
|
160
|
+
return tuple(elt.SS(data) for _ in range(count))
|
|
161
|
+
|
|
162
|
+
|
|
163
|
+
@reconstruct_p.def_abstract_eval
|
|
164
|
+
def _reconstruct_ae(*shares: elt.SSType, config: SPUConfig) -> elt.TensorType:
|
|
165
|
+
"""Reconstruct a tensor from shares."""
|
|
166
|
+
if not shares:
|
|
167
|
+
raise ValueError("reconstruct requires at least one share")
|
|
168
|
+
first = shares[0]
|
|
169
|
+
if not isinstance(first, elt.SSType):
|
|
170
|
+
raise TypeError(f"reconstruct expects SSType shares, got {first}")
|
|
171
|
+
if not isinstance(first.pt_type, elt.TensorType):
|
|
172
|
+
raise TypeError(f"reconstruct expects SS[Tensor], got {first}")
|
|
173
|
+
# Return the underlying plaintext type
|
|
174
|
+
return first.pt_type
|
|
175
|
+
|
|
176
|
+
|
|
177
|
+
# Visibility type for IR attrs (string-based, mapped to libspu.Visibility at runtime)
|
|
178
|
+
Visibility = Literal["secret", "public", "private"]
|
|
179
|
+
|
|
180
|
+
|
|
181
|
+
@exec_p.def_abstract_eval
|
|
182
|
+
def _exec_ae(
|
|
183
|
+
*args: elt.SSType | elt.TensorType,
|
|
184
|
+
executable: bytes,
|
|
185
|
+
input_vis: list[Visibility],
|
|
186
|
+
output_vis: list[Visibility],
|
|
187
|
+
output_shapes: list[tuple[int, ...]],
|
|
188
|
+
output_dtypes: list[elt.ScalarType],
|
|
189
|
+
input_names: list[str],
|
|
190
|
+
output_names: list[str],
|
|
191
|
+
config: SPUConfig,
|
|
192
|
+
) -> tuple[elt.SSType, ...] | elt.SSType:
|
|
193
|
+
"""Execute SPU kernel on shares."""
|
|
194
|
+
# Validate inputs are SS types or Tensor types
|
|
195
|
+
for arg in args:
|
|
196
|
+
if not (isinstance(arg, elt.SSType) or isinstance(arg, elt.TensorType)):
|
|
197
|
+
raise TypeError(f"spu.exec expects SSType or TensorType inputs, got {arg}")
|
|
198
|
+
|
|
199
|
+
# Outputs are SS[Tensor]
|
|
200
|
+
outputs: list[elt.SSType[Any]] = []
|
|
201
|
+
for shape, dtype in zip(output_shapes, output_dtypes, strict=True):
|
|
202
|
+
outputs.append(elt.SS(elt.Tensor(dtype, shape)))
|
|
203
|
+
|
|
204
|
+
if len(outputs) == 1:
|
|
205
|
+
return outputs[0]
|
|
206
|
+
return tuple(outputs)
|
|
207
|
+
|
|
208
|
+
|
|
209
|
+
# ==============================================================================
|
|
210
|
+
# --- High-Level API (Distributed Protocols)
|
|
211
|
+
# ==============================================================================
|
|
212
|
+
|
|
213
|
+
|
|
214
|
+
def make_shares(
|
|
215
|
+
config: SPUConfig, data: el.Object, count: int
|
|
216
|
+
) -> tuple[el.Object, ...]:
|
|
217
|
+
"""Generate shares locally (no transfer).
|
|
218
|
+
|
|
219
|
+
This function should be called inside a `simp.pcall` region.
|
|
220
|
+
|
|
221
|
+
Args:
|
|
222
|
+
config: SPU configuration.
|
|
223
|
+
data: Local TensorType object.
|
|
224
|
+
count: Number of shares to generate.
|
|
225
|
+
|
|
226
|
+
Returns:
|
|
227
|
+
Tuple of SSType objects (shares).
|
|
228
|
+
"""
|
|
229
|
+
return makeshares_p.bind(data, count=count, config=config)
|
|
230
|
+
|
|
231
|
+
|
|
232
|
+
def reconstruct(config: SPUConfig, shares: tuple[el.Object, ...]) -> el.Object:
|
|
233
|
+
"""Reconstruct data from shares locally (no transfer).
|
|
234
|
+
|
|
235
|
+
This function should be called inside a `simp.pcall` region.
|
|
236
|
+
|
|
237
|
+
Args:
|
|
238
|
+
config: SPU configuration.
|
|
239
|
+
shares: Tuple of SSType objects (shares).
|
|
240
|
+
|
|
241
|
+
Returns:
|
|
242
|
+
TensorType object (reconstructed).
|
|
243
|
+
"""
|
|
244
|
+
return reconstruct_p.bind(*shares, config=config)
|
|
245
|
+
|
|
246
|
+
|
|
247
|
+
def run_jax(config: SPUConfig, fn: Callable, *args: Any, **kwargs: Any) -> Any:
|
|
248
|
+
"""Execute a function on SPU locally.
|
|
249
|
+
|
|
250
|
+
This function should be called inside a `simp.pcall` region.
|
|
251
|
+
It compiles the function and executes it using the SPU runtime.
|
|
252
|
+
|
|
253
|
+
Args:
|
|
254
|
+
config: SPU configuration.
|
|
255
|
+
fn: The function to execute.
|
|
256
|
+
*args: Positional arguments (SSType or TensorType).
|
|
257
|
+
**kwargs: Keyword arguments.
|
|
258
|
+
"""
|
|
259
|
+
|
|
260
|
+
# 1. Inspect inputs
|
|
261
|
+
# Use normalize_fn to separate EDSL objects (variables) from raw values (immediates)
|
|
262
|
+
def is_variable(arg: Any) -> bool:
|
|
263
|
+
return isinstance(arg, el.Object)
|
|
264
|
+
|
|
265
|
+
normalized_fn, in_vars = normalize_fn(fn, args, kwargs, is_variable)
|
|
266
|
+
|
|
267
|
+
# Validate inputs
|
|
268
|
+
for arg in in_vars:
|
|
269
|
+
if not (
|
|
270
|
+
isinstance(arg.type, elt.SSType) or isinstance(arg.type, elt.TensorType)
|
|
271
|
+
):
|
|
272
|
+
raise TypeError(
|
|
273
|
+
f"spu.run_jax inputs must be SSType or TensorType, got {arg.type}"
|
|
274
|
+
)
|
|
275
|
+
|
|
276
|
+
# 2. Prepare for compilation
|
|
277
|
+
jax_args_flat = []
|
|
278
|
+
input_vis: list[Visibility] = [] # String-based visibility for IR
|
|
279
|
+
|
|
280
|
+
for arg in in_vars:
|
|
281
|
+
if isinstance(arg.type, elt.SSType):
|
|
282
|
+
pt_type = arg.type.pt_type
|
|
283
|
+
vis: Visibility = "secret"
|
|
284
|
+
elif isinstance(arg.type, elt.TensorType):
|
|
285
|
+
pt_type = arg.type
|
|
286
|
+
vis = "public"
|
|
287
|
+
else:
|
|
288
|
+
raise TypeError(f"Unsupported input type: {arg.type}")
|
|
289
|
+
|
|
290
|
+
if not isinstance(pt_type, elt.TensorType):
|
|
291
|
+
raise TypeError(f"spu.run_jax inputs must be Tensor-based, got {pt_type}")
|
|
292
|
+
|
|
293
|
+
# Map to JAX
|
|
294
|
+
jax_dtype = dtypes.to_jax(cast(elt.ScalarType, pt_type.element_type))
|
|
295
|
+
shape = tuple(d if d != -1 else 1 for d in pt_type.shape)
|
|
296
|
+
|
|
297
|
+
jax_args_flat.append(ShapeDtypeStruct(shape, jax_dtype))
|
|
298
|
+
input_vis.append(vis)
|
|
299
|
+
|
|
300
|
+
# 3. Compile
|
|
301
|
+
# Map string visibility to libspu.Visibility for spu_fe.compile
|
|
302
|
+
# Import libspu only at compile time, not stored in IR
|
|
303
|
+
import spu.libspu as libspu
|
|
304
|
+
|
|
305
|
+
def vis_to_libspu(v: Visibility) -> libspu.Visibility:
|
|
306
|
+
return (
|
|
307
|
+
libspu.Visibility.VIS_SECRET
|
|
308
|
+
if v == "secret"
|
|
309
|
+
else libspu.Visibility.VIS_PUBLIC
|
|
310
|
+
)
|
|
311
|
+
|
|
312
|
+
# Note: normalized_fn takes a list of variables as input
|
|
313
|
+
executable, output_info = spu_fe.compile(
|
|
314
|
+
spu_fe.Kind.JAX,
|
|
315
|
+
normalized_fn,
|
|
316
|
+
[jax_args_flat],
|
|
317
|
+
{},
|
|
318
|
+
input_names=[f"in{i}" for i in range(len(in_vars))],
|
|
319
|
+
input_vis=[vis_to_libspu(v) for v in input_vis],
|
|
320
|
+
outputNameGen=lambda outs: [f"out{i}" for i in range(len(outs))],
|
|
321
|
+
)
|
|
322
|
+
|
|
323
|
+
# 4. Execute SPU Kernel
|
|
324
|
+
flat_outputs_info, out_tree = tree_flatten(output_info)
|
|
325
|
+
output_shapes = [out.shape for out in flat_outputs_info]
|
|
326
|
+
|
|
327
|
+
output_dtypes = [dtypes.from_dtype(out.dtype) for out in flat_outputs_info]
|
|
328
|
+
output_vis_list: list[Visibility] = ["secret"] * len(flat_outputs_info)
|
|
329
|
+
|
|
330
|
+
res_shares = exec_p.bind(
|
|
331
|
+
*in_vars,
|
|
332
|
+
executable=executable.code,
|
|
333
|
+
input_vis=input_vis,
|
|
334
|
+
output_vis=output_vis_list,
|
|
335
|
+
output_shapes=output_shapes,
|
|
336
|
+
output_dtypes=output_dtypes,
|
|
337
|
+
input_names=executable.input_names,
|
|
338
|
+
output_names=executable.output_names,
|
|
339
|
+
config=config,
|
|
340
|
+
)
|
|
341
|
+
|
|
342
|
+
# 5. Unflatten results
|
|
343
|
+
if isinstance(res_shares, (tuple, list)):
|
|
344
|
+
leaves = list(res_shares)
|
|
345
|
+
else:
|
|
346
|
+
leaves = [res_shares]
|
|
347
|
+
final_result = tree_unflatten(out_tree, leaves)
|
|
348
|
+
|
|
349
|
+
return final_result
|
|
@@ -0,0 +1,63 @@
|
|
|
1
|
+
# Copyright 2025 Ant Group Co., Ltd.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
"""Store dialect: save/load primitives for internal state."""
|
|
16
|
+
|
|
17
|
+
from __future__ import annotations
|
|
18
|
+
|
|
19
|
+
import mplang.v2.edsl as el
|
|
20
|
+
import mplang.v2.edsl.typing as elt
|
|
21
|
+
|
|
22
|
+
save_p: el.Primitive[el.Object] = el.Primitive("store.save")
|
|
23
|
+
load_p: el.Primitive[el.Object] = el.Primitive("store.load")
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
@save_p.def_abstract_eval
|
|
27
|
+
def _save_abstract(obj: elt.BaseType, *, uri_base: str) -> elt.BaseType:
|
|
28
|
+
# Save is an identity operation: returns the input object type
|
|
29
|
+
return obj
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
@load_p.def_abstract_eval
|
|
33
|
+
def _load_abstract(*, uri_base: str, expected_type: elt.BaseType) -> elt.BaseType:
|
|
34
|
+
# Load returns an object of the expected type
|
|
35
|
+
return expected_type
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
def save(obj: el.Object, uri_base: str) -> el.Object:
|
|
39
|
+
"""Save an object to persistent storage.
|
|
40
|
+
|
|
41
|
+
This is an SPMD operation. Each party holding the object will save its
|
|
42
|
+
local portion to the location specified by `uri_base`.
|
|
43
|
+
|
|
44
|
+
Returns:
|
|
45
|
+
The input object (identity), allowing for dependency chaining.
|
|
46
|
+
"""
|
|
47
|
+
return save_p.bind(obj, uri_base=uri_base)
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
def load(uri_base: str, expected_type: elt.BaseType) -> el.Object:
|
|
51
|
+
"""Load an object from persistent storage.
|
|
52
|
+
|
|
53
|
+
This is an SPMD operation. Each party will load its local portion from
|
|
54
|
+
a path derived from `uri_base`.
|
|
55
|
+
|
|
56
|
+
Args:
|
|
57
|
+
uri_base: Base URI for the checkpoint package.
|
|
58
|
+
expected_type: The type of the object to load (reconstructed from manifest).
|
|
59
|
+
|
|
60
|
+
Returns:
|
|
61
|
+
The loaded object.
|
|
62
|
+
"""
|
|
63
|
+
return load_p.bind(uri_base=uri_base, expected_type=expected_type)
|