mplang-nightly 0.1.dev269__py3-none-any.whl → 0.1.dev270__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (180) hide show
  1. mplang/__init__.py +391 -17
  2. mplang/{v2/backends → backends}/__init__.py +9 -7
  3. mplang/{v2/backends → backends}/bfv_impl.py +6 -6
  4. mplang/{v2/backends → backends}/crypto_impl.py +6 -6
  5. mplang/{v2/backends → backends}/field_impl.py +5 -5
  6. mplang/{v2/backends → backends}/func_impl.py +4 -4
  7. mplang/{v2/backends → backends}/phe_impl.py +3 -3
  8. mplang/{v2/backends → backends}/simp_design.md +1 -1
  9. mplang/{v2/backends → backends}/simp_driver/__init__.py +5 -5
  10. mplang/{v2/backends → backends}/simp_driver/http.py +8 -8
  11. mplang/{v2/backends → backends}/simp_driver/mem.py +9 -9
  12. mplang/{v2/backends → backends}/simp_driver/ops.py +4 -4
  13. mplang/{v2/backends → backends}/simp_driver/state.py +2 -2
  14. mplang/{v2/backends → backends}/simp_driver/values.py +2 -2
  15. mplang/{v2/backends → backends}/simp_worker/__init__.py +3 -3
  16. mplang/{v2/backends → backends}/simp_worker/http.py +10 -10
  17. mplang/{v2/backends → backends}/simp_worker/mem.py +1 -1
  18. mplang/{v2/backends → backends}/simp_worker/ops.py +5 -5
  19. mplang/{v2/backends → backends}/simp_worker/state.py +2 -4
  20. mplang/{v2/backends → backends}/spu_impl.py +8 -8
  21. mplang/{v2/backends → backends}/spu_state.py +4 -4
  22. mplang/{v2/backends → backends}/store_impl.py +3 -3
  23. mplang/{v2/backends → backends}/table_impl.py +8 -8
  24. mplang/{v2/backends → backends}/tee_impl.py +6 -6
  25. mplang/{v2/backends → backends}/tensor_impl.py +6 -6
  26. mplang/{v2/cli.py → cli.py} +9 -9
  27. mplang/{v2/cli_guide.md → cli_guide.md} +12 -12
  28. mplang/{v2/dialects → dialects}/__init__.py +5 -5
  29. mplang/{v2/dialects → dialects}/bfv.py +6 -6
  30. mplang/{v2/dialects → dialects}/crypto.py +5 -5
  31. mplang/{v2/dialects → dialects}/dtypes.py +2 -2
  32. mplang/{v2/dialects → dialects}/field.py +3 -3
  33. mplang/{v2/dialects → dialects}/func.py +2 -2
  34. mplang/{v2/dialects → dialects}/phe.py +6 -6
  35. mplang/{v2/dialects → dialects}/simp.py +6 -6
  36. mplang/{v2/dialects → dialects}/spu.py +7 -7
  37. mplang/{v2/dialects → dialects}/store.py +2 -2
  38. mplang/{v2/dialects → dialects}/table.py +3 -3
  39. mplang/{v2/dialects → dialects}/tee.py +6 -6
  40. mplang/{v2/dialects → dialects}/tensor.py +5 -5
  41. mplang/{v2/edsl → edsl}/__init__.py +3 -3
  42. mplang/{v2/edsl → edsl}/context.py +6 -6
  43. mplang/{v2/edsl → edsl}/graph.py +5 -5
  44. mplang/{v2/edsl → edsl}/jit.py +2 -2
  45. mplang/{v2/edsl → edsl}/object.py +1 -1
  46. mplang/{v2/edsl → edsl}/primitive.py +5 -5
  47. mplang/{v2/edsl → edsl}/printer.py +1 -1
  48. mplang/{v2/edsl → edsl}/serde.py +1 -1
  49. mplang/{v2/edsl → edsl}/tracer.py +7 -7
  50. mplang/{v2/edsl → edsl}/typing.py +1 -1
  51. mplang/{v2/kernels → kernels}/ldpc.cpp +13 -13
  52. mplang/{v2/kernels → kernels}/okvs.cpp +4 -4
  53. mplang/{v2/kernels → kernels}/okvs_opt.cpp +31 -31
  54. mplang/{v2/kernels → kernels}/py_kernels.py +1 -1
  55. mplang/{v2/libs → libs}/collective.py +5 -5
  56. mplang/{v2/libs → libs}/device/__init__.py +1 -1
  57. mplang/{v2/libs → libs}/device/api.py +12 -12
  58. mplang/{v2/libs → libs}/ml/__init__.py +1 -1
  59. mplang/{v2/libs → libs}/ml/sgb.py +4 -4
  60. mplang/{v2/libs → libs}/mpc/__init__.py +3 -3
  61. mplang/{v2/libs → libs}/mpc/_utils.py +2 -2
  62. mplang/{v2/libs → libs}/mpc/analytics/aggregation.py +1 -1
  63. mplang/{v2/libs → libs}/mpc/analytics/groupby.py +2 -2
  64. mplang/{v2/libs → libs}/mpc/analytics/permutation.py +3 -3
  65. mplang/{v2/libs → libs}/mpc/ot/base.py +3 -3
  66. mplang/{v2/libs → libs}/mpc/ot/extension.py +2 -2
  67. mplang/{v2/libs → libs}/mpc/ot/silent.py +4 -4
  68. mplang/{v2/libs → libs}/mpc/psi/cuckoo.py +3 -3
  69. mplang/{v2/libs → libs}/mpc/psi/okvs.py +1 -1
  70. mplang/{v2/libs → libs}/mpc/psi/okvs_gct.py +3 -3
  71. mplang/{v2/libs → libs}/mpc/psi/oprf.py +3 -3
  72. mplang/{v2/libs → libs}/mpc/psi/rr22.py +7 -7
  73. mplang/{v2/libs → libs}/mpc/psi/unbalanced.py +4 -4
  74. mplang/{v2/libs → libs}/mpc/vole/gilboa.py +3 -3
  75. mplang/{v2/libs → libs}/mpc/vole/ldpc.py +2 -2
  76. mplang/{v2/libs → libs}/mpc/vole/silver.py +6 -6
  77. mplang/{v2/runtime → runtime}/interpreter.py +11 -11
  78. mplang/{v2/runtime → runtime}/value.py +2 -2
  79. mplang/{v1/runtime → utils}/__init__.py +18 -15
  80. mplang/{v1/utils → utils}/func_utils.py +1 -1
  81. {mplang_nightly-0.1.dev269.dist-info → mplang_nightly-0.1.dev270.dist-info}/METADATA +2 -2
  82. mplang_nightly-0.1.dev270.dist-info/RECORD +102 -0
  83. mplang/v1/__init__.py +0 -157
  84. mplang/v1/_device.py +0 -602
  85. mplang/v1/analysis/__init__.py +0 -37
  86. mplang/v1/analysis/diagram.py +0 -567
  87. mplang/v1/core/__init__.py +0 -157
  88. mplang/v1/core/cluster.py +0 -343
  89. mplang/v1/core/comm.py +0 -281
  90. mplang/v1/core/context_mgr.py +0 -50
  91. mplang/v1/core/dtypes.py +0 -335
  92. mplang/v1/core/expr/__init__.py +0 -80
  93. mplang/v1/core/expr/ast.py +0 -542
  94. mplang/v1/core/expr/evaluator.py +0 -581
  95. mplang/v1/core/expr/printer.py +0 -285
  96. mplang/v1/core/expr/transformer.py +0 -141
  97. mplang/v1/core/expr/utils.py +0 -78
  98. mplang/v1/core/expr/visitor.py +0 -85
  99. mplang/v1/core/expr/walk.py +0 -387
  100. mplang/v1/core/interp.py +0 -160
  101. mplang/v1/core/mask.py +0 -325
  102. mplang/v1/core/mpir.py +0 -965
  103. mplang/v1/core/mpobject.py +0 -117
  104. mplang/v1/core/mptype.py +0 -407
  105. mplang/v1/core/pfunc.py +0 -130
  106. mplang/v1/core/primitive.py +0 -877
  107. mplang/v1/core/table.py +0 -218
  108. mplang/v1/core/tensor.py +0 -75
  109. mplang/v1/core/tracer.py +0 -383
  110. mplang/v1/host.py +0 -130
  111. mplang/v1/kernels/__init__.py +0 -41
  112. mplang/v1/kernels/base.py +0 -125
  113. mplang/v1/kernels/basic.py +0 -240
  114. mplang/v1/kernels/context.py +0 -369
  115. mplang/v1/kernels/crypto.py +0 -122
  116. mplang/v1/kernels/fhe.py +0 -858
  117. mplang/v1/kernels/mock_tee.py +0 -72
  118. mplang/v1/kernels/phe.py +0 -1864
  119. mplang/v1/kernels/spu.py +0 -341
  120. mplang/v1/kernels/sql_duckdb.py +0 -44
  121. mplang/v1/kernels/stablehlo.py +0 -90
  122. mplang/v1/kernels/value.py +0 -626
  123. mplang/v1/ops/__init__.py +0 -35
  124. mplang/v1/ops/base.py +0 -424
  125. mplang/v1/ops/basic.py +0 -294
  126. mplang/v1/ops/crypto.py +0 -262
  127. mplang/v1/ops/fhe.py +0 -272
  128. mplang/v1/ops/jax_cc.py +0 -147
  129. mplang/v1/ops/nnx_cc.py +0 -168
  130. mplang/v1/ops/phe.py +0 -216
  131. mplang/v1/ops/spu.py +0 -151
  132. mplang/v1/ops/sql_cc.py +0 -303
  133. mplang/v1/ops/tee.py +0 -36
  134. mplang/v1/protos/v1alpha1/mpir_pb2.py +0 -63
  135. mplang/v1/protos/v1alpha1/mpir_pb2.pyi +0 -557
  136. mplang/v1/protos/v1alpha1/value_pb2.py +0 -34
  137. mplang/v1/protos/v1alpha1/value_pb2.pyi +0 -169
  138. mplang/v1/runtime/channel.py +0 -230
  139. mplang/v1/runtime/cli.py +0 -451
  140. mplang/v1/runtime/client.py +0 -456
  141. mplang/v1/runtime/communicator.py +0 -131
  142. mplang/v1/runtime/data_providers.py +0 -303
  143. mplang/v1/runtime/driver.py +0 -324
  144. mplang/v1/runtime/exceptions.py +0 -27
  145. mplang/v1/runtime/http_api.md +0 -56
  146. mplang/v1/runtime/link_comm.py +0 -196
  147. mplang/v1/runtime/server.py +0 -501
  148. mplang/v1/runtime/session.py +0 -270
  149. mplang/v1/runtime/simulation.py +0 -324
  150. mplang/v1/simp/__init__.py +0 -13
  151. mplang/v1/simp/api.py +0 -353
  152. mplang/v1/simp/mpi.py +0 -131
  153. mplang/v1/simp/party.py +0 -225
  154. mplang/v1/simp/random.py +0 -120
  155. mplang/v1/simp/smpc.py +0 -238
  156. mplang/v1/utils/__init__.py +0 -13
  157. mplang/v1/utils/crypto.py +0 -32
  158. mplang/v1/utils/spu_utils.py +0 -130
  159. mplang/v1/utils/table_utils.py +0 -185
  160. mplang/v2/__init__.py +0 -424
  161. mplang_nightly-0.1.dev269.dist-info/RECORD +0 -180
  162. /mplang/{v2/backends → backends}/channel.py +0 -0
  163. /mplang/{v2/edsl → edsl}/README.md +0 -0
  164. /mplang/{v2/edsl → edsl}/registry.py +0 -0
  165. /mplang/{v2/kernels → kernels}/Makefile +0 -0
  166. /mplang/{v2/kernels → kernels}/__init__.py +0 -0
  167. /mplang/{v2/kernels → kernels}/gf128.cpp +0 -0
  168. /mplang/{v2/libs → libs}/device/cluster.py +0 -0
  169. /mplang/{v2/libs → libs}/mpc/analytics/__init__.py +0 -0
  170. /mplang/{v2/libs → libs}/mpc/analytics/groupby.md +0 -0
  171. /mplang/{v2/libs → libs}/mpc/common/constants.py +0 -0
  172. /mplang/{v2/libs → libs}/mpc/ot/__init__.py +0 -0
  173. /mplang/{v2/libs → libs}/mpc/psi/__init__.py +0 -0
  174. /mplang/{v2/libs → libs}/mpc/vole/__init__.py +0 -0
  175. /mplang/{v2/runtime → runtime}/__init__.py +0 -0
  176. /mplang/{v2/runtime → runtime}/dialect_state.py +0 -0
  177. /mplang/{v2/runtime → runtime}/object_store.py +0 -0
  178. {mplang_nightly-0.1.dev269.dist-info → mplang_nightly-0.1.dev270.dist-info}/WHEEL +0 -0
  179. {mplang_nightly-0.1.dev269.dist-info → mplang_nightly-0.1.dev270.dist-info}/entry_points.txt +0 -0
  180. {mplang_nightly-0.1.dev269.dist-info → mplang_nightly-0.1.dev270.dist-info}/licenses/LICENSE +0 -0
mplang/v1/kernels/fhe.py DELETED
@@ -1,858 +0,0 @@
1
- # Copyright 2025 Ant Group Co., Ltd.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
-
15
- """FHE Vector backend implementation using TenSEAL CKKSVector/BFVVector.
16
-
17
- This module provides FHE operations using TenSEAL's vector-based encryption,
18
- which only supports 1D data. All operations enforce 1D shape constraints.
19
- """
20
-
21
- from typing import Any
22
-
23
- import numpy as np
24
- import tenseal as ts
25
-
26
- from mplang.v1.core import DType, PFunction, TensorLike
27
- from mplang.v1.kernels.base import kernel_def
28
- from mplang.v1.kernels.value import TensorValue
29
-
30
-
31
- class FHEContext:
32
- """FHE context manager for TenSEAL vector operations.
33
-
34
- Note: This context is optimized for vector-based encryption (CKKSVector/BFVVector),
35
- which only supports 1D data.
36
- """
37
-
38
- def __init__(self, context: Any, scheme: str = "CKKS"):
39
- self.context = context
40
- assert scheme in (
41
- "CKKS",
42
- "BFV",
43
- ), f"Unsupported scheme type for TenSEAL: {scheme}"
44
- self._scheme = scheme
45
-
46
- @property
47
- def dtype(self) -> Any:
48
- return np.dtype("O") # Use object dtype for binary data
49
-
50
- @property
51
- def shape(self) -> tuple[int, ...]:
52
- return ()
53
-
54
- @property
55
- def scheme(self) -> str:
56
- return self._scheme
57
-
58
- @property
59
- def global_scale(self) -> Any:
60
- if self._scheme != "CKKS":
61
- raise ValueError("global_scale is only applicable for CKKS scheme.")
62
- return self.context.global_scale
63
-
64
- @property
65
- def is_private(self) -> Any:
66
- return self.context.is_private()
67
-
68
- @property
69
- def is_public(self) -> Any:
70
- return self.context.is_public()
71
-
72
- def make_context_public(self) -> None:
73
- """Remove secret key from context to make it public."""
74
- self.context.make_context_public()
75
-
76
- def serialize(self, save_secret_key: bool = True) -> Any:
77
- """Serialize the context."""
78
- return self.context.serialize(
79
- save_public_key=True,
80
- save_secret_key=save_secret_key,
81
- save_galois_keys=True,
82
- save_relin_keys=True,
83
- )
84
-
85
- def drop_secret_key(self) -> "FHEContext":
86
- """Create a public-only copy of this context."""
87
- proto = self.serialize(save_secret_key=False)
88
- new_ctx = ts.context_from(proto)
89
- return FHEContext(new_ctx, self._scheme)
90
-
91
- def __repr__(self) -> str:
92
- return f"FHEContext(scheme={self.scheme}, is_private={self.is_private}, is_public={self.is_public})"
93
-
94
-
95
- class CipherText:
96
- """Ciphertext wrapper for TenSEAL vector operations.
97
-
98
- Note: Only supports 1D shapes (scalars represented as shape=(1,) or shape=()).
99
- """
100
-
101
- def __init__(
102
- self,
103
- ct_data: Any,
104
- semantic_dtype: DType,
105
- semantic_shape: tuple[int, ...],
106
- scheme: str,
107
- context: FHEContext | None = None,
108
- ):
109
- # Validate shape constraints for vector backend
110
- if len(semantic_shape) > 1:
111
- raise ValueError(
112
- f"FHE Vector backend only supports 1D data (scalars or vectors). "
113
- f"Got shape {semantic_shape}. Use fhe.py (tensor backend) for multi-dimensional data."
114
- )
115
-
116
- self.ct_data = ct_data
117
- self.semantic_dtype = semantic_dtype
118
- self.semantic_shape = semantic_shape
119
- self._scheme = scheme
120
- self._context = context
121
-
122
- @property
123
- def dtype(self) -> Any:
124
- return self.semantic_dtype.to_numpy()
125
-
126
- @property
127
- def shape(self) -> tuple[int, ...]:
128
- return self.semantic_shape
129
-
130
- @property
131
- def scheme(self) -> str:
132
- return self._scheme
133
-
134
- @property
135
- def context(self) -> FHEContext | None:
136
- return self._context
137
-
138
- def __repr__(self) -> str:
139
- return f"CipherText(dtype={self.semantic_dtype}, shape={self.semantic_shape}, scheme={self.scheme})"
140
-
141
-
142
- def _convert_to_numpy(obj: TensorLike) -> np.ndarray:
143
- """Convert a TensorLike object to numpy array."""
144
- if isinstance(obj, np.ndarray):
145
- return obj
146
-
147
- # Try to use .numpy() method if available
148
- if hasattr(obj, "numpy"):
149
- numpy_method = getattr(obj, "numpy", None)
150
- if callable(numpy_method):
151
- try:
152
- return np.asarray(numpy_method())
153
- except Exception:
154
- pass
155
-
156
- return np.asarray(obj)
157
-
158
-
159
- def _validate_1d_shape(shape: tuple[int, ...], operation: str) -> None:
160
- """Validate that shape is 1D (scalar or vector) for vector backend operations."""
161
- if len(shape) > 1:
162
- raise ValueError(
163
- f"FHE Vector backend operation '{operation}' only supports 1D data. "
164
- f"Got shape {shape}. Use fhe.py (tensor backend) for multi-dimensional data."
165
- )
166
-
167
-
168
- @kernel_def("fhe.keygen")
169
- def _fhe_keygen(pfunc: PFunction) -> Any:
170
- """Generate FHE context for vector operations.
171
-
172
- Returns:
173
- A tuple containing three FHEContext objects:
174
- - [0]: Private context with secret key
175
- - [1]: Public context without secret key (for distribution to other parties)
176
- - [2]: Evaluation context (same as public context for TenSEAL)
177
- """
178
- scheme = pfunc.attrs.get("scheme", "CKKS")
179
- poly_modulus_degree = pfunc.attrs.get("poly_modulus_degree", 8192)
180
-
181
- if scheme == "CKKS":
182
- # CKKS parameters for floating point operations
183
- coeff_mod_bit_sizes = pfunc.attrs.get("coeff_mod_bit_sizes", [60, 40, 40, 60])
184
- global_scale = pfunc.attrs.get("global_scale", 2**40)
185
-
186
- try:
187
- context = ts.context(
188
- ts.SCHEME_TYPE.CKKS,
189
- poly_modulus_degree=poly_modulus_degree,
190
- coeff_mod_bit_sizes=coeff_mod_bit_sizes,
191
- )
192
- context.generate_galois_keys()
193
- context.generate_relin_keys()
194
- context.global_scale = global_scale
195
-
196
- private_context = FHEContext(context, scheme="CKKS")
197
- public_context = private_context.drop_secret_key()
198
- eval_context = public_context # For TenSEAL, eval context is same as public
199
-
200
- return (private_context, public_context, eval_context)
201
-
202
- except Exception as e:
203
- raise RuntimeError(f"Failed to create CKKS context: {e}") from e
204
-
205
- elif scheme == "BFV":
206
- # BFV parameters for integer operations
207
- plain_modulus = pfunc.attrs.get("plain_modulus", 1032193)
208
- coeff_mod_bit_sizes = pfunc.attrs.get("coeff_mod_bit_sizes", [60, 40, 40, 60])
209
-
210
- try:
211
- context = ts.context(
212
- ts.SCHEME_TYPE.BFV,
213
- poly_modulus_degree=poly_modulus_degree,
214
- plain_modulus=plain_modulus,
215
- coeff_mod_bit_sizes=coeff_mod_bit_sizes,
216
- )
217
- context.generate_galois_keys()
218
- context.generate_relin_keys()
219
-
220
- private_context = FHEContext(context, scheme="BFV")
221
- public_context = private_context.drop_secret_key()
222
- eval_context = public_context
223
-
224
- return (private_context, public_context, eval_context)
225
-
226
- except Exception as e:
227
- raise RuntimeError(f"Failed to create BFV context: {e}") from e
228
- else:
229
- raise ValueError(f"Unsupported FHE scheme: {scheme}. Use 'CKKS' or 'BFV'.")
230
-
231
-
232
- @kernel_def("fhe.encrypt")
233
- def _fhe_encrypt(pfunc: PFunction, plaintext: Any, context: FHEContext) -> Any:
234
- """Encrypt plaintext data using FHE vector context.
235
-
236
- Only supports 1D data (scalars or vectors).
237
- """
238
- if not isinstance(context, FHEContext):
239
- raise TypeError(f"Expected FHEContext, got {type(context)}")
240
-
241
- try:
242
- plaintext_np = _convert_to_numpy(plaintext)
243
-
244
- # Validate shape
245
- _validate_1d_shape(plaintext_np.shape, "encrypt")
246
-
247
- # Determine semantic dtype based on input data type
248
- if context.scheme == "CKKS":
249
- # Preserve the input floating-point dtype (float32 or float64)
250
- if np.issubdtype(plaintext_np.dtype, np.floating):
251
- semantic_dtype = DType.from_numpy(plaintext_np.dtype)
252
- else:
253
- # Default to float32 for non-floating types
254
- semantic_dtype = DType.from_numpy(np.dtype("float32"))
255
- else: # BFV
256
- if not np.issubdtype(plaintext_np.dtype, np.integer):
257
- raise RuntimeError("BFV scheme requires integer semantic_dtype")
258
- semantic_dtype = DType.from_numpy(np.dtype("int64"))
259
-
260
- # Handle scalar (convert to 1-element vector)
261
- if plaintext_np.shape == ():
262
- plaintext_np = np.array([plaintext_np.item()])
263
- semantic_shape: tuple = ()
264
- else:
265
- semantic_shape = plaintext_np.shape
266
-
267
- # Encrypt based on scheme
268
- if context.scheme == "CKKS":
269
- plaintext_data = plaintext_np.astype(np.float64).tolist()
270
- ct_data = ts.ckks_vector(context.context, plaintext_data)
271
- elif context.scheme == "BFV":
272
- plaintext_data = plaintext_np.astype(np.int64).tolist()
273
- ct_data = ts.bfv_vector(context.context, plaintext_data)
274
- else:
275
- raise ValueError(f"Unsupported scheme: {context.scheme}")
276
-
277
- # Create CipherText wrapper
278
- ciphertext = CipherText(
279
- ct_data=ct_data,
280
- semantic_dtype=semantic_dtype,
281
- semantic_shape=semantic_shape,
282
- scheme=context.scheme,
283
- context=context,
284
- )
285
-
286
- return (ciphertext,)
287
-
288
- except Exception as e:
289
- raise RuntimeError(f"FHE vector encryption failed: {e}") from e
290
-
291
-
292
- @kernel_def("fhe.decrypt")
293
- def _fhe_decrypt(pfunc: PFunction, ciphertext: CipherText, context: FHEContext) -> Any:
294
- """Decrypt ciphertext using FHE vector context."""
295
- if not isinstance(ciphertext, CipherText):
296
- raise TypeError(f"Expected CipherText, got {type(ciphertext)}")
297
- if not isinstance(context, FHEContext):
298
- raise TypeError(f"Expected FHEContext, got {type(context)}")
299
-
300
- # Validate scheme compatibility
301
- if ciphertext.scheme != context.scheme:
302
- raise ValueError(
303
- f"Scheme mismatch: ciphertext uses {ciphertext.scheme}, context uses {context.scheme}"
304
- )
305
-
306
- # Check if context has secret key
307
- if not context.is_private:
308
- raise ValueError("Context must have secret key for decryption")
309
-
310
- try:
311
- # If the ciphertext was encrypted with a public context,
312
- # we need to link it to the private context for decryption
313
- ct_to_decrypt = ciphertext.ct_data
314
-
315
- # Check if the ciphertext's context is missing or public. If so, link it to the
316
- # private context provided for decryption.
317
- try:
318
- # A ciphertext might not have a context if deserialized, or it might have a public one.
319
- if (
320
- not ct_to_decrypt.context()
321
- or not ct_to_decrypt.context().has_secret_key()
322
- ):
323
- ct_to_decrypt.link_context(context.context)
324
- except Exception:
325
- # Fallback for cases where .context() might fail. Linking is the safe action.
326
- ct_to_decrypt.link_context(context.context)
327
-
328
- # Decrypt
329
- decrypted_list = ct_to_decrypt.decrypt()
330
-
331
- # Convert to numpy array using the semantic dtype from ciphertext
332
- if context.scheme == "CKKS":
333
- # Use the dtype stored in the ciphertext's semantic_dtype
334
- target_dtype = ciphertext.semantic_dtype.to_numpy()
335
- decrypted_np = np.array(decrypted_list, dtype=target_dtype)
336
- else: # BFV
337
- decrypted_np = np.array(decrypted_list, dtype=np.int64)
338
-
339
- # Restore original shape
340
- if ciphertext.semantic_shape == ():
341
- # Scalar: shape ()
342
- result_np = decrypted_np[0:1].reshape(())
343
- else:
344
- # Vector: keep 1D array
345
- result_np = decrypted_np
346
-
347
- # Return TensorValue to adhere to kernel Value I/O convention
348
- return (TensorValue(np.asarray(result_np)),)
349
-
350
- except Exception as e:
351
- raise RuntimeError(f"FHE vector decryption failed: {e}") from e
352
-
353
-
354
- @kernel_def("fhe.add")
355
- def _fhe_add(pfunc: PFunction, lhs: Any, rhs: Any) -> Any:
356
- """Perform homomorphic addition between ciphertexts or ciphertext and plaintext."""
357
- try:
358
- if isinstance(lhs, CipherText) and isinstance(rhs, CipherText):
359
- result = _fhe_add_ct2ct(lhs, rhs)
360
- elif isinstance(lhs, CipherText):
361
- result = _fhe_add_ct2pt(lhs, rhs)
362
- elif isinstance(rhs, CipherText):
363
- result = _fhe_add_ct2pt(rhs, lhs)
364
- else:
365
- raise ValueError("At least one operand must be CipherText")
366
- return (result,)
367
- except Exception as e:
368
- raise RuntimeError(f"FHE vector addition failed: {e}") from e
369
-
370
-
371
- def _fhe_add_ct2ct(ct1: CipherText, ct2: CipherText) -> CipherText:
372
- """Add two ciphertexts (vector backend)."""
373
- # Validate compatibility
374
- if ct1.scheme != ct2.scheme:
375
- raise ValueError("CipherText operands must use same scheme")
376
-
377
- # Validate shapes
378
- if ct1.semantic_shape != ct2.semantic_shape:
379
- raise ValueError(
380
- f"CipherText operands must have same shape for vector addition. "
381
- f"Got {ct1.semantic_shape} and {ct2.semantic_shape}"
382
- )
383
-
384
- # Perform addition
385
- result_ct_data = ct1.ct_data + ct2.ct_data
386
-
387
- # Create result CipherText
388
- return CipherText(
389
- ct_data=result_ct_data,
390
- semantic_dtype=ct1.semantic_dtype,
391
- semantic_shape=ct1.semantic_shape,
392
- scheme=ct1.scheme,
393
- context=ct1.context,
394
- )
395
-
396
-
397
- def _fhe_add_ct2pt(ciphertext: CipherText, plaintext: TensorLike) -> CipherText:
398
- """Add ciphertext and plaintext (vector backend)."""
399
- # Convert plaintext to numpy
400
- plaintext_np = _convert_to_numpy(plaintext)
401
-
402
- # Validate shape
403
- _validate_1d_shape(plaintext_np.shape, "add_plain")
404
-
405
- # Handle scalar plaintext
406
- if plaintext_np.shape == ():
407
- plaintext_np = np.array([plaintext_np.item()])
408
- is_scalar_pt = True
409
- else:
410
- is_scalar_pt = False
411
-
412
- # For ciphertext scalar + plaintext vector or vice versa, need matching shapes
413
- if ciphertext.semantic_shape == () and not is_scalar_pt:
414
- raise ValueError(
415
- f"Shape mismatch: cannot add scalar ciphertext with vector plaintext {plaintext_np.shape}"
416
- )
417
- if ciphertext.semantic_shape != () and is_scalar_pt:
418
- # Broadcast scalar plaintext to match ciphertext shape
419
- plaintext_np = np.full(ciphertext.semantic_shape, plaintext_np[0])
420
-
421
- # Validate final shape match (unless both scalars)
422
- if (
423
- ciphertext.semantic_shape != ()
424
- and plaintext_np.shape != ciphertext.semantic_shape
425
- ):
426
- raise ValueError(
427
- f"Shape mismatch: ciphertext shape {ciphertext.semantic_shape} vs plaintext shape {plaintext_np.shape}"
428
- )
429
-
430
- # Perform addition based on scheme
431
- if ciphertext.scheme == "CKKS":
432
- plaintext_list = plaintext_np.astype(np.float64).tolist()
433
- result_ct_data = ciphertext.ct_data + plaintext_list
434
- elif ciphertext.scheme == "BFV":
435
- if not np.issubdtype(plaintext_np.dtype, np.integer):
436
- raise RuntimeError("BFV scheme requires integer plaintext")
437
- plaintext_list = plaintext_np.astype(np.int64).tolist()
438
- result_ct_data = ciphertext.ct_data + plaintext_list
439
- else:
440
- raise ValueError(f"Unsupported scheme: {ciphertext.scheme}")
441
-
442
- # Create result CipherText
443
- return CipherText(
444
- ct_data=result_ct_data,
445
- semantic_dtype=ciphertext.semantic_dtype,
446
- semantic_shape=ciphertext.semantic_shape,
447
- scheme=ciphertext.scheme,
448
- context=ciphertext.context,
449
- )
450
-
451
-
452
- @kernel_def("fhe.sub")
453
- def _fhe_sub(pfunc: PFunction, lhs: Any, rhs: Any) -> Any:
454
- """Perform homomorphic subtraction between ciphertexts or ciphertext and plaintext."""
455
- try:
456
- if isinstance(lhs, CipherText) and isinstance(rhs, CipherText):
457
- result = _fhe_sub_ct2ct(lhs, rhs)
458
- elif isinstance(lhs, CipherText):
459
- result = _fhe_sub_ct2pt(lhs, rhs)
460
- else:
461
- raise ValueError("Left operand must be CipherText for subtraction")
462
- return (result,)
463
- except Exception as e:
464
- raise RuntimeError(f"FHE vector subtraction failed: {e}") from e
465
-
466
-
467
- def _fhe_sub_ct2ct(ct1: CipherText, ct2: CipherText) -> CipherText:
468
- """Subtract two ciphertexts (vector backend)."""
469
- # Validate compatibility
470
- if ct1.scheme != ct2.scheme:
471
- raise ValueError("CipherText operands must use same scheme")
472
-
473
- # Validate shapes
474
- if ct1.semantic_shape != ct2.semantic_shape:
475
- raise ValueError(
476
- f"CipherText operands must have same shape for vector subtraction. "
477
- f"Got {ct1.semantic_shape} and {ct2.semantic_shape}"
478
- )
479
-
480
- # Perform subtraction
481
- result_ct_data = ct1.ct_data - ct2.ct_data
482
-
483
- # Create result CipherText
484
- return CipherText(
485
- ct_data=result_ct_data,
486
- semantic_dtype=ct1.semantic_dtype,
487
- semantic_shape=ct1.semantic_shape,
488
- scheme=ct1.scheme,
489
- context=ct1.context,
490
- )
491
-
492
-
493
- def _fhe_sub_ct2pt(ciphertext: CipherText, plaintext: TensorLike) -> CipherText:
494
- """Subtract plaintext from ciphertext (vector backend)."""
495
- # Convert plaintext to numpy
496
- plaintext_np = _convert_to_numpy(plaintext)
497
-
498
- # Validate shape
499
- _validate_1d_shape(plaintext_np.shape, "sub_plain")
500
-
501
- # Handle scalar plaintext
502
- if plaintext_np.shape == ():
503
- plaintext_np = np.array([plaintext_np.item()])
504
- is_scalar_pt = True
505
- else:
506
- is_scalar_pt = False
507
-
508
- # Shape validation and broadcasting
509
- if ciphertext.semantic_shape == () and not is_scalar_pt:
510
- raise ValueError(
511
- "Shape mismatch: cannot subtract vector plaintext from scalar ciphertext"
512
- )
513
- if ciphertext.semantic_shape != () and is_scalar_pt:
514
- # Broadcast scalar plaintext
515
- plaintext_np = np.full(ciphertext.semantic_shape, plaintext_np[0])
516
-
517
- if (
518
- ciphertext.semantic_shape != ()
519
- and plaintext_np.shape != ciphertext.semantic_shape
520
- ):
521
- raise ValueError(
522
- f"Shape mismatch: ciphertext shape {ciphertext.semantic_shape} vs plaintext shape {plaintext_np.shape}"
523
- )
524
-
525
- # Perform subtraction based on scheme
526
- if ciphertext.scheme == "CKKS":
527
- plaintext_list = plaintext_np.astype(np.float64).tolist()
528
- result_ct_data = ciphertext.ct_data - plaintext_list
529
- elif ciphertext.scheme == "BFV":
530
- if not np.issubdtype(plaintext_np.dtype, np.integer):
531
- raise RuntimeError("BFV scheme requires integer plaintext")
532
- plaintext_list = plaintext_np.astype(np.int64).tolist()
533
- result_ct_data = ciphertext.ct_data - plaintext_list
534
- else:
535
- raise ValueError(f"Unsupported scheme: {ciphertext.scheme}")
536
-
537
- # Create result CipherText
538
- return CipherText(
539
- ct_data=result_ct_data,
540
- semantic_dtype=ciphertext.semantic_dtype,
541
- semantic_shape=ciphertext.semantic_shape,
542
- scheme=ciphertext.scheme,
543
- context=ciphertext.context,
544
- )
545
-
546
-
547
- @kernel_def("fhe.mul")
548
- def _fhe_mul(pfunc: PFunction, lhs: Any, rhs: Any) -> Any:
549
- """Perform homomorphic multiplication between ciphertexts or ciphertext and plaintext."""
550
- try:
551
- if isinstance(lhs, CipherText) and isinstance(rhs, CipherText):
552
- result = _fhe_mul_ct2ct(lhs, rhs)
553
- elif isinstance(lhs, CipherText):
554
- result = _fhe_mul_ct2pt(lhs, rhs)
555
- elif isinstance(rhs, CipherText):
556
- result = _fhe_mul_ct2pt(rhs, lhs)
557
- else:
558
- raise ValueError("At least one operand must be CipherText")
559
- return (result,)
560
- except Exception as e:
561
- raise RuntimeError(f"FHE vector multiplication failed: {e}") from e
562
-
563
-
564
- def _fhe_mul_ct2ct(ct1: CipherText, ct2: CipherText) -> CipherText:
565
- """Multiply two ciphertexts (vector backend)."""
566
- # Validate compatibility
567
- if ct1.scheme != ct2.scheme:
568
- raise ValueError("CipherText operands must use same scheme")
569
-
570
- # Validate shapes
571
- if ct1.semantic_shape != ct2.semantic_shape:
572
- raise ValueError(
573
- f"CipherText operands must have same shape for vector multiplication. "
574
- f"Got {ct1.semantic_shape} and {ct2.semantic_shape}"
575
- )
576
-
577
- # Perform multiplication
578
- result_ct_data = ct1.ct_data * ct2.ct_data
579
-
580
- # Create result CipherText
581
- return CipherText(
582
- ct_data=result_ct_data,
583
- semantic_dtype=ct1.semantic_dtype,
584
- semantic_shape=ct1.semantic_shape,
585
- scheme=ct1.scheme,
586
- context=ct1.context,
587
- )
588
-
589
-
590
- def _fhe_mul_ct2pt(ciphertext: CipherText, plaintext: TensorLike) -> CipherText:
591
- """Multiply ciphertext and plaintext (vector backend)."""
592
- # Convert plaintext to numpy
593
- plaintext_np = _convert_to_numpy(plaintext)
594
-
595
- # Validate shape
596
- _validate_1d_shape(plaintext_np.shape, "mul_plain")
597
-
598
- # Handle scalar plaintext
599
- if plaintext_np.shape == ():
600
- plaintext_np = np.array([plaintext_np.item()])
601
- is_scalar_pt = True
602
- else:
603
- is_scalar_pt = False
604
-
605
- # Shape validation and broadcasting
606
- if ciphertext.semantic_shape == () and not is_scalar_pt:
607
- raise ValueError(
608
- "Shape mismatch: cannot multiply scalar ciphertext with vector plaintext"
609
- )
610
- if ciphertext.semantic_shape != () and is_scalar_pt:
611
- # Broadcast scalar plaintext
612
- plaintext_np = np.full(ciphertext.semantic_shape, plaintext_np[0])
613
-
614
- if (
615
- ciphertext.semantic_shape != ()
616
- and plaintext_np.shape != ciphertext.semantic_shape
617
- ):
618
- raise ValueError(
619
- f"Shape mismatch: ciphertext shape {ciphertext.semantic_shape} vs plaintext shape {plaintext_np.shape}"
620
- )
621
-
622
- # Perform multiplication based on scheme
623
- if ciphertext.scheme == "CKKS":
624
- plaintext_list = plaintext_np.astype(np.float64).tolist()
625
- result_ct_data = ciphertext.ct_data * plaintext_list
626
- elif ciphertext.scheme == "BFV":
627
- if not np.issubdtype(plaintext_np.dtype, np.integer):
628
- raise RuntimeError("BFV scheme requires integer plaintext")
629
- plaintext_list = plaintext_np.astype(np.int64).tolist()
630
- result_ct_data = ciphertext.ct_data * plaintext_list
631
- else:
632
- raise ValueError(f"Unsupported scheme: {ciphertext.scheme}")
633
-
634
- # Create result CipherText
635
- return CipherText(
636
- ct_data=result_ct_data,
637
- semantic_dtype=ciphertext.semantic_dtype,
638
- semantic_shape=ciphertext.semantic_shape,
639
- scheme=ciphertext.scheme,
640
- context=ciphertext.context,
641
- )
642
-
643
-
644
- @kernel_def("fhe.dot")
645
- def _fhe_dot(pfunc: PFunction, lhs: Any, rhs: Any) -> Any:
646
- """Perform homomorphic dot product (only supports 1D × 1D vectors).
647
-
648
- Result is a scalar (shape=()).
649
- """
650
- try:
651
- if isinstance(lhs, CipherText) and isinstance(rhs, CipherText):
652
- result = _fhe_dot_ct2ct(lhs, rhs)
653
- elif isinstance(lhs, CipherText):
654
- result = _fhe_dot_ct2pt(lhs, rhs)
655
- elif isinstance(rhs, CipherText):
656
- result = _fhe_dot_ct2pt(rhs, lhs)
657
- else:
658
- raise ValueError("At least one operand must be CipherText")
659
- return (result,)
660
- except Exception as e:
661
- raise RuntimeError(f"FHE vector dot product failed: {e}") from e
662
-
663
-
664
- def _fhe_dot_ct2ct(ct1: CipherText, ct2: CipherText) -> CipherText:
665
- """Dot product of two ciphertexts (vector backend, 1D only)."""
666
- # Validate compatibility
667
- if ct1.scheme != ct2.scheme:
668
- raise ValueError("CipherText operands must use same scheme")
669
-
670
- # Validate 1D vector shapes (no scalars allowed in dot product)
671
- if ct1.semantic_shape == () or ct2.semantic_shape == ():
672
- raise ValueError("Dot product requires 1D vectors, not scalars")
673
-
674
- if len(ct1.semantic_shape) != 1 or len(ct2.semantic_shape) != 1:
675
- raise ValueError(
676
- f"Vector backend dot product only supports 1D X 1D. "
677
- f"Got shapes {ct1.semantic_shape} and {ct2.semantic_shape}"
678
- )
679
-
680
- if ct1.semantic_shape[0] != ct2.semantic_shape[0]:
681
- raise ValueError(
682
- f"Dot product dimension mismatch: {ct1.semantic_shape[0]} vs {ct2.semantic_shape[0]}"
683
- )
684
-
685
- # Perform dot product
686
- result_ct_data = ct1.ct_data.dot(ct2.ct_data)
687
-
688
- # Result is scalar
689
- return CipherText(
690
- ct_data=result_ct_data,
691
- semantic_dtype=ct1.semantic_dtype,
692
- semantic_shape=(), # Dot product result is scalar
693
- scheme=ct1.scheme,
694
- context=ct1.context,
695
- )
696
-
697
-
698
- def _fhe_dot_ct2pt(ciphertext: CipherText, plaintext: TensorLike) -> CipherText:
699
- """Dot product of ciphertext and plaintext (vector backend, 1D only)."""
700
- # Convert plaintext to numpy
701
- plaintext_np = _convert_to_numpy(plaintext)
702
-
703
- # Validate 1D shapes
704
- if ciphertext.semantic_shape == ():
705
- raise ValueError("Dot product requires 1D vector ciphertext, not scalar")
706
-
707
- _validate_1d_shape(plaintext_np.shape, "dot_plain")
708
-
709
- if plaintext_np.shape == ():
710
- raise ValueError("Dot product requires 1D vector plaintext, not scalar")
711
-
712
- if len(ciphertext.semantic_shape) != 1 or len(plaintext_np.shape) != 1:
713
- raise ValueError(
714
- f"Vector backend dot product only supports 1D X 1D. "
715
- f"Got shapes {ciphertext.semantic_shape} and {plaintext_np.shape}"
716
- )
717
-
718
- if ciphertext.semantic_shape[0] != plaintext_np.shape[0]:
719
- raise ValueError(
720
- f"Dot product dimension mismatch: {ciphertext.semantic_shape[0]} vs {plaintext_np.shape[0]}"
721
- )
722
-
723
- # Perform dot product based on scheme
724
- if ciphertext.scheme == "CKKS":
725
- plaintext_list = plaintext_np.astype(np.float64).tolist()
726
- result_ct_data = ciphertext.ct_data.dot(plaintext_list)
727
- elif ciphertext.scheme == "BFV":
728
- if not np.issubdtype(plaintext_np.dtype, np.integer):
729
- raise RuntimeError("BFV scheme requires integer plaintext")
730
- plaintext_list = plaintext_np.astype(np.int64).tolist()
731
- result_ct_data = ciphertext.ct_data.dot(plaintext_list)
732
- else:
733
- raise ValueError(f"Unsupported scheme: {ciphertext.scheme}")
734
-
735
- # Result is scalar
736
- return CipherText(
737
- ct_data=result_ct_data,
738
- semantic_dtype=ciphertext.semantic_dtype,
739
- semantic_shape=(), # Dot product result is scalar
740
- scheme=ciphertext.scheme,
741
- context=ciphertext.context,
742
- )
743
-
744
-
745
- @kernel_def("fhe.polyval")
746
- def _fhe_polyval(pfunc: PFunction, ciphertext: CipherText, coeffs: TensorLike) -> Any:
747
- """Evaluate polynomial on encrypted vector data with plaintext coefficients.
748
-
749
- Args:
750
- ciphertext: Encrypted data (CipherText, scalar or 1D vector)
751
- coeffs: Plaintext polynomial coefficients as 1D array [c0, c1, c2, ...]
752
- representing c0 + c1*x + c2*x^2 + ...
753
-
754
- Returns:
755
- CipherText with polynomial evaluation result (same shape as input)
756
-
757
- Note:
758
- TenSEAL has a known issue with constant polynomials (degree 0, single coefficient).
759
- For constants, consider using scalar multiplication instead: ct * 0 + constant.
760
- """
761
- if not isinstance(ciphertext, CipherText):
762
- raise TypeError(f"Expected CipherText, got {type(ciphertext)}")
763
-
764
- try:
765
- # Convert and validate coefficients
766
- coeffs_np = _convert_to_numpy(coeffs)
767
-
768
- if coeffs_np.ndim != 1:
769
- raise ValueError(
770
- f"Polynomial coefficients must be 1D array, got shape {coeffs_np.shape}"
771
- )
772
-
773
- if len(coeffs_np) == 0:
774
- raise ValueError("Polynomial coefficients cannot be empty")
775
-
776
- # Check for constant polynomial (TenSEAL limitation)
777
- if len(coeffs_np) == 1:
778
- raise ValueError(
779
- "TenSEAL does not support constant polynomials (degree 0). "
780
- "For constant values, use scalar multiplication instead: ct * 0 + constant"
781
- )
782
-
783
- # Validate scheme-specific requirements
784
- if ciphertext.scheme == "BFV":
785
- if not np.issubdtype(coeffs_np.dtype, np.integer):
786
- raise RuntimeError(
787
- "BFV scheme requires integer polynomial coefficients"
788
- )
789
- coeffs_list = coeffs_np.astype(np.int64).tolist()
790
- else: # CKKS
791
- coeffs_list = coeffs_np.astype(np.float64).tolist()
792
-
793
- # Perform polynomial evaluation
794
- result_ct_data = ciphertext.ct_data.polyval(coeffs_list)
795
-
796
- # Create result CipherText (same shape as input)
797
- return (
798
- CipherText(
799
- ct_data=result_ct_data,
800
- semantic_dtype=ciphertext.semantic_dtype,
801
- semantic_shape=ciphertext.semantic_shape,
802
- scheme=ciphertext.scheme,
803
- context=ciphertext.context,
804
- ),
805
- )
806
-
807
- except Exception as e:
808
- raise RuntimeError(f"FHE vector polyval failed: {e}") from e
809
-
810
-
811
- @kernel_def("fhe.negate")
812
- def _fhe_negate(pfunc: PFunction, ciphertext: CipherText) -> Any:
813
- """Negate encrypted data (compute -x)."""
814
- if not isinstance(ciphertext, CipherText):
815
- raise TypeError(f"Expected CipherText, got {type(ciphertext)}")
816
-
817
- try:
818
- # Perform negation
819
- result_ct_data = -ciphertext.ct_data
820
-
821
- # Create result CipherText
822
- return (
823
- CipherText(
824
- ct_data=result_ct_data,
825
- semantic_dtype=ciphertext.semantic_dtype,
826
- semantic_shape=ciphertext.semantic_shape,
827
- scheme=ciphertext.scheme,
828
- context=ciphertext.context,
829
- ),
830
- )
831
-
832
- except Exception as e:
833
- raise RuntimeError(f"FHE vector negation failed: {e}") from e
834
-
835
-
836
- @kernel_def("fhe.square")
837
- def _fhe_square(pfunc: PFunction, ciphertext: CipherText) -> Any:
838
- """Square encrypted data (compute x²)."""
839
- if not isinstance(ciphertext, CipherText):
840
- raise TypeError(f"Expected CipherText, got {type(ciphertext)}")
841
-
842
- try:
843
- # Perform squaring (x * x)
844
- result_ct_data = ciphertext.ct_data**2
845
-
846
- # Create result CipherText
847
- return (
848
- CipherText(
849
- ct_data=result_ct_data,
850
- semantic_dtype=ciphertext.semantic_dtype,
851
- semantic_shape=ciphertext.semantic_shape,
852
- scheme=ciphertext.scheme,
853
- context=ciphertext.context,
854
- ),
855
- )
856
-
857
- except Exception as e:
858
- raise RuntimeError(f"FHE vector square failed: {e}") from e