mplang-nightly 0.1.dev269__py3-none-any.whl → 0.1.dev271__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (180) hide show
  1. mplang/__init__.py +391 -17
  2. mplang/{v2/backends → backends}/__init__.py +9 -7
  3. mplang/{v2/backends → backends}/bfv_impl.py +6 -6
  4. mplang/{v2/backends → backends}/crypto_impl.py +6 -6
  5. mplang/{v2/backends → backends}/field_impl.py +5 -5
  6. mplang/{v2/backends → backends}/func_impl.py +4 -4
  7. mplang/{v2/backends → backends}/phe_impl.py +3 -3
  8. mplang/{v2/backends → backends}/simp_design.md +1 -1
  9. mplang/{v2/backends → backends}/simp_driver/__init__.py +5 -5
  10. mplang/{v2/backends → backends}/simp_driver/http.py +8 -8
  11. mplang/{v2/backends → backends}/simp_driver/mem.py +9 -9
  12. mplang/{v2/backends → backends}/simp_driver/ops.py +4 -4
  13. mplang/{v2/backends → backends}/simp_driver/state.py +2 -2
  14. mplang/{v2/backends → backends}/simp_driver/values.py +2 -2
  15. mplang/{v2/backends → backends}/simp_worker/__init__.py +3 -3
  16. mplang/{v2/backends → backends}/simp_worker/http.py +10 -10
  17. mplang/{v2/backends → backends}/simp_worker/mem.py +1 -1
  18. mplang/{v2/backends → backends}/simp_worker/ops.py +5 -5
  19. mplang/{v2/backends → backends}/simp_worker/state.py +2 -4
  20. mplang/{v2/backends → backends}/spu_impl.py +8 -8
  21. mplang/{v2/backends → backends}/spu_state.py +4 -4
  22. mplang/{v2/backends → backends}/store_impl.py +3 -3
  23. mplang/{v2/backends → backends}/table_impl.py +8 -8
  24. mplang/{v2/backends → backends}/tee_impl.py +6 -6
  25. mplang/{v2/backends → backends}/tensor_impl.py +6 -6
  26. mplang/{v2/cli.py → cli.py} +9 -9
  27. mplang/{v2/cli_guide.md → cli_guide.md} +12 -12
  28. mplang/{v2/dialects → dialects}/__init__.py +5 -5
  29. mplang/{v2/dialects → dialects}/bfv.py +6 -6
  30. mplang/{v2/dialects → dialects}/crypto.py +5 -5
  31. mplang/{v2/dialects → dialects}/dtypes.py +2 -2
  32. mplang/{v2/dialects → dialects}/field.py +3 -3
  33. mplang/{v2/dialects → dialects}/func.py +2 -2
  34. mplang/{v2/dialects → dialects}/phe.py +6 -6
  35. mplang/{v2/dialects → dialects}/simp.py +6 -6
  36. mplang/{v2/dialects → dialects}/spu.py +7 -7
  37. mplang/{v2/dialects → dialects}/store.py +2 -2
  38. mplang/{v2/dialects → dialects}/table.py +3 -3
  39. mplang/{v2/dialects → dialects}/tee.py +6 -6
  40. mplang/{v2/dialects → dialects}/tensor.py +5 -5
  41. mplang/{v2/edsl → edsl}/__init__.py +3 -3
  42. mplang/{v2/edsl → edsl}/context.py +6 -6
  43. mplang/{v2/edsl → edsl}/graph.py +5 -5
  44. mplang/{v2/edsl → edsl}/jit.py +2 -2
  45. mplang/{v2/edsl → edsl}/object.py +1 -1
  46. mplang/{v2/edsl → edsl}/primitive.py +5 -5
  47. mplang/{v2/edsl → edsl}/printer.py +1 -1
  48. mplang/{v2/edsl → edsl}/serde.py +1 -1
  49. mplang/{v2/edsl → edsl}/tracer.py +7 -7
  50. mplang/{v2/edsl → edsl}/typing.py +1 -1
  51. mplang/{v2/kernels → kernels}/ldpc.cpp +13 -13
  52. mplang/{v2/kernels → kernels}/okvs.cpp +4 -4
  53. mplang/{v2/kernels → kernels}/okvs_opt.cpp +31 -31
  54. mplang/{v2/kernels → kernels}/py_kernels.py +1 -1
  55. mplang/{v2/libs → libs}/collective.py +5 -5
  56. mplang/{v2/libs → libs}/device/__init__.py +1 -1
  57. mplang/{v2/libs → libs}/device/api.py +12 -12
  58. mplang/{v2/libs → libs}/ml/__init__.py +1 -1
  59. mplang/{v2/libs → libs}/ml/sgb.py +4 -4
  60. mplang/{v2/libs → libs}/mpc/__init__.py +3 -3
  61. mplang/{v2/libs → libs}/mpc/_utils.py +2 -2
  62. mplang/{v2/libs → libs}/mpc/analytics/aggregation.py +1 -1
  63. mplang/{v2/libs → libs}/mpc/analytics/groupby.py +2 -2
  64. mplang/{v2/libs → libs}/mpc/analytics/permutation.py +3 -3
  65. mplang/{v2/libs → libs}/mpc/ot/base.py +3 -3
  66. mplang/{v2/libs → libs}/mpc/ot/extension.py +2 -2
  67. mplang/{v2/libs → libs}/mpc/ot/silent.py +4 -4
  68. mplang/{v2/libs → libs}/mpc/psi/cuckoo.py +3 -3
  69. mplang/{v2/libs → libs}/mpc/psi/okvs.py +1 -1
  70. mplang/{v2/libs → libs}/mpc/psi/okvs_gct.py +3 -3
  71. mplang/{v2/libs → libs}/mpc/psi/oprf.py +3 -3
  72. mplang/{v2/libs → libs}/mpc/psi/rr22.py +7 -7
  73. mplang/{v2/libs → libs}/mpc/psi/unbalanced.py +4 -4
  74. mplang/{v2/libs → libs}/mpc/vole/gilboa.py +3 -3
  75. mplang/{v2/libs → libs}/mpc/vole/ldpc.py +2 -2
  76. mplang/{v2/libs → libs}/mpc/vole/silver.py +6 -6
  77. mplang/{v2/runtime → runtime}/interpreter.py +11 -11
  78. mplang/{v2/runtime → runtime}/value.py +2 -2
  79. mplang/{v1/runtime → utils}/__init__.py +18 -15
  80. mplang/{v1/utils → utils}/func_utils.py +1 -1
  81. {mplang_nightly-0.1.dev269.dist-info → mplang_nightly-0.1.dev271.dist-info}/METADATA +2 -2
  82. mplang_nightly-0.1.dev271.dist-info/RECORD +102 -0
  83. mplang/v1/__init__.py +0 -157
  84. mplang/v1/_device.py +0 -602
  85. mplang/v1/analysis/__init__.py +0 -37
  86. mplang/v1/analysis/diagram.py +0 -567
  87. mplang/v1/core/__init__.py +0 -157
  88. mplang/v1/core/cluster.py +0 -343
  89. mplang/v1/core/comm.py +0 -281
  90. mplang/v1/core/context_mgr.py +0 -50
  91. mplang/v1/core/dtypes.py +0 -335
  92. mplang/v1/core/expr/__init__.py +0 -80
  93. mplang/v1/core/expr/ast.py +0 -542
  94. mplang/v1/core/expr/evaluator.py +0 -581
  95. mplang/v1/core/expr/printer.py +0 -285
  96. mplang/v1/core/expr/transformer.py +0 -141
  97. mplang/v1/core/expr/utils.py +0 -78
  98. mplang/v1/core/expr/visitor.py +0 -85
  99. mplang/v1/core/expr/walk.py +0 -387
  100. mplang/v1/core/interp.py +0 -160
  101. mplang/v1/core/mask.py +0 -325
  102. mplang/v1/core/mpir.py +0 -965
  103. mplang/v1/core/mpobject.py +0 -117
  104. mplang/v1/core/mptype.py +0 -407
  105. mplang/v1/core/pfunc.py +0 -130
  106. mplang/v1/core/primitive.py +0 -877
  107. mplang/v1/core/table.py +0 -218
  108. mplang/v1/core/tensor.py +0 -75
  109. mplang/v1/core/tracer.py +0 -383
  110. mplang/v1/host.py +0 -130
  111. mplang/v1/kernels/__init__.py +0 -41
  112. mplang/v1/kernels/base.py +0 -125
  113. mplang/v1/kernels/basic.py +0 -240
  114. mplang/v1/kernels/context.py +0 -369
  115. mplang/v1/kernels/crypto.py +0 -122
  116. mplang/v1/kernels/fhe.py +0 -858
  117. mplang/v1/kernels/mock_tee.py +0 -72
  118. mplang/v1/kernels/phe.py +0 -1864
  119. mplang/v1/kernels/spu.py +0 -341
  120. mplang/v1/kernels/sql_duckdb.py +0 -44
  121. mplang/v1/kernels/stablehlo.py +0 -90
  122. mplang/v1/kernels/value.py +0 -626
  123. mplang/v1/ops/__init__.py +0 -35
  124. mplang/v1/ops/base.py +0 -424
  125. mplang/v1/ops/basic.py +0 -294
  126. mplang/v1/ops/crypto.py +0 -262
  127. mplang/v1/ops/fhe.py +0 -272
  128. mplang/v1/ops/jax_cc.py +0 -147
  129. mplang/v1/ops/nnx_cc.py +0 -168
  130. mplang/v1/ops/phe.py +0 -216
  131. mplang/v1/ops/spu.py +0 -151
  132. mplang/v1/ops/sql_cc.py +0 -303
  133. mplang/v1/ops/tee.py +0 -36
  134. mplang/v1/protos/v1alpha1/mpir_pb2.py +0 -63
  135. mplang/v1/protos/v1alpha1/mpir_pb2.pyi +0 -557
  136. mplang/v1/protos/v1alpha1/value_pb2.py +0 -34
  137. mplang/v1/protos/v1alpha1/value_pb2.pyi +0 -169
  138. mplang/v1/runtime/channel.py +0 -230
  139. mplang/v1/runtime/cli.py +0 -451
  140. mplang/v1/runtime/client.py +0 -456
  141. mplang/v1/runtime/communicator.py +0 -131
  142. mplang/v1/runtime/data_providers.py +0 -303
  143. mplang/v1/runtime/driver.py +0 -324
  144. mplang/v1/runtime/exceptions.py +0 -27
  145. mplang/v1/runtime/http_api.md +0 -56
  146. mplang/v1/runtime/link_comm.py +0 -196
  147. mplang/v1/runtime/server.py +0 -501
  148. mplang/v1/runtime/session.py +0 -270
  149. mplang/v1/runtime/simulation.py +0 -324
  150. mplang/v1/simp/__init__.py +0 -13
  151. mplang/v1/simp/api.py +0 -353
  152. mplang/v1/simp/mpi.py +0 -131
  153. mplang/v1/simp/party.py +0 -225
  154. mplang/v1/simp/random.py +0 -120
  155. mplang/v1/simp/smpc.py +0 -238
  156. mplang/v1/utils/__init__.py +0 -13
  157. mplang/v1/utils/crypto.py +0 -32
  158. mplang/v1/utils/spu_utils.py +0 -130
  159. mplang/v1/utils/table_utils.py +0 -185
  160. mplang/v2/__init__.py +0 -424
  161. mplang_nightly-0.1.dev269.dist-info/RECORD +0 -180
  162. /mplang/{v2/backends → backends}/channel.py +0 -0
  163. /mplang/{v2/edsl → edsl}/README.md +0 -0
  164. /mplang/{v2/edsl → edsl}/registry.py +0 -0
  165. /mplang/{v2/kernels → kernels}/Makefile +0 -0
  166. /mplang/{v2/kernels → kernels}/__init__.py +0 -0
  167. /mplang/{v2/kernels → kernels}/gf128.cpp +0 -0
  168. /mplang/{v2/libs → libs}/device/cluster.py +0 -0
  169. /mplang/{v2/libs → libs}/mpc/analytics/__init__.py +0 -0
  170. /mplang/{v2/libs → libs}/mpc/analytics/groupby.md +0 -0
  171. /mplang/{v2/libs → libs}/mpc/common/constants.py +0 -0
  172. /mplang/{v2/libs → libs}/mpc/ot/__init__.py +0 -0
  173. /mplang/{v2/libs → libs}/mpc/psi/__init__.py +0 -0
  174. /mplang/{v2/libs → libs}/mpc/vole/__init__.py +0 -0
  175. /mplang/{v2/runtime → runtime}/__init__.py +0 -0
  176. /mplang/{v2/runtime → runtime}/dialect_state.py +0 -0
  177. /mplang/{v2/runtime → runtime}/object_store.py +0 -0
  178. {mplang_nightly-0.1.dev269.dist-info → mplang_nightly-0.1.dev271.dist-info}/WHEEL +0 -0
  179. {mplang_nightly-0.1.dev269.dist-info → mplang_nightly-0.1.dev271.dist-info}/entry_points.txt +0 -0
  180. {mplang_nightly-0.1.dev269.dist-info → mplang_nightly-0.1.dev271.dist-info}/licenses/LICENSE +0 -0
mplang/v1/kernels/phe.py DELETED
@@ -1,1864 +0,0 @@
1
- # Copyright 2025 Ant Group Co., Ltd.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
-
15
- """PHE (Partially Homomorphic Encryption) backend implementation using lightPHE."""
16
-
17
- from __future__ import annotations
18
-
19
- import json
20
- from typing import Any, ClassVar
21
-
22
- import numpy as np
23
- from lightphe import LightPHE
24
- from lightphe.models.Ciphertext import Ciphertext
25
-
26
- from mplang.v1.core import DType, PFunction
27
- from mplang.v1.kernels.base import kernel_def
28
- from mplang.v1.kernels.value import (
29
- TensorValue,
30
- Value,
31
- ValueDecodeError,
32
- ValueProtoBuilder,
33
- ValueProtoReader,
34
- register_value,
35
- )
36
- from mplang.v1.protos.v1alpha1 import value_pb2 as _value_pb2
37
-
38
- # This controls the decimal precision used in lightPHE for float operations
39
- # we force it to 0 to only support integer operations
40
- # we will support negative and floating-point with our own encoding/decoding
41
- PRECISION = 0
42
-
43
-
44
- @register_value
45
- class PublicKey(Value):
46
- """PHE Public Key Value type."""
47
-
48
- KIND: ClassVar[str] = "mplang.phe.PublicKey"
49
- WIRE_VERSION: ClassVar[int] = 1
50
-
51
- def __init__(
52
- self,
53
- key_data: Any,
54
- scheme: str,
55
- key_size: int,
56
- max_value: int = 2**100,
57
- fxp_bits: int = 12,
58
- modulus: int | None = None,
59
- ):
60
- self.key_data = key_data
61
- self.scheme = scheme
62
- self.key_size = key_size
63
- self.max_value = max_value # Maximum absolute value B for range encoding
64
- self.fxp_bits = fxp_bits # Fixed-point precision bits for float encoding
65
- self.modulus = modulus # Paillier modulus N for range encoding
66
-
67
- @property
68
- def dtype(self) -> Any:
69
- return np.dtype("O") # Use object dtype for binary data
70
-
71
- @property
72
- def shape(self) -> tuple[int, ...]:
73
- return ()
74
-
75
- @property
76
- def max_float_value(self) -> float:
77
- """Maximum float value that can be encoded."""
78
- return float(self.max_value / (2**self.fxp_bits))
79
-
80
- def to_proto(self) -> _value_pb2.ValueProto:
81
- """Serialize PublicKey to wire format."""
82
- return (
83
- ValueProtoBuilder(self.KIND, self.WIRE_VERSION)
84
- .set_attr("scheme", self.scheme)
85
- .set_attr("key_size", self.key_size)
86
- .set_attr("max_value", self.max_value)
87
- .set_attr("fxp_bits", self.fxp_bits)
88
- .set_attr("modulus", str(self.modulus) if self.modulus is not None else "")
89
- .set_payload(json.dumps(self.key_data).encode("utf-8"))
90
- .build()
91
- )
92
-
93
- @classmethod
94
- def from_proto(cls, proto: _value_pb2.ValueProto) -> PublicKey:
95
- """Deserialize PublicKey from wire format."""
96
- reader = ValueProtoReader(proto)
97
- if reader.version != cls.WIRE_VERSION:
98
- raise ValueDecodeError(f"Unsupported PublicKey version {reader.version}")
99
-
100
- # Read metadata from runtime_attrs
101
- scheme = reader.get_attr("scheme")
102
- key_size = reader.get_attr("key_size")
103
- max_value = reader.get_attr("max_value")
104
- fxp_bits = reader.get_attr("fxp_bits")
105
- modulus_str = reader.get_attr("modulus")
106
- modulus = None if modulus_str == "" else int(modulus_str)
107
-
108
- # JSON deserialize the public key dict
109
- key_data = json.loads(reader.payload.decode("utf-8"))
110
-
111
- return cls(
112
- key_data=key_data,
113
- scheme=scheme,
114
- key_size=key_size,
115
- max_value=max_value,
116
- fxp_bits=fxp_bits,
117
- modulus=modulus,
118
- )
119
-
120
- def __repr__(self) -> str:
121
- return f"PublicKey(scheme={self.scheme}, key_size={self.key_size}, max_value={self.max_value}, fxp_bits={self.fxp_bits})"
122
-
123
-
124
- @register_value
125
- class PrivateKey(Value):
126
- """PHE Private Key Value type."""
127
-
128
- KIND: ClassVar[str] = "mplang.phe.PrivateKey"
129
- WIRE_VERSION: ClassVar[int] = 1
130
-
131
- def __init__(
132
- self,
133
- sk_data: Any,
134
- pk_data: Any,
135
- scheme: str,
136
- key_size: int,
137
- max_value: int = 2**100,
138
- fxp_bits: int = 12,
139
- modulus: int | None = None,
140
- ):
141
- self.sk_data = sk_data # Store private key data
142
- self.pk_data = pk_data # Store public key data as well
143
- self.scheme = scheme
144
- self.key_size = key_size
145
- self.max_value = max_value # Maximum absolute value B for range encoding
146
- self.fxp_bits = fxp_bits # Fixed-point precision bits for float encoding
147
- self.modulus = modulus # Paillier modulus N for range encoding
148
-
149
- @property
150
- def dtype(self) -> Any:
151
- return np.dtype("O") # Use object dtype for binary data
152
-
153
- @property
154
- def shape(self) -> tuple[int, ...]:
155
- return ()
156
-
157
- @property
158
- def max_float_value(self) -> float:
159
- """Maximum float value that can be encoded."""
160
- return float(self.max_value / (2**self.fxp_bits))
161
-
162
- def to_proto(self) -> _value_pb2.ValueProto:
163
- """Serialize PrivateKey to wire format."""
164
- # JSON serialize both key dicts (contain int values)
165
- # Store both keys in a single dict to avoid needing length metadata
166
- keys_dict = {"sk": self.sk_data, "pk": self.pk_data}
167
-
168
- return (
169
- ValueProtoBuilder(self.KIND, self.WIRE_VERSION)
170
- .set_attr("scheme", self.scheme)
171
- .set_attr("key_size", self.key_size)
172
- .set_attr("max_value", self.max_value)
173
- .set_attr("fxp_bits", self.fxp_bits)
174
- .set_attr("modulus", str(self.modulus) if self.modulus is not None else "")
175
- .set_payload(json.dumps(keys_dict).encode("utf-8"))
176
- .build()
177
- )
178
-
179
- @classmethod
180
- def from_proto(cls, proto: _value_pb2.ValueProto) -> PrivateKey:
181
- """Deserialize PrivateKey from wire format."""
182
- reader = ValueProtoReader(proto)
183
- if reader.version != cls.WIRE_VERSION:
184
- raise ValueDecodeError(f"Unsupported PrivateKey version {reader.version}")
185
-
186
- # Read metadata from runtime_attrs
187
- scheme = reader.get_attr("scheme")
188
- key_size = reader.get_attr("key_size")
189
- max_value = reader.get_attr("max_value")
190
- fxp_bits = reader.get_attr("fxp_bits")
191
- modulus_str = reader.get_attr("modulus")
192
- modulus = None if modulus_str == "" else int(modulus_str)
193
-
194
- # JSON deserialize both key dicts
195
- keys_dict = json.loads(reader.payload.decode("utf-8"))
196
- sk_data = keys_dict["sk"]
197
- pk_data = keys_dict["pk"]
198
-
199
- return cls(
200
- sk_data=sk_data,
201
- pk_data=pk_data,
202
- scheme=scheme,
203
- key_size=key_size,
204
- max_value=max_value,
205
- fxp_bits=fxp_bits,
206
- modulus=modulus,
207
- )
208
-
209
- def __repr__(self) -> str:
210
- return f"PrivateKey(scheme={self.scheme}, key_size={self.key_size}, max_value={self.max_value}, fxp_bits={self.fxp_bits})"
211
-
212
-
213
- @register_value
214
- class CipherText(Value):
215
- """PHE CipherText Value type."""
216
-
217
- KIND: ClassVar[str] = "mplang.phe.CipherText"
218
- WIRE_VERSION: ClassVar[int] = 1
219
-
220
- def __init__(
221
- self,
222
- ct_data: Any,
223
- semantic_dtype: DType,
224
- semantic_shape: tuple[int, ...],
225
- scheme: str,
226
- key_size: int,
227
- pk_data: Any = None, # Store public key for operations
228
- max_value: int = 2**100,
229
- fxp_bits: int = 12,
230
- modulus: int | None = None,
231
- ):
232
- self.ct_data = ct_data
233
- self.semantic_dtype = semantic_dtype
234
- self.semantic_shape = semantic_shape
235
- self.scheme = scheme
236
- self.key_size = key_size
237
- self.pk_data = pk_data
238
- self.max_value = max_value
239
- self.fxp_bits = fxp_bits
240
- self.modulus = modulus
241
-
242
- @property
243
- def dtype(self) -> Any:
244
- return self.semantic_dtype.to_numpy()
245
-
246
- @property
247
- def shape(self) -> tuple[int, ...]:
248
- return self.semantic_shape
249
-
250
- @property
251
- def max_float_value(self) -> float:
252
- """Maximum float value that can be encoded."""
253
- return float(self.max_value / (2**self.fxp_bits))
254
-
255
- def to_proto(self) -> _value_pb2.ValueProto:
256
- """Serialize CipherText to wire format.
257
-
258
- WARNING: This serialization is tightly coupled to lightphe.Ciphertext
259
- internal attributes (value, algorithm_name, keys). Any changes to these
260
- attributes in future lightphe versions will break serialization.
261
-
262
- TODO: Check if lightphe provides official serialization methods and
263
- migrate to them if available. Consider adding version compatibility checks.
264
- """
265
- # JSON serialize ciphertext components
266
- # ct_data is a list of lightPHE Ciphertext objects
267
- # Each Ciphertext has: value, algorithm_name, keys
268
- # We need to serialize the list of ciphertexts
269
- if not isinstance(self.ct_data, list):
270
- raise ValueError(f"ct_data should be a list, got {type(self.ct_data)}")
271
-
272
- ct_list = []
273
- for ct in self.ct_data:
274
- if not isinstance(ct, Ciphertext):
275
- raise TypeError(
276
- f"ct_data must contain lightphe.Ciphertext objects, got {type(ct).__name__}"
277
- )
278
- ct_list.append({
279
- "value": ct.value,
280
- "algorithm_name": ct.algorithm_name,
281
- "keys": ct.keys,
282
- })
283
-
284
- # Combine ct_data and pk_data into single dict
285
- payload_dict = {
286
- "ct_list": ct_list,
287
- "pk": self.pk_data if self.pk_data is not None else None,
288
- }
289
-
290
- return (
291
- ValueProtoBuilder(self.KIND, self.WIRE_VERSION)
292
- .set_attr("semantic_dtype", str(self.semantic_dtype))
293
- .set_attr("semantic_shape", list(self.semantic_shape))
294
- .set_attr("scheme", self.scheme)
295
- .set_attr("key_size", self.key_size)
296
- .set_attr("max_value", self.max_value)
297
- .set_attr("fxp_bits", self.fxp_bits)
298
- .set_attr("modulus", str(self.modulus) if self.modulus is not None else "")
299
- .set_payload(json.dumps(payload_dict).encode("utf-8"))
300
- .build()
301
- )
302
-
303
- @classmethod
304
- def from_proto(cls, proto: _value_pb2.ValueProto) -> CipherText:
305
- """Deserialize CipherText from wire format."""
306
- reader = ValueProtoReader(proto)
307
- if reader.version != cls.WIRE_VERSION:
308
- raise ValueDecodeError(f"Unsupported CipherText version {reader.version}")
309
-
310
- # Read metadata from runtime_attrs
311
- semantic_dtype_str = reader.get_attr("semantic_dtype")
312
- semantic_shape = reader.get_attr("semantic_shape")
313
- scheme = reader.get_attr("scheme")
314
- key_size = reader.get_attr("key_size")
315
- max_value = reader.get_attr("max_value")
316
- fxp_bits = reader.get_attr("fxp_bits")
317
- modulus_str = reader.get_attr("modulus")
318
- modulus = None if modulus_str == "" else int(modulus_str)
319
-
320
- # JSON deserialize ciphertext and public key
321
- payload_dict = json.loads(reader.payload.decode("utf-8"))
322
- ct_list = payload_dict["ct_list"]
323
- pk_data = payload_dict["pk"]
324
-
325
- # Reconstruct ct_data: list of Ciphertext objects
326
- ct_data = []
327
- for ct_dict in ct_list:
328
- if ct_dict["keys"] is None or ct_dict["algorithm_name"] is None:
329
- raise ValueDecodeError(
330
- "Invalid CipherText: missing keys or algorithm_name in serialized data"
331
- )
332
- ct_data.append(
333
- Ciphertext(
334
- algorithm_name=ct_dict["algorithm_name"],
335
- keys=ct_dict["keys"],
336
- value=ct_dict["value"],
337
- )
338
- )
339
-
340
- # Parse dtype string back to DType
341
- dtype = DType.from_any(semantic_dtype_str)
342
-
343
- return cls(
344
- ct_data=ct_data,
345
- semantic_dtype=dtype,
346
- semantic_shape=tuple(semantic_shape),
347
- scheme=scheme,
348
- key_size=key_size,
349
- pk_data=pk_data,
350
- max_value=max_value,
351
- fxp_bits=fxp_bits,
352
- modulus=modulus,
353
- )
354
-
355
- def __repr__(self) -> str:
356
- return f"CipherText(dtype={self.semantic_dtype}, shape={self.semantic_shape}, scheme={self.scheme})"
357
-
358
-
359
- # Range-based encoding functions for negative numbers and floats
360
- def _range_encode_integer(value: int, max_value: int, modulus: int) -> int:
361
- """
362
- Range encoding function for integers.
363
- - Positive numbers: encode(m) = m
364
- - Negative numbers: encode(m) = N + m
365
- """
366
- if not (-max_value <= value <= max_value):
367
- raise ValueError(
368
- f"Integer value {value} out of range [-{max_value}, {max_value}]"
369
- )
370
-
371
- if value >= 0:
372
- encoded = value
373
- else:
374
- encoded = modulus + value
375
-
376
- return encoded
377
-
378
-
379
- def _range_encode_float(
380
- value: float, max_value: int, fxp_bits: int, modulus: int
381
- ) -> int:
382
- """
383
- Range encoding function for floats.
384
- 1. Fixed-point conversion: scaled_int = round(value * 2^fxp_bits)
385
- 2. Integer encoding rules
386
- """
387
- max_float = max_value / (2**fxp_bits)
388
- if not (-max_float <= value <= max_float):
389
- raise ValueError(
390
- f"Float value {value} out of range [-{max_float}, {max_float}]"
391
- )
392
-
393
- # Fixed-point encoding: float → scaled integer
394
- scaled_int = round(value * (2**fxp_bits))
395
-
396
- # Use integer encoding rules
397
- return _range_encode_integer(scaled_int, max_value, modulus)
398
-
399
-
400
- def _range_encode_mixed(
401
- value: Any, max_value: int, fxp_bits: int, modulus: int, semantic_dtype: DType
402
- ) -> int:
403
- """
404
- Mixed encoding function - automatically handle integers and floats based on semantic type.
405
- Use semantic_dtype to choose between integer and float encoding.
406
- """
407
- if semantic_dtype.is_floating:
408
- # For floating semantic types, always use float encoding
409
- return _range_encode_float(float(value), max_value, fxp_bits, modulus)
410
- else:
411
- # For integer semantic types, use integer encoding
412
- return _range_encode_integer(int(value), max_value, modulus)
413
-
414
-
415
- def _range_decode_integer(encoded_value: int, max_value: int, modulus: int) -> int:
416
- """
417
- Range decoding function for integers.
418
- - If r <= max_value: decode(r) = r
419
- - If r >= N - max_value: decode(r) = r - N
420
- - If max_value < r < N - max_value: overflow error
421
- """
422
-
423
- # Ensure handling integer
424
- if isinstance(encoded_value, (list, tuple)):
425
- encoded_value = encoded_value[0]
426
- encoded_value = int(encoded_value) % modulus
427
-
428
- if encoded_value <= max_value:
429
- return encoded_value
430
- elif encoded_value >= modulus - max_value:
431
- return encoded_value - modulus
432
- else:
433
- raise ValueError(f"Decoded value {encoded_value} is in overflow region")
434
-
435
-
436
- def _range_decode_float(
437
- encoded_value: int, max_value: int, fxp_bits: int, modulus: int
438
- ) -> float:
439
- """
440
- Range decoding function for floats.
441
- 1. Integer decoding: decoded_int = range_decode_integer(encoded_value)
442
- 2. Fixed-point conversion: value = decoded_int / 2^fxp_bits
443
- """
444
- # First decode as integer
445
- decoded_int = _range_decode_integer(encoded_value, max_value, modulus)
446
-
447
- # Fixed-point decoding: scaled integer → float
448
- return float(decoded_int / (2**fxp_bits))
449
-
450
-
451
- def _range_decode_mixed(
452
- encoded_value: int,
453
- max_value: int,
454
- fxp_bits: int,
455
- modulus: int,
456
- semantic_dtype: DType,
457
- ) -> Any:
458
- """
459
- Mixed decoding function - decode based on semantic type.
460
- Use semantic_dtype to choose between integer and float decoding.
461
- """
462
- if semantic_dtype.is_floating:
463
- # For floating semantic types, decode as float
464
- return _range_decode_float(encoded_value, max_value, fxp_bits, modulus)
465
- else:
466
- # For integer semantic types, decode as integer
467
- return _range_decode_integer(encoded_value, max_value, modulus)
468
-
469
-
470
- @kernel_def("phe.keygen")
471
- def _phe_keygen(pfunc: PFunction) -> Any:
472
- scheme = pfunc.attrs.get("scheme", "paillier")
473
- # use small key_size to speed up tests
474
- # in production use at least 2048 bits or 3072 bits for better security
475
- key_size = pfunc.attrs.get("key_size", 2048)
476
- # Accept very large max_value; allow decimal string input, kept simple like other attrs
477
- max_value = int(pfunc.attrs.get("max_value", 2**32))
478
- fxp_bits = int(pfunc.attrs.get("fxp_bits", 12))
479
-
480
- # Validate scheme
481
- if scheme.lower() not in ["paillier"]:
482
- raise ValueError(f"Unsupported PHE scheme: {scheme}")
483
-
484
- scheme = scheme.capitalize()
485
-
486
- try:
487
- # Set higher precision for better accuracy with floats
488
- phe = LightPHE(
489
- algorithm_name=scheme,
490
- key_size=key_size,
491
- precision=PRECISION,
492
- )
493
-
494
- pk_data = phe.cs.keys["public_key"]
495
- sk_data = phe.cs.keys["private_key"]
496
- modulus = phe.cs.plaintext_modulo # Get Paillier modulus N
497
-
498
- # Validate safety: N should be much larger than 3*max_value
499
- if modulus <= 3 * max_value:
500
- raise ValueError(
501
- f"Modulus {modulus} is too small for max_value {max_value}. Require N >> 3*B"
502
- )
503
-
504
- public_key = PublicKey(
505
- key_data=pk_data,
506
- scheme=scheme,
507
- key_size=key_size,
508
- max_value=max_value,
509
- fxp_bits=fxp_bits,
510
- modulus=modulus,
511
- )
512
- private_key = PrivateKey(
513
- sk_data=sk_data,
514
- pk_data=pk_data,
515
- scheme=scheme,
516
- key_size=key_size,
517
- max_value=max_value,
518
- fxp_bits=fxp_bits,
519
- modulus=modulus,
520
- )
521
-
522
- return [public_key, private_key]
523
-
524
- except Exception as e:
525
- raise RuntimeError(f"Failed to generate PHE keys: {e}") from e
526
-
527
-
528
- @kernel_def("phe.encrypt")
529
- def _phe_encrypt(
530
- pfunc: PFunction, plaintext: TensorValue, public_key: PublicKey
531
- ) -> Any:
532
- # Validate public_key type
533
- if not isinstance(public_key, PublicKey):
534
- raise ValueError("Second argument must be a PublicKey instance")
535
-
536
- try:
537
- # Convert plaintext to numpy to get semantic type info
538
- plaintext_np = plaintext.to_numpy()
539
- semantic_dtype = DType.from_numpy(plaintext_np.dtype)
540
- semantic_shape = plaintext_np.shape
541
-
542
- # Create lightPHE instance with the same scheme/key_size as the key
543
- phe = LightPHE(
544
- algorithm_name=public_key.scheme,
545
- key_size=public_key.key_size,
546
- precision=PRECISION,
547
- )
548
-
549
- # CRITICAL: Set the same modulus as the key to ensure consistency
550
- if public_key.modulus is not None:
551
- phe.cs.plaintext_modulo = public_key.modulus
552
- phe.cs.ciphertext_modulo = public_key.modulus * public_key.modulus
553
-
554
- # Set the public key
555
- phe.cs.keys["public_key"] = public_key.key_data
556
-
557
- # Prepare data for encryption using range encoding
558
- flat_data = plaintext_np.flatten()
559
-
560
- # Use mixed encoding for consistent handling of integers and floats
561
- encoded_data_list = []
562
- for val in flat_data:
563
- # Use mixed encoding to handle both integers and floats uniformly
564
- if public_key.modulus is None:
565
- raise ValueError(
566
- "Public key modulus is None, key generation may have failed"
567
- )
568
- encoded_val = _range_encode_mixed(
569
- val,
570
- public_key.max_value,
571
- public_key.fxp_bits,
572
- public_key.modulus,
573
- semantic_dtype,
574
- )
575
- encoded_data_list.append(encoded_val)
576
-
577
- # Encrypt the encoded values (note: not passing as list, just the value)
578
- lightphe_ciphertext = [phe.encrypt(val) for val in encoded_data_list]
579
-
580
- # Create CipherText object with encoding parameters
581
- ciphertext = CipherText(
582
- ct_data=lightphe_ciphertext,
583
- semantic_dtype=semantic_dtype,
584
- semantic_shape=semantic_shape,
585
- scheme=public_key.scheme,
586
- key_size=public_key.key_size,
587
- pk_data=public_key.key_data,
588
- max_value=public_key.max_value,
589
- fxp_bits=public_key.fxp_bits,
590
- modulus=public_key.modulus,
591
- )
592
-
593
- return [ciphertext]
594
-
595
- except Exception as e:
596
- raise RuntimeError(f"Failed to encrypt data: {e}") from e
597
-
598
-
599
- @kernel_def("phe.mul")
600
- def _phe_mul(pfunc: PFunction, ciphertext: CipherText, plaintext: TensorValue) -> Any:
601
- # Validate that first argument is a CipherText
602
- if not isinstance(ciphertext, CipherText):
603
- raise ValueError("First argument must be a CipherText instance")
604
-
605
- try:
606
- # Convert plaintext to numpy
607
- plaintext_np = plaintext.to_numpy()
608
-
609
- # Check if plaintext is floating point type - multiplication not supported
610
- if np.issubdtype(plaintext_np.dtype, np.floating):
611
- raise ValueError(
612
- f"Homomorphic multiplication with floating point plaintext is not supported. "
613
- f"Got plaintext dtype: {plaintext_np.dtype}"
614
- )
615
-
616
- # Use numpy broadcasting to determine result shape and broadcast operands
617
- # Create dummy arrays with the same shapes to test broadcasting
618
- try:
619
- dummy_ct = np.zeros(ciphertext.semantic_shape)
620
- dummy_pt = np.zeros(plaintext_np.shape)
621
- broadcasted_dummy = dummy_ct * dummy_pt
622
- result_shape = broadcasted_dummy.shape
623
- except ValueError as e:
624
- raise ValueError(
625
- f"Operands cannot be broadcast together: CipherText shape {ciphertext.semantic_shape} "
626
- f"vs plaintext shape {plaintext_np.shape}: {e}"
627
- ) from e
628
-
629
- # Broadcast plaintext to match result shape if needed
630
- if plaintext_np.shape != result_shape:
631
- plaintext_broadcasted = np.broadcast_to(plaintext_np, result_shape)
632
- else:
633
- plaintext_broadcasted = plaintext_np
634
-
635
- # If ciphertext needs broadcasting, we need to replicate its encrypted values
636
- if ciphertext.semantic_shape != result_shape:
637
- # Use numpy to create a properly broadcasted index mapping
638
- # Create a dummy array with same shape as ciphertext, fill with indices
639
- dummy_ct = (
640
- np
641
- .arange(np.prod(ciphertext.semantic_shape))
642
- .reshape(ciphertext.semantic_shape)
643
- .astype(np.int64)
644
- )
645
- # Broadcast this to the result shape
646
- broadcasted_indices = np.broadcast_to(dummy_ct, result_shape).flatten()
647
-
648
- # Replicate ciphertext data according to the broadcasted indices
649
- raw_ct: list[Any] = ciphertext.ct_data
650
- broadcasted_ct_data = [raw_ct[int(idx)] for idx in broadcasted_indices]
651
- else:
652
- # No broadcasting needed for ciphertext
653
- broadcasted_ct_data = ciphertext.ct_data
654
-
655
- # Flatten the broadcasted plaintext data for element-wise multiplication
656
- target_dtype = ciphertext.semantic_dtype
657
- flat_data = plaintext_broadcasted.flatten()
658
-
659
- # For multiplication, plaintext multipliers should NOT be encoded
660
- # The ciphertext already contains the encoded value, multiplying by raw plaintext preserves semantics
661
- raw_multipliers = []
662
- for val in flat_data:
663
- # Convert to appropriate numeric type but don't apply any encoding
664
- if target_dtype.is_floating:
665
- raw_val = float(val)
666
- else:
667
- raw_val = int(val)
668
- raw_multipliers.append(raw_val)
669
-
670
- # Perform homomorphic multiplication
671
- # In Paillier, ciphertext * plaintext is supported
672
- result_ciphertext = [
673
- broadcasted_ct_data[i] * raw_multipliers[i]
674
- for i in range(len(raw_multipliers))
675
- ]
676
-
677
- # Create result CipherText with the broadcasted shape and encoding parameters
678
- return [
679
- CipherText(
680
- ct_data=result_ciphertext,
681
- semantic_dtype=ciphertext.semantic_dtype,
682
- semantic_shape=result_shape,
683
- scheme=ciphertext.scheme,
684
- key_size=ciphertext.key_size,
685
- pk_data=ciphertext.pk_data,
686
- max_value=ciphertext.max_value,
687
- fxp_bits=ciphertext.fxp_bits,
688
- modulus=ciphertext.modulus,
689
- )
690
- ]
691
-
692
- except ValueError:
693
- # Re-raise ValueError directly (validation errors)
694
- raise
695
- except Exception as e:
696
- raise RuntimeError(f"Failed to perform multiplication: {e}") from e
697
-
698
-
699
- @kernel_def("phe.add")
700
- def _phe_add(pfunc: PFunction, lhs: Any, rhs: Any) -> Any:
701
- try:
702
- if isinstance(lhs, CipherText) and isinstance(rhs, CipherText):
703
- return _phe_add_ct2ct(lhs, rhs)
704
- elif isinstance(lhs, CipherText):
705
- return _phe_add_ct2pt(lhs, rhs)
706
- elif isinstance(rhs, CipherText):
707
- return _phe_add_ct2pt(rhs, lhs)
708
- else:
709
- return TensorValue(lhs.to_numpy() + rhs.to_numpy())
710
- except ValueError:
711
- raise
712
- except Exception as e: # pragma: no cover
713
- raise RuntimeError(f"Failed to perform addition: {e}") from e
714
-
715
-
716
- def _phe_add_ct2ct(ct1: CipherText, ct2: CipherText) -> CipherText:
717
- # Validate compatibility
718
- if ct1.scheme != ct2.scheme or ct1.key_size != ct2.key_size:
719
- raise ValueError("CipherText operands must use same scheme and key size")
720
-
721
- if ct1.pk_data != ct2.pk_data:
722
- raise ValueError("CipherText operands must be encrypted with same key")
723
-
724
- # Check for mixed precision issue: floating point ciphertext + integer ciphertext
725
- # This would cause decode failures due to different fixed-point encoding scales
726
- if ct1.semantic_dtype.is_floating != ct2.semantic_dtype.is_floating:
727
- raise ValueError(
728
- f"Cannot add ciphertexts with different numeric types due to fixed-point encoding. "
729
- f"First CipherText dtype: {ct1.semantic_dtype}, second CipherText dtype: {ct2.semantic_dtype}. "
730
- f"Both operands must have the same numeric type (both floating or both integer)."
731
- )
732
-
733
- # Use numpy broadcasting to determine result shape and broadcast operands
734
- try:
735
- dummy_ct1 = np.zeros(ct1.semantic_shape)
736
- dummy_ct2 = np.zeros(ct2.semantic_shape)
737
- broadcasted_dummy = dummy_ct1 + dummy_ct2
738
- result_shape = broadcasted_dummy.shape
739
- except ValueError as e:
740
- raise ValueError(
741
- f"CipherText operands cannot be broadcast together: shape {ct1.semantic_shape} "
742
- f"vs shape {ct2.semantic_shape}: {e}"
743
- ) from e
744
-
745
- # Broadcast ct1 if needed
746
- if ct1.semantic_shape != result_shape:
747
- dummy_ct1 = (
748
- np
749
- .arange(np.prod(ct1.semantic_shape))
750
- .reshape(ct1.semantic_shape)
751
- .astype(np.int64)
752
- )
753
- broadcasted_indices1 = np.broadcast_to(dummy_ct1, result_shape).flatten()
754
- raw_ct1: list[Any] = ct1.ct_data
755
- broadcasted_ct1_data = [raw_ct1[int(idx)] for idx in broadcasted_indices1]
756
- else:
757
- broadcasted_ct1_data = ct1.ct_data
758
-
759
- # Broadcast ct2 if needed
760
- if ct2.semantic_shape != result_shape:
761
- dummy_ct2 = (
762
- np
763
- .arange(np.prod(ct2.semantic_shape))
764
- .reshape(ct2.semantic_shape)
765
- .astype(np.int64)
766
- )
767
- broadcasted_indices2 = np.broadcast_to(dummy_ct2, result_shape).flatten()
768
- raw_ct2: list[Any] = ct2.ct_data
769
- broadcasted_ct2_data = [raw_ct2[int(idx)] for idx in broadcasted_indices2]
770
- else:
771
- broadcasted_ct2_data = ct2.ct_data
772
-
773
- # Perform homomorphic addition
774
- result_ciphertext = [
775
- broadcasted_ct1_data[i] + broadcasted_ct2_data[i]
776
- for i in range(len(broadcasted_ct1_data))
777
- ]
778
-
779
- # Create result CipherText with broadcasted shape and encoding parameters
780
- return CipherText(
781
- ct_data=result_ciphertext,
782
- semantic_dtype=ct1.semantic_dtype,
783
- semantic_shape=result_shape,
784
- scheme=ct1.scheme,
785
- key_size=ct1.key_size,
786
- pk_data=ct1.pk_data,
787
- max_value=ct1.max_value,
788
- fxp_bits=ct1.fxp_bits,
789
- modulus=ct1.modulus,
790
- )
791
-
792
-
793
- def _phe_add_ct2pt(ciphertext: CipherText, plaintext: TensorValue) -> CipherText:
794
- # Convert plaintext to numpy
795
- plaintext_np = plaintext.to_numpy()
796
- plaintext_dtype = DType.from_numpy(plaintext_np.dtype)
797
-
798
- # Check for mixed precision issue: floating point ciphertext + integer plaintext
799
- # This would cause decode failures due to 2**fxp * f + i scaling mismatch
800
- if ciphertext.semantic_dtype.is_floating and not plaintext_dtype.is_floating:
801
- raise ValueError(
802
- f"Cannot add integer plaintext to floating point ciphertext due to fixed-point encoding. "
803
- f"CipherText dtype: {ciphertext.semantic_dtype}, plaintext dtype: {plaintext_dtype}. "
804
- f"Both operands must have the same numeric type (both floating or both integer)."
805
- )
806
-
807
- # Check for mixed precision issue: integer ciphertext + floating point plaintext
808
- if not ciphertext.semantic_dtype.is_floating and plaintext_dtype.is_floating:
809
- raise ValueError(
810
- f"Cannot add floating point plaintext to integer ciphertext due to fixed-point encoding. "
811
- f"CipherText dtype: {ciphertext.semantic_dtype}, plaintext dtype: {plaintext_dtype}. "
812
- f"Both operands must have the same numeric type (both floating or both integer)."
813
- )
814
-
815
- # Use numpy broadcasting to determine result shape and broadcast operands
816
- try:
817
- dummy_ct = np.zeros(ciphertext.semantic_shape)
818
- dummy_pt = np.zeros(plaintext_np.shape)
819
- broadcasted_dummy = dummy_ct + dummy_pt
820
- result_shape = broadcasted_dummy.shape
821
- except ValueError as e:
822
- raise ValueError(
823
- f"Operands cannot be broadcast together: CipherText shape {ciphertext.semantic_shape} "
824
- f"vs plaintext shape {plaintext_np.shape}: {e}"
825
- ) from e
826
-
827
- # Broadcast plaintext to match result shape if needed
828
- if plaintext_np.shape != result_shape:
829
- plaintext_broadcasted = np.broadcast_to(plaintext_np, result_shape)
830
- else:
831
- plaintext_broadcasted = plaintext_np
832
-
833
- # Broadcast ciphertext if needed
834
- if ciphertext.semantic_shape != result_shape:
835
- dummy_ct = (
836
- np
837
- .arange(np.prod(ciphertext.semantic_shape))
838
- .reshape(ciphertext.semantic_shape)
839
- .astype(np.int64)
840
- )
841
- broadcasted_indices = np.broadcast_to(dummy_ct, result_shape).flatten()
842
- raw_ct: list[Any] = ciphertext.ct_data
843
- broadcasted_ct_data = [raw_ct[int(idx)] for idx in broadcasted_indices]
844
- else:
845
- broadcasted_ct_data = ciphertext.ct_data
846
-
847
- # For ciphertext + plaintext addition, we encrypt the plaintext first
848
- # and then do ciphertext + ciphertext addition
849
- if ciphertext.pk_data is None:
850
- raise ValueError(
851
- "CipherText must contain public key data for plaintext addition"
852
- )
853
-
854
- # Create lightPHE instance to encrypt the plaintext
855
- phe = LightPHE(
856
- algorithm_name=ciphertext.scheme,
857
- key_size=ciphertext.key_size,
858
- precision=PRECISION,
859
- )
860
- phe.cs.keys["public_key"] = ciphertext.pk_data
861
-
862
- # Encrypt the broadcasted plaintext using same method as original encryption
863
- target_dtype = ciphertext.semantic_dtype
864
- flat_data = plaintext_broadcasted.flatten()
865
-
866
- # Use range encoding for consistency with encryption
867
- encoded_data_list = []
868
- for val in flat_data:
869
- if ciphertext.modulus is None:
870
- raise ValueError("Ciphertext modulus is None, encryption may have failed")
871
- encoded_val = _range_encode_mixed(
872
- val,
873
- ciphertext.max_value,
874
- ciphertext.fxp_bits,
875
- ciphertext.modulus,
876
- target_dtype,
877
- )
878
- encoded_data_list.append(encoded_val)
879
-
880
- encrypted_plaintext = [phe.encrypt(val) for val in encoded_data_list]
881
-
882
- # Perform addition
883
- result_ciphertext = [
884
- encrypted_plaintext[i] + broadcasted_ct_data[i]
885
- for i in range(len(encrypted_plaintext))
886
- ]
887
-
888
- # Create result CipherText with broadcasted shape and encoding parameters
889
- return CipherText(
890
- ct_data=result_ciphertext,
891
- semantic_dtype=ciphertext.semantic_dtype,
892
- semantic_shape=result_shape,
893
- scheme=ciphertext.scheme,
894
- key_size=ciphertext.key_size,
895
- pk_data=ciphertext.pk_data,
896
- max_value=ciphertext.max_value,
897
- fxp_bits=ciphertext.fxp_bits,
898
- modulus=ciphertext.modulus,
899
- )
900
-
901
-
902
- def _create_encrypted_zero(ciphertext: CipherText) -> Any:
903
- # Create lightPHE instance with the same configuration
904
- phe = LightPHE(
905
- algorithm_name=ciphertext.scheme,
906
- key_size=ciphertext.key_size,
907
- precision=PRECISION,
908
- )
909
-
910
- # CRITICAL: Set the same modulus as the original ciphertext
911
- if ciphertext.modulus is not None:
912
- phe.cs.plaintext_modulo = ciphertext.modulus
913
- phe.cs.ciphertext_modulo = ciphertext.modulus * ciphertext.modulus
914
-
915
- phe.cs.keys["public_key"] = ciphertext.pk_data
916
-
917
- # Encrypt zero value using range encoding for consistency
918
- if ciphertext.modulus is None:
919
- raise ValueError("Ciphertext modulus is None, encryption may have failed")
920
-
921
- zero_encoded = _range_encode_mixed(
922
- 0,
923
- ciphertext.max_value,
924
- ciphertext.fxp_bits,
925
- ciphertext.modulus,
926
- ciphertext.semantic_dtype,
927
- )
928
-
929
- return phe.encrypt(zero_encoded)
930
-
931
-
932
- @kernel_def("phe.decrypt")
933
- def _phe_decrypt(
934
- pfunc: PFunction, ciphertext: CipherText, private_key: PrivateKey
935
- ) -> Any:
936
- # Validate argument types
937
- if not isinstance(ciphertext, CipherText):
938
- raise ValueError("First argument must be a CipherText instance")
939
- if not isinstance(private_key, PrivateKey):
940
- raise ValueError("Second argument must be a PrivateKey instance")
941
-
942
- # Validate key compatibility
943
- if (
944
- ciphertext.scheme != private_key.scheme
945
- or ciphertext.key_size != private_key.key_size
946
- ):
947
- raise ValueError("CipherText and PrivateKey must use same scheme and key size")
948
-
949
- try:
950
- # Create lightPHE instance with the same scheme/key_size
951
- phe = LightPHE(
952
- algorithm_name=private_key.scheme,
953
- key_size=private_key.key_size,
954
- precision=PRECISION,
955
- )
956
-
957
- # CRITICAL FIX: Manually set the moduli to match the original encryption
958
- # This ensures the decryption uses the same mathematical structure
959
- if ciphertext.modulus is not None:
960
- # Force the lightPHE instance to use the same modulus as during encryption
961
- phe.cs.plaintext_modulo = ciphertext.modulus
962
- # For Paillier: ciphertext_modulo = N^2
963
- phe.cs.ciphertext_modulo = ciphertext.modulus * ciphertext.modulus
964
-
965
- # Set both public and private keys (lightPHE needs both for proper decryption)
966
- phe.cs.keys["private_key"] = private_key.sk_data
967
- phe.cs.keys["public_key"] = private_key.pk_data
968
-
969
- # Decrypt the data
970
- target_dtype = ciphertext.semantic_dtype.to_numpy()
971
- decrypted_raw = [phe.decrypt(ct) for ct in ciphertext.ct_data]
972
-
973
- # Decode using range decoding
974
- if ciphertext.modulus is None:
975
- raise ValueError("Ciphertext modulus is None, encryption may have failed")
976
-
977
- decoded_data = []
978
- for encrypted_val in decrypted_raw:
979
- # Extract numeric value from lightPHE result
980
- if isinstance(encrypted_val, (int, float)):
981
- raw_val = encrypted_val
982
- elif hasattr(encrypted_val, "__getitem__") and len(encrypted_val) > 0:
983
- raw_val = encrypted_val[0]
984
- else:
985
- raise ValueError(f"Cannot extract numeric value from {encrypted_val}")
986
-
987
- # Convert to int for decoding
988
- int_val = int(
989
- raw_val
990
- ) # Use mixed decoding which returns values based on semantic type
991
- decoded_val = _range_decode_mixed(
992
- int_val,
993
- ciphertext.max_value,
994
- ciphertext.fxp_bits,
995
- ciphertext.modulus,
996
- ciphertext.semantic_dtype,
997
- )
998
- decoded_data.append(decoded_val)
999
-
1000
- # Convert to target dtype
1001
- if target_dtype.kind in "iu": # integer types
1002
- # Convert floats back to integers for integer semantic types
1003
- # decoded_data are numeric (ints or floats); normalize to Python int
1004
- ints = [round(v) if isinstance(v, float) else v for v in decoded_data]
1005
- if np.issubdtype(target_dtype, np.unsignedinteger):
1006
- # Reduce modulo 2^k for unsigned to preserve ring semantics
1007
- width = np.iinfo(target_dtype).bits
1008
- mod = 1 << width
1009
- processed_data = [v % mod for v in ints]
1010
- else:
1011
- # Signed integers: clamp to dtype range
1012
- info = np.iinfo(target_dtype)
1013
- processed_data = [max(info.min, min(info.max, v)) for v in ints]
1014
- else: # float types
1015
- processed_data = decoded_data
1016
-
1017
- # Create array and reshape to target shape
1018
- plaintext_np = np.array(processed_data, dtype=target_dtype).reshape(
1019
- ciphertext.semantic_shape
1020
- )
1021
-
1022
- return [TensorValue(plaintext_np)]
1023
-
1024
- except Exception as e:
1025
- raise RuntimeError(f"Failed to decrypt data: {e}") from e
1026
-
1027
-
1028
- @kernel_def("phe.dot")
1029
- def _phe_dot(pfunc: PFunction, ciphertext: CipherText, plaintext: TensorValue) -> Any:
1030
- """Execute homomorphic dot product with zero-value optimization.
1031
-
1032
- Supports various dot product operations:
1033
- - Scalar * Scalar -> Scalar
1034
- - Vector * Vector -> Scalar (inner product)
1035
- - Matrix * Vector -> Vector
1036
- - N-D tensor * M-D tensor -> result based on numpy.dot semantics
1037
-
1038
- Optimization: Skip multiplication when plaintext value is 0, and handle
1039
- the special case where all plaintext values are 0.
1040
-
1041
- """
1042
- # Validate that first argument is a CipherText
1043
- if not isinstance(ciphertext, CipherText):
1044
- raise ValueError("First argument must be a CipherText instance")
1045
- if isinstance(plaintext, CipherText):
1046
- raise ValueError("Second argument must be a plaintext TensorLike")
1047
-
1048
- try:
1049
- # Convert plaintext to numpy
1050
- plaintext_np = plaintext.to_numpy()
1051
-
1052
- # Check if plaintext is floating point type - dot product not supported
1053
- if np.issubdtype(plaintext_np.dtype, np.floating):
1054
- raise ValueError(
1055
- f"Homomorphic dot product with floating point plaintext is not supported. "
1056
- f"Got plaintext dtype: {plaintext_np.dtype}"
1057
- )
1058
-
1059
- # Use numpy.dot to determine result shape and validate compatibility
1060
- # Create dummy arrays with same shapes to test dot product compatibility
1061
- try:
1062
- dummy_ct = np.zeros(ciphertext.semantic_shape)
1063
- dummy_pt = np.zeros(plaintext_np.shape)
1064
- dummy_result = np.dot(dummy_ct, dummy_pt)
1065
- result_shape = dummy_result.shape
1066
- except ValueError as e:
1067
- raise ValueError(
1068
- f"Shapes are not compatible for dot product: CipherText shape {ciphertext.semantic_shape} "
1069
- f"vs plaintext shape {plaintext_np.shape}: {e}"
1070
- ) from e
1071
-
1072
- # Perform dot product based on input dimensions
1073
- ct_shape = ciphertext.semantic_shape
1074
- pt_shape = plaintext_np.shape
1075
- target_dtype = ciphertext.semantic_dtype
1076
-
1077
- if target_dtype.is_floating:
1078
- pt_data = plaintext_np.astype(float)
1079
- # Use a small epsilon for floating point zero comparison
1080
- epsilon = 1e-15
1081
- is_zero_func = lambda x: abs(x) < epsilon
1082
- else: # integer types
1083
- pt_data = plaintext_np.astype(int)
1084
- is_zero_func = lambda x: x == 0
1085
-
1086
- # Helper function to create encrypted zero when needed
1087
- def get_encrypted_zero() -> Any:
1088
- return _create_encrypted_zero(ciphertext)
1089
-
1090
- if len(ct_shape) == 0 and len(pt_shape) == 0:
1091
- # Scalar * Scalar
1092
- pt_val = pt_data.item()
1093
- if is_zero_func(pt_val):
1094
- result_ciphertext = get_encrypted_zero()
1095
- else:
1096
- # Use single value (not list) for multiplication
1097
- val = float(pt_val) if target_dtype.is_floating else int(pt_val)
1098
- result_ciphertext = ciphertext.ct_data[0] * val
1099
- result_ct_data = [result_ciphertext]
1100
-
1101
- elif len(ct_shape) == 1 and len(pt_shape) == 1:
1102
- # Vector * Vector -> Scalar (inner product)
1103
- if ct_shape[0] != pt_shape[0]:
1104
- raise ValueError(
1105
- f"Vector size mismatch: CipherText size {ct_shape[0]} "
1106
- f"vs plaintext size {pt_shape[0]}"
1107
- )
1108
-
1109
- # Compute element-wise products, skipping zeros
1110
- non_zero_products = []
1111
- for i in range(ct_shape[0]):
1112
- pt_val = pt_data[i]
1113
- if not is_zero_func(pt_val):
1114
- # Convert to appropriate type and use single value (not list)
1115
- val = float(pt_val) if target_dtype.is_floating else int(pt_val)
1116
- product = ciphertext.ct_data[i] * val
1117
- non_zero_products.append(product)
1118
-
1119
- # Handle result
1120
- if not non_zero_products:
1121
- # All plaintext values are zero
1122
- result_ciphertext = get_encrypted_zero()
1123
- else:
1124
- # Sum all non-zero products
1125
- result_ciphertext = non_zero_products[0]
1126
- for i in range(1, len(non_zero_products)):
1127
- result_ciphertext = result_ciphertext + non_zero_products[i]
1128
-
1129
- result_ct_data = [result_ciphertext]
1130
-
1131
- elif len(ct_shape) == 2 and len(pt_shape) == 1:
1132
- # Matrix * Vector -> Vector
1133
- if ct_shape[1] != pt_shape[0]:
1134
- raise ValueError(
1135
- f"Matrix-vector dimension mismatch: Matrix shape {ct_shape} "
1136
- f"vs vector shape {pt_shape}"
1137
- )
1138
-
1139
- result_ct_data = []
1140
- for i in range(ct_shape[0]): # For each row of the matrix
1141
- # Compute dot product of row i with the vector, skipping zeros
1142
- row_products = []
1143
- for j in range(ct_shape[1]): # For each column in the row
1144
- pt_val = pt_data[j]
1145
- if not is_zero_func(pt_val):
1146
- ct_idx = i * ct_shape[1] + j
1147
- # Use single value (not list) for multiplication
1148
- val = float(pt_val) if target_dtype.is_floating else int(pt_val)
1149
- product = ciphertext.ct_data[ct_idx] * val
1150
- row_products.append(product)
1151
-
1152
- # Handle row result
1153
- if not row_products:
1154
- # All plaintext values in this row are zero
1155
- row_result = get_encrypted_zero()
1156
- else:
1157
- # Sum non-zero products for this row
1158
- row_result = row_products[0]
1159
- for k in range(1, len(row_products)):
1160
- row_result = row_result + row_products[k]
1161
-
1162
- result_ct_data.append(row_result)
1163
-
1164
- elif len(ct_shape) == 1 and len(pt_shape) == 2:
1165
- # Vector * Matrix -> Vector
1166
- if ct_shape[0] != pt_shape[0]:
1167
- raise ValueError(
1168
- f"Vector-matrix dimension mismatch: Vector shape {ct_shape} "
1169
- f"vs matrix shape {pt_shape}"
1170
- )
1171
-
1172
- result_ct_data = []
1173
- for j in range(pt_shape[1]): # For each column of the matrix
1174
- # Compute dot product of vector with column j, skipping zeros
1175
- col_products = []
1176
- for i in range(pt_shape[0]): # For each row in the column
1177
- pt_val = pt_data[i, j]
1178
- if not is_zero_func(pt_val):
1179
- # Use single value (not list) for multiplication
1180
- val = float(pt_val) if target_dtype.is_floating else int(pt_val)
1181
- product = ciphertext.ct_data[i] * val
1182
- col_products.append(product)
1183
-
1184
- # Handle column result
1185
- if not col_products:
1186
- # All plaintext values in this column are zero
1187
- col_result = get_encrypted_zero()
1188
- else:
1189
- # Sum non-zero products for this column
1190
- col_result = col_products[0]
1191
- for k in range(1, len(col_products)):
1192
- col_result = col_result + col_products[k]
1193
-
1194
- result_ct_data.append(col_result)
1195
-
1196
- elif len(ct_shape) == 2 and len(pt_shape) == 2:
1197
- # Matrix * Matrix -> Matrix
1198
- if ct_shape[1] != pt_shape[0]:
1199
- raise ValueError(
1200
- f"Matrix dimension mismatch: First matrix shape {ct_shape} "
1201
- f"vs second matrix shape {pt_shape}"
1202
- )
1203
-
1204
- result_ct_data = []
1205
- for i in range(ct_shape[0]): # For each row of first matrix
1206
- for j in range(pt_shape[1]): # For each column of second matrix
1207
- # Compute dot product of row i with column j, skipping zeros
1208
- products = []
1209
- for k in range(ct_shape[1]): # Sum over common dimension
1210
- pt_val = pt_data[k, j]
1211
- if not is_zero_func(pt_val):
1212
- ct_idx = i * ct_shape[1] + k
1213
- # Use single value (not list) for multiplication
1214
- val = (
1215
- float(pt_val)
1216
- if target_dtype.is_floating
1217
- else int(pt_val)
1218
- )
1219
- product = ciphertext.ct_data[ct_idx] * val
1220
- products.append(product)
1221
-
1222
- # Handle element result
1223
- if not products:
1224
- # All plaintext values for this element are zero
1225
- element_result = get_encrypted_zero()
1226
- else:
1227
- # Sum non-zero products for this element
1228
- element_result = products[0]
1229
- for p in range(1, len(products)):
1230
- element_result = element_result + products[p]
1231
-
1232
- result_ct_data.append(element_result)
1233
-
1234
- else:
1235
- # General N-D tensor dot product
1236
- # Flatten both tensors and perform generalized dot product
1237
- ct_flat = ciphertext.ct_data
1238
- pt_flat = pt_data.flatten()
1239
-
1240
- # For general case, we implement numpy.dot semantics
1241
- # This is a simplified implementation for common cases
1242
- if len(ct_shape) >= 2 and len(pt_shape) >= 1:
1243
- # Treat as matrix multiplication on the last axis of ct and first axis of pt
1244
- last_dim_ct = ct_shape[-1]
1245
- first_dim_pt = pt_shape[0]
1246
-
1247
- if last_dim_ct != first_dim_pt:
1248
- raise ValueError(
1249
- f"Tensor dimension mismatch: CipherText last dimension {last_dim_ct} "
1250
- f"vs plaintext first dimension {first_dim_pt}"
1251
- )
1252
-
1253
- # Reshape for matrix multiplication
1254
- ct_reshaped_size = int(np.prod(ct_shape[:-1]))
1255
- pt_reshaped_size = int(np.prod(pt_shape[1:]))
1256
-
1257
- result_ct_data = []
1258
- for i in range(ct_reshaped_size):
1259
- for j in range(pt_reshaped_size):
1260
- # Compute dot product for element (i, j), skipping zeros
1261
- products = []
1262
- for k in range(last_dim_ct):
1263
- pt_idx = k * pt_reshaped_size + j
1264
- pt_val = pt_flat[pt_idx]
1265
- if not is_zero_func(pt_val):
1266
- ct_idx = i * last_dim_ct + k
1267
- # Use single value (not list) for multiplication
1268
- val = (
1269
- float(pt_val)
1270
- if target_dtype.is_floating
1271
- else int(pt_val)
1272
- )
1273
- product = ct_flat[ct_idx] * val
1274
- products.append(product)
1275
-
1276
- # Handle element result
1277
- if not products:
1278
- # All plaintext values for this element are zero
1279
- element_result = get_encrypted_zero()
1280
- else:
1281
- # Sum non-zero products
1282
- element_result = products[0]
1283
- for p in range(1, len(products)):
1284
- element_result = element_result + products[p]
1285
- result_ct_data.append(element_result)
1286
- else:
1287
- raise ValueError(
1288
- f"Unsupported tensor shapes for dot product: "
1289
- f"CipherText shape {ct_shape}, plaintext shape {pt_shape}"
1290
- )
1291
-
1292
- # Create result CipherText with computed shape and encoding parameters
1293
- return [
1294
- CipherText(
1295
- ct_data=result_ct_data,
1296
- semantic_dtype=ciphertext.semantic_dtype,
1297
- semantic_shape=result_shape,
1298
- scheme=ciphertext.scheme,
1299
- key_size=ciphertext.key_size,
1300
- pk_data=ciphertext.pk_data,
1301
- max_value=ciphertext.max_value,
1302
- fxp_bits=ciphertext.fxp_bits,
1303
- modulus=ciphertext.modulus,
1304
- )
1305
- ]
1306
-
1307
- except ValueError:
1308
- # Re-raise ValueError directly (validation errors)
1309
- raise
1310
- except Exception as e:
1311
- raise RuntimeError(f"Failed to perform dot product: {e}") from e
1312
-
1313
-
1314
- @kernel_def("phe.gather")
1315
- def _phe_gather(pfunc: PFunction, ciphertext: CipherText, indices: TensorValue) -> Any:
1316
- """Execute gather operation on CipherText.
1317
-
1318
- Supports gathering from multidimensional CipherText using multidimensional indices.
1319
- The operation follows numpy.take semantics:
1320
- - result.shape = indices.shape + ciphertext.shape[:axis] + ciphertext.shape[axis+1:]
1321
- - Gathering is performed along the specified axis of ciphertext
1322
- """
1323
- # Validate that first argument is a CipherText
1324
- if not isinstance(ciphertext, CipherText):
1325
- raise ValueError("First argument must be a CipherText instance")
1326
-
1327
- # Get axis parameter from pfunc.attrs, default to 0
1328
- axis = pfunc.attrs.get("axis", 0)
1329
-
1330
- try:
1331
- # Convert indices to numpy
1332
- indices_np = indices.to_numpy()
1333
-
1334
- if not np.issubdtype(indices_np.dtype, np.integer):
1335
- raise ValueError("Indices must be of integer type")
1336
-
1337
- # Validate that ciphertext has at least 1 dimension for indexing
1338
- if len(ciphertext.semantic_shape) == 0:
1339
- raise ValueError("Cannot gather from scalar CipherText")
1340
-
1341
- # Normalize axis to positive value
1342
- ndim = len(ciphertext.semantic_shape)
1343
- if axis < 0:
1344
- axis = ndim + axis
1345
- if axis < 0 or axis >= ndim:
1346
- raise ValueError(
1347
- f"Axis {pfunc.attrs.get('axis', 0)} is out of bounds for array of dimension {ndim}"
1348
- )
1349
-
1350
- # Validate indices are within bounds for the specified axis
1351
- axis_size = ciphertext.semantic_shape[axis]
1352
- if np.any(indices_np < 0) or np.any(indices_np >= axis_size):
1353
- raise ValueError(
1354
- f"Indices are out of bounds for axis {axis} with size {axis_size}. "
1355
- f"Got indices in range [{np.min(indices_np)}, {np.max(indices_np)}]"
1356
- )
1357
-
1358
- # Calculate result shape: indices.shape + ciphertext.shape[:axis] + ciphertext.shape[axis+1:]
1359
- result_shape = (
1360
- indices_np.shape
1361
- + ciphertext.semantic_shape[:axis]
1362
- + ciphertext.semantic_shape[axis + 1 :]
1363
- )
1364
-
1365
- # Calculate strides for multi-axis gathering
1366
- ct_shape = ciphertext.semantic_shape
1367
-
1368
- # Stride calculations for arbitrary axis
1369
- # Elements before axis contribute to outer stride
1370
- outer_stride = int(np.prod(ct_shape[:axis])) if axis > 0 else 1
1371
- # Elements after axis contribute to inner stride
1372
- inner_stride = int(np.prod(ct_shape[axis + 1 :])) if axis < ndim - 1 else 1
1373
- # Total stride for one step along the specified axis
1374
- axis_stride = inner_stride
1375
-
1376
- # Perform gather operation
1377
- gathered_ct_data = []
1378
-
1379
- # Iterate through all possible combinations of indices before the gather axis
1380
- if axis == 0:
1381
- # Special case: gathering along axis 0 (existing behavior)
1382
- for idx in indices_np.flatten():
1383
- start_pos = int(idx) * axis_stride
1384
- end_pos = start_pos + axis_stride
1385
- slice_data = ciphertext.ct_data[start_pos:end_pos]
1386
- gathered_ct_data.extend(slice_data)
1387
- else:
1388
- # General case: gathering along arbitrary axis
1389
- for outer_idx in range(outer_stride):
1390
- for gather_idx in indices_np.flatten():
1391
- # Calculate position in flattened ciphertext data
1392
- pos = (
1393
- outer_idx * (ct_shape[axis] * inner_stride)
1394
- + int(gather_idx) * inner_stride
1395
- )
1396
- slice_data = ciphertext.ct_data[pos : pos + inner_stride]
1397
- gathered_ct_data.extend(slice_data)
1398
-
1399
- # Validate we got the expected number of elements
1400
- expected_size = int(np.prod(result_shape)) if result_shape else 1
1401
- if len(gathered_ct_data) != expected_size:
1402
- raise RuntimeError(
1403
- f"Internal error: Expected {expected_size} elements, got {len(gathered_ct_data)}"
1404
- )
1405
-
1406
- # Create result CipherText
1407
- return [
1408
- CipherText(
1409
- ct_data=gathered_ct_data,
1410
- semantic_dtype=ciphertext.semantic_dtype,
1411
- semantic_shape=result_shape,
1412
- scheme=ciphertext.scheme,
1413
- key_size=ciphertext.key_size,
1414
- pk_data=ciphertext.pk_data,
1415
- max_value=ciphertext.max_value,
1416
- fxp_bits=ciphertext.fxp_bits,
1417
- modulus=ciphertext.modulus,
1418
- )
1419
- ]
1420
-
1421
- except ValueError:
1422
- # Re-raise ValueError directly (validation errors)
1423
- raise
1424
- except Exception as e:
1425
- raise RuntimeError(f"Failed to perform gather: {e}") from e
1426
-
1427
-
1428
- @kernel_def("phe.scatter")
1429
- def _phe_scatter(
1430
- pfunc: PFunction,
1431
- ciphertext: CipherText,
1432
- indices: TensorValue,
1433
- updated: CipherText,
1434
- ) -> Any:
1435
- """Execute scatter operation on CipherText.
1436
-
1437
- Supports scattering into multidimensional CipherText using multidimensional indices.
1438
- The operation follows numpy scatter semantics:
1439
- - Scattering is performed along the specified axis of ciphertext
1440
- - indices.shape must equal updated.shape[:len(indices.shape)]
1441
- - updated.shape must be indices.shape + ciphertext.shape[:axis] + ciphertext.shape[axis+1:]
1442
- - Result shape is same as original ciphertext.shape
1443
-
1444
- """
1445
- # Validate that first and third arguments are CipherTexts
1446
- if not isinstance(ciphertext, CipherText) or not isinstance(updated, CipherText):
1447
- raise ValueError("First and third arguments must be CipherText instances")
1448
-
1449
- # Validate that both ciphertexts use same scheme/key_size
1450
- if ciphertext.scheme != updated.scheme or ciphertext.key_size != updated.key_size:
1451
- raise ValueError("Both CipherTexts must use same scheme and key size")
1452
-
1453
- if ciphertext.pk_data != updated.pk_data:
1454
- raise ValueError("Both CipherTexts must be encrypted with same key")
1455
-
1456
- # Get axis parameter from pfunc.attrs, default to 0
1457
- axis = pfunc.attrs.get("axis", 0)
1458
-
1459
- try:
1460
- # Convert indices to numpy
1461
- indices_np = indices.to_numpy()
1462
-
1463
- if not np.issubdtype(indices_np.dtype, np.integer):
1464
- raise ValueError("Indices must be of integer type")
1465
-
1466
- # Validate that ciphertext has at least 1 dimension for indexing
1467
- if len(ciphertext.semantic_shape) == 0:
1468
- raise ValueError("Cannot scatter into scalar CipherText")
1469
-
1470
- # Normalize axis to positive value
1471
- ndim = len(ciphertext.semantic_shape)
1472
- if axis < 0:
1473
- axis = ndim + axis
1474
- if axis < 0 or axis >= ndim:
1475
- raise ValueError(
1476
- f"Axis {pfunc.attrs.get('axis', 0)} is out of bounds for array of dimension {ndim}"
1477
- )
1478
-
1479
- # Validate indices are within bounds for the specified axis
1480
- axis_size = ciphertext.semantic_shape[axis]
1481
- if np.any(indices_np < 0) or np.any(indices_np >= axis_size):
1482
- raise ValueError(
1483
- f"Indices are out of bounds for axis {axis} with size {axis_size}. "
1484
- f"Got indices in range [{np.min(indices_np)}, {np.max(indices_np)}]"
1485
- )
1486
-
1487
- # Validate shape compatibility
1488
- # Expected updated shape: indices.shape + ciphertext.shape[:axis] + ciphertext.shape[axis+1:]
1489
- expected_updated_shape = (
1490
- indices_np.shape
1491
- + ciphertext.semantic_shape[:axis]
1492
- + ciphertext.semantic_shape[axis + 1 :]
1493
- )
1494
- if updated.semantic_shape != expected_updated_shape:
1495
- raise ValueError(
1496
- f"Updated CipherText shape mismatch. Expected {expected_updated_shape}, "
1497
- f"got {updated.semantic_shape}. "
1498
- f"Updated shape must be indices.shape + ciphertext.shape[:axis] + ciphertext.shape[axis+1:]"
1499
- )
1500
-
1501
- # Calculate strides for multi-axis scattering
1502
- ct_shape = ciphertext.semantic_shape
1503
-
1504
- # Stride calculations for arbitrary axis
1505
- # Elements before axis contribute to outer stride
1506
- outer_stride = int(np.prod(ct_shape[:axis])) if axis > 0 else 1
1507
- # Elements after axis contribute to inner stride
1508
- inner_stride = int(np.prod(ct_shape[axis + 1 :])) if axis < ndim - 1 else 1
1509
-
1510
- # Create a copy of the original ciphertext data for scattering
1511
- scattered_ct_data = ciphertext.ct_data.copy()
1512
-
1513
- # Perform scatter operation
1514
- indices_flat = indices_np.flatten()
1515
- updated_ct_data = updated.ct_data
1516
-
1517
- if axis == 0:
1518
- # Special case: scattering along axis 0 (existing behavior)
1519
- axis_stride = inner_stride
1520
- for i, idx in enumerate(indices_flat):
1521
- start_pos_updated = i * axis_stride
1522
- start_pos_original = int(idx) * axis_stride
1523
-
1524
- for j in range(axis_stride):
1525
- if start_pos_updated + j < len(updated_ct_data):
1526
- scattered_ct_data[start_pos_original + j] = updated_ct_data[
1527
- start_pos_updated + j
1528
- ]
1529
- else:
1530
- # General case: scattering along arbitrary axis
1531
- for outer_idx in range(outer_stride):
1532
- for i, scatter_idx in enumerate(indices_flat):
1533
- # Calculate position in flattened ciphertext data
1534
- start_pos_original = (
1535
- outer_idx * (ct_shape[axis] * inner_stride)
1536
- + int(scatter_idx) * inner_stride
1537
- )
1538
- start_pos_updated = (
1539
- outer_idx * len(indices_flat) + i
1540
- ) * inner_stride
1541
-
1542
- # Update the ciphertext data
1543
- for j in range(inner_stride):
1544
- if start_pos_updated + j < len(updated_ct_data):
1545
- scattered_ct_data[start_pos_original + j] = updated_ct_data[
1546
- start_pos_updated + j
1547
- ]
1548
-
1549
- # Create result CipherText with same shape as original
1550
- return [
1551
- CipherText(
1552
- ct_data=scattered_ct_data,
1553
- semantic_dtype=ciphertext.semantic_dtype,
1554
- semantic_shape=ciphertext.semantic_shape,
1555
- scheme=ciphertext.scheme,
1556
- key_size=ciphertext.key_size,
1557
- pk_data=ciphertext.pk_data,
1558
- max_value=ciphertext.max_value,
1559
- fxp_bits=ciphertext.fxp_bits,
1560
- modulus=ciphertext.modulus,
1561
- )
1562
- ]
1563
- except ValueError:
1564
- # Re-raise ValueError directly (validation errors)
1565
- raise
1566
- except Exception as e:
1567
- raise RuntimeError(f"Failed to perform scatter: {e}") from e
1568
-
1569
-
1570
- @kernel_def("phe.concat")
1571
- def _phe_concat(pfunc: PFunction, c1: CipherText, c2: CipherText) -> Any:
1572
- """Execute concat operation on multiple CipherTexts.
1573
-
1574
- Supports concatenation along any axis of multidimensional CipherTexts.
1575
- The axis parameter is obtained from pfunc.attrs.
1576
- """
1577
- # Get axis parameter from pfunc.attrs, default to 0
1578
- axis = pfunc.attrs.get("axis", 0)
1579
-
1580
- # Validate that all arguments are CipherText
1581
- if not isinstance(c1, CipherText) or not isinstance(c2, CipherText):
1582
- raise ValueError("All arguments must be CipherText instances")
1583
-
1584
- # Validate that all ciphertexts have the same key & scheme
1585
- if c1.scheme != c2.scheme or c1.key_size != c2.key_size:
1586
- raise ValueError("All CipherTexts must use same scheme and key size")
1587
- if c1.pk_data != c2.pk_data:
1588
- raise ValueError("All CipherTexts must be encrypted with same key")
1589
- if c1.semantic_dtype != c2.semantic_dtype:
1590
- raise ValueError(
1591
- f"All CipherTexts must have same semantic dtype, got {c1.semantic_dtype} vs {c2.semantic_dtype}"
1592
- )
1593
-
1594
- # Validate dimensions and axis
1595
- if len(c1.semantic_shape) != len(c2.semantic_shape):
1596
- raise ValueError(
1597
- f"All CipherTexts must have same number of dimensions for concat, got {len(c1.semantic_shape)} vs {len(c2.semantic_shape)}"
1598
- )
1599
-
1600
- # Handle scalar case
1601
- if len(c1.semantic_shape) == 0:
1602
- raise ValueError("Cannot concatenate scalar CipherTexts")
1603
-
1604
- # Normalize axis (handle negative axis)
1605
- ndim = len(c1.semantic_shape)
1606
- if axis < 0:
1607
- axis = ndim + axis
1608
- if axis < 0 or axis >= ndim:
1609
- raise ValueError(
1610
- f"axis {pfunc.attrs.get('axis', 0)} is out of bounds for array of dimension {ndim}"
1611
- )
1612
-
1613
- # Validate that all dimensions except the concat axis are the same
1614
- for i in range(ndim):
1615
- if i != axis and c1.semantic_shape[i] != c2.semantic_shape[i]:
1616
- raise ValueError(
1617
- f"All CipherTexts must have same shape except along concatenation axis {axis}. "
1618
- f"Shape mismatch at dimension {i}: {c1.semantic_shape[i]} vs {c2.semantic_shape[i]}"
1619
- )
1620
-
1621
- try:
1622
- # Calculate result shape
1623
- result_shape_list = list(c1.semantic_shape)
1624
- result_shape_list[axis] = c1.semantic_shape[axis] + c2.semantic_shape[axis]
1625
- result_shape = tuple(result_shape_list)
1626
-
1627
- # Calculate the number of slices before the concatenation axis
1628
- pre_axis_size = int(np.prod(c1.semantic_shape[:axis])) if axis > 0 else 1
1629
- # Calculate the size of data along and after the concatenation axis
1630
- c1_post_axis_size = int(np.prod(c1.semantic_shape[axis:]))
1631
- c2_post_axis_size = int(np.prod(c2.semantic_shape[axis:]))
1632
-
1633
- # Initialize result data
1634
- concatenated_ct_data = []
1635
-
1636
- # Perform concatenation
1637
- for pre_idx in range(pre_axis_size):
1638
- # For each slice before the concatenation axis
1639
-
1640
- # Add data from c1 along the concatenation axis
1641
- c1_start = pre_idx * c1_post_axis_size
1642
- c1_end = c1_start + c1_post_axis_size
1643
- concatenated_ct_data.extend(c1.ct_data[c1_start:c1_end])
1644
-
1645
- # Add data from c2 along the concatenation axis
1646
- c2_start = pre_idx * c2_post_axis_size
1647
- c2_end = c2_start + c2_post_axis_size
1648
- concatenated_ct_data.extend(c2.ct_data[c2_start:c2_end])
1649
-
1650
- # Validate we got the expected number of elements
1651
- expected_size = int(np.prod(result_shape))
1652
- if len(concatenated_ct_data) != expected_size:
1653
- raise RuntimeError(
1654
- f"Internal error: Expected {expected_size} elements, got {len(concatenated_ct_data)}"
1655
- )
1656
-
1657
- # Create result CipherText
1658
- return [
1659
- CipherText(
1660
- ct_data=concatenated_ct_data,
1661
- semantic_dtype=c1.semantic_dtype,
1662
- semantic_shape=result_shape,
1663
- scheme=c1.scheme,
1664
- key_size=c1.key_size,
1665
- pk_data=c1.pk_data,
1666
- max_value=c1.max_value,
1667
- fxp_bits=c1.fxp_bits,
1668
- modulus=c1.modulus,
1669
- )
1670
- ]
1671
-
1672
- except ValueError:
1673
- # Re-raise ValueError directly (validation errors)
1674
- raise
1675
- except Exception as e:
1676
- raise RuntimeError(f"Failed to perform concat: {e}") from e
1677
-
1678
-
1679
- @kernel_def("phe.reshape")
1680
- def _phe_reshape(pfunc: PFunction, ciphertext: CipherText) -> Any:
1681
- """Execute reshape operation on CipherText.
1682
-
1683
- Changes the shape of a CipherText without changing its encrypted data.
1684
- The new_shape parameter is obtained from pfunc.attrs.
1685
- """
1686
- # Validate that argument is a CipherText
1687
- if not isinstance(ciphertext, CipherText):
1688
- raise ValueError("Argument must be a CipherText instance")
1689
-
1690
- # Get new_shape parameter from pfunc.attrs
1691
- new_shape = pfunc.attrs.get("new_shape")
1692
- if new_shape is None:
1693
- raise ValueError("new_shape parameter is required for reshape operation")
1694
-
1695
- # Convert new_shape to tuple if it's a list
1696
- if isinstance(new_shape, list):
1697
- new_shape = tuple(new_shape)
1698
- elif not isinstance(new_shape, tuple):
1699
- raise ValueError("new_shape must be a tuple or list of integers")
1700
-
1701
- try:
1702
- # Handle -1 dimension inference
1703
- old_size = (
1704
- int(np.prod(ciphertext.semantic_shape)) if ciphertext.semantic_shape else 1
1705
- )
1706
-
1707
- # Process new_shape to infer -1 dimensions
1708
- inferred_shape = list(new_shape)
1709
- negative_ones = [i for i, dim in enumerate(new_shape) if dim == -1]
1710
-
1711
- if len(negative_ones) > 1:
1712
- raise ValueError("can only specify one unknown dimension")
1713
- elif len(negative_ones) == 1:
1714
- # Calculate the inferred dimension
1715
- known_size = 1
1716
- for dim in new_shape:
1717
- if dim != -1:
1718
- if dim <= 0:
1719
- raise ValueError(
1720
- f"negative dimensions not allowed (except -1): {dim}"
1721
- )
1722
- known_size *= dim
1723
-
1724
- if old_size % known_size != 0:
1725
- raise ValueError(
1726
- f"cannot reshape array of size {old_size} into shape {new_shape}"
1727
- )
1728
-
1729
- inferred_dim = old_size // known_size
1730
- inferred_shape[negative_ones[0]] = inferred_dim
1731
- else:
1732
- # No -1 dimensions, validate that all dimensions are positive
1733
- for dim in new_shape:
1734
- if dim <= 0:
1735
- raise ValueError(f"negative dimensions not allowed: {dim}")
1736
-
1737
- # Convert back to tuple
1738
- final_shape = tuple(inferred_shape)
1739
-
1740
- # Validate that new shape has the same number of elements
1741
- new_size = int(np.prod(final_shape)) if final_shape else 1
1742
-
1743
- if old_size != new_size:
1744
- raise ValueError(
1745
- f"Cannot reshape CipherText with {old_size} elements to shape {final_shape} "
1746
- f"with {new_size} elements"
1747
- )
1748
-
1749
- # Create result CipherText with new shape and encoding parameters (ct_data remains the same)
1750
- return [
1751
- CipherText(
1752
- ct_data=ciphertext.ct_data, # Same encrypted data
1753
- semantic_dtype=ciphertext.semantic_dtype,
1754
- semantic_shape=final_shape, # Use the final shape
1755
- scheme=ciphertext.scheme,
1756
- key_size=ciphertext.key_size,
1757
- pk_data=ciphertext.pk_data,
1758
- max_value=ciphertext.max_value,
1759
- fxp_bits=ciphertext.fxp_bits,
1760
- modulus=ciphertext.modulus,
1761
- )
1762
- ]
1763
-
1764
- except ValueError:
1765
- # Re-raise ValueError directly (validation errors)
1766
- raise
1767
- except Exception as e:
1768
- raise RuntimeError(f"Failed to perform reshape: {e}") from e
1769
-
1770
-
1771
- @kernel_def("phe.transpose")
1772
- def _phe_transpose(pfunc: PFunction, ciphertext: CipherText) -> Any:
1773
- """Execute transpose operation on CipherText.
1774
-
1775
- Permutes the dimensions of a CipherText according to the given axes.
1776
- The axes parameter is obtained from pfunc.attrs.
1777
- """
1778
- # Validate that argument is a CipherText
1779
- if not isinstance(ciphertext, CipherText):
1780
- raise ValueError("Argument must be a CipherText instance")
1781
-
1782
- # Handle scalar case
1783
- if len(ciphertext.semantic_shape) == 0:
1784
- # Transposing a scalar returns the same scalar
1785
- return [ciphertext]
1786
-
1787
- # Get axes parameter from pfunc.attrs
1788
- axes = pfunc.attrs.get("axes")
1789
-
1790
- # If axes is None, reverse all dimensions (default transpose behavior)
1791
- if axes is None:
1792
- axes = tuple(reversed(range(len(ciphertext.semantic_shape))))
1793
- elif isinstance(axes, list):
1794
- axes = tuple(axes)
1795
- elif not isinstance(axes, tuple):
1796
- raise ValueError("axes must be a tuple or list of integers, or None")
1797
-
1798
- try:
1799
- # Validate axes
1800
- ndim = len(ciphertext.semantic_shape)
1801
- if len(axes) != ndim:
1802
- raise ValueError(
1803
- f"axes length {len(axes)} does not match tensor dimensions {ndim}"
1804
- )
1805
-
1806
- # Normalize negative axes and validate range
1807
- normalized_axes = []
1808
- for axis in axes:
1809
- if axis < 0:
1810
- axis = ndim + axis
1811
- if axis < 0 or axis >= ndim:
1812
- raise ValueError(
1813
- f"axis {axis} is out of bounds for array of dimension {ndim}"
1814
- )
1815
- normalized_axes.append(axis)
1816
- axes = tuple(normalized_axes)
1817
-
1818
- # Check for duplicate axes
1819
- if len(set(axes)) != len(axes):
1820
- raise ValueError("axes cannot contain duplicate values")
1821
-
1822
- # Calculate new shape
1823
- old_shape = ciphertext.semantic_shape
1824
- new_shape = tuple(old_shape[axis] for axis in axes)
1825
-
1826
- # For multidimensional transpose, we need to rearrange the encrypted data
1827
- # Create mapping from old flat index to new flat index
1828
- def transpose_data(ct_data: list, old_shape: tuple, axes: tuple) -> list:
1829
- if len(old_shape) <= 1:
1830
- # 1D or scalar case - no actual transposition needed
1831
- return ct_data
1832
-
1833
- # Create numpy array to help with index calculations
1834
- dummy_array = np.arange(len(ct_data)).reshape(old_shape)
1835
- transposed_dummy = np.transpose(dummy_array, axes)
1836
-
1837
- # The new data should be arranged in the order that numpy.transpose would produce
1838
- new_ct_data = [ct_data[idx] for idx in transposed_dummy.flatten()]
1839
-
1840
- return new_ct_data
1841
-
1842
- # Rearrange the encrypted data according to transpose
1843
- transposed_ct_data = transpose_data(ciphertext.ct_data, old_shape, axes)
1844
-
1845
- # Create result CipherText with transposed shape and rearranged data
1846
- return [
1847
- CipherText(
1848
- ct_data=transposed_ct_data,
1849
- semantic_dtype=ciphertext.semantic_dtype,
1850
- semantic_shape=new_shape,
1851
- scheme=ciphertext.scheme,
1852
- key_size=ciphertext.key_size,
1853
- pk_data=ciphertext.pk_data,
1854
- max_value=ciphertext.max_value,
1855
- fxp_bits=ciphertext.fxp_bits,
1856
- modulus=ciphertext.modulus,
1857
- )
1858
- ]
1859
-
1860
- except ValueError:
1861
- # Re-raise ValueError directly (validation errors)
1862
- raise
1863
- except Exception as e:
1864
- raise RuntimeError(f"Failed to perform transpose: {e}") from e