mplang-nightly 0.1.dev158__py3-none-any.whl → 0.1.dev268__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mplang/__init__.py +21 -45
- mplang/py.typed +13 -0
- mplang/v1/__init__.py +157 -0
- mplang/v1/_device.py +602 -0
- mplang/{analysis → v1/analysis}/__init__.py +1 -1
- mplang/{analysis → v1/analysis}/diagram.py +5 -7
- mplang/v1/core/__init__.py +157 -0
- mplang/{core → v1/core}/cluster.py +30 -14
- mplang/{core → v1/core}/comm.py +5 -1
- mplang/{core → v1/core}/context_mgr.py +1 -1
- mplang/{core/dtype.py → v1/core/dtypes.py} +44 -2
- mplang/{core → v1/core}/expr/__init__.py +7 -7
- mplang/{core → v1/core}/expr/ast.py +13 -14
- mplang/{core → v1/core}/expr/evaluator.py +65 -24
- mplang/{core → v1/core}/expr/printer.py +24 -18
- mplang/{core → v1/core}/expr/transformer.py +3 -3
- mplang/{core → v1/core}/expr/utils.py +2 -2
- mplang/{core → v1/core}/expr/visitor.py +1 -1
- mplang/{core → v1/core}/expr/walk.py +1 -1
- mplang/{core → v1/core}/interp.py +6 -6
- mplang/{core → v1/core}/mpir.py +23 -16
- mplang/{core → v1/core}/mpobject.py +6 -6
- mplang/{core → v1/core}/mptype.py +13 -10
- mplang/{core → v1/core}/pfunc.py +4 -4
- mplang/{core → v1/core}/primitive.py +106 -201
- mplang/{core → v1/core}/table.py +36 -8
- mplang/{core → v1/core}/tensor.py +1 -1
- mplang/{core → v1/core}/tracer.py +9 -9
- mplang/{api.py → v1/host.py} +38 -6
- mplang/v1/kernels/__init__.py +41 -0
- mplang/{kernels → v1/kernels}/base.py +1 -1
- mplang/v1/kernels/basic.py +240 -0
- mplang/{kernels → v1/kernels}/context.py +42 -27
- mplang/{kernels → v1/kernels}/crypto.py +44 -37
- mplang/v1/kernels/fhe.py +858 -0
- mplang/{kernels → v1/kernels}/mock_tee.py +12 -13
- mplang/{kernels → v1/kernels}/phe.py +263 -57
- mplang/{kernels → v1/kernels}/spu.py +137 -48
- mplang/{kernels → v1/kernels}/sql_duckdb.py +12 -15
- mplang/{kernels → v1/kernels}/stablehlo.py +30 -23
- mplang/v1/kernels/value.py +626 -0
- mplang/{ops → v1/ops}/__init__.py +5 -16
- mplang/{ops → v1/ops}/base.py +2 -5
- mplang/{ops/builtin.py → v1/ops/basic.py} +34 -26
- mplang/v1/ops/crypto.py +262 -0
- mplang/v1/ops/fhe.py +272 -0
- mplang/{ops → v1/ops}/jax_cc.py +33 -68
- mplang/v1/ops/nnx_cc.py +168 -0
- mplang/{ops → v1/ops}/phe.py +16 -4
- mplang/{ops → v1/ops}/spu.py +3 -5
- mplang/v1/ops/sql_cc.py +303 -0
- mplang/{ops → v1/ops}/tee.py +9 -24
- mplang/{protos → v1/protos}/v1alpha1/mpir_pb2.pyi +71 -21
- mplang/v1/protos/v1alpha1/value_pb2.py +34 -0
- mplang/v1/protos/v1alpha1/value_pb2.pyi +169 -0
- mplang/{runtime → v1/runtime}/__init__.py +2 -2
- mplang/v1/runtime/channel.py +230 -0
- mplang/{runtime → v1/runtime}/cli.py +35 -20
- mplang/{runtime → v1/runtime}/client.py +19 -8
- mplang/{runtime → v1/runtime}/communicator.py +59 -15
- mplang/{runtime → v1/runtime}/data_providers.py +80 -19
- mplang/{runtime → v1/runtime}/driver.py +30 -12
- mplang/v1/runtime/link_comm.py +196 -0
- mplang/{runtime → v1/runtime}/server.py +58 -42
- mplang/{runtime → v1/runtime}/session.py +57 -71
- mplang/{runtime → v1/runtime}/simulation.py +55 -28
- mplang/v1/simp/api.py +353 -0
- mplang/{simp → v1/simp}/mpi.py +8 -9
- mplang/{simp/__init__.py → v1/simp/party.py} +19 -145
- mplang/{simp → v1/simp}/random.py +21 -22
- mplang/v1/simp/smpc.py +238 -0
- mplang/v1/utils/table_utils.py +185 -0
- mplang/v2/__init__.py +424 -0
- mplang/v2/backends/__init__.py +57 -0
- mplang/v2/backends/bfv_impl.py +705 -0
- mplang/v2/backends/channel.py +217 -0
- mplang/v2/backends/crypto_impl.py +723 -0
- mplang/v2/backends/field_impl.py +454 -0
- mplang/v2/backends/func_impl.py +107 -0
- mplang/v2/backends/phe_impl.py +148 -0
- mplang/v2/backends/simp_design.md +136 -0
- mplang/v2/backends/simp_driver/__init__.py +41 -0
- mplang/v2/backends/simp_driver/http.py +168 -0
- mplang/v2/backends/simp_driver/mem.py +280 -0
- mplang/v2/backends/simp_driver/ops.py +135 -0
- mplang/v2/backends/simp_driver/state.py +60 -0
- mplang/v2/backends/simp_driver/values.py +52 -0
- mplang/v2/backends/simp_worker/__init__.py +29 -0
- mplang/v2/backends/simp_worker/http.py +354 -0
- mplang/v2/backends/simp_worker/mem.py +102 -0
- mplang/v2/backends/simp_worker/ops.py +167 -0
- mplang/v2/backends/simp_worker/state.py +49 -0
- mplang/v2/backends/spu_impl.py +275 -0
- mplang/v2/backends/spu_state.py +187 -0
- mplang/v2/backends/store_impl.py +62 -0
- mplang/v2/backends/table_impl.py +838 -0
- mplang/v2/backends/tee_impl.py +215 -0
- mplang/v2/backends/tensor_impl.py +519 -0
- mplang/v2/cli.py +603 -0
- mplang/v2/cli_guide.md +122 -0
- mplang/v2/dialects/__init__.py +36 -0
- mplang/v2/dialects/bfv.py +665 -0
- mplang/v2/dialects/crypto.py +689 -0
- mplang/v2/dialects/dtypes.py +378 -0
- mplang/v2/dialects/field.py +210 -0
- mplang/v2/dialects/func.py +135 -0
- mplang/v2/dialects/phe.py +723 -0
- mplang/v2/dialects/simp.py +944 -0
- mplang/v2/dialects/spu.py +349 -0
- mplang/v2/dialects/store.py +63 -0
- mplang/v2/dialects/table.py +407 -0
- mplang/v2/dialects/tee.py +346 -0
- mplang/v2/dialects/tensor.py +1175 -0
- mplang/v2/edsl/README.md +279 -0
- mplang/v2/edsl/__init__.py +99 -0
- mplang/v2/edsl/context.py +311 -0
- mplang/v2/edsl/graph.py +463 -0
- mplang/v2/edsl/jit.py +62 -0
- mplang/v2/edsl/object.py +53 -0
- mplang/v2/edsl/primitive.py +284 -0
- mplang/v2/edsl/printer.py +119 -0
- mplang/v2/edsl/registry.py +207 -0
- mplang/v2/edsl/serde.py +375 -0
- mplang/v2/edsl/tracer.py +614 -0
- mplang/v2/edsl/typing.py +816 -0
- mplang/v2/kernels/Makefile +30 -0
- mplang/v2/kernels/__init__.py +23 -0
- mplang/v2/kernels/gf128.cpp +148 -0
- mplang/v2/kernels/ldpc.cpp +82 -0
- mplang/v2/kernels/okvs.cpp +283 -0
- mplang/v2/kernels/okvs_opt.cpp +291 -0
- mplang/v2/kernels/py_kernels.py +398 -0
- mplang/v2/libs/collective.py +330 -0
- mplang/v2/libs/device/__init__.py +51 -0
- mplang/v2/libs/device/api.py +813 -0
- mplang/v2/libs/device/cluster.py +352 -0
- mplang/v2/libs/ml/__init__.py +23 -0
- mplang/v2/libs/ml/sgb.py +1861 -0
- mplang/v2/libs/mpc/__init__.py +41 -0
- mplang/v2/libs/mpc/_utils.py +99 -0
- mplang/v2/libs/mpc/analytics/__init__.py +35 -0
- mplang/v2/libs/mpc/analytics/aggregation.py +372 -0
- mplang/v2/libs/mpc/analytics/groupby.md +99 -0
- mplang/v2/libs/mpc/analytics/groupby.py +331 -0
- mplang/v2/libs/mpc/analytics/permutation.py +386 -0
- mplang/v2/libs/mpc/common/constants.py +39 -0
- mplang/v2/libs/mpc/ot/__init__.py +32 -0
- mplang/v2/libs/mpc/ot/base.py +222 -0
- mplang/v2/libs/mpc/ot/extension.py +477 -0
- mplang/v2/libs/mpc/ot/silent.py +217 -0
- mplang/v2/libs/mpc/psi/__init__.py +40 -0
- mplang/v2/libs/mpc/psi/cuckoo.py +228 -0
- mplang/v2/libs/mpc/psi/okvs.py +49 -0
- mplang/v2/libs/mpc/psi/okvs_gct.py +79 -0
- mplang/v2/libs/mpc/psi/oprf.py +310 -0
- mplang/v2/libs/mpc/psi/rr22.py +344 -0
- mplang/v2/libs/mpc/psi/unbalanced.py +200 -0
- mplang/v2/libs/mpc/vole/__init__.py +31 -0
- mplang/v2/libs/mpc/vole/gilboa.py +327 -0
- mplang/v2/libs/mpc/vole/ldpc.py +383 -0
- mplang/v2/libs/mpc/vole/silver.py +336 -0
- mplang/v2/runtime/__init__.py +15 -0
- mplang/v2/runtime/dialect_state.py +41 -0
- mplang/v2/runtime/interpreter.py +871 -0
- mplang/v2/runtime/object_store.py +194 -0
- mplang/v2/runtime/value.py +141 -0
- {mplang_nightly-0.1.dev158.dist-info → mplang_nightly-0.1.dev268.dist-info}/METADATA +24 -17
- mplang_nightly-0.1.dev268.dist-info/RECORD +180 -0
- {mplang_nightly-0.1.dev158.dist-info → mplang_nightly-0.1.dev268.dist-info}/WHEEL +1 -1
- mplang/core/__init__.py +0 -92
- mplang/device.py +0 -340
- mplang/kernels/builtin.py +0 -207
- mplang/ops/crypto.py +0 -109
- mplang/ops/ibis_cc.py +0 -139
- mplang/ops/sql.py +0 -61
- mplang/protos/v1alpha1/mpir_pb2_grpc.py +0 -3
- mplang/runtime/link_comm.py +0 -131
- mplang/simp/smpc.py +0 -201
- mplang/utils/table_utils.py +0 -73
- mplang_nightly-0.1.dev158.dist-info/RECORD +0 -77
- /mplang/{core → v1/core}/mask.py +0 -0
- /mplang/{protos → v1/protos}/v1alpha1/mpir_pb2.py +0 -0
- /mplang/{runtime → v1/runtime}/exceptions.py +0 -0
- /mplang/{runtime → v1/runtime}/http_api.md +0 -0
- /mplang/{kernels → v1/simp}/__init__.py +0 -0
- /mplang/{utils → v1/utils}/__init__.py +0 -0
- /mplang/{utils → v1/utils}/crypto.py +0 -0
- /mplang/{utils → v1/utils}/func_utils.py +0 -0
- /mplang/{utils → v1/utils}/spu_utils.py +0 -0
- {mplang_nightly-0.1.dev158.dist-info → mplang_nightly-0.1.dev268.dist-info}/entry_points.txt +0 -0
- {mplang_nightly-0.1.dev158.dist-info → mplang_nightly-0.1.dev268.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,705 @@
|
|
|
1
|
+
# Copyright 2025 Ant Group Co., Ltd.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
"""BFV Runtime Implementation.
|
|
16
|
+
|
|
17
|
+
Implements execution logic for BFV primitives using TenSEAL low-level API (sealapi).
|
|
18
|
+
"""
|
|
19
|
+
|
|
20
|
+
from __future__ import annotations
|
|
21
|
+
|
|
22
|
+
import base64
|
|
23
|
+
import os
|
|
24
|
+
import uuid
|
|
25
|
+
from dataclasses import dataclass
|
|
26
|
+
from typing import Any, ClassVar, cast
|
|
27
|
+
|
|
28
|
+
import numpy as np
|
|
29
|
+
import tenseal as ts
|
|
30
|
+
import tenseal.sealapi as sealapi
|
|
31
|
+
|
|
32
|
+
from mplang.v2.backends.tensor_impl import TensorValue
|
|
33
|
+
from mplang.v2.dialects import bfv
|
|
34
|
+
from mplang.v2.edsl import serde
|
|
35
|
+
from mplang.v2.edsl.graph import Operation
|
|
36
|
+
from mplang.v2.runtime.interpreter import Interpreter
|
|
37
|
+
from mplang.v2.runtime.value import Value, WrapValue
|
|
38
|
+
|
|
39
|
+
# =============================================================================
|
|
40
|
+
# Helper for SEAL serialization
|
|
41
|
+
# =============================================================================
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
def _get_seal_temp_path() -> str:
|
|
45
|
+
"""Get a temp file path for SEAL serialization.
|
|
46
|
+
|
|
47
|
+
Uses /dev/shm on Linux for better performance (RAM-based tmpfs),
|
|
48
|
+
falls back to regular tempfile on other platforms.
|
|
49
|
+
"""
|
|
50
|
+
# Try /dev/shm first (Linux RAM-based tmpfs, ~30% faster)
|
|
51
|
+
shm_dir = "/dev/shm"
|
|
52
|
+
if os.path.isdir(shm_dir) and os.access(shm_dir, os.W_OK):
|
|
53
|
+
return os.path.join(shm_dir, f"seal_{uuid.uuid4().hex}.bin")
|
|
54
|
+
|
|
55
|
+
# Fallback to regular temp directory
|
|
56
|
+
import tempfile
|
|
57
|
+
|
|
58
|
+
return os.path.join(tempfile.gettempdir(), f"seal_{uuid.uuid4().hex}.bin")
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
@serde.register_class
|
|
62
|
+
class BFVParamContextValue(WrapValue[ts.Context]):
|
|
63
|
+
"""Wraps TenSEAL context with parameters only (no keys)."""
|
|
64
|
+
|
|
65
|
+
_serde_kind: ClassVar[str] = "bfv_impl.BFVParamContextValue"
|
|
66
|
+
|
|
67
|
+
def __init__(self, data: Any):
|
|
68
|
+
super().__init__(data)
|
|
69
|
+
self.ts_ctx = self._data
|
|
70
|
+
|
|
71
|
+
# Extract underlying C++ objects
|
|
72
|
+
self.seal_ctx = self.ts_ctx.seal_context()
|
|
73
|
+
self.cpp_ctx = self.seal_ctx.data
|
|
74
|
+
|
|
75
|
+
self.evaluator = sealapi.Evaluator(self.cpp_ctx)
|
|
76
|
+
self.batch_encoder = sealapi.BatchEncoder(self.cpp_ctx)
|
|
77
|
+
|
|
78
|
+
def _convert(self, data: Any) -> ts.Context:
|
|
79
|
+
if isinstance(data, BFVParamContextValue):
|
|
80
|
+
return data.unwrap()
|
|
81
|
+
if isinstance(data, ts.Context):
|
|
82
|
+
return data
|
|
83
|
+
raise TypeError(f"Expected ts.Context, got {type(data)}")
|
|
84
|
+
|
|
85
|
+
def to_json(self) -> dict[str, Any]:
|
|
86
|
+
# Serialize TenSEAL context (parameters only)
|
|
87
|
+
serialized = self.ts_ctx.serialize(
|
|
88
|
+
save_public_key=False,
|
|
89
|
+
save_secret_key=False,
|
|
90
|
+
save_galois_keys=False,
|
|
91
|
+
save_relin_keys=False,
|
|
92
|
+
)
|
|
93
|
+
return {"ctx_bytes": base64.b64encode(serialized).decode("ascii")}
|
|
94
|
+
|
|
95
|
+
@classmethod
|
|
96
|
+
def from_json(cls, data: dict[str, Any]) -> BFVParamContextValue:
|
|
97
|
+
ctx_bytes = base64.b64decode(data["ctx_bytes"])
|
|
98
|
+
ts_ctx = ts.context_from(ctx_bytes)
|
|
99
|
+
return cls(ts_ctx)
|
|
100
|
+
|
|
101
|
+
|
|
102
|
+
@serde.register_class
|
|
103
|
+
class BFVPublicContextValue(WrapValue[ts.Context]):
|
|
104
|
+
"""Wraps TenSEAL context and exposes low-level SEAL objects (Public only)."""
|
|
105
|
+
|
|
106
|
+
_serde_kind: ClassVar[str] = "bfv_impl.BFVPublicContextValue"
|
|
107
|
+
|
|
108
|
+
def __init__(self, data: Any):
|
|
109
|
+
super().__init__(data)
|
|
110
|
+
self.ts_ctx = self._data
|
|
111
|
+
|
|
112
|
+
# Extract underlying C++ objects
|
|
113
|
+
self.seal_ctx = self.ts_ctx.seal_context()
|
|
114
|
+
self.cpp_ctx = self.seal_ctx.data
|
|
115
|
+
|
|
116
|
+
self.evaluator = sealapi.Evaluator(self.cpp_ctx)
|
|
117
|
+
self.batch_encoder = sealapi.BatchEncoder(self.cpp_ctx)
|
|
118
|
+
|
|
119
|
+
# Extract keys
|
|
120
|
+
self.public_key = self.ts_ctx.public_key().data
|
|
121
|
+
self.relin_keys = self.ts_ctx.relin_keys().data
|
|
122
|
+
self.galois_keys = self.ts_ctx.galois_keys().data
|
|
123
|
+
|
|
124
|
+
self.encryptor = sealapi.Encryptor(self.cpp_ctx, self.public_key)
|
|
125
|
+
|
|
126
|
+
def _convert(self, data: Any) -> ts.Context:
|
|
127
|
+
if isinstance(data, BFVPublicContextValue):
|
|
128
|
+
return data.unwrap()
|
|
129
|
+
if isinstance(data, ts.Context):
|
|
130
|
+
return data
|
|
131
|
+
raise TypeError(f"Expected ts.Context, got {type(data)}")
|
|
132
|
+
|
|
133
|
+
def to_json(self) -> dict[str, Any]:
|
|
134
|
+
# Serialize TenSEAL context (without secret key)
|
|
135
|
+
serialized = self.ts_ctx.serialize(save_secret_key=False)
|
|
136
|
+
return {"ctx_bytes": base64.b64encode(serialized).decode("ascii")}
|
|
137
|
+
|
|
138
|
+
@classmethod
|
|
139
|
+
def from_json(cls, data: dict[str, Any]) -> BFVPublicContextValue:
|
|
140
|
+
ctx_bytes = base64.b64decode(data["ctx_bytes"])
|
|
141
|
+
ts_ctx = ts.context_from(ctx_bytes)
|
|
142
|
+
return cls(ts_ctx)
|
|
143
|
+
|
|
144
|
+
|
|
145
|
+
@serde.register_class
|
|
146
|
+
class BFVSecretContextValue(BFVPublicContextValue):
|
|
147
|
+
"""Wraps TenSEAL context and exposes low-level SEAL objects (including Secret)."""
|
|
148
|
+
|
|
149
|
+
_serde_kind: ClassVar[str] = "bfv_impl.BFVSecretContextValue"
|
|
150
|
+
|
|
151
|
+
def __init__(self, data: Any):
|
|
152
|
+
# BFVPublicContextValue.__init__ calls WrapValue.__init__ which calls _convert
|
|
153
|
+
# We need to ensure _convert is called and validation happens
|
|
154
|
+
super().__init__(data)
|
|
155
|
+
|
|
156
|
+
if not self.ts_ctx.has_secret_key():
|
|
157
|
+
raise ValueError("Context does not have a secret key")
|
|
158
|
+
|
|
159
|
+
self.secret_key = self.ts_ctx.secret_key().data
|
|
160
|
+
self.decryptor = sealapi.Decryptor(self.cpp_ctx, self.secret_key)
|
|
161
|
+
|
|
162
|
+
def make_public(self) -> BFVPublicContextValue:
|
|
163
|
+
"""Create a public-only version of this context."""
|
|
164
|
+
# Serialize without secret key
|
|
165
|
+
serialized = self.ts_ctx.serialize(save_secret_key=False)
|
|
166
|
+
# Deserialize to create a new context
|
|
167
|
+
new_ts_ctx = ts.context_from(serialized)
|
|
168
|
+
return BFVPublicContextValue(new_ts_ctx)
|
|
169
|
+
|
|
170
|
+
def to_json(self) -> dict[str, Any]:
|
|
171
|
+
# Serialize TenSEAL context (with secret key)
|
|
172
|
+
serialized = self.ts_ctx.serialize(save_secret_key=True)
|
|
173
|
+
return {"ctx_bytes": base64.b64encode(serialized).decode("ascii")}
|
|
174
|
+
|
|
175
|
+
@classmethod
|
|
176
|
+
def from_json(cls, data: dict[str, Any]) -> BFVSecretContextValue:
|
|
177
|
+
ctx_bytes = base64.b64decode(data["ctx_bytes"])
|
|
178
|
+
ts_ctx = ts.context_from(ctx_bytes)
|
|
179
|
+
return cls(ts_ctx)
|
|
180
|
+
|
|
181
|
+
|
|
182
|
+
@serde.register_class
|
|
183
|
+
@dataclass
|
|
184
|
+
class BFVValue(Value):
|
|
185
|
+
"""Runtime value holding a SEAL Ciphertext or Plaintext."""
|
|
186
|
+
|
|
187
|
+
_serde_kind: ClassVar[str] = "bfv_impl.BFVValue"
|
|
188
|
+
|
|
189
|
+
data: Any # sealapi.Ciphertext | sealapi.Plaintext
|
|
190
|
+
ctx: BFVPublicContextValue
|
|
191
|
+
is_cipher: bool = True
|
|
192
|
+
|
|
193
|
+
def to_json(self) -> dict[str, Any]:
|
|
194
|
+
# Serialize the ciphertext/plaintext via temp file (SEAL API requirement)
|
|
195
|
+
# Use /dev/shm on Linux for better performance (no disk I/O)
|
|
196
|
+
fname = _get_seal_temp_path()
|
|
197
|
+
try:
|
|
198
|
+
self.data.save(fname)
|
|
199
|
+
with open(fname, "rb") as f:
|
|
200
|
+
data_bytes = f.read()
|
|
201
|
+
finally:
|
|
202
|
+
if os.path.exists(fname):
|
|
203
|
+
os.unlink(fname)
|
|
204
|
+
|
|
205
|
+
# Serialize context as parameters only (to save bandwidth)
|
|
206
|
+
# We create a temporary BFVParamContextValue wrapper
|
|
207
|
+
param_ctx = BFVParamContextValue(self.ctx.ts_ctx)
|
|
208
|
+
ctx_json = serde.to_json(param_ctx)
|
|
209
|
+
|
|
210
|
+
return {
|
|
211
|
+
"data_bytes": base64.b64encode(data_bytes).decode("ascii"),
|
|
212
|
+
"is_cipher": self.is_cipher,
|
|
213
|
+
"ctx": ctx_json,
|
|
214
|
+
}
|
|
215
|
+
|
|
216
|
+
@classmethod
|
|
217
|
+
def from_json(cls, data: dict[str, Any]) -> BFVValue:
|
|
218
|
+
ctx = serde.from_json(data["ctx"])
|
|
219
|
+
data_bytes = base64.b64decode(data["data_bytes"])
|
|
220
|
+
is_cipher = data["is_cipher"]
|
|
221
|
+
|
|
222
|
+
# Load via temp file (SEAL API requirement)
|
|
223
|
+
# Use /dev/shm on Linux for better performance (no disk I/O)
|
|
224
|
+
fname = _get_seal_temp_path()
|
|
225
|
+
try:
|
|
226
|
+
with open(fname, "wb") as f:
|
|
227
|
+
f.write(data_bytes)
|
|
228
|
+
|
|
229
|
+
if is_cipher:
|
|
230
|
+
ct = sealapi.Ciphertext()
|
|
231
|
+
ct.load(ctx.cpp_ctx, fname)
|
|
232
|
+
return cls(data=ct, ctx=ctx, is_cipher=True)
|
|
233
|
+
else:
|
|
234
|
+
pt = sealapi.Plaintext()
|
|
235
|
+
pt.load(ctx.cpp_ctx, fname)
|
|
236
|
+
return cls(data=pt, ctx=ctx, is_cipher=False)
|
|
237
|
+
finally:
|
|
238
|
+
if os.path.exists(fname):
|
|
239
|
+
os.unlink(fname)
|
|
240
|
+
|
|
241
|
+
|
|
242
|
+
# =============================================================================
|
|
243
|
+
# Keygen Cache (Optimization: avoid regenerating keys for same parameters)
|
|
244
|
+
# =============================================================================
|
|
245
|
+
_KEYGEN_CACHE: dict[
|
|
246
|
+
tuple[int, int], tuple[BFVPublicContextValue, BFVSecretContextValue]
|
|
247
|
+
] = {}
|
|
248
|
+
|
|
249
|
+
|
|
250
|
+
def clear_keygen_cache() -> None:
|
|
251
|
+
"""Clear the keygen cache."""
|
|
252
|
+
_KEYGEN_CACHE.clear()
|
|
253
|
+
|
|
254
|
+
|
|
255
|
+
@bfv.keygen_p.def_impl
|
|
256
|
+
def keygen_impl(
|
|
257
|
+
interpreter: Interpreter, op: Operation, *args: Any
|
|
258
|
+
) -> tuple[BFVPublicContextValue, BFVSecretContextValue]:
|
|
259
|
+
poly_modulus_degree = op.attrs.get("poly_modulus_degree", 4096)
|
|
260
|
+
# Use a default plain_modulus if not provided.
|
|
261
|
+
plain_modulus = op.attrs.get("plain_modulus", 1032193)
|
|
262
|
+
|
|
263
|
+
# Check cache first
|
|
264
|
+
cache_key = (poly_modulus_degree, plain_modulus)
|
|
265
|
+
if cache_key in _KEYGEN_CACHE:
|
|
266
|
+
return _KEYGEN_CACHE[cache_key]
|
|
267
|
+
|
|
268
|
+
# Generate context with secret key
|
|
269
|
+
ts_ctx = ts.context(
|
|
270
|
+
ts.SCHEME_TYPE.BFV,
|
|
271
|
+
poly_modulus_degree=poly_modulus_degree,
|
|
272
|
+
plain_modulus=plain_modulus,
|
|
273
|
+
)
|
|
274
|
+
ts_ctx.generate_galois_keys()
|
|
275
|
+
ts_ctx.generate_relin_keys()
|
|
276
|
+
|
|
277
|
+
full_context = BFVSecretContextValue(ts_ctx)
|
|
278
|
+
public_context = full_context.make_public()
|
|
279
|
+
|
|
280
|
+
# Cache the result
|
|
281
|
+
result = (public_context, full_context)
|
|
282
|
+
_KEYGEN_CACHE[cache_key] = result
|
|
283
|
+
|
|
284
|
+
# Return (PK, SK)
|
|
285
|
+
return result
|
|
286
|
+
|
|
287
|
+
|
|
288
|
+
@bfv.make_relin_keys_p.def_impl
|
|
289
|
+
def make_relin_keys_impl(
|
|
290
|
+
interpreter: Interpreter, op: Operation, sk: BFVSecretContextValue
|
|
291
|
+
) -> BFVSecretContextValue:
|
|
292
|
+
return sk
|
|
293
|
+
|
|
294
|
+
|
|
295
|
+
@bfv.make_galois_keys_p.def_impl
|
|
296
|
+
def make_galois_keys_impl(
|
|
297
|
+
interpreter: Interpreter, op: Operation, sk: BFVSecretContextValue
|
|
298
|
+
) -> BFVSecretContextValue:
|
|
299
|
+
return sk
|
|
300
|
+
|
|
301
|
+
|
|
302
|
+
@bfv.create_encoder_p.def_impl
|
|
303
|
+
def create_encoder_impl(interpreter: Interpreter, op: Operation) -> dict[str, Any]:
|
|
304
|
+
return {"poly_modulus_degree": op.attrs.get("poly_modulus_degree", 4096)}
|
|
305
|
+
|
|
306
|
+
|
|
307
|
+
@bfv.encode_p.def_impl
|
|
308
|
+
def encode_impl(
|
|
309
|
+
interpreter: Interpreter,
|
|
310
|
+
op: Operation,
|
|
311
|
+
data: TensorValue,
|
|
312
|
+
encoder: dict[str, Any],
|
|
313
|
+
) -> TensorValue:
|
|
314
|
+
# Return raw data as "Logical Plaintext" wrapped in TensorValue
|
|
315
|
+
return TensorValue.wrap(np.asarray(data.unwrap()))
|
|
316
|
+
|
|
317
|
+
|
|
318
|
+
@bfv.batch_encode_p.def_impl
|
|
319
|
+
def batch_encode_impl(
|
|
320
|
+
interpreter: Interpreter,
|
|
321
|
+
op: Operation,
|
|
322
|
+
*args: Value,
|
|
323
|
+
) -> tuple[BFVValue | TensorValue, ...]:
|
|
324
|
+
# args will be (tensor, encoder, key)
|
|
325
|
+
key = args[-1]
|
|
326
|
+
_encoder = args[-2]
|
|
327
|
+
tensor_val = args[0]
|
|
328
|
+
|
|
329
|
+
# Eager encoding using key.ctx
|
|
330
|
+
# key is BFVPublicContextValue (or BFVSecretContextValue)
|
|
331
|
+
ctx = cast(BFVPublicContextValue, key)
|
|
332
|
+
|
|
333
|
+
results = []
|
|
334
|
+
# Optimization: Convert to numpy array first to avoid JAX dispatch overhead
|
|
335
|
+
# during iteration. This also ensures a single device-to-host transfer if on GPU.
|
|
336
|
+
arr = np.asarray(cast(TensorValue, tensor_val).unwrap())
|
|
337
|
+
|
|
338
|
+
# Iterate rows
|
|
339
|
+
for i in range(arr.shape[0]):
|
|
340
|
+
pt = sealapi.Plaintext()
|
|
341
|
+
# Use tolist() for speed
|
|
342
|
+
vec = arr[i].tolist()
|
|
343
|
+
ctx.batch_encoder.encode(vec, pt)
|
|
344
|
+
results.append(BFVValue(pt, ctx, is_cipher=False))
|
|
345
|
+
|
|
346
|
+
return tuple(results)
|
|
347
|
+
|
|
348
|
+
|
|
349
|
+
@bfv.encrypt_p.def_impl
|
|
350
|
+
def encrypt_impl(
|
|
351
|
+
interpreter: Interpreter,
|
|
352
|
+
op: Operation,
|
|
353
|
+
plaintext: TensorValue,
|
|
354
|
+
pk: BFVPublicContextValue,
|
|
355
|
+
) -> BFVValue:
|
|
356
|
+
# plaintext is TensorValue (from encode_impl)
|
|
357
|
+
# pk is BFVPublicContextValue
|
|
358
|
+
plaintext_arr = plaintext.unwrap().flatten()
|
|
359
|
+
|
|
360
|
+
# 1. Create Plaintext
|
|
361
|
+
pt = sealapi.Plaintext()
|
|
362
|
+
|
|
363
|
+
# 2. Encode
|
|
364
|
+
# We need to handle types. Assuming int64 vector.
|
|
365
|
+
# Optimization: Use tolist() instead of list comprehension
|
|
366
|
+
vec = plaintext_arr.tolist()
|
|
367
|
+
pk.batch_encoder.encode(vec, pt)
|
|
368
|
+
|
|
369
|
+
# 3. Encrypt
|
|
370
|
+
ct = sealapi.Ciphertext()
|
|
371
|
+
pk.encryptor.encrypt(pt, ct)
|
|
372
|
+
|
|
373
|
+
return BFVValue(ct, pk, is_cipher=True)
|
|
374
|
+
|
|
375
|
+
|
|
376
|
+
@bfv.decrypt_p.def_impl
|
|
377
|
+
def decrypt_impl(
|
|
378
|
+
interpreter: Interpreter,
|
|
379
|
+
op: Operation,
|
|
380
|
+
ciphertext: BFVValue,
|
|
381
|
+
sk: BFVSecretContextValue,
|
|
382
|
+
) -> BFVValue:
|
|
383
|
+
# ciphertext is BFVValue
|
|
384
|
+
# sk is BFVSecretContextValue
|
|
385
|
+
|
|
386
|
+
pt = sealapi.Plaintext()
|
|
387
|
+
sk.decryptor.decrypt(ciphertext.data, pt)
|
|
388
|
+
|
|
389
|
+
return BFVValue(pt, sk, is_cipher=False)
|
|
390
|
+
|
|
391
|
+
|
|
392
|
+
@bfv.decode_p.def_impl
|
|
393
|
+
def decode_impl(
|
|
394
|
+
interpreter: Interpreter, op: Operation, plaintext: BFVValue, encoder: Any
|
|
395
|
+
) -> TensorValue:
|
|
396
|
+
# plaintext is BFVValue(Plaintext)
|
|
397
|
+
# encoder is dummy config
|
|
398
|
+
|
|
399
|
+
vec = plaintext.ctx.batch_encoder.decode_int64(plaintext.data)
|
|
400
|
+
return TensorValue.wrap(np.array(vec))
|
|
401
|
+
|
|
402
|
+
|
|
403
|
+
def _ensure_plaintext(ctx: BFVPublicContextValue, data: BFVValue | TensorValue) -> Any:
|
|
404
|
+
"""Convert data to sealapi.Plaintext using the given context."""
|
|
405
|
+
if isinstance(data, BFVValue):
|
|
406
|
+
if data.is_cipher:
|
|
407
|
+
raise TypeError("Expected Plaintext, got Ciphertext")
|
|
408
|
+
return data.data
|
|
409
|
+
|
|
410
|
+
# data is TensorValue
|
|
411
|
+
if not isinstance(data, TensorValue):
|
|
412
|
+
raise TypeError(f"Expected BFVValue or TensorValue, got {type(data)}")
|
|
413
|
+
pt = sealapi.Plaintext()
|
|
414
|
+
arr = data.unwrap()
|
|
415
|
+
vec = arr.flatten().tolist()
|
|
416
|
+
ctx.batch_encoder.encode(vec, pt)
|
|
417
|
+
return pt
|
|
418
|
+
|
|
419
|
+
|
|
420
|
+
@bfv.add_p.def_impl
|
|
421
|
+
def add_impl(
|
|
422
|
+
interpreter: Interpreter,
|
|
423
|
+
op: Operation,
|
|
424
|
+
lhs: BFVValue | TensorValue,
|
|
425
|
+
rhs: BFVValue | TensorValue,
|
|
426
|
+
) -> BFVValue | TensorValue:
|
|
427
|
+
# Case 1: Both are BFVValues
|
|
428
|
+
if isinstance(lhs, BFVValue) and isinstance(rhs, BFVValue):
|
|
429
|
+
result_ct = sealapi.Ciphertext()
|
|
430
|
+
|
|
431
|
+
if lhs.is_cipher and rhs.is_cipher:
|
|
432
|
+
# Optimization: Handle transparent ciphertexts (zero)
|
|
433
|
+
if lhs.data.is_transparent():
|
|
434
|
+
return rhs
|
|
435
|
+
if rhs.data.is_transparent():
|
|
436
|
+
return lhs
|
|
437
|
+
|
|
438
|
+
lhs.ctx.evaluator.add(lhs.data, rhs.data, result_ct)
|
|
439
|
+
return BFVValue(result_ct, lhs.ctx, is_cipher=True)
|
|
440
|
+
elif lhs.is_cipher and not rhs.is_cipher:
|
|
441
|
+
# Optimization: Handle transparent ciphertext
|
|
442
|
+
if lhs.data.is_transparent():
|
|
443
|
+
# 0 + Plaintext -> Encrypt(Plaintext)
|
|
444
|
+
# This is expensive, but necessary for correctness if we want to return a Ciphertext
|
|
445
|
+
# Alternatively, if we allow returning Plaintext, we could just return rhs.
|
|
446
|
+
# But BFV add usually expects to return Ciphertext if one input is Ciphertext.
|
|
447
|
+
# For now, let's encrypt it.
|
|
448
|
+
new_ct = sealapi.Ciphertext()
|
|
449
|
+
lhs.ctx.encryptor.encrypt(rhs.data, new_ct)
|
|
450
|
+
return BFVValue(new_ct, lhs.ctx, is_cipher=True)
|
|
451
|
+
|
|
452
|
+
lhs.ctx.evaluator.add_plain(lhs.data, rhs.data, result_ct)
|
|
453
|
+
return BFVValue(result_ct, lhs.ctx, is_cipher=True)
|
|
454
|
+
elif not lhs.is_cipher and rhs.is_cipher:
|
|
455
|
+
# Optimization: Handle transparent ciphertext
|
|
456
|
+
if rhs.data.is_transparent():
|
|
457
|
+
new_ct = sealapi.Ciphertext()
|
|
458
|
+
rhs.ctx.encryptor.encrypt(lhs.data, new_ct)
|
|
459
|
+
return BFVValue(new_ct, rhs.ctx, is_cipher=True)
|
|
460
|
+
|
|
461
|
+
rhs.ctx.evaluator.add_plain(rhs.data, lhs.data, result_ct)
|
|
462
|
+
return BFVValue(result_ct, rhs.ctx, is_cipher=True)
|
|
463
|
+
else:
|
|
464
|
+
raise NotImplementedError(
|
|
465
|
+
"BFV Plaintext + Plaintext addition not implemented yet"
|
|
466
|
+
)
|
|
467
|
+
|
|
468
|
+
# Case 2: One is BFVValue (Ciphertext), other is Raw
|
|
469
|
+
if isinstance(lhs, BFVValue) and lhs.is_cipher:
|
|
470
|
+
# Optimization: Handle transparent ciphertext
|
|
471
|
+
if lhs.data.is_transparent():
|
|
472
|
+
# 0 + Raw -> Encrypt(Raw)
|
|
473
|
+
pt = _ensure_plaintext(lhs.ctx, rhs)
|
|
474
|
+
new_ct = sealapi.Ciphertext()
|
|
475
|
+
lhs.ctx.encryptor.encrypt(pt, new_ct)
|
|
476
|
+
return BFVValue(new_ct, lhs.ctx, is_cipher=True)
|
|
477
|
+
|
|
478
|
+
pt = _ensure_plaintext(lhs.ctx, rhs)
|
|
479
|
+
result_ct = sealapi.Ciphertext()
|
|
480
|
+
lhs.ctx.evaluator.add_plain(lhs.data, pt, result_ct)
|
|
481
|
+
return BFVValue(result_ct, lhs.ctx, is_cipher=True)
|
|
482
|
+
|
|
483
|
+
if isinstance(rhs, BFVValue) and rhs.is_cipher:
|
|
484
|
+
# Optimization: Handle transparent ciphertext
|
|
485
|
+
if rhs.data.is_transparent():
|
|
486
|
+
pt = _ensure_plaintext(rhs.ctx, lhs)
|
|
487
|
+
new_ct = sealapi.Ciphertext()
|
|
488
|
+
rhs.ctx.encryptor.encrypt(pt, new_ct)
|
|
489
|
+
return BFVValue(new_ct, rhs.ctx, is_cipher=True)
|
|
490
|
+
|
|
491
|
+
pt = _ensure_plaintext(rhs.ctx, lhs)
|
|
492
|
+
result_ct = sealapi.Ciphertext()
|
|
493
|
+
rhs.ctx.evaluator.add_plain(rhs.data, pt, result_ct)
|
|
494
|
+
return BFVValue(result_ct, rhs.ctx, is_cipher=True)
|
|
495
|
+
|
|
496
|
+
# Handle Plaintext + Plaintext (TensorValue + TensorValue)
|
|
497
|
+
if isinstance(lhs, TensorValue) and isinstance(rhs, TensorValue):
|
|
498
|
+
return TensorValue.wrap(lhs.unwrap() + rhs.unwrap())
|
|
499
|
+
raise TypeError(f"Unsupported types for add: {type(lhs)}, {type(rhs)}")
|
|
500
|
+
|
|
501
|
+
|
|
502
|
+
@bfv.sub_p.def_impl
|
|
503
|
+
def sub_impl(
|
|
504
|
+
interpreter: Interpreter,
|
|
505
|
+
op: Operation,
|
|
506
|
+
lhs: BFVValue | TensorValue,
|
|
507
|
+
rhs: BFVValue | TensorValue,
|
|
508
|
+
) -> BFVValue | TensorValue:
|
|
509
|
+
# Case 1: Both are BFVValues
|
|
510
|
+
if isinstance(lhs, BFVValue) and isinstance(rhs, BFVValue):
|
|
511
|
+
result_ct = sealapi.Ciphertext()
|
|
512
|
+
|
|
513
|
+
if lhs.is_cipher and rhs.is_cipher:
|
|
514
|
+
lhs.ctx.evaluator.sub(lhs.data, rhs.data, result_ct)
|
|
515
|
+
return BFVValue(result_ct, lhs.ctx, is_cipher=True)
|
|
516
|
+
elif lhs.is_cipher and not rhs.is_cipher:
|
|
517
|
+
lhs.ctx.evaluator.sub_plain(lhs.data, rhs.data, result_ct)
|
|
518
|
+
return BFVValue(result_ct, lhs.ctx, is_cipher=True)
|
|
519
|
+
elif not lhs.is_cipher and rhs.is_cipher:
|
|
520
|
+
neg_ct = sealapi.Ciphertext()
|
|
521
|
+
rhs.ctx.evaluator.negate(rhs.data, neg_ct)
|
|
522
|
+
rhs.ctx.evaluator.add_plain(neg_ct, lhs.data, result_ct)
|
|
523
|
+
return BFVValue(result_ct, rhs.ctx, is_cipher=True)
|
|
524
|
+
else:
|
|
525
|
+
raise NotImplementedError(
|
|
526
|
+
"BFV Plaintext - Plaintext subtraction not implemented yet"
|
|
527
|
+
)
|
|
528
|
+
|
|
529
|
+
# Case 2: One is BFVValue (Ciphertext), other is Raw
|
|
530
|
+
if isinstance(lhs, BFVValue) and lhs.is_cipher:
|
|
531
|
+
pt = _ensure_plaintext(lhs.ctx, rhs)
|
|
532
|
+
result_ct = sealapi.Ciphertext()
|
|
533
|
+
lhs.ctx.evaluator.sub_plain(lhs.data, pt, result_ct)
|
|
534
|
+
return BFVValue(result_ct, lhs.ctx, is_cipher=True)
|
|
535
|
+
|
|
536
|
+
if isinstance(rhs, BFVValue) and rhs.is_cipher:
|
|
537
|
+
# Raw - CT
|
|
538
|
+
pt = _ensure_plaintext(rhs.ctx, lhs)
|
|
539
|
+
result_ct = sealapi.Ciphertext()
|
|
540
|
+
neg_ct = sealapi.Ciphertext()
|
|
541
|
+
rhs.ctx.evaluator.negate(rhs.data, neg_ct)
|
|
542
|
+
rhs.ctx.evaluator.add_plain(neg_ct, pt, result_ct)
|
|
543
|
+
return BFVValue(result_ct, rhs.ctx, is_cipher=True)
|
|
544
|
+
|
|
545
|
+
# Handle Plaintext - Plaintext (TensorValue - TensorValue)
|
|
546
|
+
if isinstance(lhs, TensorValue) and isinstance(rhs, TensorValue):
|
|
547
|
+
return TensorValue.wrap(lhs.unwrap() - rhs.unwrap())
|
|
548
|
+
raise TypeError(f"Unsupported types for sub: {type(lhs)}, {type(rhs)}")
|
|
549
|
+
|
|
550
|
+
|
|
551
|
+
@bfv.mul_p.def_impl
|
|
552
|
+
def mul_impl(
|
|
553
|
+
interpreter: Interpreter,
|
|
554
|
+
op: Operation,
|
|
555
|
+
lhs: BFVValue | TensorValue,
|
|
556
|
+
rhs: BFVValue | TensorValue,
|
|
557
|
+
) -> BFVValue | TensorValue:
|
|
558
|
+
# Case 1: Both are BFVValues
|
|
559
|
+
if isinstance(lhs, BFVValue) and isinstance(rhs, BFVValue):
|
|
560
|
+
result_ct = sealapi.Ciphertext()
|
|
561
|
+
|
|
562
|
+
if lhs.is_cipher and rhs.is_cipher:
|
|
563
|
+
lhs.ctx.evaluator.multiply(lhs.data, rhs.data, result_ct)
|
|
564
|
+
return BFVValue(result_ct, lhs.ctx, is_cipher=True)
|
|
565
|
+
elif lhs.is_cipher and not rhs.is_cipher:
|
|
566
|
+
# Optimization: Check for zero plaintext to avoid expensive exception handling
|
|
567
|
+
if rhs.data.is_zero():
|
|
568
|
+
# Return transparent zero ciphertext (no noise, size 0)
|
|
569
|
+
# SEAL arithmetic ops handle transparent ciphertexts as zero.
|
|
570
|
+
# We must ensure relinearize/rotate also handle it.
|
|
571
|
+
return BFVValue(sealapi.Ciphertext(), lhs.ctx, is_cipher=True)
|
|
572
|
+
|
|
573
|
+
try:
|
|
574
|
+
lhs.ctx.evaluator.multiply_plain(lhs.data, rhs.data, result_ct)
|
|
575
|
+
return BFVValue(result_ct, lhs.ctx, is_cipher=True)
|
|
576
|
+
except RuntimeError as e:
|
|
577
|
+
if "transparent" in str(e):
|
|
578
|
+
return BFVValue(sealapi.Ciphertext(), lhs.ctx, is_cipher=True)
|
|
579
|
+
raise e
|
|
580
|
+
elif not lhs.is_cipher and rhs.is_cipher:
|
|
581
|
+
# Optimization: Check for zero plaintext
|
|
582
|
+
if lhs.data.is_zero():
|
|
583
|
+
return BFVValue(sealapi.Ciphertext(), rhs.ctx, is_cipher=True)
|
|
584
|
+
|
|
585
|
+
try:
|
|
586
|
+
rhs.ctx.evaluator.multiply_plain(rhs.data, lhs.data, result_ct)
|
|
587
|
+
return BFVValue(result_ct, rhs.ctx, is_cipher=True)
|
|
588
|
+
except RuntimeError as e:
|
|
589
|
+
if "transparent" in str(e):
|
|
590
|
+
return BFVValue(sealapi.Ciphertext(), rhs.ctx, is_cipher=True)
|
|
591
|
+
raise e
|
|
592
|
+
else:
|
|
593
|
+
raise NotImplementedError(
|
|
594
|
+
"BFV Plaintext * Plaintext multiplication not implemented yet"
|
|
595
|
+
)
|
|
596
|
+
|
|
597
|
+
# Case 2: One is BFVValue (Ciphertext), other is TensorValue
|
|
598
|
+
if isinstance(lhs, BFVValue) and lhs.is_cipher:
|
|
599
|
+
# Check for zero plaintext to avoid "transparent ciphertext" error
|
|
600
|
+
# Also check if plaintext is BFVValue(Plaintext)
|
|
601
|
+
if isinstance(rhs, TensorValue) and np.all(rhs.unwrap() == 0):
|
|
602
|
+
result_ct = sealapi.Ciphertext()
|
|
603
|
+
lhs.ctx.encryptor.encrypt_zero(result_ct)
|
|
604
|
+
return BFVValue(result_ct, lhs.ctx, is_cipher=True)
|
|
605
|
+
|
|
606
|
+
try:
|
|
607
|
+
pt = _ensure_plaintext(lhs.ctx, rhs)
|
|
608
|
+
result_ct = sealapi.Ciphertext()
|
|
609
|
+
lhs.ctx.evaluator.multiply_plain(lhs.data, pt, result_ct)
|
|
610
|
+
return BFVValue(result_ct, lhs.ctx, is_cipher=True)
|
|
611
|
+
except RuntimeError as e:
|
|
612
|
+
# SEAL throws "result ciphertext is transparent" when multiplying by a zero plaintext.
|
|
613
|
+
# This is mathematically valid (Enc(x) * 0 = Enc(0)), but SEAL enforces explicit zero encryption.
|
|
614
|
+
# We catch this error and return a valid zero ciphertext to maintain operator semantics.
|
|
615
|
+
if "transparent" in str(e):
|
|
616
|
+
# Fallback for zero plaintext
|
|
617
|
+
result_ct = sealapi.Ciphertext()
|
|
618
|
+
lhs.ctx.encryptor.encrypt_zero(result_ct)
|
|
619
|
+
return BFVValue(result_ct, lhs.ctx, is_cipher=True)
|
|
620
|
+
raise e
|
|
621
|
+
|
|
622
|
+
if isinstance(rhs, BFVValue) and rhs.is_cipher:
|
|
623
|
+
# Check for zero plaintext to avoid "transparent ciphertext" error
|
|
624
|
+
if isinstance(lhs, TensorValue) and np.all(lhs.unwrap() == 0):
|
|
625
|
+
result_ct = sealapi.Ciphertext()
|
|
626
|
+
rhs.ctx.encryptor.encrypt_zero(result_ct)
|
|
627
|
+
return BFVValue(result_ct, rhs.ctx, is_cipher=True)
|
|
628
|
+
|
|
629
|
+
try:
|
|
630
|
+
pt = _ensure_plaintext(rhs.ctx, lhs)
|
|
631
|
+
result_ct = sealapi.Ciphertext()
|
|
632
|
+
rhs.ctx.evaluator.multiply_plain(rhs.data, pt, result_ct)
|
|
633
|
+
return BFVValue(result_ct, rhs.ctx, is_cipher=True)
|
|
634
|
+
except RuntimeError as e:
|
|
635
|
+
# See comment above regarding "transparent ciphertext"
|
|
636
|
+
if "transparent" in str(e):
|
|
637
|
+
# Fallback for zero plaintext
|
|
638
|
+
result_ct = sealapi.Ciphertext()
|
|
639
|
+
rhs.ctx.encryptor.encrypt_zero(result_ct)
|
|
640
|
+
return BFVValue(result_ct, rhs.ctx, is_cipher=True)
|
|
641
|
+
raise e
|
|
642
|
+
|
|
643
|
+
# Handle Plaintext * Plaintext (TensorValue * TensorValue)
|
|
644
|
+
if isinstance(lhs, TensorValue) and isinstance(rhs, TensorValue):
|
|
645
|
+
return TensorValue.wrap(lhs.unwrap() * rhs.unwrap())
|
|
646
|
+
raise TypeError(f"Unsupported types for mul: {type(lhs)}, {type(rhs)}")
|
|
647
|
+
|
|
648
|
+
|
|
649
|
+
@bfv.relinearize_p.def_impl
|
|
650
|
+
def relinearize_impl(
|
|
651
|
+
interpreter: Interpreter,
|
|
652
|
+
op: Operation,
|
|
653
|
+
ciphertext: BFVValue,
|
|
654
|
+
rk: BFVPublicContextValue,
|
|
655
|
+
) -> BFVValue:
|
|
656
|
+
# rk is BFVPublicContextValue (same as ciphertext.ctx)
|
|
657
|
+
|
|
658
|
+
# Optimization: Handle transparent ciphertext (zero)
|
|
659
|
+
if ciphertext.data.is_transparent():
|
|
660
|
+
return ciphertext
|
|
661
|
+
|
|
662
|
+
# Check if relinearization is needed (size > 2)
|
|
663
|
+
if ciphertext.data.size() > 2:
|
|
664
|
+
new_ct = sealapi.Ciphertext()
|
|
665
|
+
ciphertext.ctx.evaluator.relinearize(ciphertext.data, rk.relin_keys, new_ct)
|
|
666
|
+
return BFVValue(new_ct, ciphertext.ctx, is_cipher=True)
|
|
667
|
+
|
|
668
|
+
return ciphertext
|
|
669
|
+
|
|
670
|
+
|
|
671
|
+
@bfv.rotate_p.def_impl
|
|
672
|
+
def rotate_impl(
|
|
673
|
+
interpreter: Interpreter,
|
|
674
|
+
op: Operation,
|
|
675
|
+
ciphertext: BFVValue,
|
|
676
|
+
gk: BFVPublicContextValue,
|
|
677
|
+
) -> BFVValue:
|
|
678
|
+
"""Implement rotation using low-level SEAL API directly."""
|
|
679
|
+
steps = op.attrs.get("steps", 0)
|
|
680
|
+
if steps == 0:
|
|
681
|
+
return ciphertext
|
|
682
|
+
|
|
683
|
+
# Optimization: Handle transparent ciphertext (zero)
|
|
684
|
+
if ciphertext.data.is_transparent():
|
|
685
|
+
return ciphertext
|
|
686
|
+
|
|
687
|
+
# ciphertext is BFVValue
|
|
688
|
+
# gk is BFVPublicContextValue
|
|
689
|
+
|
|
690
|
+
new_ct = sealapi.Ciphertext()
|
|
691
|
+
ciphertext.ctx.evaluator.rotate_rows(ciphertext.data, steps, gk.galois_keys, new_ct)
|
|
692
|
+
return BFVValue(new_ct, ciphertext.ctx, is_cipher=True)
|
|
693
|
+
|
|
694
|
+
|
|
695
|
+
@bfv.rotate_columns_p.def_impl
|
|
696
|
+
def rotate_columns_impl(
|
|
697
|
+
interpreter: Interpreter,
|
|
698
|
+
op: Operation,
|
|
699
|
+
ciphertext: BFVValue,
|
|
700
|
+
gk: BFVPublicContextValue,
|
|
701
|
+
) -> BFVValue:
|
|
702
|
+
"""Swap the two rows in SIMD batching (row 0 <-> row 1)."""
|
|
703
|
+
new_ct = sealapi.Ciphertext()
|
|
704
|
+
ciphertext.ctx.evaluator.rotate_columns(ciphertext.data, gk.galois_keys, new_ct)
|
|
705
|
+
return BFVValue(new_ct, ciphertext.ctx, is_cipher=True)
|