mplang-nightly 0.1.dev269__py3-none-any.whl → 0.1.dev271__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mplang/__init__.py +391 -17
- mplang/{v2/backends → backends}/__init__.py +9 -7
- mplang/{v2/backends → backends}/bfv_impl.py +6 -6
- mplang/{v2/backends → backends}/crypto_impl.py +6 -6
- mplang/{v2/backends → backends}/field_impl.py +5 -5
- mplang/{v2/backends → backends}/func_impl.py +4 -4
- mplang/{v2/backends → backends}/phe_impl.py +3 -3
- mplang/{v2/backends → backends}/simp_design.md +1 -1
- mplang/{v2/backends → backends}/simp_driver/__init__.py +5 -5
- mplang/{v2/backends → backends}/simp_driver/http.py +8 -8
- mplang/{v2/backends → backends}/simp_driver/mem.py +9 -9
- mplang/{v2/backends → backends}/simp_driver/ops.py +4 -4
- mplang/{v2/backends → backends}/simp_driver/state.py +2 -2
- mplang/{v2/backends → backends}/simp_driver/values.py +2 -2
- mplang/{v2/backends → backends}/simp_worker/__init__.py +3 -3
- mplang/{v2/backends → backends}/simp_worker/http.py +10 -10
- mplang/{v2/backends → backends}/simp_worker/mem.py +1 -1
- mplang/{v2/backends → backends}/simp_worker/ops.py +5 -5
- mplang/{v2/backends → backends}/simp_worker/state.py +2 -4
- mplang/{v2/backends → backends}/spu_impl.py +8 -8
- mplang/{v2/backends → backends}/spu_state.py +4 -4
- mplang/{v2/backends → backends}/store_impl.py +3 -3
- mplang/{v2/backends → backends}/table_impl.py +8 -8
- mplang/{v2/backends → backends}/tee_impl.py +6 -6
- mplang/{v2/backends → backends}/tensor_impl.py +6 -6
- mplang/{v2/cli.py → cli.py} +9 -9
- mplang/{v2/cli_guide.md → cli_guide.md} +12 -12
- mplang/{v2/dialects → dialects}/__init__.py +5 -5
- mplang/{v2/dialects → dialects}/bfv.py +6 -6
- mplang/{v2/dialects → dialects}/crypto.py +5 -5
- mplang/{v2/dialects → dialects}/dtypes.py +2 -2
- mplang/{v2/dialects → dialects}/field.py +3 -3
- mplang/{v2/dialects → dialects}/func.py +2 -2
- mplang/{v2/dialects → dialects}/phe.py +6 -6
- mplang/{v2/dialects → dialects}/simp.py +6 -6
- mplang/{v2/dialects → dialects}/spu.py +7 -7
- mplang/{v2/dialects → dialects}/store.py +2 -2
- mplang/{v2/dialects → dialects}/table.py +3 -3
- mplang/{v2/dialects → dialects}/tee.py +6 -6
- mplang/{v2/dialects → dialects}/tensor.py +5 -5
- mplang/{v2/edsl → edsl}/__init__.py +3 -3
- mplang/{v2/edsl → edsl}/context.py +6 -6
- mplang/{v2/edsl → edsl}/graph.py +5 -5
- mplang/{v2/edsl → edsl}/jit.py +2 -2
- mplang/{v2/edsl → edsl}/object.py +1 -1
- mplang/{v2/edsl → edsl}/primitive.py +5 -5
- mplang/{v2/edsl → edsl}/printer.py +1 -1
- mplang/{v2/edsl → edsl}/serde.py +1 -1
- mplang/{v2/edsl → edsl}/tracer.py +7 -7
- mplang/{v2/edsl → edsl}/typing.py +1 -1
- mplang/{v2/kernels → kernels}/ldpc.cpp +13 -13
- mplang/{v2/kernels → kernels}/okvs.cpp +4 -4
- mplang/{v2/kernels → kernels}/okvs_opt.cpp +31 -31
- mplang/{v2/kernels → kernels}/py_kernels.py +1 -1
- mplang/{v2/libs → libs}/collective.py +5 -5
- mplang/{v2/libs → libs}/device/__init__.py +1 -1
- mplang/{v2/libs → libs}/device/api.py +12 -12
- mplang/{v2/libs → libs}/ml/__init__.py +1 -1
- mplang/{v2/libs → libs}/ml/sgb.py +4 -4
- mplang/{v2/libs → libs}/mpc/__init__.py +3 -3
- mplang/{v2/libs → libs}/mpc/_utils.py +2 -2
- mplang/{v2/libs → libs}/mpc/analytics/aggregation.py +1 -1
- mplang/{v2/libs → libs}/mpc/analytics/groupby.py +2 -2
- mplang/{v2/libs → libs}/mpc/analytics/permutation.py +3 -3
- mplang/{v2/libs → libs}/mpc/ot/base.py +3 -3
- mplang/{v2/libs → libs}/mpc/ot/extension.py +2 -2
- mplang/{v2/libs → libs}/mpc/ot/silent.py +4 -4
- mplang/{v2/libs → libs}/mpc/psi/cuckoo.py +3 -3
- mplang/{v2/libs → libs}/mpc/psi/okvs.py +1 -1
- mplang/{v2/libs → libs}/mpc/psi/okvs_gct.py +3 -3
- mplang/{v2/libs → libs}/mpc/psi/oprf.py +3 -3
- mplang/{v2/libs → libs}/mpc/psi/rr22.py +7 -7
- mplang/{v2/libs → libs}/mpc/psi/unbalanced.py +4 -4
- mplang/{v2/libs → libs}/mpc/vole/gilboa.py +3 -3
- mplang/{v2/libs → libs}/mpc/vole/ldpc.py +2 -2
- mplang/{v2/libs → libs}/mpc/vole/silver.py +6 -6
- mplang/{v2/runtime → runtime}/interpreter.py +11 -11
- mplang/{v2/runtime → runtime}/value.py +2 -2
- mplang/{v1/runtime → utils}/__init__.py +18 -15
- mplang/{v1/utils → utils}/func_utils.py +1 -1
- {mplang_nightly-0.1.dev269.dist-info → mplang_nightly-0.1.dev271.dist-info}/METADATA +2 -2
- mplang_nightly-0.1.dev271.dist-info/RECORD +102 -0
- mplang/v1/__init__.py +0 -157
- mplang/v1/_device.py +0 -602
- mplang/v1/analysis/__init__.py +0 -37
- mplang/v1/analysis/diagram.py +0 -567
- mplang/v1/core/__init__.py +0 -157
- mplang/v1/core/cluster.py +0 -343
- mplang/v1/core/comm.py +0 -281
- mplang/v1/core/context_mgr.py +0 -50
- mplang/v1/core/dtypes.py +0 -335
- mplang/v1/core/expr/__init__.py +0 -80
- mplang/v1/core/expr/ast.py +0 -542
- mplang/v1/core/expr/evaluator.py +0 -581
- mplang/v1/core/expr/printer.py +0 -285
- mplang/v1/core/expr/transformer.py +0 -141
- mplang/v1/core/expr/utils.py +0 -78
- mplang/v1/core/expr/visitor.py +0 -85
- mplang/v1/core/expr/walk.py +0 -387
- mplang/v1/core/interp.py +0 -160
- mplang/v1/core/mask.py +0 -325
- mplang/v1/core/mpir.py +0 -965
- mplang/v1/core/mpobject.py +0 -117
- mplang/v1/core/mptype.py +0 -407
- mplang/v1/core/pfunc.py +0 -130
- mplang/v1/core/primitive.py +0 -877
- mplang/v1/core/table.py +0 -218
- mplang/v1/core/tensor.py +0 -75
- mplang/v1/core/tracer.py +0 -383
- mplang/v1/host.py +0 -130
- mplang/v1/kernels/__init__.py +0 -41
- mplang/v1/kernels/base.py +0 -125
- mplang/v1/kernels/basic.py +0 -240
- mplang/v1/kernels/context.py +0 -369
- mplang/v1/kernels/crypto.py +0 -122
- mplang/v1/kernels/fhe.py +0 -858
- mplang/v1/kernels/mock_tee.py +0 -72
- mplang/v1/kernels/phe.py +0 -1864
- mplang/v1/kernels/spu.py +0 -341
- mplang/v1/kernels/sql_duckdb.py +0 -44
- mplang/v1/kernels/stablehlo.py +0 -90
- mplang/v1/kernels/value.py +0 -626
- mplang/v1/ops/__init__.py +0 -35
- mplang/v1/ops/base.py +0 -424
- mplang/v1/ops/basic.py +0 -294
- mplang/v1/ops/crypto.py +0 -262
- mplang/v1/ops/fhe.py +0 -272
- mplang/v1/ops/jax_cc.py +0 -147
- mplang/v1/ops/nnx_cc.py +0 -168
- mplang/v1/ops/phe.py +0 -216
- mplang/v1/ops/spu.py +0 -151
- mplang/v1/ops/sql_cc.py +0 -303
- mplang/v1/ops/tee.py +0 -36
- mplang/v1/protos/v1alpha1/mpir_pb2.py +0 -63
- mplang/v1/protos/v1alpha1/mpir_pb2.pyi +0 -557
- mplang/v1/protos/v1alpha1/value_pb2.py +0 -34
- mplang/v1/protos/v1alpha1/value_pb2.pyi +0 -169
- mplang/v1/runtime/channel.py +0 -230
- mplang/v1/runtime/cli.py +0 -451
- mplang/v1/runtime/client.py +0 -456
- mplang/v1/runtime/communicator.py +0 -131
- mplang/v1/runtime/data_providers.py +0 -303
- mplang/v1/runtime/driver.py +0 -324
- mplang/v1/runtime/exceptions.py +0 -27
- mplang/v1/runtime/http_api.md +0 -56
- mplang/v1/runtime/link_comm.py +0 -196
- mplang/v1/runtime/server.py +0 -501
- mplang/v1/runtime/session.py +0 -270
- mplang/v1/runtime/simulation.py +0 -324
- mplang/v1/simp/__init__.py +0 -13
- mplang/v1/simp/api.py +0 -353
- mplang/v1/simp/mpi.py +0 -131
- mplang/v1/simp/party.py +0 -225
- mplang/v1/simp/random.py +0 -120
- mplang/v1/simp/smpc.py +0 -238
- mplang/v1/utils/__init__.py +0 -13
- mplang/v1/utils/crypto.py +0 -32
- mplang/v1/utils/spu_utils.py +0 -130
- mplang/v1/utils/table_utils.py +0 -185
- mplang/v2/__init__.py +0 -424
- mplang_nightly-0.1.dev269.dist-info/RECORD +0 -180
- /mplang/{v2/backends → backends}/channel.py +0 -0
- /mplang/{v2/edsl → edsl}/README.md +0 -0
- /mplang/{v2/edsl → edsl}/registry.py +0 -0
- /mplang/{v2/kernels → kernels}/Makefile +0 -0
- /mplang/{v2/kernels → kernels}/__init__.py +0 -0
- /mplang/{v2/kernels → kernels}/gf128.cpp +0 -0
- /mplang/{v2/libs → libs}/device/cluster.py +0 -0
- /mplang/{v2/libs → libs}/mpc/analytics/__init__.py +0 -0
- /mplang/{v2/libs → libs}/mpc/analytics/groupby.md +0 -0
- /mplang/{v2/libs → libs}/mpc/common/constants.py +0 -0
- /mplang/{v2/libs → libs}/mpc/ot/__init__.py +0 -0
- /mplang/{v2/libs → libs}/mpc/psi/__init__.py +0 -0
- /mplang/{v2/libs → libs}/mpc/vole/__init__.py +0 -0
- /mplang/{v2/runtime → runtime}/__init__.py +0 -0
- /mplang/{v2/runtime → runtime}/dialect_state.py +0 -0
- /mplang/{v2/runtime → runtime}/object_store.py +0 -0
- {mplang_nightly-0.1.dev269.dist-info → mplang_nightly-0.1.dev271.dist-info}/WHEEL +0 -0
- {mplang_nightly-0.1.dev269.dist-info → mplang_nightly-0.1.dev271.dist-info}/entry_points.txt +0 -0
- {mplang_nightly-0.1.dev269.dist-info → mplang_nightly-0.1.dev271.dist-info}/licenses/LICENSE +0 -0
mplang/v1/ops/fhe.py
DELETED
|
@@ -1,272 +0,0 @@
|
|
|
1
|
-
# Copyright 2025 Ant Group Co., Ltd.
|
|
2
|
-
#
|
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
-
# you may not use this file except in compliance with the License.
|
|
5
|
-
# You may obtain a copy of the License at
|
|
6
|
-
#
|
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
-
#
|
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
-
# See the License for the specific language governing permissions and
|
|
13
|
-
# limitations under the License.
|
|
14
|
-
|
|
15
|
-
from mplang.v1.core import UINT8, TensorType
|
|
16
|
-
from mplang.v1.ops.base import stateless_mod
|
|
17
|
-
|
|
18
|
-
_fhe_MOD = stateless_mod("fhe")
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
@_fhe_MOD.simple_op()
|
|
22
|
-
def keygen(
|
|
23
|
-
*,
|
|
24
|
-
scheme: str = "CKKS",
|
|
25
|
-
poly_modulus_degree: int = 8192,
|
|
26
|
-
coeff_mod_bit_sizes: tuple[int, ...] | None = None,
|
|
27
|
-
global_scale: int | None = None,
|
|
28
|
-
plain_modulus: int | None = None,
|
|
29
|
-
) -> tuple[TensorType, TensorType, TensorType]:
|
|
30
|
-
"""Generate an FHE key pair for Vector backend: returns (private_context, public_context, evaluation_context).
|
|
31
|
-
|
|
32
|
-
Args:
|
|
33
|
-
scheme: FHE scheme to use ("CKKS" for approximate, "BFV" for exact integer)
|
|
34
|
-
poly_modulus_degree: Polynomial modulus degree (default: 8192)
|
|
35
|
-
coeff_mod_bit_sizes: Coefficient modulus bit sizes for CKKS (optional)
|
|
36
|
-
global_scale: Global scale for CKKS (optional)
|
|
37
|
-
plain_modulus: Plain modulus for BFV (optional)
|
|
38
|
-
|
|
39
|
-
Returns:
|
|
40
|
-
Tuple of (private_context, public_context, evaluation_context) represented as UINT8[(-1, 0)]
|
|
41
|
-
|
|
42
|
-
Contexts are represented with a sentinel TensorType UINT8[(-1, 0)] to indicate
|
|
43
|
-
non-structural, backend-only handles.
|
|
44
|
-
|
|
45
|
-
Note: Vector backend only supports 1D data. For multi-dimensional tensors,
|
|
46
|
-
use mplang.ops.fhe instead.
|
|
47
|
-
"""
|
|
48
|
-
if scheme not in ("CKKS", "BFV"):
|
|
49
|
-
raise ValueError("Unsupported scheme. Choose either 'CKKS' or 'BFV'.")
|
|
50
|
-
if scheme == "CKKS":
|
|
51
|
-
assert plain_modulus is None, "plain_modulus is not used in CKKS scheme."
|
|
52
|
-
context_spec = TensorType(UINT8, (-1, 0))
|
|
53
|
-
return context_spec, context_spec, context_spec
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
@_fhe_MOD.simple_op()
|
|
57
|
-
def encrypt(plaintext: TensorType, context: TensorType) -> TensorType:
|
|
58
|
-
"""Encrypt plaintext using FHE Vector backend: returns ciphertext with same semantic type.
|
|
59
|
-
|
|
60
|
-
Args:
|
|
61
|
-
plaintext: Data to encrypt (scalar or 1D vector only)
|
|
62
|
-
context: FHE context (private or public)
|
|
63
|
-
|
|
64
|
-
Returns:
|
|
65
|
-
Ciphertext with same semantic type as plaintext
|
|
66
|
-
|
|
67
|
-
Raises:
|
|
68
|
-
ValueError: If plaintext has more than 1 dimension
|
|
69
|
-
|
|
70
|
-
Note: Vector backend only supports scalars (shape=()) and 1D vectors (shape=(n,)).
|
|
71
|
-
For multi-dimensional data, use mplang.ops.fhe.encrypt instead.
|
|
72
|
-
"""
|
|
73
|
-
_ = context
|
|
74
|
-
if len(plaintext.shape) > 1:
|
|
75
|
-
raise ValueError(
|
|
76
|
-
f"FHE Vector backend only supports 1D data. Got shape {plaintext.shape}. "
|
|
77
|
-
"Use mplang.ops.fhe for multi-dimensional tensors."
|
|
78
|
-
)
|
|
79
|
-
return plaintext
|
|
80
|
-
|
|
81
|
-
|
|
82
|
-
@_fhe_MOD.simple_op()
|
|
83
|
-
def decrypt(ciphertext: TensorType, context: TensorType) -> TensorType:
|
|
84
|
-
"""Decrypt ciphertext using FHE Vector backend: returns plaintext with same semantic type.
|
|
85
|
-
|
|
86
|
-
Args:
|
|
87
|
-
ciphertext: Encrypted data to decrypt (scalar or 1D vector)
|
|
88
|
-
context: FHE context (must be private context with secret key)
|
|
89
|
-
|
|
90
|
-
Returns:
|
|
91
|
-
Plaintext with same semantic type as ciphertext
|
|
92
|
-
|
|
93
|
-
Note: Ciphertext encrypted with public context can be decrypted with
|
|
94
|
-
the corresponding private context.
|
|
95
|
-
"""
|
|
96
|
-
_ = context
|
|
97
|
-
return ciphertext
|
|
98
|
-
|
|
99
|
-
|
|
100
|
-
@_fhe_MOD.simple_op()
|
|
101
|
-
def add(operand1: TensorType, operand2: TensorType) -> TensorType:
|
|
102
|
-
"""Add two FHE operands (ciphertext + ciphertext or ciphertext + plaintext).
|
|
103
|
-
|
|
104
|
-
Args:
|
|
105
|
-
operand1: First operand (ciphertext or plaintext, scalar or 1D vector)
|
|
106
|
-
operand2: Second operand (ciphertext or plaintext, scalar or 1D vector)
|
|
107
|
-
|
|
108
|
-
Returns:
|
|
109
|
-
Result of homomorphic addition
|
|
110
|
-
|
|
111
|
-
Raises:
|
|
112
|
-
ValueError: If operands have incompatible shapes or dtypes
|
|
113
|
-
|
|
114
|
-
Note: At least one operand must be ciphertext. Both operands must have
|
|
115
|
-
the same shape (no broadcasting in Vector backend).
|
|
116
|
-
"""
|
|
117
|
-
assert operand1.dtype == operand2.dtype, (
|
|
118
|
-
f"Operand dtypes must match, got {operand1.dtype} and {operand2.dtype}."
|
|
119
|
-
)
|
|
120
|
-
assert operand1.shape == operand2.shape, (
|
|
121
|
-
f"Operand shapes must match, got {operand1.shape} and {operand2.shape}."
|
|
122
|
-
)
|
|
123
|
-
return operand1
|
|
124
|
-
|
|
125
|
-
|
|
126
|
-
@_fhe_MOD.simple_op()
|
|
127
|
-
def sub(operand1: TensorType, operand2: TensorType) -> TensorType:
|
|
128
|
-
"""Subtract two FHE operands (ciphertext - ciphertext or ciphertext - plaintext).
|
|
129
|
-
|
|
130
|
-
Args:
|
|
131
|
-
operand1: First operand (ciphertext or plaintext, scalar or 1D vector)
|
|
132
|
-
operand2: Second operand (ciphertext or plaintext, scalar or 1D vector)
|
|
133
|
-
|
|
134
|
-
Returns:
|
|
135
|
-
Result of homomorphic subtraction
|
|
136
|
-
|
|
137
|
-
Raises:
|
|
138
|
-
ValueError: If operands have incompatible shapes or dtypes
|
|
139
|
-
|
|
140
|
-
Note: At least one operand must be ciphertext. Both operands must have
|
|
141
|
-
the same shape (no broadcasting in Vector backend).
|
|
142
|
-
"""
|
|
143
|
-
assert operand1.dtype == operand2.dtype, (
|
|
144
|
-
f"Operand dtypes must match, got {operand1.dtype} and {operand2.dtype}."
|
|
145
|
-
)
|
|
146
|
-
assert operand1.shape == operand2.shape, (
|
|
147
|
-
f"Operand shapes must match, got {operand1.shape} and {operand2.shape}."
|
|
148
|
-
)
|
|
149
|
-
return operand1
|
|
150
|
-
|
|
151
|
-
|
|
152
|
-
@_fhe_MOD.simple_op()
|
|
153
|
-
def mul(operand1: TensorType, operand2: TensorType) -> TensorType:
|
|
154
|
-
"""Multiply two FHE operands (ciphertext * ciphertext or ciphertext * plaintext).
|
|
155
|
-
|
|
156
|
-
Args:
|
|
157
|
-
operand1: First operand (ciphertext or plaintext, scalar or 1D vector)
|
|
158
|
-
operand2: Second operand (ciphertext or plaintext, scalar or 1D vector)
|
|
159
|
-
|
|
160
|
-
Returns:
|
|
161
|
-
Result of homomorphic multiplication
|
|
162
|
-
|
|
163
|
-
Raises:
|
|
164
|
-
ValueError: If operands have incompatible shapes or dtypes
|
|
165
|
-
|
|
166
|
-
Note: At least one operand must be ciphertext. Both operands must have
|
|
167
|
-
the same shape (no broadcasting in Vector backend).
|
|
168
|
-
For BFV scheme, plaintext operands must be integers.
|
|
169
|
-
"""
|
|
170
|
-
assert operand1.dtype == operand2.dtype, (
|
|
171
|
-
f"Operand dtypes must match, got {operand1.dtype} and {operand2.dtype}."
|
|
172
|
-
)
|
|
173
|
-
assert operand1.shape == operand2.shape, (
|
|
174
|
-
f"Operand shapes must match, got {operand1.shape} and {operand2.shape}."
|
|
175
|
-
)
|
|
176
|
-
return operand1
|
|
177
|
-
|
|
178
|
-
|
|
179
|
-
@_fhe_MOD.simple_op()
|
|
180
|
-
def dot(operand1: TensorType, operand2: TensorType) -> TensorType:
|
|
181
|
-
"""Compute dot product of FHE operands (ciphertext · ciphertext or ciphertext · plaintext).
|
|
182
|
-
|
|
183
|
-
Args:
|
|
184
|
-
operand1: First operand (ciphertext or plaintext, must be 1D vector)
|
|
185
|
-
operand2: Second operand (ciphertext or plaintext, must be 1D vector)
|
|
186
|
-
|
|
187
|
-
Returns:
|
|
188
|
-
Scalar result of homomorphic dot product (shape=())
|
|
189
|
-
|
|
190
|
-
Raises:
|
|
191
|
-
ValueError: If operands are not 1D vectors or have different lengths
|
|
192
|
-
|
|
193
|
-
Note: Both operands must be 1D vectors (not scalars). For scalar multiplication,
|
|
194
|
-
use mul() instead. This operation always returns a scalar.
|
|
195
|
-
"""
|
|
196
|
-
if len(operand1.shape) != 1:
|
|
197
|
-
raise ValueError(
|
|
198
|
-
f"Dot product requires 1D vectors, got shape {operand1.shape} for operand1"
|
|
199
|
-
)
|
|
200
|
-
if len(operand2.shape) != 1:
|
|
201
|
-
raise ValueError(
|
|
202
|
-
f"Dot product requires 1D vectors, got shape {operand2.shape} for operand2"
|
|
203
|
-
)
|
|
204
|
-
if operand1.shape[0] != operand2.shape[0]:
|
|
205
|
-
raise ValueError(
|
|
206
|
-
f"Dot product dimension mismatch: {operand1.shape[0]} vs {operand2.shape[0]}"
|
|
207
|
-
)
|
|
208
|
-
|
|
209
|
-
# Dot product of 1D vectors returns a scalar
|
|
210
|
-
return TensorType(operand1.dtype, ())
|
|
211
|
-
|
|
212
|
-
|
|
213
|
-
@_fhe_MOD.simple_op()
|
|
214
|
-
def polyval(ciphertext: TensorType, coeffs: TensorType) -> TensorType:
|
|
215
|
-
"""Evaluate polynomial on encrypted data with plaintext coefficients.
|
|
216
|
-
|
|
217
|
-
Args:
|
|
218
|
-
ciphertext: Encrypted data (scalar or 1D vector)
|
|
219
|
-
coeffs: Plaintext polynomial coefficients as 1D array [c0, c1, c2, ...]
|
|
220
|
-
representing c0 + c1*x + c2*x^2 + ...
|
|
221
|
-
|
|
222
|
-
Returns:
|
|
223
|
-
Result of polynomial evaluation with same shape and dtype as ciphertext
|
|
224
|
-
|
|
225
|
-
Raises:
|
|
226
|
-
ValueError: If coefficients array is not 1D or has fewer than 2 elements
|
|
227
|
-
|
|
228
|
-
Note: Polynomial must have degree >= 1 (at least 2 coefficients required).
|
|
229
|
-
Constant polynomials (degree 0, single coefficient) are NOT supported due to
|
|
230
|
-
TenSEAL limitation. For constant values, use: ct * 0 + constant instead.
|
|
231
|
-
For BFV scheme, coefficients must be integers.
|
|
232
|
-
|
|
233
|
-
Common use case - Sigmoid approximation:
|
|
234
|
-
sigmoid_coeffs = [0.5, 0.15012, 0.0, -0.0018027]
|
|
235
|
-
result = polyval(ciphertext, sigmoid_coeffs)
|
|
236
|
-
"""
|
|
237
|
-
if len(coeffs.shape) != 1:
|
|
238
|
-
raise ValueError(
|
|
239
|
-
f"Polynomial coefficients must be 1D array, got shape {coeffs.shape}"
|
|
240
|
-
)
|
|
241
|
-
_ = coeffs
|
|
242
|
-
return ciphertext
|
|
243
|
-
|
|
244
|
-
|
|
245
|
-
@_fhe_MOD.simple_op()
|
|
246
|
-
def negate(ciphertext: TensorType) -> TensorType:
|
|
247
|
-
"""Negate encrypted data (unary minus).
|
|
248
|
-
|
|
249
|
-
Args:
|
|
250
|
-
ciphertext: Encrypted data (scalar or 1D vector)
|
|
251
|
-
|
|
252
|
-
Returns:
|
|
253
|
-
Negated ciphertext with same shape and dtype
|
|
254
|
-
|
|
255
|
-
Note: Equivalent to multiplying by -1.
|
|
256
|
-
"""
|
|
257
|
-
return ciphertext
|
|
258
|
-
|
|
259
|
-
|
|
260
|
-
@_fhe_MOD.simple_op()
|
|
261
|
-
def square(ciphertext: TensorType) -> TensorType:
|
|
262
|
-
"""Square encrypted data (element-wise).
|
|
263
|
-
|
|
264
|
-
Args:
|
|
265
|
-
ciphertext: Encrypted data (scalar or 1D vector)
|
|
266
|
-
|
|
267
|
-
Returns:
|
|
268
|
-
Squared ciphertext with same shape and dtype
|
|
269
|
-
|
|
270
|
-
Note: More efficient than mul(ciphertext, ciphertext) in some FHE schemes.
|
|
271
|
-
"""
|
|
272
|
-
return ciphertext
|
mplang/v1/ops/jax_cc.py
DELETED
|
@@ -1,147 +0,0 @@
|
|
|
1
|
-
# Copyright 2025 Ant Group Co., Ltd.
|
|
2
|
-
#
|
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
-
# you may not use this file except in compliance with the License.
|
|
5
|
-
# You may obtain a copy of the License at
|
|
6
|
-
#
|
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
-
#
|
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
-
# See the License for the specific language governing permissions and
|
|
13
|
-
# limitations under the License.
|
|
14
|
-
|
|
15
|
-
from __future__ import annotations
|
|
16
|
-
|
|
17
|
-
import logging
|
|
18
|
-
from collections.abc import Callable
|
|
19
|
-
from typing import Any
|
|
20
|
-
|
|
21
|
-
import jax
|
|
22
|
-
import jax.numpy as jnp
|
|
23
|
-
from jax import export
|
|
24
|
-
from jax.tree_util import PyTreeDef, tree_flatten
|
|
25
|
-
|
|
26
|
-
from mplang.v1.core import MPObject, PFunction, TensorType, get_fn_name
|
|
27
|
-
from mplang.v1.ops.base import FeOperation, stateless_mod
|
|
28
|
-
from mplang.v1.utils.func_utils import normalize_fn
|
|
29
|
-
|
|
30
|
-
# Enable 64-bit precision for JAX to match tensor types
|
|
31
|
-
jax.config.update("jax_enable_x64", True)
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
def jax2stablehlo(
|
|
35
|
-
is_variable: Callable[[Any], bool], flat_fn: Any, *args: Any, **kwargs: Any
|
|
36
|
-
) -> tuple[PFunction, list, PyTreeDef]:
|
|
37
|
-
"""Compile JAX function to StableHLO MLIR format for remote execution.
|
|
38
|
-
|
|
39
|
-
Translates high-level JAX functions into StableHLO MLIR representations,
|
|
40
|
-
enabling execution on JAX backends across different processes and platforms.
|
|
41
|
-
Uses a hybrid approach: traditional JAX trace/lower for compilation compatibility,
|
|
42
|
-
with stable jax.export API for parameter tracking.
|
|
43
|
-
|
|
44
|
-
Args:
|
|
45
|
-
is_variable: Predicate function to classify parameters as variables vs. constants.
|
|
46
|
-
Returns True for parameters that should be treated as PFunction inputs.
|
|
47
|
-
flat_fn: JAX function to be compiled into StableHLO format
|
|
48
|
-
*args: Positional arguments passed to the function during compilation
|
|
49
|
-
**kwargs: Keyword arguments passed to the function during compilation
|
|
50
|
-
|
|
51
|
-
Returns:
|
|
52
|
-
tuple[PFunction, list, PyTreeDef]: Compilation artifacts containing:
|
|
53
|
-
- PFunction: Serialized function with embedded MLIR text and type metadata
|
|
54
|
-
- list: Extracted variable parameters (those satisfying is_variable predicate).
|
|
55
|
-
Non-variable parameters are captured as compile-time constants within
|
|
56
|
-
the PFunction body, while variables become runtime input parameters.
|
|
57
|
-
- PyTreeDef: Tree structure template for reconstructing nested output values
|
|
58
|
-
"""
|
|
59
|
-
# Flatten (args, kwargs) and capture immediates using the moved logic from primitive.py
|
|
60
|
-
normalized_fn, in_vars = normalize_fn(flat_fn, args, kwargs, is_variable)
|
|
61
|
-
|
|
62
|
-
# Convert TensorType in_vars to ShapeDtypeStruct for JAX tracing
|
|
63
|
-
jax_params = [
|
|
64
|
-
jax.ShapeDtypeStruct(arg.shape, jnp.dtype(arg.dtype.name)) for arg in in_vars
|
|
65
|
-
]
|
|
66
|
-
|
|
67
|
-
# Hybrid approach: Use standard JAX trace/lower for compatibility, but jax.export for parameter tracking
|
|
68
|
-
jitted_fn = jax.jit(normalized_fn)
|
|
69
|
-
traced = jitted_fn.trace(jax_params)
|
|
70
|
-
lowered = traced.lower()
|
|
71
|
-
|
|
72
|
-
# Get StableHLO MLIR representation using traditional approach
|
|
73
|
-
stablehlo_mlir = lowered.compiler_ir("stablehlo")
|
|
74
|
-
mlir_text = str(stablehlo_mlir)
|
|
75
|
-
|
|
76
|
-
# Get output info using traditional approach
|
|
77
|
-
out_info_flat, out_tree = tree_flatten(lowered.out_info)
|
|
78
|
-
out_info_flat = [TensorType.from_obj(info) for info in out_info_flat]
|
|
79
|
-
|
|
80
|
-
# Extract argument keep mapping using stable jax.export API for parameter tracking
|
|
81
|
-
# We use jax.export only for getting the kept_var_idx information, not for the main compilation
|
|
82
|
-
arg_keep_map = None
|
|
83
|
-
original_arg_count = len(in_vars)
|
|
84
|
-
|
|
85
|
-
try:
|
|
86
|
-
# Use jax.export just to get the stable parameter tracking information
|
|
87
|
-
export_fn = export.export(jitted_fn)
|
|
88
|
-
exported = export_fn(jax_params)
|
|
89
|
-
kept_var_idx = exported.module_kept_var_idx
|
|
90
|
-
if kept_var_idx is not None and len(kept_var_idx) < original_arg_count:
|
|
91
|
-
# JAX eliminated some unused parameters during compilation
|
|
92
|
-
# Keep the indices in sorted order for consistent mapping
|
|
93
|
-
arg_keep_map = sorted(kept_var_idx)
|
|
94
|
-
except Exception as e:
|
|
95
|
-
# Fallback: if jax.export fails, we can still use the compiled result without parameter tracking
|
|
96
|
-
# This ensures backward compatibility even if export has issues
|
|
97
|
-
logging.warning(
|
|
98
|
-
f"jax.export failed to get kept_var_idx, proceeding without it. Error: {e}"
|
|
99
|
-
)
|
|
100
|
-
|
|
101
|
-
# This format tells JaxRT how to handle the compiled result
|
|
102
|
-
pfn_kwargs: dict[str, Any] = {
|
|
103
|
-
"fn_type": "mlir.stablehlo", # Key: specify StableHLO MLIR format
|
|
104
|
-
"ins_info": tuple(TensorType.from_obj(x) for x in in_vars),
|
|
105
|
-
"outs_info": tuple(out_info_flat),
|
|
106
|
-
"fn_name": get_fn_name(flat_fn),
|
|
107
|
-
"fn_text": mlir_text, # MLIR text, serializable for transmission
|
|
108
|
-
}
|
|
109
|
-
|
|
110
|
-
if arg_keep_map is not None:
|
|
111
|
-
pfn_kwargs["arg_keep_map"] = arg_keep_map
|
|
112
|
-
|
|
113
|
-
pfn = PFunction(**pfn_kwargs)
|
|
114
|
-
return pfn, in_vars, out_tree
|
|
115
|
-
|
|
116
|
-
|
|
117
|
-
class JaxRunner(FeOperation):
|
|
118
|
-
"""JAX function runner frontend operation."""
|
|
119
|
-
|
|
120
|
-
def trace(
|
|
121
|
-
self, jax_fn: Callable, *args: Any, **kwargs: Any
|
|
122
|
-
) -> tuple[PFunction, list[MPObject], PyTreeDef]:
|
|
123
|
-
"""
|
|
124
|
-
JAX compilation helper function.
|
|
125
|
-
|
|
126
|
-
Compiles a JAX function to StableHLO format and returns the PFunction
|
|
127
|
-
along with variable arguments for evaluation.
|
|
128
|
-
|
|
129
|
-
Args:
|
|
130
|
-
jax_fn: The JAX function to compile
|
|
131
|
-
*args: Positional arguments to the function
|
|
132
|
-
**kwargs: Keyword arguments to the function
|
|
133
|
-
|
|
134
|
-
Returns:
|
|
135
|
-
tuple[PFunction, list[MPObject], PyTreeDef]: The compiled PFunction, input variables, and output tree
|
|
136
|
-
"""
|
|
137
|
-
|
|
138
|
-
def is_variable(arg: Any) -> bool:
|
|
139
|
-
return isinstance(arg, MPObject)
|
|
140
|
-
|
|
141
|
-
pfunc, in_vars, out_tree = jax2stablehlo(is_variable, jax_fn, *args, **kwargs)
|
|
142
|
-
return pfunc, in_vars, out_tree
|
|
143
|
-
|
|
144
|
-
|
|
145
|
-
_JAX_MOD = stateless_mod("jax")
|
|
146
|
-
|
|
147
|
-
run_jax = JaxRunner(_JAX_MOD, "run")
|
mplang/v1/ops/nnx_cc.py
DELETED
|
@@ -1,168 +0,0 @@
|
|
|
1
|
-
# Copyright 2025 Ant Group Co., Ltd.
|
|
2
|
-
#
|
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
-
# you may not use this file except in compliance with the License.
|
|
5
|
-
# You may obtain a copy of the License at
|
|
6
|
-
#
|
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
-
#
|
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
-
# See the License for the specific language governing permissions and
|
|
13
|
-
# limitations under the License.
|
|
14
|
-
|
|
15
|
-
from __future__ import annotations
|
|
16
|
-
|
|
17
|
-
import logging
|
|
18
|
-
from collections.abc import Callable
|
|
19
|
-
from typing import Any
|
|
20
|
-
|
|
21
|
-
import jax
|
|
22
|
-
import jax.numpy as jnp
|
|
23
|
-
from flax import nnx
|
|
24
|
-
from jax import export
|
|
25
|
-
from jax.tree_util import PyTreeDef, tree_flatten
|
|
26
|
-
|
|
27
|
-
from mplang.v1.core import MPObject, PFunction, TensorType, get_fn_name
|
|
28
|
-
from mplang.v1.ops.base import FeOperation, stateless_mod
|
|
29
|
-
from mplang.v1.utils.func_utils import normalize_fn
|
|
30
|
-
|
|
31
|
-
# Enable 64-bit precision for JAX to match tensor types
|
|
32
|
-
jax.config.update("jax_enable_x64", True)
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
def nnx2stablehlo(
|
|
36
|
-
is_variable: Callable[[Any], bool], flat_fn: Any, *args: Any, **kwargs: Any
|
|
37
|
-
) -> tuple[PFunction, list[Any], PyTreeDef]:
|
|
38
|
-
"""Compile NNX function to StableHLO MLIR format for remote execution.
|
|
39
|
-
|
|
40
|
-
Translates high-level NNX functions into StableHLO MLIR representations,
|
|
41
|
-
enabling execution on JAX backends across different processes and platforms.
|
|
42
|
-
Uses a hybrid approach: traditional NNX trace/lower for compilation compatibility,
|
|
43
|
-
with stable jax.export API for parameter tracking.
|
|
44
|
-
|
|
45
|
-
Args:
|
|
46
|
-
is_variable: Predicate function to classify parameters as variables vs. constants.
|
|
47
|
-
Returns True for parameters that should be treated as PFunction inputs.
|
|
48
|
-
flat_fn: NNX function to be compiled into StableHLO format
|
|
49
|
-
*args: Positional arguments passed to the function during compilation
|
|
50
|
-
**kwargs: Keyword arguments passed to the function during compilation
|
|
51
|
-
|
|
52
|
-
Returns:
|
|
53
|
-
tuple[PFunction, list, PyTreeDef]: Compilation artifacts containing:
|
|
54
|
-
- PFunction: Serialized function with embedded MLIR text and type metadata
|
|
55
|
-
- list: Extracted variable parameters (those satisfying is_variable predicate).
|
|
56
|
-
Non-variable parameters are captured as compile-time constants within
|
|
57
|
-
the PFunction body, while variables become runtime input parameters.
|
|
58
|
-
- PyTreeDef: Tree structure template for reconstructing nested output values
|
|
59
|
-
"""
|
|
60
|
-
# Flatten (args, kwargs) and capture immediates using the moved logic from primitive.py
|
|
61
|
-
normalized_fn, in_vars = normalize_fn(flat_fn, args, kwargs, is_variable)
|
|
62
|
-
|
|
63
|
-
# Convert TensorType in_vars to ShapeDtypeStruct for JAX tracing
|
|
64
|
-
jax_params = [
|
|
65
|
-
jax.ShapeDtypeStruct(arg.shape, jnp.dtype(arg.dtype.name)) for arg in in_vars
|
|
66
|
-
]
|
|
67
|
-
|
|
68
|
-
# NNX compilation pipeline using JAX export API: nnx.jit → jax.export → StableHLO MLIR
|
|
69
|
-
# Use nnx.jit for NNX-specific functionality, then jax.export for stable parameter handling
|
|
70
|
-
nnx_jitted = nnx.jit(normalized_fn)
|
|
71
|
-
|
|
72
|
-
# Extract the underlying JAX function for jax.export compatibility
|
|
73
|
-
# nnx.jit wraps a JAX function, and we can access it via .fun attribute
|
|
74
|
-
underlying_jax_fn = nnx_jitted.fun
|
|
75
|
-
|
|
76
|
-
# Hybrid approach: Use NNX trace/lower for compilation, but jax.export for parameter tracking
|
|
77
|
-
# Use traditional nnx.jit → trace → lower for compatibility with argument structure
|
|
78
|
-
nnx_traced = nnx_jitted.trace(jax_params)
|
|
79
|
-
nnx_lowered = nnx_traced.lower()
|
|
80
|
-
|
|
81
|
-
# Get StableHLO MLIR representation using traditional NNX approach
|
|
82
|
-
# NNX lowered object wraps JAX lowered, so we access the inner JAX lowered object
|
|
83
|
-
jax_lowered = nnx_lowered.lowered
|
|
84
|
-
stablehlo_mlir = jax_lowered.compiler_ir("stablehlo")
|
|
85
|
-
mlir_text = str(stablehlo_mlir)
|
|
86
|
-
|
|
87
|
-
# Get output info using traditional NNX approach
|
|
88
|
-
# NNX captures output in (args, kwargs, result) format, so we need to extract just the result part
|
|
89
|
-
raw_out_info = jax_lowered.out_info
|
|
90
|
-
if isinstance(raw_out_info, tuple) and len(raw_out_info) == 3:
|
|
91
|
-
# NNX format: (args, kwargs, result) - extract just the result
|
|
92
|
-
_, _, actual_out_info = raw_out_info
|
|
93
|
-
out_info_flat, out_tree = tree_flatten(actual_out_info)
|
|
94
|
-
else:
|
|
95
|
-
# Fallback to direct format (shouldn't happen with NNX, but just in case)
|
|
96
|
-
out_info_flat, out_tree = tree_flatten(raw_out_info)
|
|
97
|
-
|
|
98
|
-
out_info_flat = [TensorType.from_obj(info) for info in out_info_flat]
|
|
99
|
-
|
|
100
|
-
# Extract argument keep mapping using stable jax.export API for parameter tracking
|
|
101
|
-
# We use the underlying JAX function with jax.export only for parameter tracking
|
|
102
|
-
arg_keep_map = None
|
|
103
|
-
original_arg_count = len(in_vars)
|
|
104
|
-
|
|
105
|
-
try:
|
|
106
|
-
# Use jax.export with the underlying JAX function just to get stable parameter tracking
|
|
107
|
-
export_fn = export.export(jax.jit(underlying_jax_fn))
|
|
108
|
-
exported = export_fn(jax_params)
|
|
109
|
-
kept_var_idx = exported.module_kept_var_idx
|
|
110
|
-
if kept_var_idx is not None and len(kept_var_idx) < original_arg_count:
|
|
111
|
-
# JAX eliminated some unused parameters during compilation
|
|
112
|
-
# Keep the indices in sorted order for consistent mapping
|
|
113
|
-
arg_keep_map = sorted(kept_var_idx)
|
|
114
|
-
except Exception as e:
|
|
115
|
-
# Fallback: if jax.export fails, we can still use the compiled result without parameter tracking
|
|
116
|
-
# This ensures backward compatibility even if export has issues
|
|
117
|
-
logging.warning(
|
|
118
|
-
f"jax.export failed to get kept_var_idx, proceeding without it. Error: {e}"
|
|
119
|
-
)
|
|
120
|
-
|
|
121
|
-
# This format tells JaxRT how to handle the compiled result
|
|
122
|
-
# Use the same format as JAX since NNX compiles to the same backend
|
|
123
|
-
pfn_kwargs: dict[str, Any] = {
|
|
124
|
-
"fn_type": "mlir.stablehlo", # Key: specify StableHLO MLIR format
|
|
125
|
-
"ins_info": tuple(TensorType.from_obj(x) for x in in_vars),
|
|
126
|
-
"outs_info": tuple(out_info_flat),
|
|
127
|
-
"fn_name": get_fn_name(flat_fn),
|
|
128
|
-
"fn_text": mlir_text, # MLIR text, serializable for transmission
|
|
129
|
-
}
|
|
130
|
-
|
|
131
|
-
if arg_keep_map is not None:
|
|
132
|
-
pfn_kwargs["arg_keep_map"] = arg_keep_map
|
|
133
|
-
|
|
134
|
-
pfn = PFunction(**pfn_kwargs)
|
|
135
|
-
return pfn, in_vars, out_tree
|
|
136
|
-
|
|
137
|
-
|
|
138
|
-
class NnxRunner(FeOperation):
|
|
139
|
-
"""NNX function runner frontend operation."""
|
|
140
|
-
|
|
141
|
-
def trace(
|
|
142
|
-
self, nnx_fn: Callable, *args: Any, **kwargs: Any
|
|
143
|
-
) -> tuple[PFunction, list[MPObject], PyTreeDef]:
|
|
144
|
-
"""
|
|
145
|
-
NNX compilation helper function.
|
|
146
|
-
|
|
147
|
-
Compiles an NNX function to StableHLO format and returns the PFunction
|
|
148
|
-
along with variable arguments for evaluation.
|
|
149
|
-
|
|
150
|
-
Args:
|
|
151
|
-
nnx_fn: The NNX function to compile
|
|
152
|
-
*args: Positional arguments to the function
|
|
153
|
-
**kwargs: Keyword arguments to the function
|
|
154
|
-
|
|
155
|
-
Returns:
|
|
156
|
-
tuple[PFunction, list[MPObject], PyTreeDef]: The compiled PFunction, input variables, and output tree
|
|
157
|
-
"""
|
|
158
|
-
|
|
159
|
-
def is_variable(arg: Any) -> bool:
|
|
160
|
-
return isinstance(arg, MPObject)
|
|
161
|
-
|
|
162
|
-
pfunc, in_vars, out_tree = nnx2stablehlo(is_variable, nnx_fn, *args, **kwargs)
|
|
163
|
-
return pfunc, in_vars, out_tree
|
|
164
|
-
|
|
165
|
-
|
|
166
|
-
_NNX_MOD = stateless_mod("nnx")
|
|
167
|
-
|
|
168
|
-
run_nnx = NnxRunner(_NNX_MOD, "run")
|