mplang-nightly 0.1.dev158__py3-none-any.whl → 0.1.dev268__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mplang/__init__.py +21 -45
- mplang/py.typed +13 -0
- mplang/v1/__init__.py +157 -0
- mplang/v1/_device.py +602 -0
- mplang/{analysis → v1/analysis}/__init__.py +1 -1
- mplang/{analysis → v1/analysis}/diagram.py +5 -7
- mplang/v1/core/__init__.py +157 -0
- mplang/{core → v1/core}/cluster.py +30 -14
- mplang/{core → v1/core}/comm.py +5 -1
- mplang/{core → v1/core}/context_mgr.py +1 -1
- mplang/{core/dtype.py → v1/core/dtypes.py} +44 -2
- mplang/{core → v1/core}/expr/__init__.py +7 -7
- mplang/{core → v1/core}/expr/ast.py +13 -14
- mplang/{core → v1/core}/expr/evaluator.py +65 -24
- mplang/{core → v1/core}/expr/printer.py +24 -18
- mplang/{core → v1/core}/expr/transformer.py +3 -3
- mplang/{core → v1/core}/expr/utils.py +2 -2
- mplang/{core → v1/core}/expr/visitor.py +1 -1
- mplang/{core → v1/core}/expr/walk.py +1 -1
- mplang/{core → v1/core}/interp.py +6 -6
- mplang/{core → v1/core}/mpir.py +23 -16
- mplang/{core → v1/core}/mpobject.py +6 -6
- mplang/{core → v1/core}/mptype.py +13 -10
- mplang/{core → v1/core}/pfunc.py +4 -4
- mplang/{core → v1/core}/primitive.py +106 -201
- mplang/{core → v1/core}/table.py +36 -8
- mplang/{core → v1/core}/tensor.py +1 -1
- mplang/{core → v1/core}/tracer.py +9 -9
- mplang/{api.py → v1/host.py} +38 -6
- mplang/v1/kernels/__init__.py +41 -0
- mplang/{kernels → v1/kernels}/base.py +1 -1
- mplang/v1/kernels/basic.py +240 -0
- mplang/{kernels → v1/kernels}/context.py +42 -27
- mplang/{kernels → v1/kernels}/crypto.py +44 -37
- mplang/v1/kernels/fhe.py +858 -0
- mplang/{kernels → v1/kernels}/mock_tee.py +12 -13
- mplang/{kernels → v1/kernels}/phe.py +263 -57
- mplang/{kernels → v1/kernels}/spu.py +137 -48
- mplang/{kernels → v1/kernels}/sql_duckdb.py +12 -15
- mplang/{kernels → v1/kernels}/stablehlo.py +30 -23
- mplang/v1/kernels/value.py +626 -0
- mplang/{ops → v1/ops}/__init__.py +5 -16
- mplang/{ops → v1/ops}/base.py +2 -5
- mplang/{ops/builtin.py → v1/ops/basic.py} +34 -26
- mplang/v1/ops/crypto.py +262 -0
- mplang/v1/ops/fhe.py +272 -0
- mplang/{ops → v1/ops}/jax_cc.py +33 -68
- mplang/v1/ops/nnx_cc.py +168 -0
- mplang/{ops → v1/ops}/phe.py +16 -4
- mplang/{ops → v1/ops}/spu.py +3 -5
- mplang/v1/ops/sql_cc.py +303 -0
- mplang/{ops → v1/ops}/tee.py +9 -24
- mplang/{protos → v1/protos}/v1alpha1/mpir_pb2.pyi +71 -21
- mplang/v1/protos/v1alpha1/value_pb2.py +34 -0
- mplang/v1/protos/v1alpha1/value_pb2.pyi +169 -0
- mplang/{runtime → v1/runtime}/__init__.py +2 -2
- mplang/v1/runtime/channel.py +230 -0
- mplang/{runtime → v1/runtime}/cli.py +35 -20
- mplang/{runtime → v1/runtime}/client.py +19 -8
- mplang/{runtime → v1/runtime}/communicator.py +59 -15
- mplang/{runtime → v1/runtime}/data_providers.py +80 -19
- mplang/{runtime → v1/runtime}/driver.py +30 -12
- mplang/v1/runtime/link_comm.py +196 -0
- mplang/{runtime → v1/runtime}/server.py +58 -42
- mplang/{runtime → v1/runtime}/session.py +57 -71
- mplang/{runtime → v1/runtime}/simulation.py +55 -28
- mplang/v1/simp/api.py +353 -0
- mplang/{simp → v1/simp}/mpi.py +8 -9
- mplang/{simp/__init__.py → v1/simp/party.py} +19 -145
- mplang/{simp → v1/simp}/random.py +21 -22
- mplang/v1/simp/smpc.py +238 -0
- mplang/v1/utils/table_utils.py +185 -0
- mplang/v2/__init__.py +424 -0
- mplang/v2/backends/__init__.py +57 -0
- mplang/v2/backends/bfv_impl.py +705 -0
- mplang/v2/backends/channel.py +217 -0
- mplang/v2/backends/crypto_impl.py +723 -0
- mplang/v2/backends/field_impl.py +454 -0
- mplang/v2/backends/func_impl.py +107 -0
- mplang/v2/backends/phe_impl.py +148 -0
- mplang/v2/backends/simp_design.md +136 -0
- mplang/v2/backends/simp_driver/__init__.py +41 -0
- mplang/v2/backends/simp_driver/http.py +168 -0
- mplang/v2/backends/simp_driver/mem.py +280 -0
- mplang/v2/backends/simp_driver/ops.py +135 -0
- mplang/v2/backends/simp_driver/state.py +60 -0
- mplang/v2/backends/simp_driver/values.py +52 -0
- mplang/v2/backends/simp_worker/__init__.py +29 -0
- mplang/v2/backends/simp_worker/http.py +354 -0
- mplang/v2/backends/simp_worker/mem.py +102 -0
- mplang/v2/backends/simp_worker/ops.py +167 -0
- mplang/v2/backends/simp_worker/state.py +49 -0
- mplang/v2/backends/spu_impl.py +275 -0
- mplang/v2/backends/spu_state.py +187 -0
- mplang/v2/backends/store_impl.py +62 -0
- mplang/v2/backends/table_impl.py +838 -0
- mplang/v2/backends/tee_impl.py +215 -0
- mplang/v2/backends/tensor_impl.py +519 -0
- mplang/v2/cli.py +603 -0
- mplang/v2/cli_guide.md +122 -0
- mplang/v2/dialects/__init__.py +36 -0
- mplang/v2/dialects/bfv.py +665 -0
- mplang/v2/dialects/crypto.py +689 -0
- mplang/v2/dialects/dtypes.py +378 -0
- mplang/v2/dialects/field.py +210 -0
- mplang/v2/dialects/func.py +135 -0
- mplang/v2/dialects/phe.py +723 -0
- mplang/v2/dialects/simp.py +944 -0
- mplang/v2/dialects/spu.py +349 -0
- mplang/v2/dialects/store.py +63 -0
- mplang/v2/dialects/table.py +407 -0
- mplang/v2/dialects/tee.py +346 -0
- mplang/v2/dialects/tensor.py +1175 -0
- mplang/v2/edsl/README.md +279 -0
- mplang/v2/edsl/__init__.py +99 -0
- mplang/v2/edsl/context.py +311 -0
- mplang/v2/edsl/graph.py +463 -0
- mplang/v2/edsl/jit.py +62 -0
- mplang/v2/edsl/object.py +53 -0
- mplang/v2/edsl/primitive.py +284 -0
- mplang/v2/edsl/printer.py +119 -0
- mplang/v2/edsl/registry.py +207 -0
- mplang/v2/edsl/serde.py +375 -0
- mplang/v2/edsl/tracer.py +614 -0
- mplang/v2/edsl/typing.py +816 -0
- mplang/v2/kernels/Makefile +30 -0
- mplang/v2/kernels/__init__.py +23 -0
- mplang/v2/kernels/gf128.cpp +148 -0
- mplang/v2/kernels/ldpc.cpp +82 -0
- mplang/v2/kernels/okvs.cpp +283 -0
- mplang/v2/kernels/okvs_opt.cpp +291 -0
- mplang/v2/kernels/py_kernels.py +398 -0
- mplang/v2/libs/collective.py +330 -0
- mplang/v2/libs/device/__init__.py +51 -0
- mplang/v2/libs/device/api.py +813 -0
- mplang/v2/libs/device/cluster.py +352 -0
- mplang/v2/libs/ml/__init__.py +23 -0
- mplang/v2/libs/ml/sgb.py +1861 -0
- mplang/v2/libs/mpc/__init__.py +41 -0
- mplang/v2/libs/mpc/_utils.py +99 -0
- mplang/v2/libs/mpc/analytics/__init__.py +35 -0
- mplang/v2/libs/mpc/analytics/aggregation.py +372 -0
- mplang/v2/libs/mpc/analytics/groupby.md +99 -0
- mplang/v2/libs/mpc/analytics/groupby.py +331 -0
- mplang/v2/libs/mpc/analytics/permutation.py +386 -0
- mplang/v2/libs/mpc/common/constants.py +39 -0
- mplang/v2/libs/mpc/ot/__init__.py +32 -0
- mplang/v2/libs/mpc/ot/base.py +222 -0
- mplang/v2/libs/mpc/ot/extension.py +477 -0
- mplang/v2/libs/mpc/ot/silent.py +217 -0
- mplang/v2/libs/mpc/psi/__init__.py +40 -0
- mplang/v2/libs/mpc/psi/cuckoo.py +228 -0
- mplang/v2/libs/mpc/psi/okvs.py +49 -0
- mplang/v2/libs/mpc/psi/okvs_gct.py +79 -0
- mplang/v2/libs/mpc/psi/oprf.py +310 -0
- mplang/v2/libs/mpc/psi/rr22.py +344 -0
- mplang/v2/libs/mpc/psi/unbalanced.py +200 -0
- mplang/v2/libs/mpc/vole/__init__.py +31 -0
- mplang/v2/libs/mpc/vole/gilboa.py +327 -0
- mplang/v2/libs/mpc/vole/ldpc.py +383 -0
- mplang/v2/libs/mpc/vole/silver.py +336 -0
- mplang/v2/runtime/__init__.py +15 -0
- mplang/v2/runtime/dialect_state.py +41 -0
- mplang/v2/runtime/interpreter.py +871 -0
- mplang/v2/runtime/object_store.py +194 -0
- mplang/v2/runtime/value.py +141 -0
- {mplang_nightly-0.1.dev158.dist-info → mplang_nightly-0.1.dev268.dist-info}/METADATA +24 -17
- mplang_nightly-0.1.dev268.dist-info/RECORD +180 -0
- {mplang_nightly-0.1.dev158.dist-info → mplang_nightly-0.1.dev268.dist-info}/WHEEL +1 -1
- mplang/core/__init__.py +0 -92
- mplang/device.py +0 -340
- mplang/kernels/builtin.py +0 -207
- mplang/ops/crypto.py +0 -109
- mplang/ops/ibis_cc.py +0 -139
- mplang/ops/sql.py +0 -61
- mplang/protos/v1alpha1/mpir_pb2_grpc.py +0 -3
- mplang/runtime/link_comm.py +0 -131
- mplang/simp/smpc.py +0 -201
- mplang/utils/table_utils.py +0 -73
- mplang_nightly-0.1.dev158.dist-info/RECORD +0 -77
- /mplang/{core → v1/core}/mask.py +0 -0
- /mplang/{protos → v1/protos}/v1alpha1/mpir_pb2.py +0 -0
- /mplang/{runtime → v1/runtime}/exceptions.py +0 -0
- /mplang/{runtime → v1/runtime}/http_api.md +0 -0
- /mplang/{kernels → v1/simp}/__init__.py +0 -0
- /mplang/{utils → v1/utils}/__init__.py +0 -0
- /mplang/{utils → v1/utils}/crypto.py +0 -0
- /mplang/{utils → v1/utils}/func_utils.py +0 -0
- /mplang/{utils → v1/utils}/spu_utils.py +0 -0
- {mplang_nightly-0.1.dev158.dist-info → mplang_nightly-0.1.dev268.dist-info}/entry_points.txt +0 -0
- {mplang_nightly-0.1.dev158.dist-info → mplang_nightly-0.1.dev268.dist-info}/licenses/LICENSE +0 -0
|
@@ -15,18 +15,25 @@
|
|
|
15
15
|
|
|
16
16
|
from jax.tree_util import PyTreeDef, tree_flatten
|
|
17
17
|
|
|
18
|
-
from mplang.core
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
|
|
18
|
+
from mplang.v1.core import (
|
|
19
|
+
UINT8,
|
|
20
|
+
UINT64,
|
|
21
|
+
MPObject,
|
|
22
|
+
PFunction,
|
|
23
|
+
ScalarType,
|
|
24
|
+
Shape,
|
|
25
|
+
TableLike,
|
|
26
|
+
TableType,
|
|
27
|
+
TensorLike,
|
|
28
|
+
TensorType,
|
|
29
|
+
)
|
|
30
|
+
from mplang.v1.ops.base import stateless_mod
|
|
31
|
+
from mplang.v1.utils import table_utils
|
|
32
|
+
|
|
33
|
+
_BASIC_MOD = stateless_mod("basic")
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
@_BASIC_MOD.simple_op()
|
|
30
37
|
def identity(x: TensorType) -> TensorType:
|
|
31
38
|
"""Return the input type unchanged.
|
|
32
39
|
|
|
@@ -40,7 +47,7 @@ def identity(x: TensorType) -> TensorType:
|
|
|
40
47
|
return x
|
|
41
48
|
|
|
42
49
|
|
|
43
|
-
@
|
|
50
|
+
@_BASIC_MOD.simple_op()
|
|
44
51
|
def read(*, path: str, ty: TensorType) -> TensorType:
|
|
45
52
|
"""Declare reading a value of type ``ty`` from ``path`` (type-only).
|
|
46
53
|
|
|
@@ -63,7 +70,7 @@ def read(*, path: str, ty: TensorType) -> TensorType:
|
|
|
63
70
|
return ty
|
|
64
71
|
|
|
65
72
|
|
|
66
|
-
@
|
|
73
|
+
@_BASIC_MOD.simple_op()
|
|
67
74
|
def write(x: TensorType, *, path: str) -> TensorType:
|
|
68
75
|
"""Declare writing the input value to ``path`` and return the same type.
|
|
69
76
|
|
|
@@ -77,7 +84,7 @@ def write(x: TensorType, *, path: str) -> TensorType:
|
|
|
77
84
|
return x
|
|
78
85
|
|
|
79
86
|
|
|
80
|
-
@
|
|
87
|
+
@_BASIC_MOD.op_def()
|
|
81
88
|
def constant(
|
|
82
89
|
data: TensorLike | ScalarType | TableLike,
|
|
83
90
|
) -> tuple[PFunction, list[MPObject], PyTreeDef]:
|
|
@@ -89,7 +96,7 @@ def constant(
|
|
|
89
96
|
|
|
90
97
|
Returns:
|
|
91
98
|
Tuple[PFunction, list[MPObject], PyTreeDef]:
|
|
92
|
-
- PFunction: ``fn_type='
|
|
99
|
+
- PFunction: ``fn_type='basic.constant'`` with one output whose type
|
|
93
100
|
matches ``data``; payload serialized via ``data_bytes`` with
|
|
94
101
|
``data_format`` ('bytes[numpy]' or 'bytes[csv]').
|
|
95
102
|
- list[MPObject]: Empty (no inputs captured).
|
|
@@ -101,8 +108,9 @@ def constant(
|
|
|
101
108
|
out_type: TableType | TensorType
|
|
102
109
|
|
|
103
110
|
if isinstance(data, TableLike):
|
|
104
|
-
|
|
105
|
-
|
|
111
|
+
format = "parquet"
|
|
112
|
+
data_bytes = table_utils.encode_table(data, format=format)
|
|
113
|
+
data_format = f"bytes[{format}]"
|
|
106
114
|
out_type = TableType.from_tablelike(data)
|
|
107
115
|
elif isinstance(data, ScalarType):
|
|
108
116
|
out_type = TensorType.from_obj(data)
|
|
@@ -120,7 +128,7 @@ def constant(
|
|
|
120
128
|
data_format = "bytes[numpy]"
|
|
121
129
|
|
|
122
130
|
pfunc = PFunction(
|
|
123
|
-
fn_type="
|
|
131
|
+
fn_type="basic.constant",
|
|
124
132
|
ins_info=(),
|
|
125
133
|
outs_info=(out_type,),
|
|
126
134
|
data_bytes=data_bytes,
|
|
@@ -130,7 +138,7 @@ def constant(
|
|
|
130
138
|
return pfunc, [], treedef
|
|
131
139
|
|
|
132
140
|
|
|
133
|
-
@
|
|
141
|
+
@_BASIC_MOD.simple_op()
|
|
134
142
|
def rank() -> TensorType:
|
|
135
143
|
"""Return the scalar UINT64 tensor type for the current party rank.
|
|
136
144
|
|
|
@@ -140,7 +148,7 @@ def rank() -> TensorType:
|
|
|
140
148
|
return TensorType(UINT64, ())
|
|
141
149
|
|
|
142
150
|
|
|
143
|
-
@
|
|
151
|
+
@_BASIC_MOD.simple_op()
|
|
144
152
|
def prand(*, shape: Shape = ()) -> TensorType:
|
|
145
153
|
"""Declare a private random UINT64 tensor with the given shape.
|
|
146
154
|
|
|
@@ -153,7 +161,7 @@ def prand(*, shape: Shape = ()) -> TensorType:
|
|
|
153
161
|
return TensorType(UINT64, shape)
|
|
154
162
|
|
|
155
163
|
|
|
156
|
-
@
|
|
164
|
+
@_BASIC_MOD.simple_op()
|
|
157
165
|
def debug_print(
|
|
158
166
|
x: TensorType | TableType, *, prefix: str = ""
|
|
159
167
|
) -> TableType | TensorType:
|
|
@@ -169,7 +177,7 @@ def debug_print(
|
|
|
169
177
|
return x
|
|
170
178
|
|
|
171
179
|
|
|
172
|
-
@
|
|
180
|
+
@_BASIC_MOD.simple_op()
|
|
173
181
|
def pack(x: TensorType | TableType) -> TensorType:
|
|
174
182
|
"""Serialize a tensor/table into a byte vector (type-only).
|
|
175
183
|
|
|
@@ -189,7 +197,7 @@ def pack(x: TensorType | TableType) -> TensorType:
|
|
|
189
197
|
return TensorType(UINT8, (-1,))
|
|
190
198
|
|
|
191
199
|
|
|
192
|
-
@
|
|
200
|
+
@_BASIC_MOD.simple_op()
|
|
193
201
|
def unpack(b: TensorType, *, out_ty: TensorType | TableType) -> TensorType | TableType:
|
|
194
202
|
"""Deserialize a byte vector into the explicit output type.
|
|
195
203
|
|
|
@@ -215,7 +223,7 @@ def unpack(b: TensorType, *, out_ty: TensorType | TableType) -> TensorType | Tab
|
|
|
215
223
|
return out_ty
|
|
216
224
|
|
|
217
225
|
|
|
218
|
-
@
|
|
226
|
+
@_BASIC_MOD.simple_op()
|
|
219
227
|
def table_to_tensor(table: TableType, *, number_rows: int) -> TensorType:
|
|
220
228
|
"""Convert a homogeneous-typed table to a dense 2D tensor.
|
|
221
229
|
|
|
@@ -248,7 +256,7 @@ def table_to_tensor(table: TableType, *, number_rows: int) -> TensorType:
|
|
|
248
256
|
return TensorType(first, shape) # type: ignore[arg-type]
|
|
249
257
|
|
|
250
258
|
|
|
251
|
-
@
|
|
259
|
+
@_BASIC_MOD.simple_op()
|
|
252
260
|
def tensor_to_table(tensor: TensorType, *, column_names: list[str]) -> TableType:
|
|
253
261
|
"""Convert a rank-2 tensor into a table with named columns.
|
|
254
262
|
|
mplang/v1/ops/crypto.py
ADDED
|
@@ -0,0 +1,262 @@
|
|
|
1
|
+
# Copyright 2025 Ant Group Co., Ltd.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
"""
|
|
16
|
+
Crypto frontend operations: operation signatures, types, and high-level semantics.
|
|
17
|
+
|
|
18
|
+
Scope and contracts:
|
|
19
|
+
- This module defines portable API shapes; it does not implement cryptography.
|
|
20
|
+
- Backends execute the operations and must meet the security semantics required
|
|
21
|
+
by the deployment (confidentiality, authenticity, correctness, etc.).
|
|
22
|
+
- The enc/dec API in this frontend uses a conventional 12-byte nonce prefix
|
|
23
|
+
(ciphertext = nonce || payload), and dec expects that format. Other security
|
|
24
|
+
properties (e.g., AEAD) are backend responsibilities.
|
|
25
|
+
"""
|
|
26
|
+
|
|
27
|
+
from __future__ import annotations
|
|
28
|
+
|
|
29
|
+
from jax.tree_util import PyTreeDef, tree_flatten
|
|
30
|
+
|
|
31
|
+
from mplang.v1.core import UINT8, TensorType
|
|
32
|
+
from mplang.v1.core.mpobject import MPObject
|
|
33
|
+
from mplang.v1.core.pfunc import PFunction
|
|
34
|
+
from mplang.v1.ops.base import stateless_mod
|
|
35
|
+
|
|
36
|
+
_CRYPTO_MOD = stateless_mod("crypto")
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
def _get_algo_overhead(algo: str) -> int:
|
|
40
|
+
"""Get ciphertext overhead for a given encryption algorithm.
|
|
41
|
+
|
|
42
|
+
Args:
|
|
43
|
+
algo: Encryption algorithm identifier
|
|
44
|
+
|
|
45
|
+
Returns:
|
|
46
|
+
int: Number of overhead bytes added to plaintext length
|
|
47
|
+
"""
|
|
48
|
+
overhead_map = {
|
|
49
|
+
"aes-ctr": 16, # nonce only (legacy compatibility)
|
|
50
|
+
"aes-gcm": 28, # nonce(12) + tag(16) for AES-GCM
|
|
51
|
+
"sm4-gcm": 28, # nonce(12) + tag(16) for SM4-GCM
|
|
52
|
+
}
|
|
53
|
+
|
|
54
|
+
if algo not in overhead_map:
|
|
55
|
+
# return unknown overhead as -1
|
|
56
|
+
return -1
|
|
57
|
+
return overhead_map[algo]
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
@_CRYPTO_MOD.simple_op()
|
|
61
|
+
def keygen(*, length: int = 32) -> TensorType:
|
|
62
|
+
"""Generate random bytes for symmetric keys or generic randomness.
|
|
63
|
+
|
|
64
|
+
API: keygen(length: int = 32) -> key: u8[length]
|
|
65
|
+
|
|
66
|
+
Notes:
|
|
67
|
+
- Frontend defines the type/shape; backend provides randomness.
|
|
68
|
+
- Raises ValueError when length <= 0.
|
|
69
|
+
"""
|
|
70
|
+
if length <= 0:
|
|
71
|
+
raise ValueError("length must be > 0")
|
|
72
|
+
return TensorType(UINT8, (length,))
|
|
73
|
+
|
|
74
|
+
|
|
75
|
+
@_CRYPTO_MOD.op_def()
|
|
76
|
+
def enc(
|
|
77
|
+
plaintext: MPObject, key: MPObject, algo: str = "aes-ctr"
|
|
78
|
+
) -> tuple[PFunction, list[MPObject], PyTreeDef]:
|
|
79
|
+
"""Symmetric encryption with algorithm-aware output sizing.
|
|
80
|
+
|
|
81
|
+
API: enc(plaintext: u8[N], key: u8[M], *, algo: str = "aes-ctr") -> ciphertext: u8[N + overhead]
|
|
82
|
+
|
|
83
|
+
Supported algorithms and overhead:
|
|
84
|
+
- "aes-ctr": 16 bytes (nonce only, legacy compatibility)
|
|
85
|
+
- "aes-gcm": 28 bytes (nonce + 16-byte authentication tag)
|
|
86
|
+
- "sm4-gcm": 28 bytes (nonce + 16-byte authentication tag)
|
|
87
|
+
|
|
88
|
+
The algo parameter is stored in the PFunction attributes for backend use.
|
|
89
|
+
"""
|
|
90
|
+
pt_ty = plaintext
|
|
91
|
+
if pt_ty.dtype != UINT8:
|
|
92
|
+
raise TypeError("enc expects UINT8 plaintext")
|
|
93
|
+
if len(pt_ty.shape) != 1:
|
|
94
|
+
raise TypeError("enc expects 1-D plaintext")
|
|
95
|
+
|
|
96
|
+
# Validate and get overhead for the specified algorithm
|
|
97
|
+
overhead = _get_algo_overhead(algo)
|
|
98
|
+
length = pt_ty.shape[0]
|
|
99
|
+
if length >= 0 and overhead >= 0:
|
|
100
|
+
outs_info = (TensorType(UINT8, (length + overhead,)),)
|
|
101
|
+
else:
|
|
102
|
+
# Unknown length or overhead, return dynamic length
|
|
103
|
+
outs_info = (TensorType(UINT8, (-1,)),)
|
|
104
|
+
|
|
105
|
+
ins_info = (TensorType.from_obj(pt_ty), TensorType.from_obj(key))
|
|
106
|
+
pfunc = PFunction(
|
|
107
|
+
fn_type="crypto.enc",
|
|
108
|
+
ins_info=ins_info,
|
|
109
|
+
outs_info=outs_info,
|
|
110
|
+
algo=algo,
|
|
111
|
+
)
|
|
112
|
+
_, treedef = tree_flatten(outs_info[0])
|
|
113
|
+
return pfunc, [plaintext, key], treedef
|
|
114
|
+
|
|
115
|
+
|
|
116
|
+
@_CRYPTO_MOD.op_def()
|
|
117
|
+
def dec(
|
|
118
|
+
ciphertext: MPObject, key: MPObject, algo: str = "aes-ctr"
|
|
119
|
+
) -> tuple[PFunction, list[MPObject], PyTreeDef]:
|
|
120
|
+
"""Symmetric decryption with algorithm-aware input sizing.
|
|
121
|
+
|
|
122
|
+
API: dec(ciphertext: u8[N + overhead], key: u8[M], *, algo: str = "aes-ctr") -> plaintext: u8[N]
|
|
123
|
+
|
|
124
|
+
Supported algorithms and overhead:
|
|
125
|
+
- "aes-ctr": 16 bytes (nonce only, legacy compatibility)
|
|
126
|
+
- "aes-gcm": 28 bytes (nonce + 16-byte authentication tag)
|
|
127
|
+
- "sm4-gcm": 28 bytes (nonce + 16-byte authentication tag)
|
|
128
|
+
|
|
129
|
+
The algo parameter is stored in the PFunction attributes for backend use.
|
|
130
|
+
Backend is responsible for parsing the ciphertext format according to algo.
|
|
131
|
+
"""
|
|
132
|
+
ct_ty = ciphertext
|
|
133
|
+
if ct_ty.dtype != UINT8:
|
|
134
|
+
raise TypeError("dec expects UINT8 ciphertext")
|
|
135
|
+
if len(ct_ty.shape) != 1:
|
|
136
|
+
raise TypeError("dec expects 1-D ciphertext")
|
|
137
|
+
|
|
138
|
+
# Validate and get overhead for the specified algorithm
|
|
139
|
+
overhead = _get_algo_overhead(algo)
|
|
140
|
+
length = ct_ty.shape[0]
|
|
141
|
+
|
|
142
|
+
# Validate minimum ciphertext length
|
|
143
|
+
if length >= 0 and overhead >= 0 and length < overhead:
|
|
144
|
+
raise TypeError(
|
|
145
|
+
f"dec expects ciphertext with at least {overhead} bytes for algo='{algo}', but got {length} bytes"
|
|
146
|
+
)
|
|
147
|
+
|
|
148
|
+
# Compute output plaintext length
|
|
149
|
+
if length >= 0 and overhead >= 0:
|
|
150
|
+
outs_info = (TensorType(UINT8, (length - overhead,)),)
|
|
151
|
+
else:
|
|
152
|
+
# Unknown length or overhead, return dynamic length
|
|
153
|
+
outs_info = (TensorType(UINT8, (-1,)),)
|
|
154
|
+
|
|
155
|
+
ins_info = (TensorType.from_obj(ct_ty), TensorType.from_obj(key))
|
|
156
|
+
pfunc = PFunction(
|
|
157
|
+
fn_type="crypto.dec",
|
|
158
|
+
ins_info=ins_info,
|
|
159
|
+
outs_info=outs_info,
|
|
160
|
+
algo=algo,
|
|
161
|
+
)
|
|
162
|
+
_, treedef = tree_flatten(outs_info[0])
|
|
163
|
+
return pfunc, [ciphertext, key], treedef
|
|
164
|
+
|
|
165
|
+
|
|
166
|
+
@_CRYPTO_MOD.op_def()
|
|
167
|
+
def kem_keygen(suite: str = "x25519") -> tuple[PFunction, list[MPObject], PyTreeDef]:
|
|
168
|
+
"""KEM-style keypair generation: returns (sk, pk) bytes.
|
|
169
|
+
|
|
170
|
+
API: kem_keygen(suite: str = "x25519") -> (sk: u8[32], pk: u8[32])
|
|
171
|
+
|
|
172
|
+
The suite parameter is stored in the PFunction attributes for backend use.
|
|
173
|
+
"""
|
|
174
|
+
if suite == "x25519":
|
|
175
|
+
sk_ty = TensorType(UINT8, (32,))
|
|
176
|
+
pk_ty = TensorType(UINT8, (32,))
|
|
177
|
+
else:
|
|
178
|
+
# Unknown suite, return dynamic lengths
|
|
179
|
+
sk_ty = TensorType(UINT8, (-1,))
|
|
180
|
+
pk_ty = TensorType(UINT8, (-1,))
|
|
181
|
+
outs_info = (sk_ty, pk_ty)
|
|
182
|
+
|
|
183
|
+
pfunc = PFunction(
|
|
184
|
+
fn_type="crypto.kem_keygen",
|
|
185
|
+
ins_info=(),
|
|
186
|
+
outs_info=outs_info,
|
|
187
|
+
suite=suite,
|
|
188
|
+
)
|
|
189
|
+
_, treedef = tree_flatten(outs_info)
|
|
190
|
+
return pfunc, [], treedef
|
|
191
|
+
|
|
192
|
+
|
|
193
|
+
@_CRYPTO_MOD.op_def()
|
|
194
|
+
def kem_derive(
|
|
195
|
+
sk: MPObject, peer_pk: MPObject, suite: str = "x25519"
|
|
196
|
+
) -> tuple[PFunction, list[MPObject], PyTreeDef]:
|
|
197
|
+
"""KEM-style shared secret derivation: returns secret bytes.
|
|
198
|
+
|
|
199
|
+
API: kem_derive(sk: u8[32], peer_pk: u8[32], suite: str = "x25519") -> secret: u8[32]
|
|
200
|
+
|
|
201
|
+
The suite parameter is stored in the PFunction attributes for backend use.
|
|
202
|
+
"""
|
|
203
|
+
# Validate input types
|
|
204
|
+
if sk.dtype != UINT8:
|
|
205
|
+
raise TypeError("kem_derive expects UINT8 secret key")
|
|
206
|
+
if peer_pk.dtype != UINT8:
|
|
207
|
+
raise TypeError("kem_derive expects UINT8 peer public key")
|
|
208
|
+
if len(sk.shape) != 1 or len(peer_pk.shape) != 1:
|
|
209
|
+
raise TypeError("kem_derive expects 1-D inputs")
|
|
210
|
+
|
|
211
|
+
if suite == "x25519":
|
|
212
|
+
if sk.shape[0] != 32 or peer_pk.shape[0] != 32:
|
|
213
|
+
raise TypeError("kem_derive expects 32-byte keys for suite 'x25519'")
|
|
214
|
+
secret_ty = TensorType(UINT8, (32,))
|
|
215
|
+
else:
|
|
216
|
+
# Unknown suite, return dynamic length
|
|
217
|
+
secret_ty = TensorType(UINT8, (-1,))
|
|
218
|
+
outs_info = (secret_ty,)
|
|
219
|
+
|
|
220
|
+
ins_info = (TensorType.from_obj(sk), TensorType.from_obj(peer_pk))
|
|
221
|
+
pfunc = PFunction(
|
|
222
|
+
fn_type="crypto.kem_derive",
|
|
223
|
+
ins_info=ins_info,
|
|
224
|
+
outs_info=outs_info,
|
|
225
|
+
suite=suite,
|
|
226
|
+
)
|
|
227
|
+
_, treedef = tree_flatten(outs_info[0])
|
|
228
|
+
return pfunc, [sk, peer_pk], treedef
|
|
229
|
+
|
|
230
|
+
|
|
231
|
+
@_CRYPTO_MOD.op_def()
|
|
232
|
+
def hkdf(
|
|
233
|
+
secret: MPObject, info: str, hash: str = "SHA-256"
|
|
234
|
+
) -> tuple[PFunction, list[MPObject], PyTreeDef]:
|
|
235
|
+
"""HKDF-style key derivation: returns a 32-byte key.
|
|
236
|
+
|
|
237
|
+
API: hkdf(secret: u8[N], info: str, hash: str = "SHA-256") -> key: u8[32]
|
|
238
|
+
|
|
239
|
+
The hash parameter is stored in the PFunction attributes for backend use.
|
|
240
|
+
"""
|
|
241
|
+
# Validate input types
|
|
242
|
+
if secret.dtype != UINT8:
|
|
243
|
+
raise TypeError("hkdf expects UINT8 secret")
|
|
244
|
+
if len(secret.shape) != 1:
|
|
245
|
+
raise TypeError("hkdf expects 1-D secret")
|
|
246
|
+
|
|
247
|
+
if hash == "SHA-256" or hash == "SM3":
|
|
248
|
+
outs_info = (TensorType(UINT8, (32,)),)
|
|
249
|
+
else:
|
|
250
|
+
# Unknown hash, return dynamic length
|
|
251
|
+
outs_info = (TensorType(UINT8, (-1,)),)
|
|
252
|
+
|
|
253
|
+
ins_info = (TensorType.from_obj(secret),)
|
|
254
|
+
pfunc = PFunction(
|
|
255
|
+
fn_type="crypto.hkdf",
|
|
256
|
+
ins_info=ins_info,
|
|
257
|
+
outs_info=outs_info,
|
|
258
|
+
hash=hash,
|
|
259
|
+
info=info,
|
|
260
|
+
)
|
|
261
|
+
_, treedef = tree_flatten(outs_info[0])
|
|
262
|
+
return pfunc, [secret], treedef
|
mplang/v1/ops/fhe.py
ADDED
|
@@ -0,0 +1,272 @@
|
|
|
1
|
+
# Copyright 2025 Ant Group Co., Ltd.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
from mplang.v1.core import UINT8, TensorType
|
|
16
|
+
from mplang.v1.ops.base import stateless_mod
|
|
17
|
+
|
|
18
|
+
_fhe_MOD = stateless_mod("fhe")
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
@_fhe_MOD.simple_op()
|
|
22
|
+
def keygen(
|
|
23
|
+
*,
|
|
24
|
+
scheme: str = "CKKS",
|
|
25
|
+
poly_modulus_degree: int = 8192,
|
|
26
|
+
coeff_mod_bit_sizes: tuple[int, ...] | None = None,
|
|
27
|
+
global_scale: int | None = None,
|
|
28
|
+
plain_modulus: int | None = None,
|
|
29
|
+
) -> tuple[TensorType, TensorType, TensorType]:
|
|
30
|
+
"""Generate an FHE key pair for Vector backend: returns (private_context, public_context, evaluation_context).
|
|
31
|
+
|
|
32
|
+
Args:
|
|
33
|
+
scheme: FHE scheme to use ("CKKS" for approximate, "BFV" for exact integer)
|
|
34
|
+
poly_modulus_degree: Polynomial modulus degree (default: 8192)
|
|
35
|
+
coeff_mod_bit_sizes: Coefficient modulus bit sizes for CKKS (optional)
|
|
36
|
+
global_scale: Global scale for CKKS (optional)
|
|
37
|
+
plain_modulus: Plain modulus for BFV (optional)
|
|
38
|
+
|
|
39
|
+
Returns:
|
|
40
|
+
Tuple of (private_context, public_context, evaluation_context) represented as UINT8[(-1, 0)]
|
|
41
|
+
|
|
42
|
+
Contexts are represented with a sentinel TensorType UINT8[(-1, 0)] to indicate
|
|
43
|
+
non-structural, backend-only handles.
|
|
44
|
+
|
|
45
|
+
Note: Vector backend only supports 1D data. For multi-dimensional tensors,
|
|
46
|
+
use mplang.ops.fhe instead.
|
|
47
|
+
"""
|
|
48
|
+
if scheme not in ("CKKS", "BFV"):
|
|
49
|
+
raise ValueError("Unsupported scheme. Choose either 'CKKS' or 'BFV'.")
|
|
50
|
+
if scheme == "CKKS":
|
|
51
|
+
assert plain_modulus is None, "plain_modulus is not used in CKKS scheme."
|
|
52
|
+
context_spec = TensorType(UINT8, (-1, 0))
|
|
53
|
+
return context_spec, context_spec, context_spec
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
@_fhe_MOD.simple_op()
|
|
57
|
+
def encrypt(plaintext: TensorType, context: TensorType) -> TensorType:
|
|
58
|
+
"""Encrypt plaintext using FHE Vector backend: returns ciphertext with same semantic type.
|
|
59
|
+
|
|
60
|
+
Args:
|
|
61
|
+
plaintext: Data to encrypt (scalar or 1D vector only)
|
|
62
|
+
context: FHE context (private or public)
|
|
63
|
+
|
|
64
|
+
Returns:
|
|
65
|
+
Ciphertext with same semantic type as plaintext
|
|
66
|
+
|
|
67
|
+
Raises:
|
|
68
|
+
ValueError: If plaintext has more than 1 dimension
|
|
69
|
+
|
|
70
|
+
Note: Vector backend only supports scalars (shape=()) and 1D vectors (shape=(n,)).
|
|
71
|
+
For multi-dimensional data, use mplang.ops.fhe.encrypt instead.
|
|
72
|
+
"""
|
|
73
|
+
_ = context
|
|
74
|
+
if len(plaintext.shape) > 1:
|
|
75
|
+
raise ValueError(
|
|
76
|
+
f"FHE Vector backend only supports 1D data. Got shape {plaintext.shape}. "
|
|
77
|
+
"Use mplang.ops.fhe for multi-dimensional tensors."
|
|
78
|
+
)
|
|
79
|
+
return plaintext
|
|
80
|
+
|
|
81
|
+
|
|
82
|
+
@_fhe_MOD.simple_op()
|
|
83
|
+
def decrypt(ciphertext: TensorType, context: TensorType) -> TensorType:
|
|
84
|
+
"""Decrypt ciphertext using FHE Vector backend: returns plaintext with same semantic type.
|
|
85
|
+
|
|
86
|
+
Args:
|
|
87
|
+
ciphertext: Encrypted data to decrypt (scalar or 1D vector)
|
|
88
|
+
context: FHE context (must be private context with secret key)
|
|
89
|
+
|
|
90
|
+
Returns:
|
|
91
|
+
Plaintext with same semantic type as ciphertext
|
|
92
|
+
|
|
93
|
+
Note: Ciphertext encrypted with public context can be decrypted with
|
|
94
|
+
the corresponding private context.
|
|
95
|
+
"""
|
|
96
|
+
_ = context
|
|
97
|
+
return ciphertext
|
|
98
|
+
|
|
99
|
+
|
|
100
|
+
@_fhe_MOD.simple_op()
|
|
101
|
+
def add(operand1: TensorType, operand2: TensorType) -> TensorType:
|
|
102
|
+
"""Add two FHE operands (ciphertext + ciphertext or ciphertext + plaintext).
|
|
103
|
+
|
|
104
|
+
Args:
|
|
105
|
+
operand1: First operand (ciphertext or plaintext, scalar or 1D vector)
|
|
106
|
+
operand2: Second operand (ciphertext or plaintext, scalar or 1D vector)
|
|
107
|
+
|
|
108
|
+
Returns:
|
|
109
|
+
Result of homomorphic addition
|
|
110
|
+
|
|
111
|
+
Raises:
|
|
112
|
+
ValueError: If operands have incompatible shapes or dtypes
|
|
113
|
+
|
|
114
|
+
Note: At least one operand must be ciphertext. Both operands must have
|
|
115
|
+
the same shape (no broadcasting in Vector backend).
|
|
116
|
+
"""
|
|
117
|
+
assert operand1.dtype == operand2.dtype, (
|
|
118
|
+
f"Operand dtypes must match, got {operand1.dtype} and {operand2.dtype}."
|
|
119
|
+
)
|
|
120
|
+
assert operand1.shape == operand2.shape, (
|
|
121
|
+
f"Operand shapes must match, got {operand1.shape} and {operand2.shape}."
|
|
122
|
+
)
|
|
123
|
+
return operand1
|
|
124
|
+
|
|
125
|
+
|
|
126
|
+
@_fhe_MOD.simple_op()
|
|
127
|
+
def sub(operand1: TensorType, operand2: TensorType) -> TensorType:
|
|
128
|
+
"""Subtract two FHE operands (ciphertext - ciphertext or ciphertext - plaintext).
|
|
129
|
+
|
|
130
|
+
Args:
|
|
131
|
+
operand1: First operand (ciphertext or plaintext, scalar or 1D vector)
|
|
132
|
+
operand2: Second operand (ciphertext or plaintext, scalar or 1D vector)
|
|
133
|
+
|
|
134
|
+
Returns:
|
|
135
|
+
Result of homomorphic subtraction
|
|
136
|
+
|
|
137
|
+
Raises:
|
|
138
|
+
ValueError: If operands have incompatible shapes or dtypes
|
|
139
|
+
|
|
140
|
+
Note: At least one operand must be ciphertext. Both operands must have
|
|
141
|
+
the same shape (no broadcasting in Vector backend).
|
|
142
|
+
"""
|
|
143
|
+
assert operand1.dtype == operand2.dtype, (
|
|
144
|
+
f"Operand dtypes must match, got {operand1.dtype} and {operand2.dtype}."
|
|
145
|
+
)
|
|
146
|
+
assert operand1.shape == operand2.shape, (
|
|
147
|
+
f"Operand shapes must match, got {operand1.shape} and {operand2.shape}."
|
|
148
|
+
)
|
|
149
|
+
return operand1
|
|
150
|
+
|
|
151
|
+
|
|
152
|
+
@_fhe_MOD.simple_op()
|
|
153
|
+
def mul(operand1: TensorType, operand2: TensorType) -> TensorType:
|
|
154
|
+
"""Multiply two FHE operands (ciphertext * ciphertext or ciphertext * plaintext).
|
|
155
|
+
|
|
156
|
+
Args:
|
|
157
|
+
operand1: First operand (ciphertext or plaintext, scalar or 1D vector)
|
|
158
|
+
operand2: Second operand (ciphertext or plaintext, scalar or 1D vector)
|
|
159
|
+
|
|
160
|
+
Returns:
|
|
161
|
+
Result of homomorphic multiplication
|
|
162
|
+
|
|
163
|
+
Raises:
|
|
164
|
+
ValueError: If operands have incompatible shapes or dtypes
|
|
165
|
+
|
|
166
|
+
Note: At least one operand must be ciphertext. Both operands must have
|
|
167
|
+
the same shape (no broadcasting in Vector backend).
|
|
168
|
+
For BFV scheme, plaintext operands must be integers.
|
|
169
|
+
"""
|
|
170
|
+
assert operand1.dtype == operand2.dtype, (
|
|
171
|
+
f"Operand dtypes must match, got {operand1.dtype} and {operand2.dtype}."
|
|
172
|
+
)
|
|
173
|
+
assert operand1.shape == operand2.shape, (
|
|
174
|
+
f"Operand shapes must match, got {operand1.shape} and {operand2.shape}."
|
|
175
|
+
)
|
|
176
|
+
return operand1
|
|
177
|
+
|
|
178
|
+
|
|
179
|
+
@_fhe_MOD.simple_op()
|
|
180
|
+
def dot(operand1: TensorType, operand2: TensorType) -> TensorType:
|
|
181
|
+
"""Compute dot product of FHE operands (ciphertext · ciphertext or ciphertext · plaintext).
|
|
182
|
+
|
|
183
|
+
Args:
|
|
184
|
+
operand1: First operand (ciphertext or plaintext, must be 1D vector)
|
|
185
|
+
operand2: Second operand (ciphertext or plaintext, must be 1D vector)
|
|
186
|
+
|
|
187
|
+
Returns:
|
|
188
|
+
Scalar result of homomorphic dot product (shape=())
|
|
189
|
+
|
|
190
|
+
Raises:
|
|
191
|
+
ValueError: If operands are not 1D vectors or have different lengths
|
|
192
|
+
|
|
193
|
+
Note: Both operands must be 1D vectors (not scalars). For scalar multiplication,
|
|
194
|
+
use mul() instead. This operation always returns a scalar.
|
|
195
|
+
"""
|
|
196
|
+
if len(operand1.shape) != 1:
|
|
197
|
+
raise ValueError(
|
|
198
|
+
f"Dot product requires 1D vectors, got shape {operand1.shape} for operand1"
|
|
199
|
+
)
|
|
200
|
+
if len(operand2.shape) != 1:
|
|
201
|
+
raise ValueError(
|
|
202
|
+
f"Dot product requires 1D vectors, got shape {operand2.shape} for operand2"
|
|
203
|
+
)
|
|
204
|
+
if operand1.shape[0] != operand2.shape[0]:
|
|
205
|
+
raise ValueError(
|
|
206
|
+
f"Dot product dimension mismatch: {operand1.shape[0]} vs {operand2.shape[0]}"
|
|
207
|
+
)
|
|
208
|
+
|
|
209
|
+
# Dot product of 1D vectors returns a scalar
|
|
210
|
+
return TensorType(operand1.dtype, ())
|
|
211
|
+
|
|
212
|
+
|
|
213
|
+
@_fhe_MOD.simple_op()
|
|
214
|
+
def polyval(ciphertext: TensorType, coeffs: TensorType) -> TensorType:
|
|
215
|
+
"""Evaluate polynomial on encrypted data with plaintext coefficients.
|
|
216
|
+
|
|
217
|
+
Args:
|
|
218
|
+
ciphertext: Encrypted data (scalar or 1D vector)
|
|
219
|
+
coeffs: Plaintext polynomial coefficients as 1D array [c0, c1, c2, ...]
|
|
220
|
+
representing c0 + c1*x + c2*x^2 + ...
|
|
221
|
+
|
|
222
|
+
Returns:
|
|
223
|
+
Result of polynomial evaluation with same shape and dtype as ciphertext
|
|
224
|
+
|
|
225
|
+
Raises:
|
|
226
|
+
ValueError: If coefficients array is not 1D or has fewer than 2 elements
|
|
227
|
+
|
|
228
|
+
Note: Polynomial must have degree >= 1 (at least 2 coefficients required).
|
|
229
|
+
Constant polynomials (degree 0, single coefficient) are NOT supported due to
|
|
230
|
+
TenSEAL limitation. For constant values, use: ct * 0 + constant instead.
|
|
231
|
+
For BFV scheme, coefficients must be integers.
|
|
232
|
+
|
|
233
|
+
Common use case - Sigmoid approximation:
|
|
234
|
+
sigmoid_coeffs = [0.5, 0.15012, 0.0, -0.0018027]
|
|
235
|
+
result = polyval(ciphertext, sigmoid_coeffs)
|
|
236
|
+
"""
|
|
237
|
+
if len(coeffs.shape) != 1:
|
|
238
|
+
raise ValueError(
|
|
239
|
+
f"Polynomial coefficients must be 1D array, got shape {coeffs.shape}"
|
|
240
|
+
)
|
|
241
|
+
_ = coeffs
|
|
242
|
+
return ciphertext
|
|
243
|
+
|
|
244
|
+
|
|
245
|
+
@_fhe_MOD.simple_op()
|
|
246
|
+
def negate(ciphertext: TensorType) -> TensorType:
|
|
247
|
+
"""Negate encrypted data (unary minus).
|
|
248
|
+
|
|
249
|
+
Args:
|
|
250
|
+
ciphertext: Encrypted data (scalar or 1D vector)
|
|
251
|
+
|
|
252
|
+
Returns:
|
|
253
|
+
Negated ciphertext with same shape and dtype
|
|
254
|
+
|
|
255
|
+
Note: Equivalent to multiplying by -1.
|
|
256
|
+
"""
|
|
257
|
+
return ciphertext
|
|
258
|
+
|
|
259
|
+
|
|
260
|
+
@_fhe_MOD.simple_op()
|
|
261
|
+
def square(ciphertext: TensorType) -> TensorType:
|
|
262
|
+
"""Square encrypted data (element-wise).
|
|
263
|
+
|
|
264
|
+
Args:
|
|
265
|
+
ciphertext: Encrypted data (scalar or 1D vector)
|
|
266
|
+
|
|
267
|
+
Returns:
|
|
268
|
+
Squared ciphertext with same shape and dtype
|
|
269
|
+
|
|
270
|
+
Note: More efficient than mul(ciphertext, ciphertext) in some FHE schemes.
|
|
271
|
+
"""
|
|
272
|
+
return ciphertext
|