mplang-nightly 0.1.dev192__py3-none-any.whl → 0.1.dev268__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mplang/__init__.py +21 -130
- mplang/py.typed +13 -0
- mplang/v1/__init__.py +157 -0
- mplang/v1/_device.py +602 -0
- mplang/{analysis → v1/analysis}/__init__.py +1 -1
- mplang/{analysis → v1/analysis}/diagram.py +4 -4
- mplang/{core → v1/core}/__init__.py +20 -14
- mplang/{core → v1/core}/cluster.py +6 -1
- mplang/{core → v1/core}/comm.py +1 -1
- mplang/{core → v1/core}/context_mgr.py +1 -1
- mplang/{core → v1/core}/dtypes.py +38 -0
- mplang/{core → v1/core}/expr/__init__.py +7 -7
- mplang/{core → v1/core}/expr/ast.py +11 -13
- mplang/{core → v1/core}/expr/evaluator.py +8 -8
- mplang/{core → v1/core}/expr/printer.py +6 -6
- mplang/{core → v1/core}/expr/transformer.py +2 -2
- mplang/{core → v1/core}/expr/utils.py +2 -2
- mplang/{core → v1/core}/expr/visitor.py +1 -1
- mplang/{core → v1/core}/expr/walk.py +1 -1
- mplang/{core → v1/core}/interp.py +6 -6
- mplang/{core → v1/core}/mpir.py +13 -11
- mplang/{core → v1/core}/mpobject.py +6 -6
- mplang/{core → v1/core}/mptype.py +13 -10
- mplang/{core → v1/core}/pfunc.py +2 -2
- mplang/{core → v1/core}/primitive.py +12 -12
- mplang/{core → v1/core}/table.py +36 -8
- mplang/{core → v1/core}/tensor.py +1 -1
- mplang/{core → v1/core}/tracer.py +9 -9
- mplang/{host.py → v1/host.py} +5 -5
- mplang/{kernels → v1/kernels}/__init__.py +1 -1
- mplang/{kernels → v1/kernels}/base.py +1 -1
- mplang/{kernels → v1/kernels}/basic.py +15 -15
- mplang/{kernels → v1/kernels}/context.py +19 -16
- mplang/{kernels → v1/kernels}/crypto.py +8 -10
- mplang/{kernels → v1/kernels}/fhe.py +9 -7
- mplang/{kernels → v1/kernels}/mock_tee.py +3 -3
- mplang/{kernels → v1/kernels}/phe.py +26 -18
- mplang/{kernels → v1/kernels}/spu.py +5 -5
- mplang/{kernels → v1/kernels}/sql_duckdb.py +5 -3
- mplang/{kernels → v1/kernels}/stablehlo.py +18 -17
- mplang/{kernels → v1/kernels}/value.py +2 -2
- mplang/{ops → v1/ops}/__init__.py +3 -3
- mplang/{ops → v1/ops}/base.py +1 -1
- mplang/{ops → v1/ops}/basic.py +6 -5
- mplang/v1/ops/crypto.py +262 -0
- mplang/{ops → v1/ops}/fhe.py +2 -2
- mplang/{ops → v1/ops}/jax_cc.py +26 -59
- mplang/v1/ops/nnx_cc.py +168 -0
- mplang/{ops → v1/ops}/phe.py +16 -3
- mplang/{ops → v1/ops}/spu.py +3 -3
- mplang/v1/ops/sql_cc.py +303 -0
- mplang/{ops → v1/ops}/tee.py +2 -2
- mplang/{runtime → v1/runtime}/__init__.py +2 -2
- mplang/v1/runtime/channel.py +230 -0
- mplang/{runtime → v1/runtime}/cli.py +3 -3
- mplang/{runtime → v1/runtime}/client.py +1 -1
- mplang/{runtime → v1/runtime}/communicator.py +39 -15
- mplang/{runtime → v1/runtime}/data_providers.py +80 -19
- mplang/{runtime → v1/runtime}/driver.py +4 -4
- mplang/v1/runtime/link_comm.py +196 -0
- mplang/{runtime → v1/runtime}/server.py +22 -9
- mplang/{runtime → v1/runtime}/session.py +24 -51
- mplang/{runtime → v1/runtime}/simulation.py +36 -14
- mplang/{simp → v1/simp}/api.py +72 -14
- mplang/{simp → v1/simp}/mpi.py +1 -1
- mplang/{simp → v1/simp}/party.py +5 -5
- mplang/{simp → v1/simp}/random.py +2 -2
- mplang/v1/simp/smpc.py +238 -0
- mplang/v1/utils/table_utils.py +185 -0
- mplang/v2/__init__.py +424 -0
- mplang/v2/backends/__init__.py +57 -0
- mplang/v2/backends/bfv_impl.py +705 -0
- mplang/v2/backends/channel.py +217 -0
- mplang/v2/backends/crypto_impl.py +723 -0
- mplang/v2/backends/field_impl.py +454 -0
- mplang/v2/backends/func_impl.py +107 -0
- mplang/v2/backends/phe_impl.py +148 -0
- mplang/v2/backends/simp_design.md +136 -0
- mplang/v2/backends/simp_driver/__init__.py +41 -0
- mplang/v2/backends/simp_driver/http.py +168 -0
- mplang/v2/backends/simp_driver/mem.py +280 -0
- mplang/v2/backends/simp_driver/ops.py +135 -0
- mplang/v2/backends/simp_driver/state.py +60 -0
- mplang/v2/backends/simp_driver/values.py +52 -0
- mplang/v2/backends/simp_worker/__init__.py +29 -0
- mplang/v2/backends/simp_worker/http.py +354 -0
- mplang/v2/backends/simp_worker/mem.py +102 -0
- mplang/v2/backends/simp_worker/ops.py +167 -0
- mplang/v2/backends/simp_worker/state.py +49 -0
- mplang/v2/backends/spu_impl.py +275 -0
- mplang/v2/backends/spu_state.py +187 -0
- mplang/v2/backends/store_impl.py +62 -0
- mplang/v2/backends/table_impl.py +838 -0
- mplang/v2/backends/tee_impl.py +215 -0
- mplang/v2/backends/tensor_impl.py +519 -0
- mplang/v2/cli.py +603 -0
- mplang/v2/cli_guide.md +122 -0
- mplang/v2/dialects/__init__.py +36 -0
- mplang/v2/dialects/bfv.py +665 -0
- mplang/v2/dialects/crypto.py +689 -0
- mplang/v2/dialects/dtypes.py +378 -0
- mplang/v2/dialects/field.py +210 -0
- mplang/v2/dialects/func.py +135 -0
- mplang/v2/dialects/phe.py +723 -0
- mplang/v2/dialects/simp.py +944 -0
- mplang/v2/dialects/spu.py +349 -0
- mplang/v2/dialects/store.py +63 -0
- mplang/v2/dialects/table.py +407 -0
- mplang/v2/dialects/tee.py +346 -0
- mplang/v2/dialects/tensor.py +1175 -0
- mplang/v2/edsl/README.md +279 -0
- mplang/v2/edsl/__init__.py +99 -0
- mplang/v2/edsl/context.py +311 -0
- mplang/v2/edsl/graph.py +463 -0
- mplang/v2/edsl/jit.py +62 -0
- mplang/v2/edsl/object.py +53 -0
- mplang/v2/edsl/primitive.py +284 -0
- mplang/v2/edsl/printer.py +119 -0
- mplang/v2/edsl/registry.py +207 -0
- mplang/v2/edsl/serde.py +375 -0
- mplang/v2/edsl/tracer.py +614 -0
- mplang/v2/edsl/typing.py +816 -0
- mplang/v2/kernels/Makefile +30 -0
- mplang/v2/kernels/__init__.py +23 -0
- mplang/v2/kernels/gf128.cpp +148 -0
- mplang/v2/kernels/ldpc.cpp +82 -0
- mplang/v2/kernels/okvs.cpp +283 -0
- mplang/v2/kernels/okvs_opt.cpp +291 -0
- mplang/v2/kernels/py_kernels.py +398 -0
- mplang/v2/libs/collective.py +330 -0
- mplang/v2/libs/device/__init__.py +51 -0
- mplang/v2/libs/device/api.py +813 -0
- mplang/v2/libs/device/cluster.py +352 -0
- mplang/v2/libs/ml/__init__.py +23 -0
- mplang/v2/libs/ml/sgb.py +1861 -0
- mplang/v2/libs/mpc/__init__.py +41 -0
- mplang/v2/libs/mpc/_utils.py +99 -0
- mplang/v2/libs/mpc/analytics/__init__.py +35 -0
- mplang/v2/libs/mpc/analytics/aggregation.py +372 -0
- mplang/v2/libs/mpc/analytics/groupby.md +99 -0
- mplang/v2/libs/mpc/analytics/groupby.py +331 -0
- mplang/v2/libs/mpc/analytics/permutation.py +386 -0
- mplang/v2/libs/mpc/common/constants.py +39 -0
- mplang/v2/libs/mpc/ot/__init__.py +32 -0
- mplang/v2/libs/mpc/ot/base.py +222 -0
- mplang/v2/libs/mpc/ot/extension.py +477 -0
- mplang/v2/libs/mpc/ot/silent.py +217 -0
- mplang/v2/libs/mpc/psi/__init__.py +40 -0
- mplang/v2/libs/mpc/psi/cuckoo.py +228 -0
- mplang/v2/libs/mpc/psi/okvs.py +49 -0
- mplang/v2/libs/mpc/psi/okvs_gct.py +79 -0
- mplang/v2/libs/mpc/psi/oprf.py +310 -0
- mplang/v2/libs/mpc/psi/rr22.py +344 -0
- mplang/v2/libs/mpc/psi/unbalanced.py +200 -0
- mplang/v2/libs/mpc/vole/__init__.py +31 -0
- mplang/v2/libs/mpc/vole/gilboa.py +327 -0
- mplang/v2/libs/mpc/vole/ldpc.py +383 -0
- mplang/v2/libs/mpc/vole/silver.py +336 -0
- mplang/v2/runtime/__init__.py +15 -0
- mplang/v2/runtime/dialect_state.py +41 -0
- mplang/v2/runtime/interpreter.py +871 -0
- mplang/v2/runtime/object_store.py +194 -0
- mplang/v2/runtime/value.py +141 -0
- {mplang_nightly-0.1.dev192.dist-info → mplang_nightly-0.1.dev268.dist-info}/METADATA +22 -16
- mplang_nightly-0.1.dev268.dist-info/RECORD +180 -0
- {mplang_nightly-0.1.dev192.dist-info → mplang_nightly-0.1.dev268.dist-info}/WHEEL +1 -1
- mplang/device.py +0 -327
- mplang/ops/crypto.py +0 -108
- mplang/ops/ibis_cc.py +0 -136
- mplang/ops/sql_cc.py +0 -62
- mplang/runtime/link_comm.py +0 -78
- mplang/simp/smpc.py +0 -201
- mplang/utils/table_utils.py +0 -85
- mplang_nightly-0.1.dev192.dist-info/RECORD +0 -83
- /mplang/{core → v1/core}/mask.py +0 -0
- /mplang/{protos → v1/protos}/v1alpha1/mpir_pb2.py +0 -0
- /mplang/{protos → v1/protos}/v1alpha1/mpir_pb2.pyi +0 -0
- /mplang/{protos → v1/protos}/v1alpha1/value_pb2.py +0 -0
- /mplang/{protos → v1/protos}/v1alpha1/value_pb2.pyi +0 -0
- /mplang/{runtime → v1/runtime}/exceptions.py +0 -0
- /mplang/{runtime → v1/runtime}/http_api.md +0 -0
- /mplang/{simp → v1/simp}/__init__.py +0 -0
- /mplang/{utils → v1/utils}/__init__.py +0 -0
- /mplang/{utils → v1/utils}/crypto.py +0 -0
- /mplang/{utils → v1/utils}/func_utils.py +0 -0
- /mplang/{utils → v1/utils}/spu_utils.py +0 -0
- {mplang_nightly-0.1.dev192.dist-info → mplang_nightly-0.1.dev268.dist-info}/entry_points.txt +0 -0
- {mplang_nightly-0.1.dev192.dist-info → mplang_nightly-0.1.dev268.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,723 @@
|
|
|
1
|
+
# Copyright 2025 Ant Group Co., Ltd.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
"""PHE (Partially Homomorphic Encryption) dialect for the EDSL.
|
|
16
|
+
|
|
17
|
+
Design principles:
|
|
18
|
+
- Separate encoding from encryption for semantic clarity
|
|
19
|
+
- Element-level primitives operate on encoded integers
|
|
20
|
+
- Reuse `tensor.elementwise` to lift primitives across tensors
|
|
21
|
+
- Provide ergonomic wrappers for common workflows
|
|
22
|
+
|
|
23
|
+
Architecture:
|
|
24
|
+
Source Type (f64, i32, etc.)
|
|
25
|
+
↓ encode(encoder)
|
|
26
|
+
Encoded Integer (i64)
|
|
27
|
+
↓ encrypt(pk)
|
|
28
|
+
Ciphertext (CiphertextType)
|
|
29
|
+
↓ homomorphic operations
|
|
30
|
+
Ciphertext (CiphertextType)
|
|
31
|
+
↓ decrypt(sk)
|
|
32
|
+
Encoded Integer (i64)
|
|
33
|
+
↓ decode(encoder)
|
|
34
|
+
Source Type (f64, i32, etc.)
|
|
35
|
+
|
|
36
|
+
Example:
|
|
37
|
+
```python
|
|
38
|
+
from mplang.v2.dialects import tensor, phe
|
|
39
|
+
import mplang.v2.edsl.typing as elt
|
|
40
|
+
import numpy as np
|
|
41
|
+
|
|
42
|
+
# 1. Generate keys (cryptographic only)
|
|
43
|
+
pk, sk = phe.keygen()
|
|
44
|
+
|
|
45
|
+
# 2. Create encoder (encoding parameters)
|
|
46
|
+
encoder = phe.create_encoder(dtype=elt.f64, fxp_bits=16)
|
|
47
|
+
|
|
48
|
+
# 3. Encode data
|
|
49
|
+
x = tensor.constant(np.array([1.0, 2.0, 3.0]))
|
|
50
|
+
y = tensor.constant(np.array([4.0, 5.0, 6.0]))
|
|
51
|
+
x_enc = phe.encode(x, encoder) # f64 → i64
|
|
52
|
+
y_enc = phe.encode(y, encoder) # f64 → i64
|
|
53
|
+
|
|
54
|
+
# 4. Encrypt
|
|
55
|
+
ct_x = phe.encrypt(x_enc, pk) # i64 → CiphertextType
|
|
56
|
+
ct_y = phe.encrypt(y_enc, pk) # i64 → CiphertextType
|
|
57
|
+
|
|
58
|
+
# 5. Homomorphic operations
|
|
59
|
+
ct_sum = phe.add(ct_x, ct_y) # CiphertextType + CiphertextType
|
|
60
|
+
|
|
61
|
+
# 6. Decrypt and decode
|
|
62
|
+
sum_enc = phe.decrypt(ct_sum, sk) # CiphertextType → i64
|
|
63
|
+
result = phe.decode(sum_enc, encoder) # i64 → f64
|
|
64
|
+
```
|
|
65
|
+
|
|
66
|
+
For convenience, auto wrappers combine encode+encrypt and decrypt+decode:
|
|
67
|
+
```python
|
|
68
|
+
ct = phe.encrypt_auto(x, encoder, pk)
|
|
69
|
+
result = phe.decrypt_auto(ct, encoder, sk)
|
|
70
|
+
```
|
|
71
|
+
"""
|
|
72
|
+
|
|
73
|
+
from __future__ import annotations
|
|
74
|
+
|
|
75
|
+
from collections.abc import Callable
|
|
76
|
+
from typing import Any, NamedTuple
|
|
77
|
+
|
|
78
|
+
import mplang.v2.edsl as el
|
|
79
|
+
import mplang.v2.edsl.typing as elt
|
|
80
|
+
from mplang.v2.dialects import tensor
|
|
81
|
+
|
|
82
|
+
# ==============================================================================
|
|
83
|
+
# --- Type Definitions
|
|
84
|
+
# ==============================================================================
|
|
85
|
+
|
|
86
|
+
|
|
87
|
+
class KeyType(elt.BaseType):
|
|
88
|
+
"""Type for PHE keys carrying scheme information."""
|
|
89
|
+
|
|
90
|
+
def __init__(self, scheme: str, is_public: bool):
|
|
91
|
+
self.scheme = scheme
|
|
92
|
+
self.is_public = is_public
|
|
93
|
+
|
|
94
|
+
def __str__(self) -> str:
|
|
95
|
+
kind = "P" if self.is_public else "S"
|
|
96
|
+
return f"{kind}Key[{self.scheme}]"
|
|
97
|
+
|
|
98
|
+
def __eq__(self, other: object) -> bool:
|
|
99
|
+
if not isinstance(other, KeyType):
|
|
100
|
+
return False
|
|
101
|
+
return self.scheme == other.scheme and self.is_public == other.is_public
|
|
102
|
+
|
|
103
|
+
def __hash__(self) -> int:
|
|
104
|
+
return hash((self.scheme, self.is_public))
|
|
105
|
+
|
|
106
|
+
|
|
107
|
+
class PlaintextType(elt.ScalarType):
|
|
108
|
+
"""Represents an encoded integer ready for PHE encryption.
|
|
109
|
+
|
|
110
|
+
This type wraps the underlying integer representation (typically i64 or i128)
|
|
111
|
+
to distinguish it from regular integers. This ensures type safety by preventing
|
|
112
|
+
accidental encryption of raw integers or arithmetic between encoded and raw values.
|
|
113
|
+
"""
|
|
114
|
+
|
|
115
|
+
def __init__(self, bitwidth: int = 64):
|
|
116
|
+
self.bitwidth = bitwidth
|
|
117
|
+
|
|
118
|
+
def __str__(self) -> str:
|
|
119
|
+
return f"PT[i{self.bitwidth}]"
|
|
120
|
+
|
|
121
|
+
def __eq__(self, other: object) -> bool:
|
|
122
|
+
if not isinstance(other, PlaintextType):
|
|
123
|
+
return False
|
|
124
|
+
return self.bitwidth == other.bitwidth
|
|
125
|
+
|
|
126
|
+
def __hash__(self) -> int:
|
|
127
|
+
return hash(("PlaintextType", self.bitwidth))
|
|
128
|
+
|
|
129
|
+
|
|
130
|
+
class CiphertextType(elt.ScalarType, elt.EncryptedTrait):
|
|
131
|
+
"""Represents a single scalar value encrypted with a PHE scheme.
|
|
132
|
+
|
|
133
|
+
Inherits from ScalarType, so it can be used as a tensor element type.
|
|
134
|
+
"""
|
|
135
|
+
|
|
136
|
+
def __init__(self, scheme: str):
|
|
137
|
+
self._scheme = scheme
|
|
138
|
+
|
|
139
|
+
@property
|
|
140
|
+
def scheme(self) -> str:
|
|
141
|
+
return self._scheme
|
|
142
|
+
|
|
143
|
+
def __str__(self) -> str:
|
|
144
|
+
return f"CT[{self._scheme}]"
|
|
145
|
+
|
|
146
|
+
def __eq__(self, other: object) -> bool:
|
|
147
|
+
if not isinstance(other, CiphertextType):
|
|
148
|
+
return False
|
|
149
|
+
return self._scheme == other._scheme
|
|
150
|
+
|
|
151
|
+
def __hash__(self) -> int:
|
|
152
|
+
return hash(("CiphertextType", self._scheme))
|
|
153
|
+
|
|
154
|
+
|
|
155
|
+
# Opaque types for PHE (singleton instances)
|
|
156
|
+
EncoderType: elt.CustomType = elt.CustomType("Encoder")
|
|
157
|
+
|
|
158
|
+
# ==============================================================================
|
|
159
|
+
# --- Key Management Operations
|
|
160
|
+
# ==============================================================================
|
|
161
|
+
|
|
162
|
+
keygen_p = el.Primitive[tuple[el.Object, el.Object]]("phe.keygen")
|
|
163
|
+
|
|
164
|
+
|
|
165
|
+
@keygen_p.def_abstract_eval
|
|
166
|
+
def _keygen_ae(
|
|
167
|
+
*,
|
|
168
|
+
scheme: str = "paillier",
|
|
169
|
+
key_size: int = 2048,
|
|
170
|
+
) -> tuple[KeyType, KeyType]:
|
|
171
|
+
"""Generate PHE key pair (cryptographic parameters only).
|
|
172
|
+
|
|
173
|
+
Args:
|
|
174
|
+
scheme: PHE scheme name (e.g., "paillier", "elgamal")
|
|
175
|
+
key_size: Key size in bits (default: 2048)
|
|
176
|
+
|
|
177
|
+
Returns:
|
|
178
|
+
Tuple of (PublicKey, PrivateKey) with scheme info
|
|
179
|
+
"""
|
|
180
|
+
return (KeyType(scheme, True), KeyType(scheme, False))
|
|
181
|
+
|
|
182
|
+
|
|
183
|
+
# ==============================================================================
|
|
184
|
+
# --- Encoder Operations
|
|
185
|
+
# ==============================================================================
|
|
186
|
+
|
|
187
|
+
create_encoder_p = el.Primitive[el.Object]("phe.create_encoder")
|
|
188
|
+
encode_p = el.Primitive[el.Object]("phe.encode")
|
|
189
|
+
decode_p = el.Primitive[el.Object]("phe.decode")
|
|
190
|
+
|
|
191
|
+
|
|
192
|
+
@create_encoder_p.def_abstract_eval
|
|
193
|
+
def _create_encoder_ae(
|
|
194
|
+
*,
|
|
195
|
+
dtype: elt.ScalarType,
|
|
196
|
+
fxp_bits: int = 16,
|
|
197
|
+
max_value: int | None = None,
|
|
198
|
+
) -> elt.CustomType:
|
|
199
|
+
"""Create PHE encoder for type conversion and fixed-point encoding.
|
|
200
|
+
|
|
201
|
+
Args:
|
|
202
|
+
dtype: Source data type (f32, f64, i32, i64, etc.)
|
|
203
|
+
fxp_bits: Fixed-point fractional bits for float types (default: 16)
|
|
204
|
+
max_value: Optional maximum value for range checking
|
|
205
|
+
|
|
206
|
+
Returns:
|
|
207
|
+
EncoderType configured for the specified dtype
|
|
208
|
+
"""
|
|
209
|
+
if not isinstance(dtype, elt.ScalarType):
|
|
210
|
+
raise TypeError(f"dtype must be ScalarType, got {type(dtype).__name__}")
|
|
211
|
+
return EncoderType
|
|
212
|
+
|
|
213
|
+
|
|
214
|
+
@encode_p.def_abstract_eval
|
|
215
|
+
def _encode_ae(value: elt.ScalarType, encoder: elt.CustomType) -> PlaintextType:
|
|
216
|
+
"""Encode scalar value to fixed-point integer representation.
|
|
217
|
+
|
|
218
|
+
Args:
|
|
219
|
+
value: Source value (f32, f64, i32, etc.)
|
|
220
|
+
encoder: PHE encoder with encoding parameters
|
|
221
|
+
|
|
222
|
+
Returns:
|
|
223
|
+
Encoded integer (PlaintextType)
|
|
224
|
+
|
|
225
|
+
Raises:
|
|
226
|
+
TypeError: If encoder is not EncoderType
|
|
227
|
+
"""
|
|
228
|
+
if encoder != EncoderType:
|
|
229
|
+
raise TypeError(f"Expected Encoder, got {encoder}")
|
|
230
|
+
if not isinstance(value, elt.ScalarType):
|
|
231
|
+
raise TypeError(f"Can only encode ScalarType, got {value}")
|
|
232
|
+
# Return sufficient integer type for encoded values
|
|
233
|
+
return PlaintextType(bitwidth=64)
|
|
234
|
+
|
|
235
|
+
|
|
236
|
+
@decode_p.def_abstract_eval
|
|
237
|
+
def _decode_ae(encoded: PlaintextType, encoder: elt.CustomType) -> elt.ScalarType:
|
|
238
|
+
"""Decode fixed-point integer back to original scalar type.
|
|
239
|
+
|
|
240
|
+
Args:
|
|
241
|
+
encoded: Encoded integer value
|
|
242
|
+
encoder: PHE encoder (contains target dtype)
|
|
243
|
+
|
|
244
|
+
Returns:
|
|
245
|
+
Decoded value in original type (inferred from encoder's dtype)
|
|
246
|
+
|
|
247
|
+
Raises:
|
|
248
|
+
TypeError: If encoder is not EncoderType or encoded is not PlaintextType
|
|
249
|
+
"""
|
|
250
|
+
if encoder != EncoderType:
|
|
251
|
+
raise TypeError(f"Expected Encoder, got {encoder}")
|
|
252
|
+
if not isinstance(encoded, PlaintextType):
|
|
253
|
+
raise TypeError(f"Can only decode PlaintextType, got {encoded}")
|
|
254
|
+
# In real implementation, would extract dtype from encoder attrs
|
|
255
|
+
# For now, return a default (this will be improved with attr introspection)
|
|
256
|
+
return elt.f64
|
|
257
|
+
|
|
258
|
+
|
|
259
|
+
# ==============================================================================
|
|
260
|
+
# --- Encryption/Decryption Operations (Integer only)
|
|
261
|
+
# ==============================================================================
|
|
262
|
+
|
|
263
|
+
encrypt_p = el.Primitive[el.Object]("phe.encrypt")
|
|
264
|
+
decrypt_p = el.Primitive[el.Object]("phe.decrypt")
|
|
265
|
+
|
|
266
|
+
|
|
267
|
+
@encrypt_p.def_abstract_eval
|
|
268
|
+
def _encrypt_ae(encoded: PlaintextType, pk: KeyType) -> CiphertextType:
|
|
269
|
+
"""Encrypt encoded integer using PHE public key.
|
|
270
|
+
|
|
271
|
+
Args:
|
|
272
|
+
encoded: Encoded integer (from phe.encode)
|
|
273
|
+
pk: PHE public key
|
|
274
|
+
|
|
275
|
+
Returns:
|
|
276
|
+
CiphertextType - encrypted integer
|
|
277
|
+
|
|
278
|
+
Raises:
|
|
279
|
+
TypeError: If input is not PlaintextType or pk is not PublicKey
|
|
280
|
+
"""
|
|
281
|
+
if not isinstance(pk, KeyType) or not pk.is_public:
|
|
282
|
+
raise TypeError(f"Expected PublicKey, got {pk}")
|
|
283
|
+
if not isinstance(encoded, PlaintextType):
|
|
284
|
+
raise TypeError(f"Can only encrypt PlaintextType, got {encoded}")
|
|
285
|
+
return CiphertextType(pk.scheme)
|
|
286
|
+
|
|
287
|
+
|
|
288
|
+
@decrypt_p.def_abstract_eval
|
|
289
|
+
def _decrypt_ae(ct: CiphertextType, sk: KeyType) -> PlaintextType:
|
|
290
|
+
"""Decrypt ciphertext to encoded integer using PHE private key.
|
|
291
|
+
|
|
292
|
+
Args:
|
|
293
|
+
ct: Encrypted integer
|
|
294
|
+
sk: PHE private key
|
|
295
|
+
|
|
296
|
+
Returns:
|
|
297
|
+
Decrypted encoded integer
|
|
298
|
+
|
|
299
|
+
Raises:
|
|
300
|
+
TypeError: If ct is not CiphertextType or sk is not PrivateKey
|
|
301
|
+
"""
|
|
302
|
+
if not isinstance(sk, KeyType) or sk.is_public:
|
|
303
|
+
raise TypeError(f"Expected PrivateKey, got {sk}")
|
|
304
|
+
if not isinstance(ct, CiphertextType):
|
|
305
|
+
raise TypeError(f"Expected CiphertextType, got {ct}")
|
|
306
|
+
# We assume it decrypts to i64 (standard encoded integer)
|
|
307
|
+
return PlaintextType(bitwidth=64)
|
|
308
|
+
|
|
309
|
+
|
|
310
|
+
# ==============================================================================
|
|
311
|
+
# --- Element-level Homomorphic Operations
|
|
312
|
+
# ==============================================================================
|
|
313
|
+
|
|
314
|
+
add_cc_p = el.Primitive[el.Object]("phe.add_cc")
|
|
315
|
+
add_cp_p = el.Primitive[el.Object]("phe.add_cp")
|
|
316
|
+
mul_cp_p = el.Primitive[el.Object]("phe.mul_cp")
|
|
317
|
+
|
|
318
|
+
|
|
319
|
+
@add_cc_p.def_abstract_eval
|
|
320
|
+
def _add_cc_ae(operand1: CiphertextType, operand2: CiphertextType) -> CiphertextType:
|
|
321
|
+
"""Ciphertext + ciphertext → ciphertext."""
|
|
322
|
+
if not isinstance(operand1, CiphertextType) or not isinstance(
|
|
323
|
+
operand2, CiphertextType
|
|
324
|
+
):
|
|
325
|
+
raise TypeError(f"Expected CiphertextType operands, got {operand1}, {operand2}")
|
|
326
|
+
if operand1 != operand2:
|
|
327
|
+
raise TypeError(f"Scheme mismatch: {operand1} vs {operand2}")
|
|
328
|
+
return operand1
|
|
329
|
+
|
|
330
|
+
|
|
331
|
+
@add_cp_p.def_abstract_eval
|
|
332
|
+
def _add_cp_ae(ciphertext: CiphertextType, plaintext: PlaintextType) -> CiphertextType:
|
|
333
|
+
"""Ciphertext + plaintext → ciphertext."""
|
|
334
|
+
if not isinstance(ciphertext, CiphertextType):
|
|
335
|
+
raise TypeError(f"Expected CiphertextType ciphertext, got {ciphertext}")
|
|
336
|
+
if not isinstance(plaintext, PlaintextType):
|
|
337
|
+
raise TypeError(
|
|
338
|
+
f"Plaintext operand must be PlaintextType (encoded), got {plaintext}"
|
|
339
|
+
)
|
|
340
|
+
return ciphertext
|
|
341
|
+
|
|
342
|
+
|
|
343
|
+
@mul_cp_p.def_abstract_eval
|
|
344
|
+
def _mul_cp_ae(ciphertext: CiphertextType, plaintext: PlaintextType) -> CiphertextType:
|
|
345
|
+
"""Element-level homomorphic scalar multiplication.
|
|
346
|
+
|
|
347
|
+
Args:
|
|
348
|
+
ciphertext: Encrypted value
|
|
349
|
+
plaintext: Encoded integer scalar
|
|
350
|
+
|
|
351
|
+
Returns:
|
|
352
|
+
Encrypted product
|
|
353
|
+
"""
|
|
354
|
+
if not isinstance(ciphertext, CiphertextType):
|
|
355
|
+
raise TypeError(f"Expected CiphertextType ciphertext, got {ciphertext}")
|
|
356
|
+
if not isinstance(plaintext, PlaintextType):
|
|
357
|
+
raise TypeError(
|
|
358
|
+
f"Plaintext operand must be PlaintextType (encoded), got {plaintext}"
|
|
359
|
+
)
|
|
360
|
+
return ciphertext
|
|
361
|
+
|
|
362
|
+
|
|
363
|
+
# ==============================================================================
|
|
364
|
+
# --- User-facing API
|
|
365
|
+
# ==============================================================================
|
|
366
|
+
|
|
367
|
+
|
|
368
|
+
def keygen(
|
|
369
|
+
scheme: str = "paillier",
|
|
370
|
+
key_size: int = 2048,
|
|
371
|
+
) -> tuple[el.Object, el.Object]:
|
|
372
|
+
"""Generate PHE key pair (cryptographic parameters only).
|
|
373
|
+
|
|
374
|
+
Encoding parameters (fxp_bits, max_value) are now separate via create_encoder().
|
|
375
|
+
|
|
376
|
+
Args:
|
|
377
|
+
scheme: PHE scheme name (default: "paillier")
|
|
378
|
+
Supported: "paillier", "elgamal", "okamoto-uchiyama"
|
|
379
|
+
key_size: Key size in bits (default: 2048)
|
|
380
|
+
Larger keys = more security but slower computation
|
|
381
|
+
|
|
382
|
+
Returns:
|
|
383
|
+
Tuple of (public_key, private_key)
|
|
384
|
+
|
|
385
|
+
Example:
|
|
386
|
+
>>> # Basic usage
|
|
387
|
+
>>> pk, sk = phe.keygen()
|
|
388
|
+
>>>
|
|
389
|
+
>>> # Higher security
|
|
390
|
+
>>> pk, sk = phe.keygen(key_size=4096)
|
|
391
|
+
"""
|
|
392
|
+
return keygen_p.bind(scheme=scheme, key_size=key_size)
|
|
393
|
+
|
|
394
|
+
|
|
395
|
+
def create_encoder(
|
|
396
|
+
dtype: elt.ScalarType,
|
|
397
|
+
fxp_bits: int = 16,
|
|
398
|
+
max_value: int | None = None,
|
|
399
|
+
) -> el.Object:
|
|
400
|
+
"""Create PHE encoder for value encoding/decoding.
|
|
401
|
+
|
|
402
|
+
Encoders are independent of keys and handle type conversion and
|
|
403
|
+
fixed-point representation for homomorphic operations.
|
|
404
|
+
|
|
405
|
+
Args:
|
|
406
|
+
dtype: Source data type (e.g., elt.f64, elt.i32)
|
|
407
|
+
Determines encoding/decoding behavior
|
|
408
|
+
fxp_bits: Fixed-point fractional bits for float types (default: 16)
|
|
409
|
+
Higher = more precision but smaller value range
|
|
410
|
+
Example: fxp_bits=16 means precision ≈ 1/65536
|
|
411
|
+
max_value: Optional maximum absolute value for overflow checking
|
|
412
|
+
Example: max_value=2**32 ensures |encoded_value| < 2**32
|
|
413
|
+
|
|
414
|
+
Returns:
|
|
415
|
+
PHE encoder configured for the specified dtype
|
|
416
|
+
|
|
417
|
+
Example:
|
|
418
|
+
>>> import mplang.v2.edsl.typing as elt
|
|
419
|
+
>>>
|
|
420
|
+
>>> # Float encoder with 16-bit fractional precision
|
|
421
|
+
>>> encoder_f64 = phe.create_encoder(dtype=elt.f64, fxp_bits=16)
|
|
422
|
+
>>>
|
|
423
|
+
>>> # Higher precision for sensitive computations
|
|
424
|
+
>>> encoder_hp = phe.create_encoder(dtype=elt.f64, fxp_bits=32)
|
|
425
|
+
>>>
|
|
426
|
+
>>> # Integer encoder (no fixed-point needed)
|
|
427
|
+
>>> encoder_i32 = phe.create_encoder(dtype=elt.i32)
|
|
428
|
+
"""
|
|
429
|
+
attrs: dict[str, Any] = {
|
|
430
|
+
"dtype": dtype,
|
|
431
|
+
"fxp_bits": fxp_bits,
|
|
432
|
+
}
|
|
433
|
+
if max_value is not None:
|
|
434
|
+
attrs["max_value"] = max_value
|
|
435
|
+
return create_encoder_p.bind(**attrs)
|
|
436
|
+
|
|
437
|
+
|
|
438
|
+
def _has_tensor_args(*objs: el.Object) -> bool:
|
|
439
|
+
"""Check whether any argument carries a TensorType."""
|
|
440
|
+
return any(isinstance(obj.type, elt.TensorType) for obj in objs)
|
|
441
|
+
|
|
442
|
+
|
|
443
|
+
class OperandInfo(NamedTuple):
|
|
444
|
+
"""Classification of operand for PHE operation dispatch."""
|
|
445
|
+
|
|
446
|
+
is_tensor: bool
|
|
447
|
+
is_encrypted: bool
|
|
448
|
+
scalar_type: elt.BaseType | None
|
|
449
|
+
|
|
450
|
+
|
|
451
|
+
def _inspect_operand(obj: el.Object) -> OperandInfo:
|
|
452
|
+
"""Classify operand layout/security for dispatch."""
|
|
453
|
+
obj_type = obj.type
|
|
454
|
+
if isinstance(obj_type, elt.TensorType):
|
|
455
|
+
elem = obj_type.element_type
|
|
456
|
+
if isinstance(elem, CiphertextType):
|
|
457
|
+
return OperandInfo(True, True, None)
|
|
458
|
+
if isinstance(elem, elt.ScalarType):
|
|
459
|
+
return OperandInfo(True, False, elem)
|
|
460
|
+
raise TypeError(
|
|
461
|
+
f"PHE operations support Tensor[ScalarType] or Tensor[CiphertextType], got Tensor[{elem}]"
|
|
462
|
+
)
|
|
463
|
+
if isinstance(obj_type, CiphertextType):
|
|
464
|
+
return OperandInfo(False, True, None)
|
|
465
|
+
if isinstance(obj_type, elt.ScalarType):
|
|
466
|
+
return OperandInfo(False, False, obj_type)
|
|
467
|
+
raise TypeError(f"PHE operations expect Scalar or Tensor operands, got {obj_type}")
|
|
468
|
+
|
|
469
|
+
|
|
470
|
+
BinaryFn = Callable[[el.Object, el.Object], el.Object]
|
|
471
|
+
|
|
472
|
+
|
|
473
|
+
def _apply_binary(fn: BinaryFn, lhs: el.Object, rhs: el.Object) -> el.Object:
|
|
474
|
+
"""Apply scalar primitive, lifting to tensor.elementwise when needed."""
|
|
475
|
+
if _has_tensor_args(lhs, rhs):
|
|
476
|
+
return tensor.elementwise(fn, lhs, rhs)
|
|
477
|
+
return fn(lhs, rhs)
|
|
478
|
+
|
|
479
|
+
|
|
480
|
+
def _add_cp(ciphertext: el.Object, plaintext: el.Object) -> el.Object:
|
|
481
|
+
"""Ciphertext ⊕ plaintext helper (order enforced)."""
|
|
482
|
+
return _apply_binary(add_cp_p.bind, ciphertext, plaintext)
|
|
483
|
+
|
|
484
|
+
|
|
485
|
+
def _mul_cp(ciphertext: el.Object, plaintext: el.Object) -> el.Object:
|
|
486
|
+
"""Ciphertext ⊗ plaintext helper (order enforced)."""
|
|
487
|
+
return _apply_binary(mul_cp_p.bind, ciphertext, plaintext)
|
|
488
|
+
|
|
489
|
+
|
|
490
|
+
def encode(value: el.Object, encoder: el.Object) -> el.Object:
|
|
491
|
+
"""Encode scalar value to fixed-point integer representation.
|
|
492
|
+
|
|
493
|
+
Args:
|
|
494
|
+
value: Source value (scalar or tensor)
|
|
495
|
+
encoder: PHE encoder (from create_encoder)
|
|
496
|
+
|
|
497
|
+
Returns:
|
|
498
|
+
Encoded integer with same structure as input
|
|
499
|
+
|
|
500
|
+
Example:
|
|
501
|
+
>>> x = tensor.constant(3.14) # f64
|
|
502
|
+
>>> encoder = phe.create_encoder(dtype=elt.f64, fxp_bits=16)
|
|
503
|
+
>>> x_enc = phe.encode(x, encoder) # i64 (encoded as 205887)
|
|
504
|
+
"""
|
|
505
|
+
if _has_tensor_args(value):
|
|
506
|
+
return tensor.elementwise(encode_p.bind, value, encoder)
|
|
507
|
+
return encode_p.bind(value, encoder)
|
|
508
|
+
|
|
509
|
+
|
|
510
|
+
def decode(encoded: el.Object, encoder: el.Object) -> el.Object:
|
|
511
|
+
"""Decode fixed-point integer back to original scalar type.
|
|
512
|
+
|
|
513
|
+
Args:
|
|
514
|
+
encoded: Encoded integer (from encode or decrypt)
|
|
515
|
+
encoder: PHE encoder (same as used for encoding)
|
|
516
|
+
|
|
517
|
+
Returns:
|
|
518
|
+
Decoded value in original type
|
|
519
|
+
|
|
520
|
+
Example:
|
|
521
|
+
>>> encoded = phe.encode(x, encoder)
|
|
522
|
+
>>> result = phe.decode(encoded, encoder) # Back to f64
|
|
523
|
+
"""
|
|
524
|
+
if _has_tensor_args(encoded):
|
|
525
|
+
return tensor.elementwise(decode_p.bind, encoded, encoder)
|
|
526
|
+
return decode_p.bind(encoded, encoder)
|
|
527
|
+
|
|
528
|
+
|
|
529
|
+
def encrypt(encoded: el.Object, public_key: el.Object) -> el.Object:
|
|
530
|
+
"""Encrypt encoded integer using PHE public key.
|
|
531
|
+
|
|
532
|
+
Note: Input must be encoded first via phe.encode().
|
|
533
|
+
|
|
534
|
+
Args:
|
|
535
|
+
encoded: Encoded integer (from phe.encode)
|
|
536
|
+
public_key: PHE public key
|
|
537
|
+
|
|
538
|
+
Returns:
|
|
539
|
+
Encrypted integer
|
|
540
|
+
|
|
541
|
+
Example:
|
|
542
|
+
>>> x_enc = phe.encode(x, encoder)
|
|
543
|
+
>>> ct = phe.encrypt(x_enc, pk) # i64 → PHECiphertext
|
|
544
|
+
"""
|
|
545
|
+
if _has_tensor_args(encoded):
|
|
546
|
+
return tensor.elementwise(encrypt_p.bind, encoded, public_key)
|
|
547
|
+
return encrypt_p.bind(encoded, public_key)
|
|
548
|
+
|
|
549
|
+
|
|
550
|
+
def decrypt(ciphertext: el.Object, private_key: el.Object) -> el.Object:
|
|
551
|
+
"""Decrypt ciphertext to encoded integer using PHE private key.
|
|
552
|
+
|
|
553
|
+
Note: Output is still encoded; use phe.decode() to get original type.
|
|
554
|
+
|
|
555
|
+
Args:
|
|
556
|
+
ciphertext: Encrypted value
|
|
557
|
+
private_key: PHE private key
|
|
558
|
+
|
|
559
|
+
Returns:
|
|
560
|
+
Decrypted encoded integer
|
|
561
|
+
|
|
562
|
+
Example:
|
|
563
|
+
>>> ct_sum = phe.add(ct1, ct2)
|
|
564
|
+
>>> sum_enc = phe.decrypt(ct_sum, sk) # PHECiphertext → i64
|
|
565
|
+
>>> result = phe.decode(sum_enc, encoder) # i64 → f64
|
|
566
|
+
"""
|
|
567
|
+
if _has_tensor_args(ciphertext):
|
|
568
|
+
return tensor.elementwise(decrypt_p.bind, ciphertext, private_key)
|
|
569
|
+
return decrypt_p.bind(ciphertext, private_key)
|
|
570
|
+
|
|
571
|
+
|
|
572
|
+
def encrypt_auto(
|
|
573
|
+
value: el.Object, encoder: el.Object, public_key: el.Object
|
|
574
|
+
) -> el.Object:
|
|
575
|
+
"""Convenience: encode + encrypt in one step.
|
|
576
|
+
|
|
577
|
+
Args:
|
|
578
|
+
value: Source value (any scalar type)
|
|
579
|
+
encoder: PHE encoder
|
|
580
|
+
public_key: PHE public key
|
|
581
|
+
|
|
582
|
+
Returns:
|
|
583
|
+
Encrypted value
|
|
584
|
+
|
|
585
|
+
Example:
|
|
586
|
+
>>> ct = phe.encrypt_auto(x, encoder, pk)
|
|
587
|
+
>>> # Equivalent to:
|
|
588
|
+
>>> # ct = phe.encrypt(phe.encode(x, encoder), pk)
|
|
589
|
+
"""
|
|
590
|
+
encoded = encode(value, encoder)
|
|
591
|
+
return encrypt(encoded, public_key)
|
|
592
|
+
|
|
593
|
+
|
|
594
|
+
def decrypt_auto(
|
|
595
|
+
ciphertext: el.Object, encoder: el.Object, private_key: el.Object
|
|
596
|
+
) -> el.Object:
|
|
597
|
+
"""Convenience: decrypt + decode in one step.
|
|
598
|
+
|
|
599
|
+
Args:
|
|
600
|
+
ciphertext: Encrypted value
|
|
601
|
+
encoder: PHE encoder (same as used for encoding)
|
|
602
|
+
private_key: PHE private key
|
|
603
|
+
|
|
604
|
+
Returns:
|
|
605
|
+
Decrypted value in original type
|
|
606
|
+
|
|
607
|
+
Example:
|
|
608
|
+
>>> result = phe.decrypt_auto(ct, encoder, sk)
|
|
609
|
+
>>> # Equivalent to:
|
|
610
|
+
>>> # result = phe.decode(phe.decrypt(ct, sk), encoder)
|
|
611
|
+
"""
|
|
612
|
+
decoded = decrypt(ciphertext, private_key)
|
|
613
|
+
return decode(decoded, encoder)
|
|
614
|
+
|
|
615
|
+
|
|
616
|
+
def add(lhs: el.Object, rhs: el.Object) -> el.Object:
|
|
617
|
+
"""Homomorphic addition.
|
|
618
|
+
|
|
619
|
+
Supports:
|
|
620
|
+
Ciphertext + Ciphertext → Ciphertext (ciphertext + ciphertext)
|
|
621
|
+
Ciphertext + T → Ciphertext (ciphertext + plaintext)
|
|
622
|
+
T + Ciphertext → Ciphertext (plaintext + ciphertext)
|
|
623
|
+
|
|
624
|
+
Args:
|
|
625
|
+
lhs: Left operand (encrypted or plaintext)
|
|
626
|
+
rhs: Right operand (encrypted or plaintext)
|
|
627
|
+
|
|
628
|
+
Returns:
|
|
629
|
+
Encrypted sum
|
|
630
|
+
|
|
631
|
+
Raises:
|
|
632
|
+
TypeError: If no operand is encrypted or types mismatch
|
|
633
|
+
"""
|
|
634
|
+
lhs_info = _inspect_operand(lhs)
|
|
635
|
+
rhs_info = _inspect_operand(rhs)
|
|
636
|
+
|
|
637
|
+
if not (lhs_info.is_encrypted or rhs_info.is_encrypted):
|
|
638
|
+
raise TypeError("phe.add requires at least one ciphertext operand")
|
|
639
|
+
|
|
640
|
+
# CT + CT
|
|
641
|
+
if lhs_info.is_encrypted and rhs_info.is_encrypted:
|
|
642
|
+
return _apply_binary(add_cc_p.bind, lhs, rhs)
|
|
643
|
+
|
|
644
|
+
# CT + PT or PT + CT
|
|
645
|
+
if lhs_info.is_encrypted:
|
|
646
|
+
return _add_cp(lhs, rhs)
|
|
647
|
+
return _add_cp(rhs, lhs)
|
|
648
|
+
|
|
649
|
+
|
|
650
|
+
def mul_plain(lhs: el.Object, rhs: el.Object) -> el.Object:
|
|
651
|
+
"""Homomorphic multiplication: ciphertext × plaintext (encoded integer).
|
|
652
|
+
|
|
653
|
+
Supports:
|
|
654
|
+
Ciphertext × i64 → Ciphertext (ciphertext × encoded plaintext)
|
|
655
|
+
i64 × Ciphertext → Ciphertext (encoded plaintext × ciphertext)
|
|
656
|
+
|
|
657
|
+
Args:
|
|
658
|
+
lhs: Left operand (one must be encrypted, other must be encoded integer)
|
|
659
|
+
rhs: Right operand
|
|
660
|
+
|
|
661
|
+
Returns:
|
|
662
|
+
Encrypted product
|
|
663
|
+
|
|
664
|
+
Raises:
|
|
665
|
+
TypeError: If both operands are encrypted or both are plaintext
|
|
666
|
+
|
|
667
|
+
Note:
|
|
668
|
+
- Ciphertext × ciphertext is not supported (would require FHE)
|
|
669
|
+
- Plaintext must be encoded integer (use phe.encode first)
|
|
670
|
+
- For float multiplication, may need truncation to maintain precision
|
|
671
|
+
|
|
672
|
+
Example:
|
|
673
|
+
>>> ct = phe.encrypt(phe.encode(x, encoder), pk)
|
|
674
|
+
>>> y_enc = phe.encode(y, encoder)
|
|
675
|
+
>>> ct_prod = phe.mul_plain(ct, y_enc)
|
|
676
|
+
"""
|
|
677
|
+
lhs_info = _inspect_operand(lhs)
|
|
678
|
+
rhs_info = _inspect_operand(rhs)
|
|
679
|
+
|
|
680
|
+
# CT * PT
|
|
681
|
+
if lhs_info.is_encrypted and not rhs_info.is_encrypted:
|
|
682
|
+
return _mul_cp(lhs, rhs)
|
|
683
|
+
# PT * CT
|
|
684
|
+
if rhs_info.is_encrypted and not lhs_info.is_encrypted:
|
|
685
|
+
return _mul_cp(rhs, lhs)
|
|
686
|
+
# CT * CT (not supported)
|
|
687
|
+
if lhs_info.is_encrypted and rhs_info.is_encrypted:
|
|
688
|
+
raise TypeError(
|
|
689
|
+
"phe.mul_plain supports ciphertext * plaintext only, not CT * CT. "
|
|
690
|
+
"Ciphertext * ciphertext requires FHE."
|
|
691
|
+
)
|
|
692
|
+
# PT * PT (invalid)
|
|
693
|
+
raise TypeError("phe.mul_plain requires at least one ciphertext operand")
|
|
694
|
+
|
|
695
|
+
|
|
696
|
+
__all__ = [
|
|
697
|
+
"CiphertextType",
|
|
698
|
+
# Types
|
|
699
|
+
"EncoderType",
|
|
700
|
+
"KeyType",
|
|
701
|
+
"PlaintextType",
|
|
702
|
+
# User API
|
|
703
|
+
"add",
|
|
704
|
+
# Primitives
|
|
705
|
+
"add_cc_p",
|
|
706
|
+
"add_cp_p",
|
|
707
|
+
"create_encoder",
|
|
708
|
+
"create_encoder_p",
|
|
709
|
+
"decode",
|
|
710
|
+
"decode_p",
|
|
711
|
+
"decrypt",
|
|
712
|
+
"decrypt_auto",
|
|
713
|
+
"decrypt_p",
|
|
714
|
+
"encode",
|
|
715
|
+
"encode_p",
|
|
716
|
+
"encrypt",
|
|
717
|
+
"encrypt_auto",
|
|
718
|
+
"encrypt_p",
|
|
719
|
+
"keygen",
|
|
720
|
+
"keygen_p",
|
|
721
|
+
"mul_cp_p",
|
|
722
|
+
"mul_plain",
|
|
723
|
+
]
|