mplang-nightly 0.1.dev142__py3-none-any.whl → 0.1.dev143__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mplang/backend/phe.py +1448 -91
- mplang/frontend/phe.py +140 -3
- {mplang_nightly-0.1.dev142.dist-info → mplang_nightly-0.1.dev143.dist-info}/METADATA +1 -1
- {mplang_nightly-0.1.dev142.dist-info → mplang_nightly-0.1.dev143.dist-info}/RECORD +7 -7
- {mplang_nightly-0.1.dev142.dist-info → mplang_nightly-0.1.dev143.dist-info}/WHEEL +0 -0
- {mplang_nightly-0.1.dev142.dist-info → mplang_nightly-0.1.dev143.dist-info}/entry_points.txt +0 -0
- {mplang_nightly-0.1.dev142.dist-info → mplang_nightly-0.1.dev143.dist-info}/licenses/LICENSE +0 -0
mplang/backend/phe.py
CHANGED
@@ -25,16 +25,29 @@ from mplang.core.mptype import TensorLike
|
|
25
25
|
from mplang.core.pfunc import PFunction
|
26
26
|
|
27
27
|
# This controls the decimal precision used in lightPHE for float operations
|
28
|
-
|
28
|
+
# we force it to 0 to only support integer operations
|
29
|
+
# we will support negative and floating-point with our own encoding/decoding
|
30
|
+
PRECISION = 0
|
29
31
|
|
30
32
|
|
31
33
|
class PublicKey:
|
32
34
|
"""PHE Public Key that implements TensorLike protocol."""
|
33
35
|
|
34
|
-
def __init__(
|
36
|
+
def __init__(
|
37
|
+
self,
|
38
|
+
key_data: Any,
|
39
|
+
scheme: str,
|
40
|
+
key_size: int,
|
41
|
+
max_value: int = 2**100,
|
42
|
+
fxp_bits: int = 12,
|
43
|
+
modulus: int | None = None,
|
44
|
+
):
|
35
45
|
self.key_data = key_data
|
36
46
|
self.scheme = scheme
|
37
47
|
self.key_size = key_size
|
48
|
+
self.max_value = max_value # Maximum absolute value B for range encoding
|
49
|
+
self.fxp_bits = fxp_bits # Fixed-point precision bits for float encoding
|
50
|
+
self.modulus = modulus # Paillier modulus N for range encoding
|
38
51
|
|
39
52
|
@property
|
40
53
|
def dtype(self) -> Any:
|
@@ -44,18 +57,35 @@ class PublicKey:
|
|
44
57
|
def shape(self) -> tuple[int, ...]:
|
45
58
|
return ()
|
46
59
|
|
60
|
+
@property
|
61
|
+
def max_float_value(self) -> float:
|
62
|
+
"""Maximum float value that can be encoded."""
|
63
|
+
return float(self.max_value / (2**self.fxp_bits))
|
64
|
+
|
47
65
|
def __repr__(self) -> str:
|
48
|
-
return f"PublicKey(scheme={self.scheme}, key_size={self.key_size})"
|
66
|
+
return f"PublicKey(scheme={self.scheme}, key_size={self.key_size}, max_value={self.max_value}, fxp_bits={self.fxp_bits})"
|
49
67
|
|
50
68
|
|
51
69
|
class PrivateKey:
|
52
70
|
"""PHE Private Key that implements TensorLike protocol."""
|
53
71
|
|
54
|
-
def __init__(
|
72
|
+
def __init__(
|
73
|
+
self,
|
74
|
+
sk_data: Any,
|
75
|
+
pk_data: Any,
|
76
|
+
scheme: str,
|
77
|
+
key_size: int,
|
78
|
+
max_value: int = 2**100,
|
79
|
+
fxp_bits: int = 12,
|
80
|
+
modulus: int | None = None,
|
81
|
+
):
|
55
82
|
self.sk_data = sk_data # Store private key data
|
56
83
|
self.pk_data = pk_data # Store public key data as well
|
57
84
|
self.scheme = scheme
|
58
85
|
self.key_size = key_size
|
86
|
+
self.max_value = max_value # Maximum absolute value B for range encoding
|
87
|
+
self.fxp_bits = fxp_bits # Fixed-point precision bits for float encoding
|
88
|
+
self.modulus = modulus # Paillier modulus N for range encoding
|
59
89
|
|
60
90
|
@property
|
61
91
|
def dtype(self) -> Any:
|
@@ -65,8 +95,13 @@ class PrivateKey:
|
|
65
95
|
def shape(self) -> tuple[int, ...]:
|
66
96
|
return ()
|
67
97
|
|
98
|
+
@property
|
99
|
+
def max_float_value(self) -> float:
|
100
|
+
"""Maximum float value that can be encoded."""
|
101
|
+
return float(self.max_value / (2**self.fxp_bits))
|
102
|
+
|
68
103
|
def __repr__(self) -> str:
|
69
|
-
return f"PrivateKey(scheme={self.scheme}, key_size={self.key_size})"
|
104
|
+
return f"PrivateKey(scheme={self.scheme}, key_size={self.key_size}, max_value={self.max_value}, fxp_bits={self.fxp_bits})"
|
70
105
|
|
71
106
|
|
72
107
|
class CipherText:
|
@@ -80,6 +115,9 @@ class CipherText:
|
|
80
115
|
scheme: str,
|
81
116
|
key_size: int,
|
82
117
|
pk_data: Any = None, # Store public key for operations
|
118
|
+
max_value: int = 2**100,
|
119
|
+
fxp_bits: int = 12,
|
120
|
+
modulus: int | None = None,
|
83
121
|
):
|
84
122
|
self.ct_data = ct_data
|
85
123
|
self.semantic_dtype = semantic_dtype
|
@@ -87,6 +125,9 @@ class CipherText:
|
|
87
125
|
self.scheme = scheme
|
88
126
|
self.key_size = key_size
|
89
127
|
self.pk_data = pk_data
|
128
|
+
self.max_value = max_value
|
129
|
+
self.fxp_bits = fxp_bits
|
130
|
+
self.modulus = modulus
|
90
131
|
|
91
132
|
@property
|
92
133
|
def dtype(self) -> Any:
|
@@ -96,102 +137,367 @@ class CipherText:
|
|
96
137
|
def shape(self) -> tuple[int, ...]:
|
97
138
|
return self.semantic_shape
|
98
139
|
|
140
|
+
@property
|
141
|
+
def max_float_value(self) -> float:
|
142
|
+
"""Maximum float value that can be encoded."""
|
143
|
+
return float(self.max_value / (2**self.fxp_bits))
|
144
|
+
|
99
145
|
def __repr__(self) -> str:
|
100
146
|
return f"CipherText(dtype={self.semantic_dtype}, shape={self.semantic_shape}, scheme={self.scheme})"
|
101
147
|
|
102
148
|
|
103
|
-
|
149
|
+
# Range-based encoding functions for negative numbers and floats
|
150
|
+
def _range_encode_integer(value: int, max_value: int, modulus: int) -> int:
|
151
|
+
"""
|
152
|
+
Range encoding function for integers.
|
153
|
+
- Positive numbers: encode(m) = m
|
154
|
+
- Negative numbers: encode(m) = N + m
|
155
|
+
"""
|
156
|
+
if not (-max_value <= value <= max_value):
|
157
|
+
raise ValueError(
|
158
|
+
f"Integer value {value} out of range [-{max_value}, {max_value}]"
|
159
|
+
)
|
160
|
+
|
161
|
+
if value >= 0:
|
162
|
+
encoded = value
|
163
|
+
else:
|
164
|
+
encoded = modulus + value
|
165
|
+
|
166
|
+
return encoded
|
167
|
+
|
168
|
+
|
169
|
+
def _range_encode_float(
|
170
|
+
value: float, max_value: int, fxp_bits: int, modulus: int
|
171
|
+
) -> int:
|
172
|
+
"""
|
173
|
+
Range encoding function for floats.
|
174
|
+
1. Fixed-point conversion: scaled_int = round(value * 2^fxp_bits)
|
175
|
+
2. Integer encoding rules
|
176
|
+
"""
|
177
|
+
max_float = max_value / (2**fxp_bits)
|
178
|
+
if not (-max_float <= value <= max_float):
|
179
|
+
raise ValueError(
|
180
|
+
f"Float value {value} out of range [-{max_float}, {max_float}]"
|
181
|
+
)
|
182
|
+
|
183
|
+
# Fixed-point encoding: float → scaled integer
|
184
|
+
scaled_int = round(value * (2**fxp_bits))
|
185
|
+
|
186
|
+
# Use integer encoding rules
|
187
|
+
return _range_encode_integer(scaled_int, max_value, modulus)
|
188
|
+
|
189
|
+
|
190
|
+
def _range_encode_mixed(
|
191
|
+
value: Any, max_value: int, fxp_bits: int, modulus: int, semantic_dtype: DType
|
192
|
+
) -> int:
|
193
|
+
"""
|
194
|
+
Mixed encoding function - automatically handle integers and floats based on semantic type.
|
195
|
+
Use semantic_dtype to choose between integer and float encoding.
|
196
|
+
"""
|
197
|
+
if semantic_dtype.is_floating:
|
198
|
+
# For floating semantic types, always use float encoding
|
199
|
+
return _range_encode_float(float(value), max_value, fxp_bits, modulus)
|
200
|
+
else:
|
201
|
+
# For integer semantic types, use integer encoding
|
202
|
+
return _range_encode_integer(int(value), max_value, modulus)
|
203
|
+
|
204
|
+
|
205
|
+
def _range_decode_integer(encoded_value: int, max_value: int, modulus: int) -> int:
|
206
|
+
"""
|
207
|
+
Range decoding function for integers.
|
208
|
+
- If r <= max_value: decode(r) = r
|
209
|
+
- If r >= N - max_value: decode(r) = r - N
|
210
|
+
- If max_value < r < N - max_value: overflow error
|
211
|
+
"""
|
212
|
+
|
213
|
+
# Ensure handling integer
|
214
|
+
if isinstance(encoded_value, (list, tuple)):
|
215
|
+
encoded_value = encoded_value[0]
|
216
|
+
encoded_value = int(encoded_value) % modulus
|
217
|
+
|
218
|
+
if encoded_value <= max_value:
|
219
|
+
return encoded_value
|
220
|
+
elif encoded_value >= modulus - max_value:
|
221
|
+
return encoded_value - modulus
|
222
|
+
else:
|
223
|
+
raise ValueError(f"Decoded value {encoded_value} is in overflow region")
|
224
|
+
|
225
|
+
|
226
|
+
def _range_decode_float(
|
227
|
+
encoded_value: int, max_value: int, fxp_bits: int, modulus: int
|
228
|
+
) -> float:
|
229
|
+
"""
|
230
|
+
Range decoding function for floats.
|
231
|
+
1. Integer decoding: decoded_int = range_decode_integer(encoded_value)
|
232
|
+
2. Fixed-point conversion: value = decoded_int / 2^fxp_bits
|
233
|
+
"""
|
234
|
+
# First decode as integer
|
235
|
+
decoded_int = _range_decode_integer(encoded_value, max_value, modulus)
|
236
|
+
|
237
|
+
# Fixed-point decoding: scaled integer → float
|
238
|
+
return float(decoded_int / (2**fxp_bits))
|
239
|
+
|
240
|
+
|
241
|
+
def _range_decode_mixed(
|
242
|
+
encoded_value: int,
|
243
|
+
max_value: int,
|
244
|
+
fxp_bits: int,
|
245
|
+
modulus: int,
|
246
|
+
semantic_dtype: DType,
|
247
|
+
) -> Any:
|
248
|
+
"""
|
249
|
+
Mixed decoding function - decode based on semantic type.
|
250
|
+
Use semantic_dtype to choose between integer and float decoding.
|
251
|
+
"""
|
252
|
+
if semantic_dtype.is_floating:
|
253
|
+
# For floating semantic types, decode as float
|
254
|
+
return _range_decode_float(encoded_value, max_value, fxp_bits, modulus)
|
255
|
+
else:
|
256
|
+
# For integer semantic types, decode as integer
|
257
|
+
return _range_decode_integer(encoded_value, max_value, modulus)
|
258
|
+
|
259
|
+
|
260
|
+
def _convert_to_numpy(obj: TensorLike) -> np.ndarray:
|
261
|
+
"""Convert a TensorLike object to numpy array."""
|
104
262
|
if isinstance(obj, np.ndarray):
|
105
263
|
return obj
|
264
|
+
|
265
|
+
# Try to use .numpy() method if available
|
106
266
|
if hasattr(obj, "numpy"):
|
107
|
-
|
108
|
-
|
109
|
-
|
110
|
-
|
267
|
+
numpy_method = getattr(obj, "numpy", None)
|
268
|
+
if callable(numpy_method):
|
269
|
+
try:
|
270
|
+
return np.asarray(numpy_method())
|
271
|
+
except Exception:
|
272
|
+
pass
|
273
|
+
|
111
274
|
return np.asarray(obj)
|
112
275
|
|
113
276
|
|
114
277
|
@kernel_def("phe.keygen")
|
115
278
|
def _phe_keygen(pfunc: PFunction) -> Any:
|
116
279
|
scheme = pfunc.attrs.get("scheme", "paillier")
|
280
|
+
# use small key_size to speed up tests
|
281
|
+
# in production use at least 2048 bits or 3072 bits for better security
|
117
282
|
key_size = pfunc.attrs.get("key_size", 2048)
|
118
|
-
|
283
|
+
max_value = pfunc.attrs.get(
|
284
|
+
"max_value", 2**32
|
285
|
+
) # Use larger range to avoid overflow
|
286
|
+
fxp_bits = pfunc.attrs.get("fxp_bits", 12)
|
287
|
+
|
288
|
+
# Validate scheme
|
289
|
+
if scheme.lower() not in ["paillier"]:
|
119
290
|
raise ValueError(f"Unsupported PHE scheme: {scheme}")
|
120
|
-
|
291
|
+
|
292
|
+
scheme = scheme.capitalize()
|
293
|
+
|
121
294
|
try:
|
295
|
+
# Set higher precision for better accuracy with floats
|
122
296
|
phe = LightPHE(
|
123
|
-
algorithm_name=
|
297
|
+
algorithm_name=scheme,
|
298
|
+
key_size=key_size,
|
299
|
+
precision=PRECISION,
|
124
300
|
)
|
301
|
+
|
125
302
|
pk_data = phe.cs.keys["public_key"]
|
126
303
|
sk_data = phe.cs.keys["private_key"]
|
127
|
-
|
304
|
+
modulus = phe.cs.plaintext_modulo # Get Paillier modulus N
|
305
|
+
|
306
|
+
# Validate safety: N should be much larger than 3*max_value
|
307
|
+
if modulus <= 3 * max_value:
|
308
|
+
raise ValueError(
|
309
|
+
f"Modulus {modulus} is too small for max_value {max_value}. Require N >> 3*B"
|
310
|
+
)
|
311
|
+
|
312
|
+
public_key = PublicKey(
|
313
|
+
key_data=pk_data,
|
314
|
+
scheme=scheme,
|
315
|
+
key_size=key_size,
|
316
|
+
max_value=max_value,
|
317
|
+
fxp_bits=fxp_bits,
|
318
|
+
modulus=modulus,
|
319
|
+
)
|
128
320
|
private_key = PrivateKey(
|
129
|
-
sk_data=sk_data,
|
321
|
+
sk_data=sk_data,
|
322
|
+
pk_data=pk_data,
|
323
|
+
scheme=scheme,
|
324
|
+
key_size=key_size,
|
325
|
+
max_value=max_value,
|
326
|
+
fxp_bits=fxp_bits,
|
327
|
+
modulus=modulus,
|
130
328
|
)
|
131
|
-
|
132
|
-
|
329
|
+
|
330
|
+
return [public_key, private_key]
|
331
|
+
|
332
|
+
except Exception as e:
|
133
333
|
raise RuntimeError(f"Failed to generate PHE keys: {e}") from e
|
134
334
|
|
135
335
|
|
136
336
|
@kernel_def("phe.encrypt")
|
137
|
-
def _phe_encrypt(pfunc: PFunction, plaintext: Any, public_key:
|
337
|
+
def _phe_encrypt(pfunc: PFunction, plaintext: Any, public_key: PublicKey) -> Any:
|
338
|
+
# Validate public_key type
|
138
339
|
if not isinstance(public_key, PublicKey):
|
139
|
-
raise ValueError("
|
340
|
+
raise ValueError("Second argument must be a PublicKey instance")
|
341
|
+
|
140
342
|
try:
|
141
|
-
|
142
|
-
|
143
|
-
|
343
|
+
# Convert plaintext to numpy to get semantic type info
|
344
|
+
plaintext_np = _convert_to_numpy(plaintext)
|
345
|
+
semantic_dtype = DType.from_numpy(plaintext_np.dtype)
|
346
|
+
semantic_shape = plaintext_np.shape
|
347
|
+
|
348
|
+
# Create lightPHE instance with the same scheme/key_size as the key
|
144
349
|
phe = LightPHE(
|
145
350
|
algorithm_name=public_key.scheme,
|
146
351
|
key_size=public_key.key_size,
|
147
352
|
precision=PRECISION,
|
148
353
|
)
|
354
|
+
|
355
|
+
# CRITICAL: Set the same modulus as the key to ensure consistency
|
356
|
+
if public_key.modulus is not None:
|
357
|
+
phe.cs.plaintext_modulo = public_key.modulus
|
358
|
+
phe.cs.ciphertext_modulo = public_key.modulus * public_key.modulus
|
359
|
+
|
360
|
+
# Set the public key
|
149
361
|
phe.cs.keys["public_key"] = public_key.key_data
|
150
|
-
|
151
|
-
|
152
|
-
|
153
|
-
|
154
|
-
|
155
|
-
|
362
|
+
|
363
|
+
# Prepare data for encryption using range encoding
|
364
|
+
flat_data = plaintext_np.flatten()
|
365
|
+
|
366
|
+
# Use mixed encoding for consistent handling of integers and floats
|
367
|
+
encoded_data_list = []
|
368
|
+
for val in flat_data:
|
369
|
+
# Use mixed encoding to handle both integers and floats uniformly
|
370
|
+
if public_key.modulus is None:
|
371
|
+
raise ValueError(
|
372
|
+
"Public key modulus is None, key generation may have failed"
|
373
|
+
)
|
374
|
+
encoded_val = _range_encode_mixed(
|
375
|
+
val,
|
376
|
+
public_key.max_value,
|
377
|
+
public_key.fxp_bits,
|
378
|
+
public_key.modulus,
|
379
|
+
semantic_dtype,
|
380
|
+
)
|
381
|
+
encoded_data_list.append(encoded_val)
|
382
|
+
|
383
|
+
# Encrypt the encoded values (note: not passing as list, just the value)
|
384
|
+
lightphe_ciphertext = [phe.encrypt(val) for val in encoded_data_list]
|
385
|
+
|
386
|
+
# Create CipherText object with encoding parameters
|
156
387
|
ciphertext = CipherText(
|
157
|
-
ct_data=
|
388
|
+
ct_data=lightphe_ciphertext,
|
158
389
|
semantic_dtype=semantic_dtype,
|
159
390
|
semantic_shape=semantic_shape,
|
160
391
|
scheme=public_key.scheme,
|
161
392
|
key_size=public_key.key_size,
|
162
393
|
pk_data=public_key.key_data,
|
394
|
+
max_value=public_key.max_value,
|
395
|
+
fxp_bits=public_key.fxp_bits,
|
396
|
+
modulus=public_key.modulus,
|
163
397
|
)
|
164
|
-
|
165
|
-
|
398
|
+
|
399
|
+
return [ciphertext]
|
400
|
+
|
401
|
+
except Exception as e:
|
166
402
|
raise RuntimeError(f"Failed to encrypt data: {e}") from e
|
167
403
|
|
168
404
|
|
169
405
|
@kernel_def("phe.mul")
|
170
|
-
def _phe_mul(pfunc: PFunction, ciphertext:
|
406
|
+
def _phe_mul(pfunc: PFunction, ciphertext: CipherText, plaintext: Any) -> Any:
|
407
|
+
# Validate that first argument is a CipherText
|
171
408
|
if not isinstance(ciphertext, CipherText):
|
172
|
-
raise ValueError("
|
409
|
+
raise ValueError("First argument must be a CipherText instance")
|
410
|
+
|
173
411
|
try:
|
174
|
-
|
175
|
-
|
176
|
-
|
177
|
-
|
178
|
-
|
179
|
-
|
180
|
-
|
412
|
+
# Convert plaintext to numpy
|
413
|
+
plaintext_np = _convert_to_numpy(plaintext)
|
414
|
+
|
415
|
+
# Check if plaintext is floating point type - multiplication not supported
|
416
|
+
if np.issubdtype(plaintext_np.dtype, np.floating):
|
417
|
+
raise ValueError(
|
418
|
+
f"Homomorphic multiplication with floating point plaintext is not supported. "
|
419
|
+
f"Got plaintext dtype: {plaintext_np.dtype}"
|
420
|
+
)
|
421
|
+
|
422
|
+
# Use numpy broadcasting to determine result shape and broadcast operands
|
423
|
+
# Create dummy arrays with the same shapes to test broadcasting
|
424
|
+
try:
|
425
|
+
dummy_ct = np.zeros(ciphertext.semantic_shape)
|
426
|
+
dummy_pt = np.zeros(plaintext_np.shape)
|
427
|
+
broadcasted_dummy = dummy_ct * dummy_pt
|
428
|
+
result_shape = broadcasted_dummy.shape
|
429
|
+
except ValueError as e:
|
430
|
+
raise ValueError(
|
431
|
+
f"Operands cannot be broadcast together: CipherText shape {ciphertext.semantic_shape} "
|
432
|
+
f"vs plaintext shape {plaintext_np.shape}: {e}"
|
433
|
+
) from e
|
434
|
+
|
435
|
+
# Broadcast plaintext to match result shape if needed
|
436
|
+
if plaintext_np.shape != result_shape:
|
437
|
+
plaintext_broadcasted = np.broadcast_to(plaintext_np, result_shape)
|
181
438
|
else:
|
182
|
-
|
183
|
-
|
184
|
-
|
185
|
-
|
186
|
-
|
187
|
-
|
188
|
-
|
189
|
-
|
190
|
-
|
191
|
-
|
439
|
+
plaintext_broadcasted = plaintext_np
|
440
|
+
|
441
|
+
# If ciphertext needs broadcasting, we need to replicate its encrypted values
|
442
|
+
if ciphertext.semantic_shape != result_shape:
|
443
|
+
# Use numpy to create a properly broadcasted index mapping
|
444
|
+
# Create a dummy array with same shape as ciphertext, fill with indices
|
445
|
+
dummy_ct = (
|
446
|
+
np.arange(np.prod(ciphertext.semantic_shape))
|
447
|
+
.reshape(ciphertext.semantic_shape)
|
448
|
+
.astype(np.int64)
|
449
|
+
)
|
450
|
+
# Broadcast this to the result shape
|
451
|
+
broadcasted_indices = np.broadcast_to(dummy_ct, result_shape).flatten()
|
452
|
+
|
453
|
+
# Replicate ciphertext data according to the broadcasted indices
|
454
|
+
raw_ct: list[Any] = ciphertext.ct_data
|
455
|
+
broadcasted_ct_data = [raw_ct[int(idx)] for idx in broadcasted_indices]
|
456
|
+
else:
|
457
|
+
# No broadcasting needed for ciphertext
|
458
|
+
broadcasted_ct_data = ciphertext.ct_data
|
459
|
+
|
460
|
+
# Flatten the broadcasted plaintext data for element-wise multiplication
|
461
|
+
target_dtype = ciphertext.semantic_dtype
|
462
|
+
flat_data = plaintext_broadcasted.flatten()
|
463
|
+
|
464
|
+
# For multiplication, plaintext multipliers should NOT be encoded
|
465
|
+
# The ciphertext already contains the encoded value, multiplying by raw plaintext preserves semantics
|
466
|
+
raw_multipliers = []
|
467
|
+
for val in flat_data:
|
468
|
+
# Convert to appropriate numeric type but don't apply any encoding
|
469
|
+
if target_dtype.is_floating:
|
470
|
+
raw_val = float(val)
|
471
|
+
else:
|
472
|
+
raw_val = int(val)
|
473
|
+
raw_multipliers.append(raw_val)
|
474
|
+
|
475
|
+
# Perform homomorphic multiplication
|
476
|
+
# In Paillier, ciphertext * plaintext is supported
|
477
|
+
result_ciphertext = [
|
478
|
+
broadcasted_ct_data[i] * raw_multipliers[i]
|
479
|
+
for i in range(len(raw_multipliers))
|
480
|
+
]
|
481
|
+
|
482
|
+
# Create result CipherText with the broadcasted shape and encoding parameters
|
483
|
+
return [
|
484
|
+
CipherText(
|
485
|
+
ct_data=result_ciphertext,
|
486
|
+
semantic_dtype=ciphertext.semantic_dtype,
|
487
|
+
semantic_shape=result_shape,
|
488
|
+
scheme=ciphertext.scheme,
|
489
|
+
key_size=ciphertext.key_size,
|
490
|
+
pk_data=ciphertext.pk_data,
|
491
|
+
max_value=ciphertext.max_value,
|
492
|
+
fxp_bits=ciphertext.fxp_bits,
|
493
|
+
modulus=ciphertext.modulus,
|
494
|
+
)
|
495
|
+
]
|
496
|
+
|
192
497
|
except ValueError:
|
498
|
+
# Re-raise ValueError directly (validation errors)
|
193
499
|
raise
|
194
|
-
except Exception as e:
|
500
|
+
except Exception as e:
|
195
501
|
raise RuntimeError(f"Failed to perform multiplication: {e}") from e
|
196
502
|
|
197
503
|
|
@@ -205,7 +511,7 @@ def _phe_add(pfunc: PFunction, lhs: Any, rhs: Any) -> Any:
|
|
205
511
|
elif isinstance(rhs, CipherText):
|
206
512
|
return _phe_add_ct2pt(rhs, lhs)
|
207
513
|
else:
|
208
|
-
return
|
514
|
+
return _convert_to_numpy(lhs) + _convert_to_numpy(rhs)
|
209
515
|
except ValueError:
|
210
516
|
raise
|
211
517
|
except Exception as e: # pragma: no cover
|
@@ -213,89 +519,1140 @@ def _phe_add(pfunc: PFunction, lhs: Any, rhs: Any) -> Any:
|
|
213
519
|
|
214
520
|
|
215
521
|
def _phe_add_ct2ct(ct1: CipherText, ct2: CipherText) -> CipherText:
|
522
|
+
# Validate compatibility
|
216
523
|
if ct1.scheme != ct2.scheme or ct1.key_size != ct2.key_size:
|
217
|
-
raise ValueError("CipherText operands must use same scheme
|
524
|
+
raise ValueError("CipherText operands must use same scheme and key size")
|
525
|
+
|
218
526
|
if ct1.pk_data != ct2.pk_data:
|
219
|
-
raise ValueError("CipherText operands must
|
220
|
-
|
221
|
-
|
222
|
-
|
527
|
+
raise ValueError("CipherText operands must be encrypted with same key")
|
528
|
+
|
529
|
+
# Check for mixed precision issue: floating point ciphertext + integer ciphertext
|
530
|
+
# This would cause decode failures due to different fixed-point encoding scales
|
531
|
+
if ct1.semantic_dtype.is_floating != ct2.semantic_dtype.is_floating:
|
532
|
+
raise ValueError(
|
533
|
+
f"Cannot add ciphertexts with different numeric types due to fixed-point encoding. "
|
534
|
+
f"First CipherText dtype: {ct1.semantic_dtype}, second CipherText dtype: {ct2.semantic_dtype}. "
|
535
|
+
f"Both operands must have the same numeric type (both floating or both integer)."
|
536
|
+
)
|
537
|
+
|
538
|
+
# Use numpy broadcasting to determine result shape and broadcast operands
|
539
|
+
try:
|
540
|
+
dummy_ct1 = np.zeros(ct1.semantic_shape)
|
541
|
+
dummy_ct2 = np.zeros(ct2.semantic_shape)
|
542
|
+
broadcasted_dummy = dummy_ct1 + dummy_ct2
|
543
|
+
result_shape = broadcasted_dummy.shape
|
544
|
+
except ValueError as e:
|
545
|
+
raise ValueError(
|
546
|
+
f"CipherText operands cannot be broadcast together: shape {ct1.semantic_shape} "
|
547
|
+
f"vs shape {ct2.semantic_shape}: {e}"
|
548
|
+
) from e
|
549
|
+
|
550
|
+
# Broadcast ct1 if needed
|
551
|
+
if ct1.semantic_shape != result_shape:
|
552
|
+
dummy_ct1 = (
|
553
|
+
np.arange(np.prod(ct1.semantic_shape))
|
554
|
+
.reshape(ct1.semantic_shape)
|
555
|
+
.astype(np.int64)
|
556
|
+
)
|
557
|
+
broadcasted_indices1 = np.broadcast_to(dummy_ct1, result_shape).flatten()
|
558
|
+
raw_ct1: list[Any] = ct1.ct_data
|
559
|
+
broadcasted_ct1_data = [raw_ct1[int(idx)] for idx in broadcasted_indices1]
|
560
|
+
else:
|
561
|
+
broadcasted_ct1_data = ct1.ct_data
|
562
|
+
|
563
|
+
# Broadcast ct2 if needed
|
564
|
+
if ct2.semantic_shape != result_shape:
|
565
|
+
dummy_ct2 = (
|
566
|
+
np.arange(np.prod(ct2.semantic_shape))
|
567
|
+
.reshape(ct2.semantic_shape)
|
568
|
+
.astype(np.int64)
|
569
|
+
)
|
570
|
+
broadcasted_indices2 = np.broadcast_to(dummy_ct2, result_shape).flatten()
|
571
|
+
raw_ct2: list[Any] = ct2.ct_data
|
572
|
+
broadcasted_ct2_data = [raw_ct2[int(idx)] for idx in broadcasted_indices2]
|
573
|
+
else:
|
574
|
+
broadcasted_ct2_data = ct2.ct_data
|
575
|
+
|
576
|
+
# Perform homomorphic addition
|
577
|
+
result_ciphertext = [
|
578
|
+
broadcasted_ct1_data[i] + broadcasted_ct2_data[i]
|
579
|
+
for i in range(len(broadcasted_ct1_data))
|
580
|
+
]
|
581
|
+
|
582
|
+
# Create result CipherText with broadcasted shape and encoding parameters
|
223
583
|
return CipherText(
|
224
|
-
ct_data=
|
584
|
+
ct_data=result_ciphertext,
|
225
585
|
semantic_dtype=ct1.semantic_dtype,
|
226
|
-
semantic_shape=
|
586
|
+
semantic_shape=result_shape,
|
227
587
|
scheme=ct1.scheme,
|
228
588
|
key_size=ct1.key_size,
|
229
589
|
pk_data=ct1.pk_data,
|
590
|
+
max_value=ct1.max_value,
|
591
|
+
fxp_bits=ct1.fxp_bits,
|
592
|
+
modulus=ct1.modulus,
|
230
593
|
)
|
231
594
|
|
232
595
|
|
233
596
|
def _phe_add_ct2pt(ciphertext: CipherText, plaintext: TensorLike) -> CipherText:
|
234
|
-
|
235
|
-
|
236
|
-
|
597
|
+
# Convert plaintext to numpy
|
598
|
+
plaintext_np = _convert_to_numpy(plaintext)
|
599
|
+
plaintext_dtype = DType.from_numpy(plaintext_np.dtype)
|
600
|
+
|
601
|
+
# Check for mixed precision issue: floating point ciphertext + integer plaintext
|
602
|
+
# This would cause decode failures due to 2**fxp * f + i scaling mismatch
|
603
|
+
if ciphertext.semantic_dtype.is_floating and not plaintext_dtype.is_floating:
|
604
|
+
raise ValueError(
|
605
|
+
f"Cannot add integer plaintext to floating point ciphertext due to fixed-point encoding. "
|
606
|
+
f"CipherText dtype: {ciphertext.semantic_dtype}, plaintext dtype: {plaintext_dtype}. "
|
607
|
+
f"Both operands must have the same numeric type (both floating or both integer)."
|
608
|
+
)
|
609
|
+
|
610
|
+
# Check for mixed precision issue: integer ciphertext + floating point plaintext
|
611
|
+
if not ciphertext.semantic_dtype.is_floating and plaintext_dtype.is_floating:
|
612
|
+
raise ValueError(
|
613
|
+
f"Cannot add floating point plaintext to integer ciphertext due to fixed-point encoding. "
|
614
|
+
f"CipherText dtype: {ciphertext.semantic_dtype}, plaintext dtype: {plaintext_dtype}. "
|
615
|
+
f"Both operands must have the same numeric type (both floating or both integer)."
|
616
|
+
)
|
617
|
+
|
618
|
+
# Use numpy broadcasting to determine result shape and broadcast operands
|
619
|
+
try:
|
620
|
+
dummy_ct = np.zeros(ciphertext.semantic_shape)
|
621
|
+
dummy_pt = np.zeros(plaintext_np.shape)
|
622
|
+
broadcasted_dummy = dummy_ct + dummy_pt
|
623
|
+
result_shape = broadcasted_dummy.shape
|
624
|
+
except ValueError as e:
|
625
|
+
raise ValueError(
|
626
|
+
f"Operands cannot be broadcast together: CipherText shape {ciphertext.semantic_shape} "
|
627
|
+
f"vs plaintext shape {plaintext_np.shape}: {e}"
|
628
|
+
) from e
|
629
|
+
|
630
|
+
# Broadcast plaintext to match result shape if needed
|
631
|
+
if plaintext_np.shape != result_shape:
|
632
|
+
plaintext_broadcasted = np.broadcast_to(plaintext_np, result_shape)
|
633
|
+
else:
|
634
|
+
plaintext_broadcasted = plaintext_np
|
635
|
+
|
636
|
+
# Broadcast ciphertext if needed
|
637
|
+
if ciphertext.semantic_shape != result_shape:
|
638
|
+
dummy_ct = (
|
639
|
+
np.arange(np.prod(ciphertext.semantic_shape))
|
640
|
+
.reshape(ciphertext.semantic_shape)
|
641
|
+
.astype(np.int64)
|
642
|
+
)
|
643
|
+
broadcasted_indices = np.broadcast_to(dummy_ct, result_shape).flatten()
|
644
|
+
raw_ct: list[Any] = ciphertext.ct_data
|
645
|
+
broadcasted_ct_data = [raw_ct[int(idx)] for idx in broadcasted_indices]
|
646
|
+
else:
|
647
|
+
broadcasted_ct_data = ciphertext.ct_data
|
648
|
+
|
649
|
+
# For ciphertext + plaintext addition, we encrypt the plaintext first
|
650
|
+
# and then do ciphertext + ciphertext addition
|
237
651
|
if ciphertext.pk_data is None:
|
238
|
-
raise ValueError(
|
652
|
+
raise ValueError(
|
653
|
+
"CipherText must contain public key data for plaintext addition"
|
654
|
+
)
|
655
|
+
|
656
|
+
# Create lightPHE instance to encrypt the plaintext
|
239
657
|
phe = LightPHE(
|
240
658
|
algorithm_name=ciphertext.scheme,
|
241
659
|
key_size=ciphertext.key_size,
|
242
660
|
precision=PRECISION,
|
243
661
|
)
|
244
662
|
phe.cs.keys["public_key"] = ciphertext.pk_data
|
663
|
+
|
664
|
+
# Encrypt the broadcasted plaintext using same method as original encryption
|
245
665
|
target_dtype = ciphertext.semantic_dtype
|
246
|
-
|
247
|
-
|
248
|
-
|
249
|
-
|
250
|
-
|
251
|
-
|
252
|
-
|
666
|
+
flat_data = plaintext_broadcasted.flatten()
|
667
|
+
|
668
|
+
# Use range encoding for consistency with encryption
|
669
|
+
encoded_data_list = []
|
670
|
+
for val in flat_data:
|
671
|
+
if ciphertext.modulus is None:
|
672
|
+
raise ValueError("Ciphertext modulus is None, encryption may have failed")
|
673
|
+
encoded_val = _range_encode_mixed(
|
674
|
+
val,
|
675
|
+
ciphertext.max_value,
|
676
|
+
ciphertext.fxp_bits,
|
677
|
+
ciphertext.modulus,
|
678
|
+
target_dtype,
|
679
|
+
)
|
680
|
+
encoded_data_list.append(encoded_val)
|
681
|
+
|
682
|
+
encrypted_plaintext = [phe.encrypt(val) for val in encoded_data_list]
|
683
|
+
|
684
|
+
# Perform addition
|
685
|
+
result_ciphertext = [
|
686
|
+
encrypted_plaintext[i] + broadcasted_ct_data[i]
|
687
|
+
for i in range(len(encrypted_plaintext))
|
688
|
+
]
|
689
|
+
|
690
|
+
# Create result CipherText with broadcasted shape and encoding parameters
|
253
691
|
return CipherText(
|
254
|
-
ct_data=
|
692
|
+
ct_data=result_ciphertext,
|
255
693
|
semantic_dtype=ciphertext.semantic_dtype,
|
256
|
-
semantic_shape=
|
694
|
+
semantic_shape=result_shape,
|
257
695
|
scheme=ciphertext.scheme,
|
258
696
|
key_size=ciphertext.key_size,
|
259
697
|
pk_data=ciphertext.pk_data,
|
698
|
+
max_value=ciphertext.max_value,
|
699
|
+
fxp_bits=ciphertext.fxp_bits,
|
700
|
+
modulus=ciphertext.modulus,
|
701
|
+
)
|
702
|
+
|
703
|
+
|
704
|
+
def _create_encrypted_zero(ciphertext: CipherText) -> Any:
|
705
|
+
# Create lightPHE instance with the same configuration
|
706
|
+
phe = LightPHE(
|
707
|
+
algorithm_name=ciphertext.scheme,
|
708
|
+
key_size=ciphertext.key_size,
|
709
|
+
precision=PRECISION,
|
260
710
|
)
|
261
711
|
|
712
|
+
# CRITICAL: Set the same modulus as the original ciphertext
|
713
|
+
if ciphertext.modulus is not None:
|
714
|
+
phe.cs.plaintext_modulo = ciphertext.modulus
|
715
|
+
phe.cs.ciphertext_modulo = ciphertext.modulus * ciphertext.modulus
|
716
|
+
|
717
|
+
phe.cs.keys["public_key"] = ciphertext.pk_data
|
718
|
+
|
719
|
+
# Encrypt zero value using range encoding for consistency
|
720
|
+
if ciphertext.modulus is None:
|
721
|
+
raise ValueError("Ciphertext modulus is None, encryption may have failed")
|
722
|
+
|
723
|
+
zero_encoded = _range_encode_mixed(
|
724
|
+
0,
|
725
|
+
ciphertext.max_value,
|
726
|
+
ciphertext.fxp_bits,
|
727
|
+
ciphertext.modulus,
|
728
|
+
ciphertext.semantic_dtype,
|
729
|
+
)
|
730
|
+
|
731
|
+
return phe.encrypt(zero_encoded)
|
732
|
+
|
262
733
|
|
263
734
|
@kernel_def("phe.decrypt")
|
264
|
-
def _phe_decrypt(
|
735
|
+
def _phe_decrypt(
|
736
|
+
pfunc: PFunction, ciphertext: CipherText, private_key: PrivateKey
|
737
|
+
) -> Any:
|
738
|
+
# Validate argument types
|
265
739
|
if not isinstance(ciphertext, CipherText):
|
266
|
-
raise ValueError("
|
740
|
+
raise ValueError("First argument must be a CipherText instance")
|
267
741
|
if not isinstance(private_key, PrivateKey):
|
268
|
-
raise ValueError("
|
742
|
+
raise ValueError("Second argument must be a PrivateKey instance")
|
743
|
+
|
744
|
+
# Validate key compatibility
|
269
745
|
if (
|
270
746
|
ciphertext.scheme != private_key.scheme
|
271
747
|
or ciphertext.key_size != private_key.key_size
|
272
748
|
):
|
273
|
-
raise ValueError("CipherText and PrivateKey must use same scheme
|
749
|
+
raise ValueError("CipherText and PrivateKey must use same scheme and key size")
|
750
|
+
|
274
751
|
try:
|
752
|
+
# Create lightPHE instance with the same scheme/key_size
|
275
753
|
phe = LightPHE(
|
276
754
|
algorithm_name=private_key.scheme,
|
277
755
|
key_size=private_key.key_size,
|
278
756
|
precision=PRECISION,
|
279
757
|
)
|
758
|
+
|
759
|
+
# CRITICAL FIX: Manually set the moduli to match the original encryption
|
760
|
+
# This ensures the decryption uses the same mathematical structure
|
761
|
+
if ciphertext.modulus is not None:
|
762
|
+
# Force the lightPHE instance to use the same modulus as during encryption
|
763
|
+
phe.cs.plaintext_modulo = ciphertext.modulus
|
764
|
+
# For Paillier: ciphertext_modulo = N^2
|
765
|
+
phe.cs.ciphertext_modulo = ciphertext.modulus * ciphertext.modulus
|
766
|
+
|
767
|
+
# Set both public and private keys (lightPHE needs both for proper decryption)
|
280
768
|
phe.cs.keys["private_key"] = private_key.sk_data
|
281
769
|
phe.cs.keys["public_key"] = private_key.pk_data
|
770
|
+
|
771
|
+
# Decrypt the data
|
282
772
|
target_dtype = ciphertext.semantic_dtype.to_numpy()
|
283
|
-
decrypted_raw = phe.decrypt(ciphertext.ct_data
|
284
|
-
|
285
|
-
|
286
|
-
|
287
|
-
|
288
|
-
|
289
|
-
|
290
|
-
|
291
|
-
|
773
|
+
decrypted_raw = [phe.decrypt(ct) for ct in ciphertext.ct_data]
|
774
|
+
|
775
|
+
# Decode using range decoding
|
776
|
+
if ciphertext.modulus is None:
|
777
|
+
raise ValueError("Ciphertext modulus is None, encryption may have failed")
|
778
|
+
|
779
|
+
decoded_data = []
|
780
|
+
for encrypted_val in decrypted_raw:
|
781
|
+
# Extract numeric value from lightPHE result
|
782
|
+
if isinstance(encrypted_val, (int, float)):
|
783
|
+
raw_val = encrypted_val
|
784
|
+
elif hasattr(encrypted_val, "__getitem__") and len(encrypted_val) > 0:
|
785
|
+
raw_val = encrypted_val[0]
|
786
|
+
else:
|
787
|
+
raise ValueError(f"Cannot extract numeric value from {encrypted_val}")
|
788
|
+
|
789
|
+
# Convert to int for decoding
|
790
|
+
int_val = int(
|
791
|
+
raw_val
|
792
|
+
) # Use mixed decoding which returns values based on semantic type
|
793
|
+
decoded_val = _range_decode_mixed(
|
794
|
+
int_val,
|
795
|
+
ciphertext.max_value,
|
796
|
+
ciphertext.fxp_bits,
|
797
|
+
ciphertext.modulus,
|
798
|
+
ciphertext.semantic_dtype,
|
799
|
+
)
|
800
|
+
decoded_data.append(decoded_val)
|
801
|
+
|
802
|
+
# Convert to target dtype
|
803
|
+
if target_dtype.kind in "iu": # integer types
|
804
|
+
# Convert floats back to integers for integer semantic types
|
805
|
+
processed_data = [round(val) for val in decoded_data]
|
806
|
+
# Handle overflow for smaller integer types
|
292
807
|
info = np.iinfo(target_dtype)
|
293
|
-
|
294
|
-
|
295
|
-
|
296
|
-
|
808
|
+
processed_data = [
|
809
|
+
max(info.min, min(info.max, val)) for val in processed_data
|
810
|
+
]
|
811
|
+
else: # float types
|
812
|
+
processed_data = decoded_data
|
813
|
+
|
814
|
+
# Create array and reshape to target shape
|
815
|
+
plaintext_np = np.array(processed_data, dtype=target_dtype).reshape(
|
297
816
|
ciphertext.semantic_shape
|
298
817
|
)
|
299
|
-
|
300
|
-
|
818
|
+
|
819
|
+
return [plaintext_np]
|
820
|
+
|
821
|
+
except Exception as e:
|
301
822
|
raise RuntimeError(f"Failed to decrypt data: {e}") from e
|
823
|
+
|
824
|
+
|
825
|
+
@kernel_def("phe.dot")
|
826
|
+
def _phe_dot(pfunc: PFunction, ciphertext: CipherText, plaintext: TensorLike) -> Any:
|
827
|
+
"""Execute homomorphic dot product with zero-value optimization.
|
828
|
+
|
829
|
+
Supports various dot product operations:
|
830
|
+
- Scalar * Scalar -> Scalar
|
831
|
+
- Vector * Vector -> Scalar (inner product)
|
832
|
+
- Matrix * Vector -> Vector
|
833
|
+
- N-D tensor * M-D tensor -> result based on numpy.dot semantics
|
834
|
+
|
835
|
+
Optimization: Skip multiplication when plaintext value is 0, and handle
|
836
|
+
the special case where all plaintext values are 0.
|
837
|
+
|
838
|
+
"""
|
839
|
+
# Validate that first argument is a CipherText
|
840
|
+
if not isinstance(ciphertext, CipherText):
|
841
|
+
raise ValueError("First argument must be a CipherText instance")
|
842
|
+
if isinstance(plaintext, CipherText):
|
843
|
+
raise ValueError("Second argument must be a plaintext TensorLike")
|
844
|
+
|
845
|
+
try:
|
846
|
+
# Convert plaintext to numpy
|
847
|
+
plaintext_np = _convert_to_numpy(plaintext)
|
848
|
+
|
849
|
+
# Check if plaintext is floating point type - dot product not supported
|
850
|
+
if np.issubdtype(plaintext_np.dtype, np.floating):
|
851
|
+
raise ValueError(
|
852
|
+
f"Homomorphic dot product with floating point plaintext is not supported. "
|
853
|
+
f"Got plaintext dtype: {plaintext_np.dtype}"
|
854
|
+
)
|
855
|
+
|
856
|
+
# Use numpy.dot to determine result shape and validate compatibility
|
857
|
+
# Create dummy arrays with same shapes to test dot product compatibility
|
858
|
+
try:
|
859
|
+
dummy_ct = np.zeros(ciphertext.semantic_shape)
|
860
|
+
dummy_pt = np.zeros(plaintext_np.shape)
|
861
|
+
dummy_result = np.dot(dummy_ct, dummy_pt)
|
862
|
+
result_shape = dummy_result.shape
|
863
|
+
except ValueError as e:
|
864
|
+
raise ValueError(
|
865
|
+
f"Shapes are not compatible for dot product: CipherText shape {ciphertext.semantic_shape} "
|
866
|
+
f"vs plaintext shape {plaintext_np.shape}: {e}"
|
867
|
+
) from e
|
868
|
+
|
869
|
+
# Perform dot product based on input dimensions
|
870
|
+
ct_shape = ciphertext.semantic_shape
|
871
|
+
pt_shape = plaintext_np.shape
|
872
|
+
target_dtype = ciphertext.semantic_dtype
|
873
|
+
|
874
|
+
if target_dtype.is_floating:
|
875
|
+
pt_data = plaintext_np.astype(float)
|
876
|
+
# Use a small epsilon for floating point zero comparison
|
877
|
+
epsilon = 1e-15
|
878
|
+
is_zero_func = lambda x: abs(x) < epsilon
|
879
|
+
else: # integer types
|
880
|
+
pt_data = plaintext_np.astype(int)
|
881
|
+
is_zero_func = lambda x: x == 0
|
882
|
+
|
883
|
+
# Helper function to create encrypted zero when needed
|
884
|
+
def get_encrypted_zero() -> Any:
|
885
|
+
return _create_encrypted_zero(ciphertext)
|
886
|
+
|
887
|
+
if len(ct_shape) == 0 and len(pt_shape) == 0:
|
888
|
+
# Scalar * Scalar
|
889
|
+
pt_val = pt_data.item()
|
890
|
+
if is_zero_func(pt_val):
|
891
|
+
result_ciphertext = get_encrypted_zero()
|
892
|
+
else:
|
893
|
+
# Use single value (not list) for multiplication
|
894
|
+
val = float(pt_val) if target_dtype.is_floating else int(pt_val)
|
895
|
+
result_ciphertext = ciphertext.ct_data[0] * val
|
896
|
+
result_ct_data = [result_ciphertext]
|
897
|
+
|
898
|
+
elif len(ct_shape) == 1 and len(pt_shape) == 1:
|
899
|
+
# Vector * Vector -> Scalar (inner product)
|
900
|
+
if ct_shape[0] != pt_shape[0]:
|
901
|
+
raise ValueError(
|
902
|
+
f"Vector size mismatch: CipherText size {ct_shape[0]} "
|
903
|
+
f"vs plaintext size {pt_shape[0]}"
|
904
|
+
)
|
905
|
+
|
906
|
+
# Compute element-wise products, skipping zeros
|
907
|
+
non_zero_products = []
|
908
|
+
for i in range(ct_shape[0]):
|
909
|
+
pt_val = pt_data[i]
|
910
|
+
if not is_zero_func(pt_val):
|
911
|
+
# Convert to appropriate type and use single value (not list)
|
912
|
+
val = float(pt_val) if target_dtype.is_floating else int(pt_val)
|
913
|
+
product = ciphertext.ct_data[i] * val
|
914
|
+
non_zero_products.append(product)
|
915
|
+
|
916
|
+
# Handle result
|
917
|
+
if not non_zero_products:
|
918
|
+
# All plaintext values are zero
|
919
|
+
result_ciphertext = get_encrypted_zero()
|
920
|
+
else:
|
921
|
+
# Sum all non-zero products
|
922
|
+
result_ciphertext = non_zero_products[0]
|
923
|
+
for i in range(1, len(non_zero_products)):
|
924
|
+
result_ciphertext = result_ciphertext + non_zero_products[i]
|
925
|
+
|
926
|
+
result_ct_data = [result_ciphertext]
|
927
|
+
|
928
|
+
elif len(ct_shape) == 2 and len(pt_shape) == 1:
|
929
|
+
# Matrix * Vector -> Vector
|
930
|
+
if ct_shape[1] != pt_shape[0]:
|
931
|
+
raise ValueError(
|
932
|
+
f"Matrix-vector dimension mismatch: Matrix shape {ct_shape} "
|
933
|
+
f"vs vector shape {pt_shape}"
|
934
|
+
)
|
935
|
+
|
936
|
+
result_ct_data = []
|
937
|
+
for i in range(ct_shape[0]): # For each row of the matrix
|
938
|
+
# Compute dot product of row i with the vector, skipping zeros
|
939
|
+
row_products = []
|
940
|
+
for j in range(ct_shape[1]): # For each column in the row
|
941
|
+
pt_val = pt_data[j]
|
942
|
+
if not is_zero_func(pt_val):
|
943
|
+
ct_idx = i * ct_shape[1] + j
|
944
|
+
# Use single value (not list) for multiplication
|
945
|
+
val = float(pt_val) if target_dtype.is_floating else int(pt_val)
|
946
|
+
product = ciphertext.ct_data[ct_idx] * val
|
947
|
+
row_products.append(product)
|
948
|
+
|
949
|
+
# Handle row result
|
950
|
+
if not row_products:
|
951
|
+
# All plaintext values in this row are zero
|
952
|
+
row_result = get_encrypted_zero()
|
953
|
+
else:
|
954
|
+
# Sum non-zero products for this row
|
955
|
+
row_result = row_products[0]
|
956
|
+
for k in range(1, len(row_products)):
|
957
|
+
row_result = row_result + row_products[k]
|
958
|
+
|
959
|
+
result_ct_data.append(row_result)
|
960
|
+
|
961
|
+
elif len(ct_shape) == 1 and len(pt_shape) == 2:
|
962
|
+
# Vector * Matrix -> Vector
|
963
|
+
if ct_shape[0] != pt_shape[0]:
|
964
|
+
raise ValueError(
|
965
|
+
f"Vector-matrix dimension mismatch: Vector shape {ct_shape} "
|
966
|
+
f"vs matrix shape {pt_shape}"
|
967
|
+
)
|
968
|
+
|
969
|
+
result_ct_data = []
|
970
|
+
for j in range(pt_shape[1]): # For each column of the matrix
|
971
|
+
# Compute dot product of vector with column j, skipping zeros
|
972
|
+
col_products = []
|
973
|
+
for i in range(pt_shape[0]): # For each row in the column
|
974
|
+
pt_val = pt_data[i, j]
|
975
|
+
if not is_zero_func(pt_val):
|
976
|
+
# Use single value (not list) for multiplication
|
977
|
+
val = float(pt_val) if target_dtype.is_floating else int(pt_val)
|
978
|
+
product = ciphertext.ct_data[i] * val
|
979
|
+
col_products.append(product)
|
980
|
+
|
981
|
+
# Handle column result
|
982
|
+
if not col_products:
|
983
|
+
# All plaintext values in this column are zero
|
984
|
+
col_result = get_encrypted_zero()
|
985
|
+
else:
|
986
|
+
# Sum non-zero products for this column
|
987
|
+
col_result = col_products[0]
|
988
|
+
for k in range(1, len(col_products)):
|
989
|
+
col_result = col_result + col_products[k]
|
990
|
+
|
991
|
+
result_ct_data.append(col_result)
|
992
|
+
|
993
|
+
elif len(ct_shape) == 2 and len(pt_shape) == 2:
|
994
|
+
# Matrix * Matrix -> Matrix
|
995
|
+
if ct_shape[1] != pt_shape[0]:
|
996
|
+
raise ValueError(
|
997
|
+
f"Matrix dimension mismatch: First matrix shape {ct_shape} "
|
998
|
+
f"vs second matrix shape {pt_shape}"
|
999
|
+
)
|
1000
|
+
|
1001
|
+
result_ct_data = []
|
1002
|
+
for i in range(ct_shape[0]): # For each row of first matrix
|
1003
|
+
for j in range(pt_shape[1]): # For each column of second matrix
|
1004
|
+
# Compute dot product of row i with column j, skipping zeros
|
1005
|
+
products = []
|
1006
|
+
for k in range(ct_shape[1]): # Sum over common dimension
|
1007
|
+
pt_val = pt_data[k, j]
|
1008
|
+
if not is_zero_func(pt_val):
|
1009
|
+
ct_idx = i * ct_shape[1] + k
|
1010
|
+
# Use single value (not list) for multiplication
|
1011
|
+
val = (
|
1012
|
+
float(pt_val)
|
1013
|
+
if target_dtype.is_floating
|
1014
|
+
else int(pt_val)
|
1015
|
+
)
|
1016
|
+
product = ciphertext.ct_data[ct_idx] * val
|
1017
|
+
products.append(product)
|
1018
|
+
|
1019
|
+
# Handle element result
|
1020
|
+
if not products:
|
1021
|
+
# All plaintext values for this element are zero
|
1022
|
+
element_result = get_encrypted_zero()
|
1023
|
+
else:
|
1024
|
+
# Sum non-zero products for this element
|
1025
|
+
element_result = products[0]
|
1026
|
+
for p in range(1, len(products)):
|
1027
|
+
element_result = element_result + products[p]
|
1028
|
+
|
1029
|
+
result_ct_data.append(element_result)
|
1030
|
+
|
1031
|
+
else:
|
1032
|
+
# General N-D tensor dot product
|
1033
|
+
# Flatten both tensors and perform generalized dot product
|
1034
|
+
ct_flat = ciphertext.ct_data
|
1035
|
+
pt_flat = pt_data.flatten()
|
1036
|
+
|
1037
|
+
# For general case, we implement numpy.dot semantics
|
1038
|
+
# This is a simplified implementation for common cases
|
1039
|
+
if len(ct_shape) >= 2 and len(pt_shape) >= 1:
|
1040
|
+
# Treat as matrix multiplication on the last axis of ct and first axis of pt
|
1041
|
+
last_dim_ct = ct_shape[-1]
|
1042
|
+
first_dim_pt = pt_shape[0]
|
1043
|
+
|
1044
|
+
if last_dim_ct != first_dim_pt:
|
1045
|
+
raise ValueError(
|
1046
|
+
f"Tensor dimension mismatch: CipherText last dimension {last_dim_ct} "
|
1047
|
+
f"vs plaintext first dimension {first_dim_pt}"
|
1048
|
+
)
|
1049
|
+
|
1050
|
+
# Reshape for matrix multiplication
|
1051
|
+
ct_reshaped_size = int(np.prod(ct_shape[:-1]))
|
1052
|
+
pt_reshaped_size = int(np.prod(pt_shape[1:]))
|
1053
|
+
|
1054
|
+
result_ct_data = []
|
1055
|
+
for i in range(ct_reshaped_size):
|
1056
|
+
for j in range(pt_reshaped_size):
|
1057
|
+
# Compute dot product for element (i, j), skipping zeros
|
1058
|
+
products = []
|
1059
|
+
for k in range(last_dim_ct):
|
1060
|
+
pt_idx = k * pt_reshaped_size + j
|
1061
|
+
pt_val = pt_flat[pt_idx]
|
1062
|
+
if not is_zero_func(pt_val):
|
1063
|
+
ct_idx = i * last_dim_ct + k
|
1064
|
+
# Use single value (not list) for multiplication
|
1065
|
+
val = (
|
1066
|
+
float(pt_val)
|
1067
|
+
if target_dtype.is_floating
|
1068
|
+
else int(pt_val)
|
1069
|
+
)
|
1070
|
+
product = ct_flat[ct_idx] * val
|
1071
|
+
products.append(product)
|
1072
|
+
|
1073
|
+
# Handle element result
|
1074
|
+
if not products:
|
1075
|
+
# All plaintext values for this element are zero
|
1076
|
+
element_result = get_encrypted_zero()
|
1077
|
+
else:
|
1078
|
+
# Sum non-zero products
|
1079
|
+
element_result = products[0]
|
1080
|
+
for p in range(1, len(products)):
|
1081
|
+
element_result = element_result + products[p]
|
1082
|
+
result_ct_data.append(element_result)
|
1083
|
+
else:
|
1084
|
+
raise ValueError(
|
1085
|
+
f"Unsupported tensor shapes for dot product: "
|
1086
|
+
f"CipherText shape {ct_shape}, plaintext shape {pt_shape}"
|
1087
|
+
)
|
1088
|
+
|
1089
|
+
# Create result CipherText with computed shape and encoding parameters
|
1090
|
+
return [
|
1091
|
+
CipherText(
|
1092
|
+
ct_data=result_ct_data,
|
1093
|
+
semantic_dtype=ciphertext.semantic_dtype,
|
1094
|
+
semantic_shape=result_shape,
|
1095
|
+
scheme=ciphertext.scheme,
|
1096
|
+
key_size=ciphertext.key_size,
|
1097
|
+
pk_data=ciphertext.pk_data,
|
1098
|
+
max_value=ciphertext.max_value,
|
1099
|
+
fxp_bits=ciphertext.fxp_bits,
|
1100
|
+
modulus=ciphertext.modulus,
|
1101
|
+
)
|
1102
|
+
]
|
1103
|
+
|
1104
|
+
except ValueError:
|
1105
|
+
# Re-raise ValueError directly (validation errors)
|
1106
|
+
raise
|
1107
|
+
except Exception as e:
|
1108
|
+
raise RuntimeError(f"Failed to perform dot product: {e}") from e
|
1109
|
+
|
1110
|
+
|
1111
|
+
@kernel_def("phe.gather")
|
1112
|
+
def _phe_gather(pfunc: PFunction, ciphertext: CipherText, indices: Any) -> Any:
|
1113
|
+
"""Execute gather operation on CipherText.
|
1114
|
+
|
1115
|
+
Supports gathering from multidimensional CipherText using multidimensional indices.
|
1116
|
+
The operation follows numpy.take semantics:
|
1117
|
+
- result.shape = indices.shape + ciphertext.shape[:axis] + ciphertext.shape[axis+1:]
|
1118
|
+
- Gathering is performed along the specified axis of ciphertext
|
1119
|
+
"""
|
1120
|
+
# Validate that first argument is a CipherText
|
1121
|
+
if not isinstance(ciphertext, CipherText):
|
1122
|
+
raise ValueError("First argument must be a CipherText instance")
|
1123
|
+
|
1124
|
+
# Get axis parameter from pfunc.attrs, default to 0
|
1125
|
+
axis = pfunc.attrs.get("axis", 0)
|
1126
|
+
|
1127
|
+
try:
|
1128
|
+
# Convert indices to numpy
|
1129
|
+
indices_np = _convert_to_numpy(indices)
|
1130
|
+
|
1131
|
+
if not np.issubdtype(indices_np.dtype, np.integer):
|
1132
|
+
raise ValueError("Indices must be of integer type")
|
1133
|
+
|
1134
|
+
# Validate that ciphertext has at least 1 dimension for indexing
|
1135
|
+
if len(ciphertext.semantic_shape) == 0:
|
1136
|
+
raise ValueError("Cannot gather from scalar CipherText")
|
1137
|
+
|
1138
|
+
# Normalize axis to positive value
|
1139
|
+
ndim = len(ciphertext.semantic_shape)
|
1140
|
+
if axis < 0:
|
1141
|
+
axis = ndim + axis
|
1142
|
+
if axis < 0 or axis >= ndim:
|
1143
|
+
raise ValueError(
|
1144
|
+
f"Axis {pfunc.attrs.get('axis', 0)} is out of bounds for array of dimension {ndim}"
|
1145
|
+
)
|
1146
|
+
|
1147
|
+
# Validate indices are within bounds for the specified axis
|
1148
|
+
axis_size = ciphertext.semantic_shape[axis]
|
1149
|
+
if np.any(indices_np < 0) or np.any(indices_np >= axis_size):
|
1150
|
+
raise ValueError(
|
1151
|
+
f"Indices are out of bounds for axis {axis} with size {axis_size}. "
|
1152
|
+
f"Got indices in range [{np.min(indices_np)}, {np.max(indices_np)}]"
|
1153
|
+
)
|
1154
|
+
|
1155
|
+
# Calculate result shape: indices.shape + ciphertext.shape[:axis] + ciphertext.shape[axis+1:]
|
1156
|
+
result_shape = (
|
1157
|
+
indices_np.shape
|
1158
|
+
+ ciphertext.semantic_shape[:axis]
|
1159
|
+
+ ciphertext.semantic_shape[axis + 1 :]
|
1160
|
+
)
|
1161
|
+
|
1162
|
+
# Calculate strides for multi-axis gathering
|
1163
|
+
ct_shape = ciphertext.semantic_shape
|
1164
|
+
|
1165
|
+
# Stride calculations for arbitrary axis
|
1166
|
+
# Elements before axis contribute to outer stride
|
1167
|
+
outer_stride = int(np.prod(ct_shape[:axis])) if axis > 0 else 1
|
1168
|
+
# Elements after axis contribute to inner stride
|
1169
|
+
inner_stride = int(np.prod(ct_shape[axis + 1 :])) if axis < ndim - 1 else 1
|
1170
|
+
# Total stride for one step along the specified axis
|
1171
|
+
axis_stride = inner_stride
|
1172
|
+
|
1173
|
+
# Perform gather operation
|
1174
|
+
gathered_ct_data = []
|
1175
|
+
|
1176
|
+
# Iterate through all possible combinations of indices before the gather axis
|
1177
|
+
if axis == 0:
|
1178
|
+
# Special case: gathering along axis 0 (existing behavior)
|
1179
|
+
for idx in indices_np.flatten():
|
1180
|
+
start_pos = int(idx) * axis_stride
|
1181
|
+
end_pos = start_pos + axis_stride
|
1182
|
+
slice_data = ciphertext.ct_data[start_pos:end_pos]
|
1183
|
+
gathered_ct_data.extend(slice_data)
|
1184
|
+
else:
|
1185
|
+
# General case: gathering along arbitrary axis
|
1186
|
+
for outer_idx in range(outer_stride):
|
1187
|
+
for gather_idx in indices_np.flatten():
|
1188
|
+
# Calculate position in flattened ciphertext data
|
1189
|
+
pos = (
|
1190
|
+
outer_idx * (ct_shape[axis] * inner_stride)
|
1191
|
+
+ int(gather_idx) * inner_stride
|
1192
|
+
)
|
1193
|
+
slice_data = ciphertext.ct_data[pos : pos + inner_stride]
|
1194
|
+
gathered_ct_data.extend(slice_data)
|
1195
|
+
|
1196
|
+
# Validate we got the expected number of elements
|
1197
|
+
expected_size = int(np.prod(result_shape)) if result_shape else 1
|
1198
|
+
if len(gathered_ct_data) != expected_size:
|
1199
|
+
raise RuntimeError(
|
1200
|
+
f"Internal error: Expected {expected_size} elements, got {len(gathered_ct_data)}"
|
1201
|
+
)
|
1202
|
+
|
1203
|
+
# Create result CipherText
|
1204
|
+
return [
|
1205
|
+
CipherText(
|
1206
|
+
ct_data=gathered_ct_data,
|
1207
|
+
semantic_dtype=ciphertext.semantic_dtype,
|
1208
|
+
semantic_shape=result_shape,
|
1209
|
+
scheme=ciphertext.scheme,
|
1210
|
+
key_size=ciphertext.key_size,
|
1211
|
+
pk_data=ciphertext.pk_data,
|
1212
|
+
max_value=ciphertext.max_value,
|
1213
|
+
fxp_bits=ciphertext.fxp_bits,
|
1214
|
+
modulus=ciphertext.modulus,
|
1215
|
+
)
|
1216
|
+
]
|
1217
|
+
|
1218
|
+
except ValueError:
|
1219
|
+
# Re-raise ValueError directly (validation errors)
|
1220
|
+
raise
|
1221
|
+
except Exception as e:
|
1222
|
+
raise RuntimeError(f"Failed to perform gather: {e}") from e
|
1223
|
+
|
1224
|
+
|
1225
|
+
@kernel_def("phe.scatter")
|
1226
|
+
def _phe_scatter(
|
1227
|
+
pfunc: PFunction, ciphertext: CipherText, indices: TensorLike, updated: CipherText
|
1228
|
+
) -> Any:
|
1229
|
+
"""Execute scatter operation on CipherText.
|
1230
|
+
|
1231
|
+
Supports scattering into multidimensional CipherText using multidimensional indices.
|
1232
|
+
The operation follows numpy scatter semantics:
|
1233
|
+
- Scattering is performed along the specified axis of ciphertext
|
1234
|
+
- indices.shape must equal updated.shape[:len(indices.shape)]
|
1235
|
+
- updated.shape must be indices.shape + ciphertext.shape[:axis] + ciphertext.shape[axis+1:]
|
1236
|
+
- Result shape is same as original ciphertext.shape
|
1237
|
+
|
1238
|
+
"""
|
1239
|
+
# Validate that first and third arguments are CipherTexts
|
1240
|
+
if not isinstance(ciphertext, CipherText) or not isinstance(updated, CipherText):
|
1241
|
+
raise ValueError("First and third arguments must be CipherText instances")
|
1242
|
+
|
1243
|
+
# Validate that both ciphertexts use same scheme/key_size
|
1244
|
+
if ciphertext.scheme != updated.scheme or ciphertext.key_size != updated.key_size:
|
1245
|
+
raise ValueError("Both CipherTexts must use same scheme and key size")
|
1246
|
+
|
1247
|
+
if ciphertext.pk_data != updated.pk_data:
|
1248
|
+
raise ValueError("Both CipherTexts must be encrypted with same key")
|
1249
|
+
|
1250
|
+
# Get axis parameter from pfunc.attrs, default to 0
|
1251
|
+
axis = pfunc.attrs.get("axis", 0)
|
1252
|
+
|
1253
|
+
try:
|
1254
|
+
# Convert indices to numpy
|
1255
|
+
indices_np = _convert_to_numpy(indices)
|
1256
|
+
|
1257
|
+
if not np.issubdtype(indices_np.dtype, np.integer):
|
1258
|
+
raise ValueError("Indices must be of integer type")
|
1259
|
+
|
1260
|
+
# Validate that ciphertext has at least 1 dimension for indexing
|
1261
|
+
if len(ciphertext.semantic_shape) == 0:
|
1262
|
+
raise ValueError("Cannot scatter into scalar CipherText")
|
1263
|
+
|
1264
|
+
# Normalize axis to positive value
|
1265
|
+
ndim = len(ciphertext.semantic_shape)
|
1266
|
+
if axis < 0:
|
1267
|
+
axis = ndim + axis
|
1268
|
+
if axis < 0 or axis >= ndim:
|
1269
|
+
raise ValueError(
|
1270
|
+
f"Axis {pfunc.attrs.get('axis', 0)} is out of bounds for array of dimension {ndim}"
|
1271
|
+
)
|
1272
|
+
|
1273
|
+
# Validate indices are within bounds for the specified axis
|
1274
|
+
axis_size = ciphertext.semantic_shape[axis]
|
1275
|
+
if np.any(indices_np < 0) or np.any(indices_np >= axis_size):
|
1276
|
+
raise ValueError(
|
1277
|
+
f"Indices are out of bounds for axis {axis} with size {axis_size}. "
|
1278
|
+
f"Got indices in range [{np.min(indices_np)}, {np.max(indices_np)}]"
|
1279
|
+
)
|
1280
|
+
|
1281
|
+
# Validate shape compatibility
|
1282
|
+
# Expected updated shape: indices.shape + ciphertext.shape[:axis] + ciphertext.shape[axis+1:]
|
1283
|
+
expected_updated_shape = (
|
1284
|
+
indices_np.shape
|
1285
|
+
+ ciphertext.semantic_shape[:axis]
|
1286
|
+
+ ciphertext.semantic_shape[axis + 1 :]
|
1287
|
+
)
|
1288
|
+
if updated.semantic_shape != expected_updated_shape:
|
1289
|
+
raise ValueError(
|
1290
|
+
f"Updated CipherText shape mismatch. Expected {expected_updated_shape}, "
|
1291
|
+
f"got {updated.semantic_shape}. "
|
1292
|
+
f"Updated shape must be indices.shape + ciphertext.shape[:axis] + ciphertext.shape[axis+1:]"
|
1293
|
+
)
|
1294
|
+
|
1295
|
+
# Calculate strides for multi-axis scattering
|
1296
|
+
ct_shape = ciphertext.semantic_shape
|
1297
|
+
|
1298
|
+
# Stride calculations for arbitrary axis
|
1299
|
+
# Elements before axis contribute to outer stride
|
1300
|
+
outer_stride = int(np.prod(ct_shape[:axis])) if axis > 0 else 1
|
1301
|
+
# Elements after axis contribute to inner stride
|
1302
|
+
inner_stride = int(np.prod(ct_shape[axis + 1 :])) if axis < ndim - 1 else 1
|
1303
|
+
|
1304
|
+
# Create a copy of the original ciphertext data for scattering
|
1305
|
+
scattered_ct_data = ciphertext.ct_data.copy()
|
1306
|
+
|
1307
|
+
# Perform scatter operation
|
1308
|
+
indices_flat = indices_np.flatten()
|
1309
|
+
updated_ct_data = updated.ct_data
|
1310
|
+
|
1311
|
+
if axis == 0:
|
1312
|
+
# Special case: scattering along axis 0 (existing behavior)
|
1313
|
+
axis_stride = inner_stride
|
1314
|
+
for i, idx in enumerate(indices_flat):
|
1315
|
+
start_pos_updated = i * axis_stride
|
1316
|
+
start_pos_original = int(idx) * axis_stride
|
1317
|
+
|
1318
|
+
for j in range(axis_stride):
|
1319
|
+
if start_pos_updated + j < len(updated_ct_data):
|
1320
|
+
scattered_ct_data[start_pos_original + j] = updated_ct_data[
|
1321
|
+
start_pos_updated + j
|
1322
|
+
]
|
1323
|
+
else:
|
1324
|
+
# General case: scattering along arbitrary axis
|
1325
|
+
for outer_idx in range(outer_stride):
|
1326
|
+
for i, scatter_idx in enumerate(indices_flat):
|
1327
|
+
# Calculate position in flattened ciphertext data
|
1328
|
+
start_pos_original = (
|
1329
|
+
outer_idx * (ct_shape[axis] * inner_stride)
|
1330
|
+
+ int(scatter_idx) * inner_stride
|
1331
|
+
)
|
1332
|
+
start_pos_updated = (
|
1333
|
+
outer_idx * len(indices_flat) + i
|
1334
|
+
) * inner_stride
|
1335
|
+
|
1336
|
+
# Update the ciphertext data
|
1337
|
+
for j in range(inner_stride):
|
1338
|
+
if start_pos_updated + j < len(updated_ct_data):
|
1339
|
+
scattered_ct_data[start_pos_original + j] = updated_ct_data[
|
1340
|
+
start_pos_updated + j
|
1341
|
+
]
|
1342
|
+
|
1343
|
+
# Create result CipherText with same shape as original
|
1344
|
+
return [
|
1345
|
+
CipherText(
|
1346
|
+
ct_data=scattered_ct_data,
|
1347
|
+
semantic_dtype=ciphertext.semantic_dtype,
|
1348
|
+
semantic_shape=ciphertext.semantic_shape,
|
1349
|
+
scheme=ciphertext.scheme,
|
1350
|
+
key_size=ciphertext.key_size,
|
1351
|
+
pk_data=ciphertext.pk_data,
|
1352
|
+
max_value=ciphertext.max_value,
|
1353
|
+
fxp_bits=ciphertext.fxp_bits,
|
1354
|
+
modulus=ciphertext.modulus,
|
1355
|
+
)
|
1356
|
+
]
|
1357
|
+
except ValueError:
|
1358
|
+
# Re-raise ValueError directly (validation errors)
|
1359
|
+
raise
|
1360
|
+
except Exception as e:
|
1361
|
+
raise RuntimeError(f"Failed to perform scatter: {e}") from e
|
1362
|
+
|
1363
|
+
|
1364
|
+
@kernel_def("phe.concat")
|
1365
|
+
def _phe_concat(pfunc: PFunction, c1: CipherText, c2: CipherText) -> Any:
|
1366
|
+
"""Execute concat operation on multiple CipherTexts.
|
1367
|
+
|
1368
|
+
Supports concatenation along any axis of multidimensional CipherTexts.
|
1369
|
+
The axis parameter is obtained from pfunc.attrs.
|
1370
|
+
"""
|
1371
|
+
# Get axis parameter from pfunc.attrs, default to 0
|
1372
|
+
axis = pfunc.attrs.get("axis", 0)
|
1373
|
+
|
1374
|
+
# Validate that all arguments are CipherText
|
1375
|
+
if not isinstance(c1, CipherText) or not isinstance(c2, CipherText):
|
1376
|
+
raise ValueError("All arguments must be CipherText instances")
|
1377
|
+
|
1378
|
+
# Validate that all ciphertexts have the same key & scheme
|
1379
|
+
if c1.scheme != c2.scheme or c1.key_size != c2.key_size:
|
1380
|
+
raise ValueError("All CipherTexts must use same scheme and key size")
|
1381
|
+
if c1.pk_data != c2.pk_data:
|
1382
|
+
raise ValueError("All CipherTexts must be encrypted with same key")
|
1383
|
+
if c1.semantic_dtype != c2.semantic_dtype:
|
1384
|
+
raise ValueError(
|
1385
|
+
f"All CipherTexts must have same semantic dtype, got {c1.semantic_dtype} vs {c2.semantic_dtype}"
|
1386
|
+
)
|
1387
|
+
|
1388
|
+
# Validate dimensions and axis
|
1389
|
+
if len(c1.semantic_shape) != len(c2.semantic_shape):
|
1390
|
+
raise ValueError(
|
1391
|
+
f"All CipherTexts must have same number of dimensions for concat, got {len(c1.semantic_shape)} vs {len(c2.semantic_shape)}"
|
1392
|
+
)
|
1393
|
+
|
1394
|
+
# Handle scalar case
|
1395
|
+
if len(c1.semantic_shape) == 0:
|
1396
|
+
raise ValueError("Cannot concatenate scalar CipherTexts")
|
1397
|
+
|
1398
|
+
# Normalize axis (handle negative axis)
|
1399
|
+
ndim = len(c1.semantic_shape)
|
1400
|
+
if axis < 0:
|
1401
|
+
axis = ndim + axis
|
1402
|
+
if axis < 0 or axis >= ndim:
|
1403
|
+
raise ValueError(
|
1404
|
+
f"axis {pfunc.attrs.get('axis', 0)} is out of bounds for array of dimension {ndim}"
|
1405
|
+
)
|
1406
|
+
|
1407
|
+
# Validate that all dimensions except the concat axis are the same
|
1408
|
+
for i in range(ndim):
|
1409
|
+
if i != axis and c1.semantic_shape[i] != c2.semantic_shape[i]:
|
1410
|
+
raise ValueError(
|
1411
|
+
f"All CipherTexts must have same shape except along concatenation axis {axis}. "
|
1412
|
+
f"Shape mismatch at dimension {i}: {c1.semantic_shape[i]} vs {c2.semantic_shape[i]}"
|
1413
|
+
)
|
1414
|
+
|
1415
|
+
try:
|
1416
|
+
# Calculate result shape
|
1417
|
+
result_shape_list = list(c1.semantic_shape)
|
1418
|
+
result_shape_list[axis] = c1.semantic_shape[axis] + c2.semantic_shape[axis]
|
1419
|
+
result_shape = tuple(result_shape_list)
|
1420
|
+
|
1421
|
+
# Calculate the number of slices before the concatenation axis
|
1422
|
+
pre_axis_size = int(np.prod(c1.semantic_shape[:axis])) if axis > 0 else 1
|
1423
|
+
# Calculate the size of data along and after the concatenation axis
|
1424
|
+
c1_post_axis_size = int(np.prod(c1.semantic_shape[axis:]))
|
1425
|
+
c2_post_axis_size = int(np.prod(c2.semantic_shape[axis:]))
|
1426
|
+
|
1427
|
+
# Initialize result data
|
1428
|
+
concatenated_ct_data = []
|
1429
|
+
|
1430
|
+
# Perform concatenation
|
1431
|
+
for pre_idx in range(pre_axis_size):
|
1432
|
+
# For each slice before the concatenation axis
|
1433
|
+
|
1434
|
+
# Add data from c1 along the concatenation axis
|
1435
|
+
c1_start = pre_idx * c1_post_axis_size
|
1436
|
+
c1_end = c1_start + c1_post_axis_size
|
1437
|
+
concatenated_ct_data.extend(c1.ct_data[c1_start:c1_end])
|
1438
|
+
|
1439
|
+
# Add data from c2 along the concatenation axis
|
1440
|
+
c2_start = pre_idx * c2_post_axis_size
|
1441
|
+
c2_end = c2_start + c2_post_axis_size
|
1442
|
+
concatenated_ct_data.extend(c2.ct_data[c2_start:c2_end])
|
1443
|
+
|
1444
|
+
# Validate we got the expected number of elements
|
1445
|
+
expected_size = int(np.prod(result_shape))
|
1446
|
+
if len(concatenated_ct_data) != expected_size:
|
1447
|
+
raise RuntimeError(
|
1448
|
+
f"Internal error: Expected {expected_size} elements, got {len(concatenated_ct_data)}"
|
1449
|
+
)
|
1450
|
+
|
1451
|
+
# Create result CipherText
|
1452
|
+
return [
|
1453
|
+
CipherText(
|
1454
|
+
ct_data=concatenated_ct_data,
|
1455
|
+
semantic_dtype=c1.semantic_dtype,
|
1456
|
+
semantic_shape=result_shape,
|
1457
|
+
scheme=c1.scheme,
|
1458
|
+
key_size=c1.key_size,
|
1459
|
+
pk_data=c1.pk_data,
|
1460
|
+
max_value=c1.max_value,
|
1461
|
+
fxp_bits=c1.fxp_bits,
|
1462
|
+
modulus=c1.modulus,
|
1463
|
+
)
|
1464
|
+
]
|
1465
|
+
|
1466
|
+
except ValueError:
|
1467
|
+
# Re-raise ValueError directly (validation errors)
|
1468
|
+
raise
|
1469
|
+
except Exception as e:
|
1470
|
+
raise RuntimeError(f"Failed to perform concat: {e}") from e
|
1471
|
+
|
1472
|
+
|
1473
|
+
@kernel_def("phe.reshape")
|
1474
|
+
def _phe_reshape(pfunc: PFunction, ciphertext: CipherText) -> Any:
|
1475
|
+
"""Execute reshape operation on CipherText.
|
1476
|
+
|
1477
|
+
Changes the shape of a CipherText without changing its encrypted data.
|
1478
|
+
The new_shape parameter is obtained from pfunc.attrs.
|
1479
|
+
"""
|
1480
|
+
# Validate that argument is a CipherText
|
1481
|
+
if not isinstance(ciphertext, CipherText):
|
1482
|
+
raise ValueError("Argument must be a CipherText instance")
|
1483
|
+
|
1484
|
+
# Get new_shape parameter from pfunc.attrs
|
1485
|
+
new_shape = pfunc.attrs.get("new_shape")
|
1486
|
+
if new_shape is None:
|
1487
|
+
raise ValueError("new_shape parameter is required for reshape operation")
|
1488
|
+
|
1489
|
+
# Convert new_shape to tuple if it's a list
|
1490
|
+
if isinstance(new_shape, list):
|
1491
|
+
new_shape = tuple(new_shape)
|
1492
|
+
elif not isinstance(new_shape, tuple):
|
1493
|
+
raise ValueError("new_shape must be a tuple or list of integers")
|
1494
|
+
|
1495
|
+
try:
|
1496
|
+
# Handle -1 dimension inference
|
1497
|
+
old_size = (
|
1498
|
+
int(np.prod(ciphertext.semantic_shape)) if ciphertext.semantic_shape else 1
|
1499
|
+
)
|
1500
|
+
|
1501
|
+
# Process new_shape to infer -1 dimensions
|
1502
|
+
inferred_shape = list(new_shape)
|
1503
|
+
negative_ones = [i for i, dim in enumerate(new_shape) if dim == -1]
|
1504
|
+
|
1505
|
+
if len(negative_ones) > 1:
|
1506
|
+
raise ValueError("can only specify one unknown dimension")
|
1507
|
+
elif len(negative_ones) == 1:
|
1508
|
+
# Calculate the inferred dimension
|
1509
|
+
known_size = 1
|
1510
|
+
for dim in new_shape:
|
1511
|
+
if dim != -1:
|
1512
|
+
if dim <= 0:
|
1513
|
+
raise ValueError(
|
1514
|
+
f"negative dimensions not allowed (except -1): {dim}"
|
1515
|
+
)
|
1516
|
+
known_size *= dim
|
1517
|
+
|
1518
|
+
if old_size % known_size != 0:
|
1519
|
+
raise ValueError(
|
1520
|
+
f"cannot reshape array of size {old_size} into shape {new_shape}"
|
1521
|
+
)
|
1522
|
+
|
1523
|
+
inferred_dim = old_size // known_size
|
1524
|
+
inferred_shape[negative_ones[0]] = inferred_dim
|
1525
|
+
else:
|
1526
|
+
# No -1 dimensions, validate that all dimensions are positive
|
1527
|
+
for dim in new_shape:
|
1528
|
+
if dim <= 0:
|
1529
|
+
raise ValueError(f"negative dimensions not allowed: {dim}")
|
1530
|
+
|
1531
|
+
# Convert back to tuple
|
1532
|
+
final_shape = tuple(inferred_shape)
|
1533
|
+
|
1534
|
+
# Validate that new shape has the same number of elements
|
1535
|
+
new_size = int(np.prod(final_shape)) if final_shape else 1
|
1536
|
+
|
1537
|
+
if old_size != new_size:
|
1538
|
+
raise ValueError(
|
1539
|
+
f"Cannot reshape CipherText with {old_size} elements to shape {final_shape} "
|
1540
|
+
f"with {new_size} elements"
|
1541
|
+
)
|
1542
|
+
|
1543
|
+
# Create result CipherText with new shape and encoding parameters (ct_data remains the same)
|
1544
|
+
return [
|
1545
|
+
CipherText(
|
1546
|
+
ct_data=ciphertext.ct_data, # Same encrypted data
|
1547
|
+
semantic_dtype=ciphertext.semantic_dtype,
|
1548
|
+
semantic_shape=final_shape, # Use the final shape
|
1549
|
+
scheme=ciphertext.scheme,
|
1550
|
+
key_size=ciphertext.key_size,
|
1551
|
+
pk_data=ciphertext.pk_data,
|
1552
|
+
max_value=ciphertext.max_value,
|
1553
|
+
fxp_bits=ciphertext.fxp_bits,
|
1554
|
+
modulus=ciphertext.modulus,
|
1555
|
+
)
|
1556
|
+
]
|
1557
|
+
|
1558
|
+
except ValueError:
|
1559
|
+
# Re-raise ValueError directly (validation errors)
|
1560
|
+
raise
|
1561
|
+
except Exception as e:
|
1562
|
+
raise RuntimeError(f"Failed to perform reshape: {e}") from e
|
1563
|
+
|
1564
|
+
|
1565
|
+
@kernel_def("phe.transpose")
|
1566
|
+
def _phe_transpose(pfunc: PFunction, ciphertext: CipherText) -> Any:
|
1567
|
+
"""Execute transpose operation on CipherText.
|
1568
|
+
|
1569
|
+
Permutes the dimensions of a CipherText according to the given axes.
|
1570
|
+
The axes parameter is obtained from pfunc.attrs.
|
1571
|
+
"""
|
1572
|
+
# Validate that argument is a CipherText
|
1573
|
+
if not isinstance(ciphertext, CipherText):
|
1574
|
+
raise ValueError("Argument must be a CipherText instance")
|
1575
|
+
|
1576
|
+
# Handle scalar case
|
1577
|
+
if len(ciphertext.semantic_shape) == 0:
|
1578
|
+
# Transposing a scalar returns the same scalar
|
1579
|
+
return [ciphertext]
|
1580
|
+
|
1581
|
+
# Get axes parameter from pfunc.attrs
|
1582
|
+
axes = pfunc.attrs.get("axes")
|
1583
|
+
|
1584
|
+
# If axes is None, reverse all dimensions (default transpose behavior)
|
1585
|
+
if axes is None:
|
1586
|
+
axes = tuple(reversed(range(len(ciphertext.semantic_shape))))
|
1587
|
+
elif isinstance(axes, list):
|
1588
|
+
axes = tuple(axes)
|
1589
|
+
elif not isinstance(axes, tuple):
|
1590
|
+
raise ValueError("axes must be a tuple or list of integers, or None")
|
1591
|
+
|
1592
|
+
try:
|
1593
|
+
# Validate axes
|
1594
|
+
ndim = len(ciphertext.semantic_shape)
|
1595
|
+
if len(axes) != ndim:
|
1596
|
+
raise ValueError(
|
1597
|
+
f"axes length {len(axes)} does not match tensor dimensions {ndim}"
|
1598
|
+
)
|
1599
|
+
|
1600
|
+
# Normalize negative axes and validate range
|
1601
|
+
normalized_axes = []
|
1602
|
+
for axis in axes:
|
1603
|
+
if axis < 0:
|
1604
|
+
axis = ndim + axis
|
1605
|
+
if axis < 0 or axis >= ndim:
|
1606
|
+
raise ValueError(
|
1607
|
+
f"axis {axis} is out of bounds for array of dimension {ndim}"
|
1608
|
+
)
|
1609
|
+
normalized_axes.append(axis)
|
1610
|
+
axes = tuple(normalized_axes)
|
1611
|
+
|
1612
|
+
# Check for duplicate axes
|
1613
|
+
if len(set(axes)) != len(axes):
|
1614
|
+
raise ValueError("axes cannot contain duplicate values")
|
1615
|
+
|
1616
|
+
# Calculate new shape
|
1617
|
+
old_shape = ciphertext.semantic_shape
|
1618
|
+
new_shape = tuple(old_shape[axis] for axis in axes)
|
1619
|
+
|
1620
|
+
# For multidimensional transpose, we need to rearrange the encrypted data
|
1621
|
+
# Create mapping from old flat index to new flat index
|
1622
|
+
def transpose_data(ct_data: list, old_shape: tuple, axes: tuple) -> list:
|
1623
|
+
if len(old_shape) <= 1:
|
1624
|
+
# 1D or scalar case - no actual transposition needed
|
1625
|
+
return ct_data
|
1626
|
+
|
1627
|
+
# Create numpy array to help with index calculations
|
1628
|
+
dummy_array = np.arange(len(ct_data)).reshape(old_shape)
|
1629
|
+
transposed_dummy = np.transpose(dummy_array, axes)
|
1630
|
+
|
1631
|
+
# The new data should be arranged in the order that numpy.transpose would produce
|
1632
|
+
new_ct_data = [ct_data[idx] for idx in transposed_dummy.flatten()]
|
1633
|
+
|
1634
|
+
return new_ct_data
|
1635
|
+
|
1636
|
+
# Rearrange the encrypted data according to transpose
|
1637
|
+
transposed_ct_data = transpose_data(ciphertext.ct_data, old_shape, axes)
|
1638
|
+
|
1639
|
+
# Create result CipherText with transposed shape and rearranged data
|
1640
|
+
return [
|
1641
|
+
CipherText(
|
1642
|
+
ct_data=transposed_ct_data,
|
1643
|
+
semantic_dtype=ciphertext.semantic_dtype,
|
1644
|
+
semantic_shape=new_shape,
|
1645
|
+
scheme=ciphertext.scheme,
|
1646
|
+
key_size=ciphertext.key_size,
|
1647
|
+
pk_data=ciphertext.pk_data,
|
1648
|
+
max_value=ciphertext.max_value,
|
1649
|
+
fxp_bits=ciphertext.fxp_bits,
|
1650
|
+
modulus=ciphertext.modulus,
|
1651
|
+
)
|
1652
|
+
]
|
1653
|
+
|
1654
|
+
except ValueError:
|
1655
|
+
# Re-raise ValueError directly (validation errors)
|
1656
|
+
raise
|
1657
|
+
except Exception as e:
|
1658
|
+
raise RuntimeError(f"Failed to perform transpose: {e}") from e
|