mplang-nightly 0.1.dev158__py3-none-any.whl → 0.1.dev268__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (191) hide show
  1. mplang/__init__.py +21 -45
  2. mplang/py.typed +13 -0
  3. mplang/v1/__init__.py +157 -0
  4. mplang/v1/_device.py +602 -0
  5. mplang/{analysis → v1/analysis}/__init__.py +1 -1
  6. mplang/{analysis → v1/analysis}/diagram.py +5 -7
  7. mplang/v1/core/__init__.py +157 -0
  8. mplang/{core → v1/core}/cluster.py +30 -14
  9. mplang/{core → v1/core}/comm.py +5 -1
  10. mplang/{core → v1/core}/context_mgr.py +1 -1
  11. mplang/{core/dtype.py → v1/core/dtypes.py} +44 -2
  12. mplang/{core → v1/core}/expr/__init__.py +7 -7
  13. mplang/{core → v1/core}/expr/ast.py +13 -14
  14. mplang/{core → v1/core}/expr/evaluator.py +65 -24
  15. mplang/{core → v1/core}/expr/printer.py +24 -18
  16. mplang/{core → v1/core}/expr/transformer.py +3 -3
  17. mplang/{core → v1/core}/expr/utils.py +2 -2
  18. mplang/{core → v1/core}/expr/visitor.py +1 -1
  19. mplang/{core → v1/core}/expr/walk.py +1 -1
  20. mplang/{core → v1/core}/interp.py +6 -6
  21. mplang/{core → v1/core}/mpir.py +23 -16
  22. mplang/{core → v1/core}/mpobject.py +6 -6
  23. mplang/{core → v1/core}/mptype.py +13 -10
  24. mplang/{core → v1/core}/pfunc.py +4 -4
  25. mplang/{core → v1/core}/primitive.py +106 -201
  26. mplang/{core → v1/core}/table.py +36 -8
  27. mplang/{core → v1/core}/tensor.py +1 -1
  28. mplang/{core → v1/core}/tracer.py +9 -9
  29. mplang/{api.py → v1/host.py} +38 -6
  30. mplang/v1/kernels/__init__.py +41 -0
  31. mplang/{kernels → v1/kernels}/base.py +1 -1
  32. mplang/v1/kernels/basic.py +240 -0
  33. mplang/{kernels → v1/kernels}/context.py +42 -27
  34. mplang/{kernels → v1/kernels}/crypto.py +44 -37
  35. mplang/v1/kernels/fhe.py +858 -0
  36. mplang/{kernels → v1/kernels}/mock_tee.py +12 -13
  37. mplang/{kernels → v1/kernels}/phe.py +263 -57
  38. mplang/{kernels → v1/kernels}/spu.py +137 -48
  39. mplang/{kernels → v1/kernels}/sql_duckdb.py +12 -15
  40. mplang/{kernels → v1/kernels}/stablehlo.py +30 -23
  41. mplang/v1/kernels/value.py +626 -0
  42. mplang/{ops → v1/ops}/__init__.py +5 -16
  43. mplang/{ops → v1/ops}/base.py +2 -5
  44. mplang/{ops/builtin.py → v1/ops/basic.py} +34 -26
  45. mplang/v1/ops/crypto.py +262 -0
  46. mplang/v1/ops/fhe.py +272 -0
  47. mplang/{ops → v1/ops}/jax_cc.py +33 -68
  48. mplang/v1/ops/nnx_cc.py +168 -0
  49. mplang/{ops → v1/ops}/phe.py +16 -4
  50. mplang/{ops → v1/ops}/spu.py +3 -5
  51. mplang/v1/ops/sql_cc.py +303 -0
  52. mplang/{ops → v1/ops}/tee.py +9 -24
  53. mplang/{protos → v1/protos}/v1alpha1/mpir_pb2.pyi +71 -21
  54. mplang/v1/protos/v1alpha1/value_pb2.py +34 -0
  55. mplang/v1/protos/v1alpha1/value_pb2.pyi +169 -0
  56. mplang/{runtime → v1/runtime}/__init__.py +2 -2
  57. mplang/v1/runtime/channel.py +230 -0
  58. mplang/{runtime → v1/runtime}/cli.py +35 -20
  59. mplang/{runtime → v1/runtime}/client.py +19 -8
  60. mplang/{runtime → v1/runtime}/communicator.py +59 -15
  61. mplang/{runtime → v1/runtime}/data_providers.py +80 -19
  62. mplang/{runtime → v1/runtime}/driver.py +30 -12
  63. mplang/v1/runtime/link_comm.py +196 -0
  64. mplang/{runtime → v1/runtime}/server.py +58 -42
  65. mplang/{runtime → v1/runtime}/session.py +57 -71
  66. mplang/{runtime → v1/runtime}/simulation.py +55 -28
  67. mplang/v1/simp/api.py +353 -0
  68. mplang/{simp → v1/simp}/mpi.py +8 -9
  69. mplang/{simp/__init__.py → v1/simp/party.py} +19 -145
  70. mplang/{simp → v1/simp}/random.py +21 -22
  71. mplang/v1/simp/smpc.py +238 -0
  72. mplang/v1/utils/table_utils.py +185 -0
  73. mplang/v2/__init__.py +424 -0
  74. mplang/v2/backends/__init__.py +57 -0
  75. mplang/v2/backends/bfv_impl.py +705 -0
  76. mplang/v2/backends/channel.py +217 -0
  77. mplang/v2/backends/crypto_impl.py +723 -0
  78. mplang/v2/backends/field_impl.py +454 -0
  79. mplang/v2/backends/func_impl.py +107 -0
  80. mplang/v2/backends/phe_impl.py +148 -0
  81. mplang/v2/backends/simp_design.md +136 -0
  82. mplang/v2/backends/simp_driver/__init__.py +41 -0
  83. mplang/v2/backends/simp_driver/http.py +168 -0
  84. mplang/v2/backends/simp_driver/mem.py +280 -0
  85. mplang/v2/backends/simp_driver/ops.py +135 -0
  86. mplang/v2/backends/simp_driver/state.py +60 -0
  87. mplang/v2/backends/simp_driver/values.py +52 -0
  88. mplang/v2/backends/simp_worker/__init__.py +29 -0
  89. mplang/v2/backends/simp_worker/http.py +354 -0
  90. mplang/v2/backends/simp_worker/mem.py +102 -0
  91. mplang/v2/backends/simp_worker/ops.py +167 -0
  92. mplang/v2/backends/simp_worker/state.py +49 -0
  93. mplang/v2/backends/spu_impl.py +275 -0
  94. mplang/v2/backends/spu_state.py +187 -0
  95. mplang/v2/backends/store_impl.py +62 -0
  96. mplang/v2/backends/table_impl.py +838 -0
  97. mplang/v2/backends/tee_impl.py +215 -0
  98. mplang/v2/backends/tensor_impl.py +519 -0
  99. mplang/v2/cli.py +603 -0
  100. mplang/v2/cli_guide.md +122 -0
  101. mplang/v2/dialects/__init__.py +36 -0
  102. mplang/v2/dialects/bfv.py +665 -0
  103. mplang/v2/dialects/crypto.py +689 -0
  104. mplang/v2/dialects/dtypes.py +378 -0
  105. mplang/v2/dialects/field.py +210 -0
  106. mplang/v2/dialects/func.py +135 -0
  107. mplang/v2/dialects/phe.py +723 -0
  108. mplang/v2/dialects/simp.py +944 -0
  109. mplang/v2/dialects/spu.py +349 -0
  110. mplang/v2/dialects/store.py +63 -0
  111. mplang/v2/dialects/table.py +407 -0
  112. mplang/v2/dialects/tee.py +346 -0
  113. mplang/v2/dialects/tensor.py +1175 -0
  114. mplang/v2/edsl/README.md +279 -0
  115. mplang/v2/edsl/__init__.py +99 -0
  116. mplang/v2/edsl/context.py +311 -0
  117. mplang/v2/edsl/graph.py +463 -0
  118. mplang/v2/edsl/jit.py +62 -0
  119. mplang/v2/edsl/object.py +53 -0
  120. mplang/v2/edsl/primitive.py +284 -0
  121. mplang/v2/edsl/printer.py +119 -0
  122. mplang/v2/edsl/registry.py +207 -0
  123. mplang/v2/edsl/serde.py +375 -0
  124. mplang/v2/edsl/tracer.py +614 -0
  125. mplang/v2/edsl/typing.py +816 -0
  126. mplang/v2/kernels/Makefile +30 -0
  127. mplang/v2/kernels/__init__.py +23 -0
  128. mplang/v2/kernels/gf128.cpp +148 -0
  129. mplang/v2/kernels/ldpc.cpp +82 -0
  130. mplang/v2/kernels/okvs.cpp +283 -0
  131. mplang/v2/kernels/okvs_opt.cpp +291 -0
  132. mplang/v2/kernels/py_kernels.py +398 -0
  133. mplang/v2/libs/collective.py +330 -0
  134. mplang/v2/libs/device/__init__.py +51 -0
  135. mplang/v2/libs/device/api.py +813 -0
  136. mplang/v2/libs/device/cluster.py +352 -0
  137. mplang/v2/libs/ml/__init__.py +23 -0
  138. mplang/v2/libs/ml/sgb.py +1861 -0
  139. mplang/v2/libs/mpc/__init__.py +41 -0
  140. mplang/v2/libs/mpc/_utils.py +99 -0
  141. mplang/v2/libs/mpc/analytics/__init__.py +35 -0
  142. mplang/v2/libs/mpc/analytics/aggregation.py +372 -0
  143. mplang/v2/libs/mpc/analytics/groupby.md +99 -0
  144. mplang/v2/libs/mpc/analytics/groupby.py +331 -0
  145. mplang/v2/libs/mpc/analytics/permutation.py +386 -0
  146. mplang/v2/libs/mpc/common/constants.py +39 -0
  147. mplang/v2/libs/mpc/ot/__init__.py +32 -0
  148. mplang/v2/libs/mpc/ot/base.py +222 -0
  149. mplang/v2/libs/mpc/ot/extension.py +477 -0
  150. mplang/v2/libs/mpc/ot/silent.py +217 -0
  151. mplang/v2/libs/mpc/psi/__init__.py +40 -0
  152. mplang/v2/libs/mpc/psi/cuckoo.py +228 -0
  153. mplang/v2/libs/mpc/psi/okvs.py +49 -0
  154. mplang/v2/libs/mpc/psi/okvs_gct.py +79 -0
  155. mplang/v2/libs/mpc/psi/oprf.py +310 -0
  156. mplang/v2/libs/mpc/psi/rr22.py +344 -0
  157. mplang/v2/libs/mpc/psi/unbalanced.py +200 -0
  158. mplang/v2/libs/mpc/vole/__init__.py +31 -0
  159. mplang/v2/libs/mpc/vole/gilboa.py +327 -0
  160. mplang/v2/libs/mpc/vole/ldpc.py +383 -0
  161. mplang/v2/libs/mpc/vole/silver.py +336 -0
  162. mplang/v2/runtime/__init__.py +15 -0
  163. mplang/v2/runtime/dialect_state.py +41 -0
  164. mplang/v2/runtime/interpreter.py +871 -0
  165. mplang/v2/runtime/object_store.py +194 -0
  166. mplang/v2/runtime/value.py +141 -0
  167. {mplang_nightly-0.1.dev158.dist-info → mplang_nightly-0.1.dev268.dist-info}/METADATA +24 -17
  168. mplang_nightly-0.1.dev268.dist-info/RECORD +180 -0
  169. {mplang_nightly-0.1.dev158.dist-info → mplang_nightly-0.1.dev268.dist-info}/WHEEL +1 -1
  170. mplang/core/__init__.py +0 -92
  171. mplang/device.py +0 -340
  172. mplang/kernels/builtin.py +0 -207
  173. mplang/ops/crypto.py +0 -109
  174. mplang/ops/ibis_cc.py +0 -139
  175. mplang/ops/sql.py +0 -61
  176. mplang/protos/v1alpha1/mpir_pb2_grpc.py +0 -3
  177. mplang/runtime/link_comm.py +0 -131
  178. mplang/simp/smpc.py +0 -201
  179. mplang/utils/table_utils.py +0 -73
  180. mplang_nightly-0.1.dev158.dist-info/RECORD +0 -77
  181. /mplang/{core → v1/core}/mask.py +0 -0
  182. /mplang/{protos → v1/protos}/v1alpha1/mpir_pb2.py +0 -0
  183. /mplang/{runtime → v1/runtime}/exceptions.py +0 -0
  184. /mplang/{runtime → v1/runtime}/http_api.md +0 -0
  185. /mplang/{kernels → v1/simp}/__init__.py +0 -0
  186. /mplang/{utils → v1/utils}/__init__.py +0 -0
  187. /mplang/{utils → v1/utils}/crypto.py +0 -0
  188. /mplang/{utils → v1/utils}/func_utils.py +0 -0
  189. /mplang/{utils → v1/utils}/spu_utils.py +0 -0
  190. {mplang_nightly-0.1.dev158.dist-info → mplang_nightly-0.1.dev268.dist-info}/entry_points.txt +0 -0
  191. {mplang_nightly-0.1.dev158.dist-info → mplang_nightly-0.1.dev268.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,705 @@
1
+ # Copyright 2025 Ant Group Co., Ltd.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """BFV Runtime Implementation.
16
+
17
+ Implements execution logic for BFV primitives using TenSEAL low-level API (sealapi).
18
+ """
19
+
20
+ from __future__ import annotations
21
+
22
+ import base64
23
+ import os
24
+ import uuid
25
+ from dataclasses import dataclass
26
+ from typing import Any, ClassVar, cast
27
+
28
+ import numpy as np
29
+ import tenseal as ts
30
+ import tenseal.sealapi as sealapi
31
+
32
+ from mplang.v2.backends.tensor_impl import TensorValue
33
+ from mplang.v2.dialects import bfv
34
+ from mplang.v2.edsl import serde
35
+ from mplang.v2.edsl.graph import Operation
36
+ from mplang.v2.runtime.interpreter import Interpreter
37
+ from mplang.v2.runtime.value import Value, WrapValue
38
+
39
+ # =============================================================================
40
+ # Helper for SEAL serialization
41
+ # =============================================================================
42
+
43
+
44
+ def _get_seal_temp_path() -> str:
45
+ """Get a temp file path for SEAL serialization.
46
+
47
+ Uses /dev/shm on Linux for better performance (RAM-based tmpfs),
48
+ falls back to regular tempfile on other platforms.
49
+ """
50
+ # Try /dev/shm first (Linux RAM-based tmpfs, ~30% faster)
51
+ shm_dir = "/dev/shm"
52
+ if os.path.isdir(shm_dir) and os.access(shm_dir, os.W_OK):
53
+ return os.path.join(shm_dir, f"seal_{uuid.uuid4().hex}.bin")
54
+
55
+ # Fallback to regular temp directory
56
+ import tempfile
57
+
58
+ return os.path.join(tempfile.gettempdir(), f"seal_{uuid.uuid4().hex}.bin")
59
+
60
+
61
+ @serde.register_class
62
+ class BFVParamContextValue(WrapValue[ts.Context]):
63
+ """Wraps TenSEAL context with parameters only (no keys)."""
64
+
65
+ _serde_kind: ClassVar[str] = "bfv_impl.BFVParamContextValue"
66
+
67
+ def __init__(self, data: Any):
68
+ super().__init__(data)
69
+ self.ts_ctx = self._data
70
+
71
+ # Extract underlying C++ objects
72
+ self.seal_ctx = self.ts_ctx.seal_context()
73
+ self.cpp_ctx = self.seal_ctx.data
74
+
75
+ self.evaluator = sealapi.Evaluator(self.cpp_ctx)
76
+ self.batch_encoder = sealapi.BatchEncoder(self.cpp_ctx)
77
+
78
+ def _convert(self, data: Any) -> ts.Context:
79
+ if isinstance(data, BFVParamContextValue):
80
+ return data.unwrap()
81
+ if isinstance(data, ts.Context):
82
+ return data
83
+ raise TypeError(f"Expected ts.Context, got {type(data)}")
84
+
85
+ def to_json(self) -> dict[str, Any]:
86
+ # Serialize TenSEAL context (parameters only)
87
+ serialized = self.ts_ctx.serialize(
88
+ save_public_key=False,
89
+ save_secret_key=False,
90
+ save_galois_keys=False,
91
+ save_relin_keys=False,
92
+ )
93
+ return {"ctx_bytes": base64.b64encode(serialized).decode("ascii")}
94
+
95
+ @classmethod
96
+ def from_json(cls, data: dict[str, Any]) -> BFVParamContextValue:
97
+ ctx_bytes = base64.b64decode(data["ctx_bytes"])
98
+ ts_ctx = ts.context_from(ctx_bytes)
99
+ return cls(ts_ctx)
100
+
101
+
102
+ @serde.register_class
103
+ class BFVPublicContextValue(WrapValue[ts.Context]):
104
+ """Wraps TenSEAL context and exposes low-level SEAL objects (Public only)."""
105
+
106
+ _serde_kind: ClassVar[str] = "bfv_impl.BFVPublicContextValue"
107
+
108
+ def __init__(self, data: Any):
109
+ super().__init__(data)
110
+ self.ts_ctx = self._data
111
+
112
+ # Extract underlying C++ objects
113
+ self.seal_ctx = self.ts_ctx.seal_context()
114
+ self.cpp_ctx = self.seal_ctx.data
115
+
116
+ self.evaluator = sealapi.Evaluator(self.cpp_ctx)
117
+ self.batch_encoder = sealapi.BatchEncoder(self.cpp_ctx)
118
+
119
+ # Extract keys
120
+ self.public_key = self.ts_ctx.public_key().data
121
+ self.relin_keys = self.ts_ctx.relin_keys().data
122
+ self.galois_keys = self.ts_ctx.galois_keys().data
123
+
124
+ self.encryptor = sealapi.Encryptor(self.cpp_ctx, self.public_key)
125
+
126
+ def _convert(self, data: Any) -> ts.Context:
127
+ if isinstance(data, BFVPublicContextValue):
128
+ return data.unwrap()
129
+ if isinstance(data, ts.Context):
130
+ return data
131
+ raise TypeError(f"Expected ts.Context, got {type(data)}")
132
+
133
+ def to_json(self) -> dict[str, Any]:
134
+ # Serialize TenSEAL context (without secret key)
135
+ serialized = self.ts_ctx.serialize(save_secret_key=False)
136
+ return {"ctx_bytes": base64.b64encode(serialized).decode("ascii")}
137
+
138
+ @classmethod
139
+ def from_json(cls, data: dict[str, Any]) -> BFVPublicContextValue:
140
+ ctx_bytes = base64.b64decode(data["ctx_bytes"])
141
+ ts_ctx = ts.context_from(ctx_bytes)
142
+ return cls(ts_ctx)
143
+
144
+
145
+ @serde.register_class
146
+ class BFVSecretContextValue(BFVPublicContextValue):
147
+ """Wraps TenSEAL context and exposes low-level SEAL objects (including Secret)."""
148
+
149
+ _serde_kind: ClassVar[str] = "bfv_impl.BFVSecretContextValue"
150
+
151
+ def __init__(self, data: Any):
152
+ # BFVPublicContextValue.__init__ calls WrapValue.__init__ which calls _convert
153
+ # We need to ensure _convert is called and validation happens
154
+ super().__init__(data)
155
+
156
+ if not self.ts_ctx.has_secret_key():
157
+ raise ValueError("Context does not have a secret key")
158
+
159
+ self.secret_key = self.ts_ctx.secret_key().data
160
+ self.decryptor = sealapi.Decryptor(self.cpp_ctx, self.secret_key)
161
+
162
+ def make_public(self) -> BFVPublicContextValue:
163
+ """Create a public-only version of this context."""
164
+ # Serialize without secret key
165
+ serialized = self.ts_ctx.serialize(save_secret_key=False)
166
+ # Deserialize to create a new context
167
+ new_ts_ctx = ts.context_from(serialized)
168
+ return BFVPublicContextValue(new_ts_ctx)
169
+
170
+ def to_json(self) -> dict[str, Any]:
171
+ # Serialize TenSEAL context (with secret key)
172
+ serialized = self.ts_ctx.serialize(save_secret_key=True)
173
+ return {"ctx_bytes": base64.b64encode(serialized).decode("ascii")}
174
+
175
+ @classmethod
176
+ def from_json(cls, data: dict[str, Any]) -> BFVSecretContextValue:
177
+ ctx_bytes = base64.b64decode(data["ctx_bytes"])
178
+ ts_ctx = ts.context_from(ctx_bytes)
179
+ return cls(ts_ctx)
180
+
181
+
182
+ @serde.register_class
183
+ @dataclass
184
+ class BFVValue(Value):
185
+ """Runtime value holding a SEAL Ciphertext or Plaintext."""
186
+
187
+ _serde_kind: ClassVar[str] = "bfv_impl.BFVValue"
188
+
189
+ data: Any # sealapi.Ciphertext | sealapi.Plaintext
190
+ ctx: BFVPublicContextValue
191
+ is_cipher: bool = True
192
+
193
+ def to_json(self) -> dict[str, Any]:
194
+ # Serialize the ciphertext/plaintext via temp file (SEAL API requirement)
195
+ # Use /dev/shm on Linux for better performance (no disk I/O)
196
+ fname = _get_seal_temp_path()
197
+ try:
198
+ self.data.save(fname)
199
+ with open(fname, "rb") as f:
200
+ data_bytes = f.read()
201
+ finally:
202
+ if os.path.exists(fname):
203
+ os.unlink(fname)
204
+
205
+ # Serialize context as parameters only (to save bandwidth)
206
+ # We create a temporary BFVParamContextValue wrapper
207
+ param_ctx = BFVParamContextValue(self.ctx.ts_ctx)
208
+ ctx_json = serde.to_json(param_ctx)
209
+
210
+ return {
211
+ "data_bytes": base64.b64encode(data_bytes).decode("ascii"),
212
+ "is_cipher": self.is_cipher,
213
+ "ctx": ctx_json,
214
+ }
215
+
216
+ @classmethod
217
+ def from_json(cls, data: dict[str, Any]) -> BFVValue:
218
+ ctx = serde.from_json(data["ctx"])
219
+ data_bytes = base64.b64decode(data["data_bytes"])
220
+ is_cipher = data["is_cipher"]
221
+
222
+ # Load via temp file (SEAL API requirement)
223
+ # Use /dev/shm on Linux for better performance (no disk I/O)
224
+ fname = _get_seal_temp_path()
225
+ try:
226
+ with open(fname, "wb") as f:
227
+ f.write(data_bytes)
228
+
229
+ if is_cipher:
230
+ ct = sealapi.Ciphertext()
231
+ ct.load(ctx.cpp_ctx, fname)
232
+ return cls(data=ct, ctx=ctx, is_cipher=True)
233
+ else:
234
+ pt = sealapi.Plaintext()
235
+ pt.load(ctx.cpp_ctx, fname)
236
+ return cls(data=pt, ctx=ctx, is_cipher=False)
237
+ finally:
238
+ if os.path.exists(fname):
239
+ os.unlink(fname)
240
+
241
+
242
+ # =============================================================================
243
+ # Keygen Cache (Optimization: avoid regenerating keys for same parameters)
244
+ # =============================================================================
245
+ _KEYGEN_CACHE: dict[
246
+ tuple[int, int], tuple[BFVPublicContextValue, BFVSecretContextValue]
247
+ ] = {}
248
+
249
+
250
+ def clear_keygen_cache() -> None:
251
+ """Clear the keygen cache."""
252
+ _KEYGEN_CACHE.clear()
253
+
254
+
255
+ @bfv.keygen_p.def_impl
256
+ def keygen_impl(
257
+ interpreter: Interpreter, op: Operation, *args: Any
258
+ ) -> tuple[BFVPublicContextValue, BFVSecretContextValue]:
259
+ poly_modulus_degree = op.attrs.get("poly_modulus_degree", 4096)
260
+ # Use a default plain_modulus if not provided.
261
+ plain_modulus = op.attrs.get("plain_modulus", 1032193)
262
+
263
+ # Check cache first
264
+ cache_key = (poly_modulus_degree, plain_modulus)
265
+ if cache_key in _KEYGEN_CACHE:
266
+ return _KEYGEN_CACHE[cache_key]
267
+
268
+ # Generate context with secret key
269
+ ts_ctx = ts.context(
270
+ ts.SCHEME_TYPE.BFV,
271
+ poly_modulus_degree=poly_modulus_degree,
272
+ plain_modulus=plain_modulus,
273
+ )
274
+ ts_ctx.generate_galois_keys()
275
+ ts_ctx.generate_relin_keys()
276
+
277
+ full_context = BFVSecretContextValue(ts_ctx)
278
+ public_context = full_context.make_public()
279
+
280
+ # Cache the result
281
+ result = (public_context, full_context)
282
+ _KEYGEN_CACHE[cache_key] = result
283
+
284
+ # Return (PK, SK)
285
+ return result
286
+
287
+
288
+ @bfv.make_relin_keys_p.def_impl
289
+ def make_relin_keys_impl(
290
+ interpreter: Interpreter, op: Operation, sk: BFVSecretContextValue
291
+ ) -> BFVSecretContextValue:
292
+ return sk
293
+
294
+
295
+ @bfv.make_galois_keys_p.def_impl
296
+ def make_galois_keys_impl(
297
+ interpreter: Interpreter, op: Operation, sk: BFVSecretContextValue
298
+ ) -> BFVSecretContextValue:
299
+ return sk
300
+
301
+
302
+ @bfv.create_encoder_p.def_impl
303
+ def create_encoder_impl(interpreter: Interpreter, op: Operation) -> dict[str, Any]:
304
+ return {"poly_modulus_degree": op.attrs.get("poly_modulus_degree", 4096)}
305
+
306
+
307
+ @bfv.encode_p.def_impl
308
+ def encode_impl(
309
+ interpreter: Interpreter,
310
+ op: Operation,
311
+ data: TensorValue,
312
+ encoder: dict[str, Any],
313
+ ) -> TensorValue:
314
+ # Return raw data as "Logical Plaintext" wrapped in TensorValue
315
+ return TensorValue.wrap(np.asarray(data.unwrap()))
316
+
317
+
318
+ @bfv.batch_encode_p.def_impl
319
+ def batch_encode_impl(
320
+ interpreter: Interpreter,
321
+ op: Operation,
322
+ *args: Value,
323
+ ) -> tuple[BFVValue | TensorValue, ...]:
324
+ # args will be (tensor, encoder, key)
325
+ key = args[-1]
326
+ _encoder = args[-2]
327
+ tensor_val = args[0]
328
+
329
+ # Eager encoding using key.ctx
330
+ # key is BFVPublicContextValue (or BFVSecretContextValue)
331
+ ctx = cast(BFVPublicContextValue, key)
332
+
333
+ results = []
334
+ # Optimization: Convert to numpy array first to avoid JAX dispatch overhead
335
+ # during iteration. This also ensures a single device-to-host transfer if on GPU.
336
+ arr = np.asarray(cast(TensorValue, tensor_val).unwrap())
337
+
338
+ # Iterate rows
339
+ for i in range(arr.shape[0]):
340
+ pt = sealapi.Plaintext()
341
+ # Use tolist() for speed
342
+ vec = arr[i].tolist()
343
+ ctx.batch_encoder.encode(vec, pt)
344
+ results.append(BFVValue(pt, ctx, is_cipher=False))
345
+
346
+ return tuple(results)
347
+
348
+
349
+ @bfv.encrypt_p.def_impl
350
+ def encrypt_impl(
351
+ interpreter: Interpreter,
352
+ op: Operation,
353
+ plaintext: TensorValue,
354
+ pk: BFVPublicContextValue,
355
+ ) -> BFVValue:
356
+ # plaintext is TensorValue (from encode_impl)
357
+ # pk is BFVPublicContextValue
358
+ plaintext_arr = plaintext.unwrap().flatten()
359
+
360
+ # 1. Create Plaintext
361
+ pt = sealapi.Plaintext()
362
+
363
+ # 2. Encode
364
+ # We need to handle types. Assuming int64 vector.
365
+ # Optimization: Use tolist() instead of list comprehension
366
+ vec = plaintext_arr.tolist()
367
+ pk.batch_encoder.encode(vec, pt)
368
+
369
+ # 3. Encrypt
370
+ ct = sealapi.Ciphertext()
371
+ pk.encryptor.encrypt(pt, ct)
372
+
373
+ return BFVValue(ct, pk, is_cipher=True)
374
+
375
+
376
+ @bfv.decrypt_p.def_impl
377
+ def decrypt_impl(
378
+ interpreter: Interpreter,
379
+ op: Operation,
380
+ ciphertext: BFVValue,
381
+ sk: BFVSecretContextValue,
382
+ ) -> BFVValue:
383
+ # ciphertext is BFVValue
384
+ # sk is BFVSecretContextValue
385
+
386
+ pt = sealapi.Plaintext()
387
+ sk.decryptor.decrypt(ciphertext.data, pt)
388
+
389
+ return BFVValue(pt, sk, is_cipher=False)
390
+
391
+
392
+ @bfv.decode_p.def_impl
393
+ def decode_impl(
394
+ interpreter: Interpreter, op: Operation, plaintext: BFVValue, encoder: Any
395
+ ) -> TensorValue:
396
+ # plaintext is BFVValue(Plaintext)
397
+ # encoder is dummy config
398
+
399
+ vec = plaintext.ctx.batch_encoder.decode_int64(plaintext.data)
400
+ return TensorValue.wrap(np.array(vec))
401
+
402
+
403
+ def _ensure_plaintext(ctx: BFVPublicContextValue, data: BFVValue | TensorValue) -> Any:
404
+ """Convert data to sealapi.Plaintext using the given context."""
405
+ if isinstance(data, BFVValue):
406
+ if data.is_cipher:
407
+ raise TypeError("Expected Plaintext, got Ciphertext")
408
+ return data.data
409
+
410
+ # data is TensorValue
411
+ if not isinstance(data, TensorValue):
412
+ raise TypeError(f"Expected BFVValue or TensorValue, got {type(data)}")
413
+ pt = sealapi.Plaintext()
414
+ arr = data.unwrap()
415
+ vec = arr.flatten().tolist()
416
+ ctx.batch_encoder.encode(vec, pt)
417
+ return pt
418
+
419
+
420
+ @bfv.add_p.def_impl
421
+ def add_impl(
422
+ interpreter: Interpreter,
423
+ op: Operation,
424
+ lhs: BFVValue | TensorValue,
425
+ rhs: BFVValue | TensorValue,
426
+ ) -> BFVValue | TensorValue:
427
+ # Case 1: Both are BFVValues
428
+ if isinstance(lhs, BFVValue) and isinstance(rhs, BFVValue):
429
+ result_ct = sealapi.Ciphertext()
430
+
431
+ if lhs.is_cipher and rhs.is_cipher:
432
+ # Optimization: Handle transparent ciphertexts (zero)
433
+ if lhs.data.is_transparent():
434
+ return rhs
435
+ if rhs.data.is_transparent():
436
+ return lhs
437
+
438
+ lhs.ctx.evaluator.add(lhs.data, rhs.data, result_ct)
439
+ return BFVValue(result_ct, lhs.ctx, is_cipher=True)
440
+ elif lhs.is_cipher and not rhs.is_cipher:
441
+ # Optimization: Handle transparent ciphertext
442
+ if lhs.data.is_transparent():
443
+ # 0 + Plaintext -> Encrypt(Plaintext)
444
+ # This is expensive, but necessary for correctness if we want to return a Ciphertext
445
+ # Alternatively, if we allow returning Plaintext, we could just return rhs.
446
+ # But BFV add usually expects to return Ciphertext if one input is Ciphertext.
447
+ # For now, let's encrypt it.
448
+ new_ct = sealapi.Ciphertext()
449
+ lhs.ctx.encryptor.encrypt(rhs.data, new_ct)
450
+ return BFVValue(new_ct, lhs.ctx, is_cipher=True)
451
+
452
+ lhs.ctx.evaluator.add_plain(lhs.data, rhs.data, result_ct)
453
+ return BFVValue(result_ct, lhs.ctx, is_cipher=True)
454
+ elif not lhs.is_cipher and rhs.is_cipher:
455
+ # Optimization: Handle transparent ciphertext
456
+ if rhs.data.is_transparent():
457
+ new_ct = sealapi.Ciphertext()
458
+ rhs.ctx.encryptor.encrypt(lhs.data, new_ct)
459
+ return BFVValue(new_ct, rhs.ctx, is_cipher=True)
460
+
461
+ rhs.ctx.evaluator.add_plain(rhs.data, lhs.data, result_ct)
462
+ return BFVValue(result_ct, rhs.ctx, is_cipher=True)
463
+ else:
464
+ raise NotImplementedError(
465
+ "BFV Plaintext + Plaintext addition not implemented yet"
466
+ )
467
+
468
+ # Case 2: One is BFVValue (Ciphertext), other is Raw
469
+ if isinstance(lhs, BFVValue) and lhs.is_cipher:
470
+ # Optimization: Handle transparent ciphertext
471
+ if lhs.data.is_transparent():
472
+ # 0 + Raw -> Encrypt(Raw)
473
+ pt = _ensure_plaintext(lhs.ctx, rhs)
474
+ new_ct = sealapi.Ciphertext()
475
+ lhs.ctx.encryptor.encrypt(pt, new_ct)
476
+ return BFVValue(new_ct, lhs.ctx, is_cipher=True)
477
+
478
+ pt = _ensure_plaintext(lhs.ctx, rhs)
479
+ result_ct = sealapi.Ciphertext()
480
+ lhs.ctx.evaluator.add_plain(lhs.data, pt, result_ct)
481
+ return BFVValue(result_ct, lhs.ctx, is_cipher=True)
482
+
483
+ if isinstance(rhs, BFVValue) and rhs.is_cipher:
484
+ # Optimization: Handle transparent ciphertext
485
+ if rhs.data.is_transparent():
486
+ pt = _ensure_plaintext(rhs.ctx, lhs)
487
+ new_ct = sealapi.Ciphertext()
488
+ rhs.ctx.encryptor.encrypt(pt, new_ct)
489
+ return BFVValue(new_ct, rhs.ctx, is_cipher=True)
490
+
491
+ pt = _ensure_plaintext(rhs.ctx, lhs)
492
+ result_ct = sealapi.Ciphertext()
493
+ rhs.ctx.evaluator.add_plain(rhs.data, pt, result_ct)
494
+ return BFVValue(result_ct, rhs.ctx, is_cipher=True)
495
+
496
+ # Handle Plaintext + Plaintext (TensorValue + TensorValue)
497
+ if isinstance(lhs, TensorValue) and isinstance(rhs, TensorValue):
498
+ return TensorValue.wrap(lhs.unwrap() + rhs.unwrap())
499
+ raise TypeError(f"Unsupported types for add: {type(lhs)}, {type(rhs)}")
500
+
501
+
502
+ @bfv.sub_p.def_impl
503
+ def sub_impl(
504
+ interpreter: Interpreter,
505
+ op: Operation,
506
+ lhs: BFVValue | TensorValue,
507
+ rhs: BFVValue | TensorValue,
508
+ ) -> BFVValue | TensorValue:
509
+ # Case 1: Both are BFVValues
510
+ if isinstance(lhs, BFVValue) and isinstance(rhs, BFVValue):
511
+ result_ct = sealapi.Ciphertext()
512
+
513
+ if lhs.is_cipher and rhs.is_cipher:
514
+ lhs.ctx.evaluator.sub(lhs.data, rhs.data, result_ct)
515
+ return BFVValue(result_ct, lhs.ctx, is_cipher=True)
516
+ elif lhs.is_cipher and not rhs.is_cipher:
517
+ lhs.ctx.evaluator.sub_plain(lhs.data, rhs.data, result_ct)
518
+ return BFVValue(result_ct, lhs.ctx, is_cipher=True)
519
+ elif not lhs.is_cipher and rhs.is_cipher:
520
+ neg_ct = sealapi.Ciphertext()
521
+ rhs.ctx.evaluator.negate(rhs.data, neg_ct)
522
+ rhs.ctx.evaluator.add_plain(neg_ct, lhs.data, result_ct)
523
+ return BFVValue(result_ct, rhs.ctx, is_cipher=True)
524
+ else:
525
+ raise NotImplementedError(
526
+ "BFV Plaintext - Plaintext subtraction not implemented yet"
527
+ )
528
+
529
+ # Case 2: One is BFVValue (Ciphertext), other is Raw
530
+ if isinstance(lhs, BFVValue) and lhs.is_cipher:
531
+ pt = _ensure_plaintext(lhs.ctx, rhs)
532
+ result_ct = sealapi.Ciphertext()
533
+ lhs.ctx.evaluator.sub_plain(lhs.data, pt, result_ct)
534
+ return BFVValue(result_ct, lhs.ctx, is_cipher=True)
535
+
536
+ if isinstance(rhs, BFVValue) and rhs.is_cipher:
537
+ # Raw - CT
538
+ pt = _ensure_plaintext(rhs.ctx, lhs)
539
+ result_ct = sealapi.Ciphertext()
540
+ neg_ct = sealapi.Ciphertext()
541
+ rhs.ctx.evaluator.negate(rhs.data, neg_ct)
542
+ rhs.ctx.evaluator.add_plain(neg_ct, pt, result_ct)
543
+ return BFVValue(result_ct, rhs.ctx, is_cipher=True)
544
+
545
+ # Handle Plaintext - Plaintext (TensorValue - TensorValue)
546
+ if isinstance(lhs, TensorValue) and isinstance(rhs, TensorValue):
547
+ return TensorValue.wrap(lhs.unwrap() - rhs.unwrap())
548
+ raise TypeError(f"Unsupported types for sub: {type(lhs)}, {type(rhs)}")
549
+
550
+
551
+ @bfv.mul_p.def_impl
552
+ def mul_impl(
553
+ interpreter: Interpreter,
554
+ op: Operation,
555
+ lhs: BFVValue | TensorValue,
556
+ rhs: BFVValue | TensorValue,
557
+ ) -> BFVValue | TensorValue:
558
+ # Case 1: Both are BFVValues
559
+ if isinstance(lhs, BFVValue) and isinstance(rhs, BFVValue):
560
+ result_ct = sealapi.Ciphertext()
561
+
562
+ if lhs.is_cipher and rhs.is_cipher:
563
+ lhs.ctx.evaluator.multiply(lhs.data, rhs.data, result_ct)
564
+ return BFVValue(result_ct, lhs.ctx, is_cipher=True)
565
+ elif lhs.is_cipher and not rhs.is_cipher:
566
+ # Optimization: Check for zero plaintext to avoid expensive exception handling
567
+ if rhs.data.is_zero():
568
+ # Return transparent zero ciphertext (no noise, size 0)
569
+ # SEAL arithmetic ops handle transparent ciphertexts as zero.
570
+ # We must ensure relinearize/rotate also handle it.
571
+ return BFVValue(sealapi.Ciphertext(), lhs.ctx, is_cipher=True)
572
+
573
+ try:
574
+ lhs.ctx.evaluator.multiply_plain(lhs.data, rhs.data, result_ct)
575
+ return BFVValue(result_ct, lhs.ctx, is_cipher=True)
576
+ except RuntimeError as e:
577
+ if "transparent" in str(e):
578
+ return BFVValue(sealapi.Ciphertext(), lhs.ctx, is_cipher=True)
579
+ raise e
580
+ elif not lhs.is_cipher and rhs.is_cipher:
581
+ # Optimization: Check for zero plaintext
582
+ if lhs.data.is_zero():
583
+ return BFVValue(sealapi.Ciphertext(), rhs.ctx, is_cipher=True)
584
+
585
+ try:
586
+ rhs.ctx.evaluator.multiply_plain(rhs.data, lhs.data, result_ct)
587
+ return BFVValue(result_ct, rhs.ctx, is_cipher=True)
588
+ except RuntimeError as e:
589
+ if "transparent" in str(e):
590
+ return BFVValue(sealapi.Ciphertext(), rhs.ctx, is_cipher=True)
591
+ raise e
592
+ else:
593
+ raise NotImplementedError(
594
+ "BFV Plaintext * Plaintext multiplication not implemented yet"
595
+ )
596
+
597
+ # Case 2: One is BFVValue (Ciphertext), other is TensorValue
598
+ if isinstance(lhs, BFVValue) and lhs.is_cipher:
599
+ # Check for zero plaintext to avoid "transparent ciphertext" error
600
+ # Also check if plaintext is BFVValue(Plaintext)
601
+ if isinstance(rhs, TensorValue) and np.all(rhs.unwrap() == 0):
602
+ result_ct = sealapi.Ciphertext()
603
+ lhs.ctx.encryptor.encrypt_zero(result_ct)
604
+ return BFVValue(result_ct, lhs.ctx, is_cipher=True)
605
+
606
+ try:
607
+ pt = _ensure_plaintext(lhs.ctx, rhs)
608
+ result_ct = sealapi.Ciphertext()
609
+ lhs.ctx.evaluator.multiply_plain(lhs.data, pt, result_ct)
610
+ return BFVValue(result_ct, lhs.ctx, is_cipher=True)
611
+ except RuntimeError as e:
612
+ # SEAL throws "result ciphertext is transparent" when multiplying by a zero plaintext.
613
+ # This is mathematically valid (Enc(x) * 0 = Enc(0)), but SEAL enforces explicit zero encryption.
614
+ # We catch this error and return a valid zero ciphertext to maintain operator semantics.
615
+ if "transparent" in str(e):
616
+ # Fallback for zero plaintext
617
+ result_ct = sealapi.Ciphertext()
618
+ lhs.ctx.encryptor.encrypt_zero(result_ct)
619
+ return BFVValue(result_ct, lhs.ctx, is_cipher=True)
620
+ raise e
621
+
622
+ if isinstance(rhs, BFVValue) and rhs.is_cipher:
623
+ # Check for zero plaintext to avoid "transparent ciphertext" error
624
+ if isinstance(lhs, TensorValue) and np.all(lhs.unwrap() == 0):
625
+ result_ct = sealapi.Ciphertext()
626
+ rhs.ctx.encryptor.encrypt_zero(result_ct)
627
+ return BFVValue(result_ct, rhs.ctx, is_cipher=True)
628
+
629
+ try:
630
+ pt = _ensure_plaintext(rhs.ctx, lhs)
631
+ result_ct = sealapi.Ciphertext()
632
+ rhs.ctx.evaluator.multiply_plain(rhs.data, pt, result_ct)
633
+ return BFVValue(result_ct, rhs.ctx, is_cipher=True)
634
+ except RuntimeError as e:
635
+ # See comment above regarding "transparent ciphertext"
636
+ if "transparent" in str(e):
637
+ # Fallback for zero plaintext
638
+ result_ct = sealapi.Ciphertext()
639
+ rhs.ctx.encryptor.encrypt_zero(result_ct)
640
+ return BFVValue(result_ct, rhs.ctx, is_cipher=True)
641
+ raise e
642
+
643
+ # Handle Plaintext * Plaintext (TensorValue * TensorValue)
644
+ if isinstance(lhs, TensorValue) and isinstance(rhs, TensorValue):
645
+ return TensorValue.wrap(lhs.unwrap() * rhs.unwrap())
646
+ raise TypeError(f"Unsupported types for mul: {type(lhs)}, {type(rhs)}")
647
+
648
+
649
+ @bfv.relinearize_p.def_impl
650
+ def relinearize_impl(
651
+ interpreter: Interpreter,
652
+ op: Operation,
653
+ ciphertext: BFVValue,
654
+ rk: BFVPublicContextValue,
655
+ ) -> BFVValue:
656
+ # rk is BFVPublicContextValue (same as ciphertext.ctx)
657
+
658
+ # Optimization: Handle transparent ciphertext (zero)
659
+ if ciphertext.data.is_transparent():
660
+ return ciphertext
661
+
662
+ # Check if relinearization is needed (size > 2)
663
+ if ciphertext.data.size() > 2:
664
+ new_ct = sealapi.Ciphertext()
665
+ ciphertext.ctx.evaluator.relinearize(ciphertext.data, rk.relin_keys, new_ct)
666
+ return BFVValue(new_ct, ciphertext.ctx, is_cipher=True)
667
+
668
+ return ciphertext
669
+
670
+
671
+ @bfv.rotate_p.def_impl
672
+ def rotate_impl(
673
+ interpreter: Interpreter,
674
+ op: Operation,
675
+ ciphertext: BFVValue,
676
+ gk: BFVPublicContextValue,
677
+ ) -> BFVValue:
678
+ """Implement rotation using low-level SEAL API directly."""
679
+ steps = op.attrs.get("steps", 0)
680
+ if steps == 0:
681
+ return ciphertext
682
+
683
+ # Optimization: Handle transparent ciphertext (zero)
684
+ if ciphertext.data.is_transparent():
685
+ return ciphertext
686
+
687
+ # ciphertext is BFVValue
688
+ # gk is BFVPublicContextValue
689
+
690
+ new_ct = sealapi.Ciphertext()
691
+ ciphertext.ctx.evaluator.rotate_rows(ciphertext.data, steps, gk.galois_keys, new_ct)
692
+ return BFVValue(new_ct, ciphertext.ctx, is_cipher=True)
693
+
694
+
695
+ @bfv.rotate_columns_p.def_impl
696
+ def rotate_columns_impl(
697
+ interpreter: Interpreter,
698
+ op: Operation,
699
+ ciphertext: BFVValue,
700
+ gk: BFVPublicContextValue,
701
+ ) -> BFVValue:
702
+ """Swap the two rows in SIMD batching (row 0 <-> row 1)."""
703
+ new_ct = sealapi.Ciphertext()
704
+ ciphertext.ctx.evaluator.rotate_columns(ciphertext.data, gk.galois_keys, new_ct)
705
+ return BFVValue(new_ct, ciphertext.ctx, is_cipher=True)