mplang-nightly 0.1.dev141__py3-none-any.whl → 0.1.dev143__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
mplang/backend/phe.py CHANGED
@@ -25,16 +25,29 @@ from mplang.core.mptype import TensorLike
25
25
  from mplang.core.pfunc import PFunction
26
26
 
27
27
  # This controls the decimal precision used in lightPHE for float operations
28
- PRECISION = 12
28
+ # we force it to 0 to only support integer operations
29
+ # we will support negative and floating-point with our own encoding/decoding
30
+ PRECISION = 0
29
31
 
30
32
 
31
33
  class PublicKey:
32
34
  """PHE Public Key that implements TensorLike protocol."""
33
35
 
34
- def __init__(self, key_data: Any, scheme: str, key_size: int):
36
+ def __init__(
37
+ self,
38
+ key_data: Any,
39
+ scheme: str,
40
+ key_size: int,
41
+ max_value: int = 2**100,
42
+ fxp_bits: int = 12,
43
+ modulus: int | None = None,
44
+ ):
35
45
  self.key_data = key_data
36
46
  self.scheme = scheme
37
47
  self.key_size = key_size
48
+ self.max_value = max_value # Maximum absolute value B for range encoding
49
+ self.fxp_bits = fxp_bits # Fixed-point precision bits for float encoding
50
+ self.modulus = modulus # Paillier modulus N for range encoding
38
51
 
39
52
  @property
40
53
  def dtype(self) -> Any:
@@ -44,18 +57,35 @@ class PublicKey:
44
57
  def shape(self) -> tuple[int, ...]:
45
58
  return ()
46
59
 
60
+ @property
61
+ def max_float_value(self) -> float:
62
+ """Maximum float value that can be encoded."""
63
+ return float(self.max_value / (2**self.fxp_bits))
64
+
47
65
  def __repr__(self) -> str:
48
- return f"PublicKey(scheme={self.scheme}, key_size={self.key_size})"
66
+ return f"PublicKey(scheme={self.scheme}, key_size={self.key_size}, max_value={self.max_value}, fxp_bits={self.fxp_bits})"
49
67
 
50
68
 
51
69
  class PrivateKey:
52
70
  """PHE Private Key that implements TensorLike protocol."""
53
71
 
54
- def __init__(self, sk_data: Any, pk_data: Any, scheme: str, key_size: int):
72
+ def __init__(
73
+ self,
74
+ sk_data: Any,
75
+ pk_data: Any,
76
+ scheme: str,
77
+ key_size: int,
78
+ max_value: int = 2**100,
79
+ fxp_bits: int = 12,
80
+ modulus: int | None = None,
81
+ ):
55
82
  self.sk_data = sk_data # Store private key data
56
83
  self.pk_data = pk_data # Store public key data as well
57
84
  self.scheme = scheme
58
85
  self.key_size = key_size
86
+ self.max_value = max_value # Maximum absolute value B for range encoding
87
+ self.fxp_bits = fxp_bits # Fixed-point precision bits for float encoding
88
+ self.modulus = modulus # Paillier modulus N for range encoding
59
89
 
60
90
  @property
61
91
  def dtype(self) -> Any:
@@ -65,8 +95,13 @@ class PrivateKey:
65
95
  def shape(self) -> tuple[int, ...]:
66
96
  return ()
67
97
 
98
+ @property
99
+ def max_float_value(self) -> float:
100
+ """Maximum float value that can be encoded."""
101
+ return float(self.max_value / (2**self.fxp_bits))
102
+
68
103
  def __repr__(self) -> str:
69
- return f"PrivateKey(scheme={self.scheme}, key_size={self.key_size})"
104
+ return f"PrivateKey(scheme={self.scheme}, key_size={self.key_size}, max_value={self.max_value}, fxp_bits={self.fxp_bits})"
70
105
 
71
106
 
72
107
  class CipherText:
@@ -80,6 +115,9 @@ class CipherText:
80
115
  scheme: str,
81
116
  key_size: int,
82
117
  pk_data: Any = None, # Store public key for operations
118
+ max_value: int = 2**100,
119
+ fxp_bits: int = 12,
120
+ modulus: int | None = None,
83
121
  ):
84
122
  self.ct_data = ct_data
85
123
  self.semantic_dtype = semantic_dtype
@@ -87,6 +125,9 @@ class CipherText:
87
125
  self.scheme = scheme
88
126
  self.key_size = key_size
89
127
  self.pk_data = pk_data
128
+ self.max_value = max_value
129
+ self.fxp_bits = fxp_bits
130
+ self.modulus = modulus
90
131
 
91
132
  @property
92
133
  def dtype(self) -> Any:
@@ -96,102 +137,367 @@ class CipherText:
96
137
  def shape(self) -> tuple[int, ...]:
97
138
  return self.semantic_shape
98
139
 
140
+ @property
141
+ def max_float_value(self) -> float:
142
+ """Maximum float value that can be encoded."""
143
+ return float(self.max_value / (2**self.fxp_bits))
144
+
99
145
  def __repr__(self) -> str:
100
146
  return f"CipherText(dtype={self.semantic_dtype}, shape={self.semantic_shape}, scheme={self.scheme})"
101
147
 
102
148
 
103
- def _to_numpy(obj: TensorLike) -> np.ndarray:
149
+ # Range-based encoding functions for negative numbers and floats
150
+ def _range_encode_integer(value: int, max_value: int, modulus: int) -> int:
151
+ """
152
+ Range encoding function for integers.
153
+ - Positive numbers: encode(m) = m
154
+ - Negative numbers: encode(m) = N + m
155
+ """
156
+ if not (-max_value <= value <= max_value):
157
+ raise ValueError(
158
+ f"Integer value {value} out of range [-{max_value}, {max_value}]"
159
+ )
160
+
161
+ if value >= 0:
162
+ encoded = value
163
+ else:
164
+ encoded = modulus + value
165
+
166
+ return encoded
167
+
168
+
169
+ def _range_encode_float(
170
+ value: float, max_value: int, fxp_bits: int, modulus: int
171
+ ) -> int:
172
+ """
173
+ Range encoding function for floats.
174
+ 1. Fixed-point conversion: scaled_int = round(value * 2^fxp_bits)
175
+ 2. Integer encoding rules
176
+ """
177
+ max_float = max_value / (2**fxp_bits)
178
+ if not (-max_float <= value <= max_float):
179
+ raise ValueError(
180
+ f"Float value {value} out of range [-{max_float}, {max_float}]"
181
+ )
182
+
183
+ # Fixed-point encoding: float → scaled integer
184
+ scaled_int = round(value * (2**fxp_bits))
185
+
186
+ # Use integer encoding rules
187
+ return _range_encode_integer(scaled_int, max_value, modulus)
188
+
189
+
190
+ def _range_encode_mixed(
191
+ value: Any, max_value: int, fxp_bits: int, modulus: int, semantic_dtype: DType
192
+ ) -> int:
193
+ """
194
+ Mixed encoding function - automatically handle integers and floats based on semantic type.
195
+ Use semantic_dtype to choose between integer and float encoding.
196
+ """
197
+ if semantic_dtype.is_floating:
198
+ # For floating semantic types, always use float encoding
199
+ return _range_encode_float(float(value), max_value, fxp_bits, modulus)
200
+ else:
201
+ # For integer semantic types, use integer encoding
202
+ return _range_encode_integer(int(value), max_value, modulus)
203
+
204
+
205
+ def _range_decode_integer(encoded_value: int, max_value: int, modulus: int) -> int:
206
+ """
207
+ Range decoding function for integers.
208
+ - If r <= max_value: decode(r) = r
209
+ - If r >= N - max_value: decode(r) = r - N
210
+ - If max_value < r < N - max_value: overflow error
211
+ """
212
+
213
+ # Ensure handling integer
214
+ if isinstance(encoded_value, (list, tuple)):
215
+ encoded_value = encoded_value[0]
216
+ encoded_value = int(encoded_value) % modulus
217
+
218
+ if encoded_value <= max_value:
219
+ return encoded_value
220
+ elif encoded_value >= modulus - max_value:
221
+ return encoded_value - modulus
222
+ else:
223
+ raise ValueError(f"Decoded value {encoded_value} is in overflow region")
224
+
225
+
226
+ def _range_decode_float(
227
+ encoded_value: int, max_value: int, fxp_bits: int, modulus: int
228
+ ) -> float:
229
+ """
230
+ Range decoding function for floats.
231
+ 1. Integer decoding: decoded_int = range_decode_integer(encoded_value)
232
+ 2. Fixed-point conversion: value = decoded_int / 2^fxp_bits
233
+ """
234
+ # First decode as integer
235
+ decoded_int = _range_decode_integer(encoded_value, max_value, modulus)
236
+
237
+ # Fixed-point decoding: scaled integer → float
238
+ return float(decoded_int / (2**fxp_bits))
239
+
240
+
241
+ def _range_decode_mixed(
242
+ encoded_value: int,
243
+ max_value: int,
244
+ fxp_bits: int,
245
+ modulus: int,
246
+ semantic_dtype: DType,
247
+ ) -> Any:
248
+ """
249
+ Mixed decoding function - decode based on semantic type.
250
+ Use semantic_dtype to choose between integer and float decoding.
251
+ """
252
+ if semantic_dtype.is_floating:
253
+ # For floating semantic types, decode as float
254
+ return _range_decode_float(encoded_value, max_value, fxp_bits, modulus)
255
+ else:
256
+ # For integer semantic types, decode as integer
257
+ return _range_decode_integer(encoded_value, max_value, modulus)
258
+
259
+
260
+ def _convert_to_numpy(obj: TensorLike) -> np.ndarray:
261
+ """Convert a TensorLike object to numpy array."""
104
262
  if isinstance(obj, np.ndarray):
105
263
  return obj
264
+
265
+ # Try to use .numpy() method if available
106
266
  if hasattr(obj, "numpy"):
107
- try:
108
- return np.asarray(obj.numpy()) # type: ignore
109
- except Exception:
110
- pass
267
+ numpy_method = getattr(obj, "numpy", None)
268
+ if callable(numpy_method):
269
+ try:
270
+ return np.asarray(numpy_method())
271
+ except Exception:
272
+ pass
273
+
111
274
  return np.asarray(obj)
112
275
 
113
276
 
114
277
  @kernel_def("phe.keygen")
115
278
  def _phe_keygen(pfunc: PFunction) -> Any:
116
279
  scheme = pfunc.attrs.get("scheme", "paillier")
280
+ # use small key_size to speed up tests
281
+ # in production use at least 2048 bits or 3072 bits for better security
117
282
  key_size = pfunc.attrs.get("key_size", 2048)
118
- if scheme.lower() not in ["paillier", "elgamal"]:
283
+ max_value = pfunc.attrs.get(
284
+ "max_value", 2**32
285
+ ) # Use larger range to avoid overflow
286
+ fxp_bits = pfunc.attrs.get("fxp_bits", 12)
287
+
288
+ # Validate scheme
289
+ if scheme.lower() not in ["paillier"]:
119
290
  raise ValueError(f"Unsupported PHE scheme: {scheme}")
120
- scheme_cap = scheme.capitalize()
291
+
292
+ scheme = scheme.capitalize()
293
+
121
294
  try:
295
+ # Set higher precision for better accuracy with floats
122
296
  phe = LightPHE(
123
- algorithm_name=scheme_cap, key_size=key_size, precision=PRECISION
297
+ algorithm_name=scheme,
298
+ key_size=key_size,
299
+ precision=PRECISION,
124
300
  )
301
+
125
302
  pk_data = phe.cs.keys["public_key"]
126
303
  sk_data = phe.cs.keys["private_key"]
127
- public_key = PublicKey(key_data=pk_data, scheme=scheme_cap, key_size=key_size)
304
+ modulus = phe.cs.plaintext_modulo # Get Paillier modulus N
305
+
306
+ # Validate safety: N should be much larger than 3*max_value
307
+ if modulus <= 3 * max_value:
308
+ raise ValueError(
309
+ f"Modulus {modulus} is too small for max_value {max_value}. Require N >> 3*B"
310
+ )
311
+
312
+ public_key = PublicKey(
313
+ key_data=pk_data,
314
+ scheme=scheme,
315
+ key_size=key_size,
316
+ max_value=max_value,
317
+ fxp_bits=fxp_bits,
318
+ modulus=modulus,
319
+ )
128
320
  private_key = PrivateKey(
129
- sk_data=sk_data, pk_data=pk_data, scheme=scheme_cap, key_size=key_size
321
+ sk_data=sk_data,
322
+ pk_data=pk_data,
323
+ scheme=scheme,
324
+ key_size=key_size,
325
+ max_value=max_value,
326
+ fxp_bits=fxp_bits,
327
+ modulus=modulus,
130
328
  )
131
- return (public_key, private_key)
132
- except Exception as e: # pragma: no cover
329
+
330
+ return [public_key, private_key]
331
+
332
+ except Exception as e:
133
333
  raise RuntimeError(f"Failed to generate PHE keys: {e}") from e
134
334
 
135
335
 
136
336
  @kernel_def("phe.encrypt")
137
- def _phe_encrypt(pfunc: PFunction, plaintext: Any, public_key: Any) -> Any:
337
+ def _phe_encrypt(pfunc: PFunction, plaintext: Any, public_key: PublicKey) -> Any:
338
+ # Validate public_key type
138
339
  if not isinstance(public_key, PublicKey):
139
- raise ValueError("second arg must be PublicKey")
340
+ raise ValueError("Second argument must be a PublicKey instance")
341
+
140
342
  try:
141
- pt_np = _to_numpy(plaintext)
142
- semantic_dtype = DType.from_numpy(pt_np.dtype)
143
- semantic_shape = pt_np.shape
343
+ # Convert plaintext to numpy to get semantic type info
344
+ plaintext_np = _convert_to_numpy(plaintext)
345
+ semantic_dtype = DType.from_numpy(plaintext_np.dtype)
346
+ semantic_shape = plaintext_np.shape
347
+
348
+ # Create lightPHE instance with the same scheme/key_size as the key
144
349
  phe = LightPHE(
145
350
  algorithm_name=public_key.scheme,
146
351
  key_size=public_key.key_size,
147
352
  precision=PRECISION,
148
353
  )
354
+
355
+ # CRITICAL: Set the same modulus as the key to ensure consistency
356
+ if public_key.modulus is not None:
357
+ phe.cs.plaintext_modulo = public_key.modulus
358
+ phe.cs.ciphertext_modulo = public_key.modulus * public_key.modulus
359
+
360
+ # Set the public key
149
361
  phe.cs.keys["public_key"] = public_key.key_data
150
- flat = pt_np.flatten()
151
- if semantic_dtype.is_floating:
152
- data_list = [float(x) for x in flat]
153
- else:
154
- data_list = [int(x) for x in flat]
155
- ct_data = phe.encrypt(data_list)
362
+
363
+ # Prepare data for encryption using range encoding
364
+ flat_data = plaintext_np.flatten()
365
+
366
+ # Use mixed encoding for consistent handling of integers and floats
367
+ encoded_data_list = []
368
+ for val in flat_data:
369
+ # Use mixed encoding to handle both integers and floats uniformly
370
+ if public_key.modulus is None:
371
+ raise ValueError(
372
+ "Public key modulus is None, key generation may have failed"
373
+ )
374
+ encoded_val = _range_encode_mixed(
375
+ val,
376
+ public_key.max_value,
377
+ public_key.fxp_bits,
378
+ public_key.modulus,
379
+ semantic_dtype,
380
+ )
381
+ encoded_data_list.append(encoded_val)
382
+
383
+ # Encrypt the encoded values (note: not passing as list, just the value)
384
+ lightphe_ciphertext = [phe.encrypt(val) for val in encoded_data_list]
385
+
386
+ # Create CipherText object with encoding parameters
156
387
  ciphertext = CipherText(
157
- ct_data=ct_data,
388
+ ct_data=lightphe_ciphertext,
158
389
  semantic_dtype=semantic_dtype,
159
390
  semantic_shape=semantic_shape,
160
391
  scheme=public_key.scheme,
161
392
  key_size=public_key.key_size,
162
393
  pk_data=public_key.key_data,
394
+ max_value=public_key.max_value,
395
+ fxp_bits=public_key.fxp_bits,
396
+ modulus=public_key.modulus,
163
397
  )
164
- return ciphertext
165
- except Exception as e: # pragma: no cover
398
+
399
+ return [ciphertext]
400
+
401
+ except Exception as e:
166
402
  raise RuntimeError(f"Failed to encrypt data: {e}") from e
167
403
 
168
404
 
169
405
  @kernel_def("phe.mul")
170
- def _phe_mul(pfunc: PFunction, ciphertext: Any, plaintext: Any) -> Any:
406
+ def _phe_mul(pfunc: PFunction, ciphertext: CipherText, plaintext: Any) -> Any:
407
+ # Validate that first argument is a CipherText
171
408
  if not isinstance(ciphertext, CipherText):
172
- raise ValueError("first arg must be CipherText")
409
+ raise ValueError("First argument must be a CipherText instance")
410
+
173
411
  try:
174
- pt_np = _to_numpy(plaintext)
175
- if pt_np.shape != ciphertext.semantic_shape:
176
- raise ValueError("shape mismatch for phe.mul")
177
- target_dtype = ciphertext.semantic_dtype
178
- flat = pt_np.flatten()
179
- if target_dtype.is_floating:
180
- mult = [float(x) for x in flat]
412
+ # Convert plaintext to numpy
413
+ plaintext_np = _convert_to_numpy(plaintext)
414
+
415
+ # Check if plaintext is floating point type - multiplication not supported
416
+ if np.issubdtype(plaintext_np.dtype, np.floating):
417
+ raise ValueError(
418
+ f"Homomorphic multiplication with floating point plaintext is not supported. "
419
+ f"Got plaintext dtype: {plaintext_np.dtype}"
420
+ )
421
+
422
+ # Use numpy broadcasting to determine result shape and broadcast operands
423
+ # Create dummy arrays with the same shapes to test broadcasting
424
+ try:
425
+ dummy_ct = np.zeros(ciphertext.semantic_shape)
426
+ dummy_pt = np.zeros(plaintext_np.shape)
427
+ broadcasted_dummy = dummy_ct * dummy_pt
428
+ result_shape = broadcasted_dummy.shape
429
+ except ValueError as e:
430
+ raise ValueError(
431
+ f"Operands cannot be broadcast together: CipherText shape {ciphertext.semantic_shape} "
432
+ f"vs plaintext shape {plaintext_np.shape}: {e}"
433
+ ) from e
434
+
435
+ # Broadcast plaintext to match result shape if needed
436
+ if plaintext_np.shape != result_shape:
437
+ plaintext_broadcasted = np.broadcast_to(plaintext_np, result_shape)
181
438
  else:
182
- mult = [int(x) for x in flat]
183
- res_ct = ciphertext.ct_data * mult
184
- return CipherText(
185
- ct_data=res_ct,
186
- semantic_dtype=ciphertext.semantic_dtype,
187
- semantic_shape=ciphertext.semantic_shape,
188
- scheme=ciphertext.scheme,
189
- key_size=ciphertext.key_size,
190
- pk_data=ciphertext.pk_data,
191
- )
439
+ plaintext_broadcasted = plaintext_np
440
+
441
+ # If ciphertext needs broadcasting, we need to replicate its encrypted values
442
+ if ciphertext.semantic_shape != result_shape:
443
+ # Use numpy to create a properly broadcasted index mapping
444
+ # Create a dummy array with same shape as ciphertext, fill with indices
445
+ dummy_ct = (
446
+ np.arange(np.prod(ciphertext.semantic_shape))
447
+ .reshape(ciphertext.semantic_shape)
448
+ .astype(np.int64)
449
+ )
450
+ # Broadcast this to the result shape
451
+ broadcasted_indices = np.broadcast_to(dummy_ct, result_shape).flatten()
452
+
453
+ # Replicate ciphertext data according to the broadcasted indices
454
+ raw_ct: list[Any] = ciphertext.ct_data
455
+ broadcasted_ct_data = [raw_ct[int(idx)] for idx in broadcasted_indices]
456
+ else:
457
+ # No broadcasting needed for ciphertext
458
+ broadcasted_ct_data = ciphertext.ct_data
459
+
460
+ # Flatten the broadcasted plaintext data for element-wise multiplication
461
+ target_dtype = ciphertext.semantic_dtype
462
+ flat_data = plaintext_broadcasted.flatten()
463
+
464
+ # For multiplication, plaintext multipliers should NOT be encoded
465
+ # The ciphertext already contains the encoded value, multiplying by raw plaintext preserves semantics
466
+ raw_multipliers = []
467
+ for val in flat_data:
468
+ # Convert to appropriate numeric type but don't apply any encoding
469
+ if target_dtype.is_floating:
470
+ raw_val = float(val)
471
+ else:
472
+ raw_val = int(val)
473
+ raw_multipliers.append(raw_val)
474
+
475
+ # Perform homomorphic multiplication
476
+ # In Paillier, ciphertext * plaintext is supported
477
+ result_ciphertext = [
478
+ broadcasted_ct_data[i] * raw_multipliers[i]
479
+ for i in range(len(raw_multipliers))
480
+ ]
481
+
482
+ # Create result CipherText with the broadcasted shape and encoding parameters
483
+ return [
484
+ CipherText(
485
+ ct_data=result_ciphertext,
486
+ semantic_dtype=ciphertext.semantic_dtype,
487
+ semantic_shape=result_shape,
488
+ scheme=ciphertext.scheme,
489
+ key_size=ciphertext.key_size,
490
+ pk_data=ciphertext.pk_data,
491
+ max_value=ciphertext.max_value,
492
+ fxp_bits=ciphertext.fxp_bits,
493
+ modulus=ciphertext.modulus,
494
+ )
495
+ ]
496
+
192
497
  except ValueError:
498
+ # Re-raise ValueError directly (validation errors)
193
499
  raise
194
- except Exception as e: # pragma: no cover
500
+ except Exception as e:
195
501
  raise RuntimeError(f"Failed to perform multiplication: {e}") from e
196
502
 
197
503
 
@@ -205,7 +511,7 @@ def _phe_add(pfunc: PFunction, lhs: Any, rhs: Any) -> Any:
205
511
  elif isinstance(rhs, CipherText):
206
512
  return _phe_add_ct2pt(rhs, lhs)
207
513
  else:
208
- return _to_numpy(lhs) + _to_numpy(rhs)
514
+ return _convert_to_numpy(lhs) + _convert_to_numpy(rhs)
209
515
  except ValueError:
210
516
  raise
211
517
  except Exception as e: # pragma: no cover
@@ -213,89 +519,1140 @@ def _phe_add(pfunc: PFunction, lhs: Any, rhs: Any) -> Any:
213
519
 
214
520
 
215
521
  def _phe_add_ct2ct(ct1: CipherText, ct2: CipherText) -> CipherText:
522
+ # Validate compatibility
216
523
  if ct1.scheme != ct2.scheme or ct1.key_size != ct2.key_size:
217
- raise ValueError("CipherText operands must use same scheme/key size")
524
+ raise ValueError("CipherText operands must use same scheme and key size")
525
+
218
526
  if ct1.pk_data != ct2.pk_data:
219
- raise ValueError("CipherText operands must share public key")
220
- if ct1.semantic_shape != ct2.semantic_shape:
221
- raise ValueError("CipherText operands must have same shape")
222
- res_ct = ct1.ct_data + ct2.ct_data
527
+ raise ValueError("CipherText operands must be encrypted with same key")
528
+
529
+ # Check for mixed precision issue: floating point ciphertext + integer ciphertext
530
+ # This would cause decode failures due to different fixed-point encoding scales
531
+ if ct1.semantic_dtype.is_floating != ct2.semantic_dtype.is_floating:
532
+ raise ValueError(
533
+ f"Cannot add ciphertexts with different numeric types due to fixed-point encoding. "
534
+ f"First CipherText dtype: {ct1.semantic_dtype}, second CipherText dtype: {ct2.semantic_dtype}. "
535
+ f"Both operands must have the same numeric type (both floating or both integer)."
536
+ )
537
+
538
+ # Use numpy broadcasting to determine result shape and broadcast operands
539
+ try:
540
+ dummy_ct1 = np.zeros(ct1.semantic_shape)
541
+ dummy_ct2 = np.zeros(ct2.semantic_shape)
542
+ broadcasted_dummy = dummy_ct1 + dummy_ct2
543
+ result_shape = broadcasted_dummy.shape
544
+ except ValueError as e:
545
+ raise ValueError(
546
+ f"CipherText operands cannot be broadcast together: shape {ct1.semantic_shape} "
547
+ f"vs shape {ct2.semantic_shape}: {e}"
548
+ ) from e
549
+
550
+ # Broadcast ct1 if needed
551
+ if ct1.semantic_shape != result_shape:
552
+ dummy_ct1 = (
553
+ np.arange(np.prod(ct1.semantic_shape))
554
+ .reshape(ct1.semantic_shape)
555
+ .astype(np.int64)
556
+ )
557
+ broadcasted_indices1 = np.broadcast_to(dummy_ct1, result_shape).flatten()
558
+ raw_ct1: list[Any] = ct1.ct_data
559
+ broadcasted_ct1_data = [raw_ct1[int(idx)] for idx in broadcasted_indices1]
560
+ else:
561
+ broadcasted_ct1_data = ct1.ct_data
562
+
563
+ # Broadcast ct2 if needed
564
+ if ct2.semantic_shape != result_shape:
565
+ dummy_ct2 = (
566
+ np.arange(np.prod(ct2.semantic_shape))
567
+ .reshape(ct2.semantic_shape)
568
+ .astype(np.int64)
569
+ )
570
+ broadcasted_indices2 = np.broadcast_to(dummy_ct2, result_shape).flatten()
571
+ raw_ct2: list[Any] = ct2.ct_data
572
+ broadcasted_ct2_data = [raw_ct2[int(idx)] for idx in broadcasted_indices2]
573
+ else:
574
+ broadcasted_ct2_data = ct2.ct_data
575
+
576
+ # Perform homomorphic addition
577
+ result_ciphertext = [
578
+ broadcasted_ct1_data[i] + broadcasted_ct2_data[i]
579
+ for i in range(len(broadcasted_ct1_data))
580
+ ]
581
+
582
+ # Create result CipherText with broadcasted shape and encoding parameters
223
583
  return CipherText(
224
- ct_data=res_ct,
584
+ ct_data=result_ciphertext,
225
585
  semantic_dtype=ct1.semantic_dtype,
226
- semantic_shape=ct1.semantic_shape,
586
+ semantic_shape=result_shape,
227
587
  scheme=ct1.scheme,
228
588
  key_size=ct1.key_size,
229
589
  pk_data=ct1.pk_data,
590
+ max_value=ct1.max_value,
591
+ fxp_bits=ct1.fxp_bits,
592
+ modulus=ct1.modulus,
230
593
  )
231
594
 
232
595
 
233
596
  def _phe_add_ct2pt(ciphertext: CipherText, plaintext: TensorLike) -> CipherText:
234
- pt_np = _to_numpy(plaintext)
235
- if pt_np.shape != ciphertext.semantic_shape:
236
- raise ValueError("operands must have same shape")
597
+ # Convert plaintext to numpy
598
+ plaintext_np = _convert_to_numpy(plaintext)
599
+ plaintext_dtype = DType.from_numpy(plaintext_np.dtype)
600
+
601
+ # Check for mixed precision issue: floating point ciphertext + integer plaintext
602
+ # This would cause decode failures due to 2**fxp * f + i scaling mismatch
603
+ if ciphertext.semantic_dtype.is_floating and not plaintext_dtype.is_floating:
604
+ raise ValueError(
605
+ f"Cannot add integer plaintext to floating point ciphertext due to fixed-point encoding. "
606
+ f"CipherText dtype: {ciphertext.semantic_dtype}, plaintext dtype: {plaintext_dtype}. "
607
+ f"Both operands must have the same numeric type (both floating or both integer)."
608
+ )
609
+
610
+ # Check for mixed precision issue: integer ciphertext + floating point plaintext
611
+ if not ciphertext.semantic_dtype.is_floating and plaintext_dtype.is_floating:
612
+ raise ValueError(
613
+ f"Cannot add floating point plaintext to integer ciphertext due to fixed-point encoding. "
614
+ f"CipherText dtype: {ciphertext.semantic_dtype}, plaintext dtype: {plaintext_dtype}. "
615
+ f"Both operands must have the same numeric type (both floating or both integer)."
616
+ )
617
+
618
+ # Use numpy broadcasting to determine result shape and broadcast operands
619
+ try:
620
+ dummy_ct = np.zeros(ciphertext.semantic_shape)
621
+ dummy_pt = np.zeros(plaintext_np.shape)
622
+ broadcasted_dummy = dummy_ct + dummy_pt
623
+ result_shape = broadcasted_dummy.shape
624
+ except ValueError as e:
625
+ raise ValueError(
626
+ f"Operands cannot be broadcast together: CipherText shape {ciphertext.semantic_shape} "
627
+ f"vs plaintext shape {plaintext_np.shape}: {e}"
628
+ ) from e
629
+
630
+ # Broadcast plaintext to match result shape if needed
631
+ if plaintext_np.shape != result_shape:
632
+ plaintext_broadcasted = np.broadcast_to(plaintext_np, result_shape)
633
+ else:
634
+ plaintext_broadcasted = plaintext_np
635
+
636
+ # Broadcast ciphertext if needed
637
+ if ciphertext.semantic_shape != result_shape:
638
+ dummy_ct = (
639
+ np.arange(np.prod(ciphertext.semantic_shape))
640
+ .reshape(ciphertext.semantic_shape)
641
+ .astype(np.int64)
642
+ )
643
+ broadcasted_indices = np.broadcast_to(dummy_ct, result_shape).flatten()
644
+ raw_ct: list[Any] = ciphertext.ct_data
645
+ broadcasted_ct_data = [raw_ct[int(idx)] for idx in broadcasted_indices]
646
+ else:
647
+ broadcasted_ct_data = ciphertext.ct_data
648
+
649
+ # For ciphertext + plaintext addition, we encrypt the plaintext first
650
+ # and then do ciphertext + ciphertext addition
237
651
  if ciphertext.pk_data is None:
238
- raise ValueError("CipherText missing public key data")
652
+ raise ValueError(
653
+ "CipherText must contain public key data for plaintext addition"
654
+ )
655
+
656
+ # Create lightPHE instance to encrypt the plaintext
239
657
  phe = LightPHE(
240
658
  algorithm_name=ciphertext.scheme,
241
659
  key_size=ciphertext.key_size,
242
660
  precision=PRECISION,
243
661
  )
244
662
  phe.cs.keys["public_key"] = ciphertext.pk_data
663
+
664
+ # Encrypt the broadcasted plaintext using same method as original encryption
245
665
  target_dtype = ciphertext.semantic_dtype
246
- flat = pt_np.flatten()
247
- if target_dtype.is_floating:
248
- data_list = [float(x) for x in flat]
249
- else:
250
- data_list = [int(x) for x in flat]
251
- enc_pt = phe.encrypt(data_list)
252
- res_ct = ciphertext.ct_data + enc_pt
666
+ flat_data = plaintext_broadcasted.flatten()
667
+
668
+ # Use range encoding for consistency with encryption
669
+ encoded_data_list = []
670
+ for val in flat_data:
671
+ if ciphertext.modulus is None:
672
+ raise ValueError("Ciphertext modulus is None, encryption may have failed")
673
+ encoded_val = _range_encode_mixed(
674
+ val,
675
+ ciphertext.max_value,
676
+ ciphertext.fxp_bits,
677
+ ciphertext.modulus,
678
+ target_dtype,
679
+ )
680
+ encoded_data_list.append(encoded_val)
681
+
682
+ encrypted_plaintext = [phe.encrypt(val) for val in encoded_data_list]
683
+
684
+ # Perform addition
685
+ result_ciphertext = [
686
+ encrypted_plaintext[i] + broadcasted_ct_data[i]
687
+ for i in range(len(encrypted_plaintext))
688
+ ]
689
+
690
+ # Create result CipherText with broadcasted shape and encoding parameters
253
691
  return CipherText(
254
- ct_data=res_ct,
692
+ ct_data=result_ciphertext,
255
693
  semantic_dtype=ciphertext.semantic_dtype,
256
- semantic_shape=ciphertext.semantic_shape,
694
+ semantic_shape=result_shape,
257
695
  scheme=ciphertext.scheme,
258
696
  key_size=ciphertext.key_size,
259
697
  pk_data=ciphertext.pk_data,
698
+ max_value=ciphertext.max_value,
699
+ fxp_bits=ciphertext.fxp_bits,
700
+ modulus=ciphertext.modulus,
701
+ )
702
+
703
+
704
+ def _create_encrypted_zero(ciphertext: CipherText) -> Any:
705
+ # Create lightPHE instance with the same configuration
706
+ phe = LightPHE(
707
+ algorithm_name=ciphertext.scheme,
708
+ key_size=ciphertext.key_size,
709
+ precision=PRECISION,
260
710
  )
261
711
 
712
+ # CRITICAL: Set the same modulus as the original ciphertext
713
+ if ciphertext.modulus is not None:
714
+ phe.cs.plaintext_modulo = ciphertext.modulus
715
+ phe.cs.ciphertext_modulo = ciphertext.modulus * ciphertext.modulus
716
+
717
+ phe.cs.keys["public_key"] = ciphertext.pk_data
718
+
719
+ # Encrypt zero value using range encoding for consistency
720
+ if ciphertext.modulus is None:
721
+ raise ValueError("Ciphertext modulus is None, encryption may have failed")
722
+
723
+ zero_encoded = _range_encode_mixed(
724
+ 0,
725
+ ciphertext.max_value,
726
+ ciphertext.fxp_bits,
727
+ ciphertext.modulus,
728
+ ciphertext.semantic_dtype,
729
+ )
730
+
731
+ return phe.encrypt(zero_encoded)
732
+
262
733
 
263
734
  @kernel_def("phe.decrypt")
264
- def _phe_decrypt(pfunc: PFunction, ciphertext: Any, private_key: Any) -> Any:
735
+ def _phe_decrypt(
736
+ pfunc: PFunction, ciphertext: CipherText, private_key: PrivateKey
737
+ ) -> Any:
738
+ # Validate argument types
265
739
  if not isinstance(ciphertext, CipherText):
266
- raise ValueError("first arg must be CipherText")
740
+ raise ValueError("First argument must be a CipherText instance")
267
741
  if not isinstance(private_key, PrivateKey):
268
- raise ValueError("second arg must be PrivateKey")
742
+ raise ValueError("Second argument must be a PrivateKey instance")
743
+
744
+ # Validate key compatibility
269
745
  if (
270
746
  ciphertext.scheme != private_key.scheme
271
747
  or ciphertext.key_size != private_key.key_size
272
748
  ):
273
- raise ValueError("CipherText and PrivateKey must use same scheme/key size")
749
+ raise ValueError("CipherText and PrivateKey must use same scheme and key size")
750
+
274
751
  try:
752
+ # Create lightPHE instance with the same scheme/key_size
275
753
  phe = LightPHE(
276
754
  algorithm_name=private_key.scheme,
277
755
  key_size=private_key.key_size,
278
756
  precision=PRECISION,
279
757
  )
758
+
759
+ # CRITICAL FIX: Manually set the moduli to match the original encryption
760
+ # This ensures the decryption uses the same mathematical structure
761
+ if ciphertext.modulus is not None:
762
+ # Force the lightPHE instance to use the same modulus as during encryption
763
+ phe.cs.plaintext_modulo = ciphertext.modulus
764
+ # For Paillier: ciphertext_modulo = N^2
765
+ phe.cs.ciphertext_modulo = ciphertext.modulus * ciphertext.modulus
766
+
767
+ # Set both public and private keys (lightPHE needs both for proper decryption)
280
768
  phe.cs.keys["private_key"] = private_key.sk_data
281
769
  phe.cs.keys["public_key"] = private_key.pk_data
770
+
771
+ # Decrypt the data
282
772
  target_dtype = ciphertext.semantic_dtype.to_numpy()
283
- decrypted_raw = phe.decrypt(ciphertext.ct_data)
284
- if not isinstance(decrypted_raw, list):
285
- raise RuntimeError("Expected list from decryption")
286
- expected_size = (
287
- int(np.prod(ciphertext.semantic_shape)) if ciphertext.semantic_shape else 1
288
- )
289
- if len(decrypted_raw) != expected_size:
290
- raise RuntimeError("Unexpected decrypted length")
291
- if target_dtype.kind in "iu":
773
+ decrypted_raw = [phe.decrypt(ct) for ct in ciphertext.ct_data]
774
+
775
+ # Decode using range decoding
776
+ if ciphertext.modulus is None:
777
+ raise ValueError("Ciphertext modulus is None, encryption may have failed")
778
+
779
+ decoded_data = []
780
+ for encrypted_val in decrypted_raw:
781
+ # Extract numeric value from lightPHE result
782
+ if isinstance(encrypted_val, (int, float)):
783
+ raw_val = encrypted_val
784
+ elif hasattr(encrypted_val, "__getitem__") and len(encrypted_val) > 0:
785
+ raw_val = encrypted_val[0]
786
+ else:
787
+ raise ValueError(f"Cannot extract numeric value from {encrypted_val}")
788
+
789
+ # Convert to int for decoding
790
+ int_val = int(
791
+ raw_val
792
+ ) # Use mixed decoding which returns values based on semantic type
793
+ decoded_val = _range_decode_mixed(
794
+ int_val,
795
+ ciphertext.max_value,
796
+ ciphertext.fxp_bits,
797
+ ciphertext.modulus,
798
+ ciphertext.semantic_dtype,
799
+ )
800
+ decoded_data.append(decoded_val)
801
+
802
+ # Convert to target dtype
803
+ if target_dtype.kind in "iu": # integer types
804
+ # Convert floats back to integers for integer semantic types
805
+ processed_data = [round(val) for val in decoded_data]
806
+ # Handle overflow for smaller integer types
292
807
  info = np.iinfo(target_dtype)
293
- processed = [max(info.min, min(info.max, v)) for v in decrypted_raw]
294
- else:
295
- processed = decrypted_raw
296
- plaintext_np = np.array(processed, dtype=target_dtype).reshape(
808
+ processed_data = [
809
+ max(info.min, min(info.max, val)) for val in processed_data
810
+ ]
811
+ else: # float types
812
+ processed_data = decoded_data
813
+
814
+ # Create array and reshape to target shape
815
+ plaintext_np = np.array(processed_data, dtype=target_dtype).reshape(
297
816
  ciphertext.semantic_shape
298
817
  )
299
- return plaintext_np
300
- except Exception as e: # pragma: no cover
818
+
819
+ return [plaintext_np]
820
+
821
+ except Exception as e:
301
822
  raise RuntimeError(f"Failed to decrypt data: {e}") from e
823
+
824
+
825
+ @kernel_def("phe.dot")
826
+ def _phe_dot(pfunc: PFunction, ciphertext: CipherText, plaintext: TensorLike) -> Any:
827
+ """Execute homomorphic dot product with zero-value optimization.
828
+
829
+ Supports various dot product operations:
830
+ - Scalar * Scalar -> Scalar
831
+ - Vector * Vector -> Scalar (inner product)
832
+ - Matrix * Vector -> Vector
833
+ - N-D tensor * M-D tensor -> result based on numpy.dot semantics
834
+
835
+ Optimization: Skip multiplication when plaintext value is 0, and handle
836
+ the special case where all plaintext values are 0.
837
+
838
+ """
839
+ # Validate that first argument is a CipherText
840
+ if not isinstance(ciphertext, CipherText):
841
+ raise ValueError("First argument must be a CipherText instance")
842
+ if isinstance(plaintext, CipherText):
843
+ raise ValueError("Second argument must be a plaintext TensorLike")
844
+
845
+ try:
846
+ # Convert plaintext to numpy
847
+ plaintext_np = _convert_to_numpy(plaintext)
848
+
849
+ # Check if plaintext is floating point type - dot product not supported
850
+ if np.issubdtype(plaintext_np.dtype, np.floating):
851
+ raise ValueError(
852
+ f"Homomorphic dot product with floating point plaintext is not supported. "
853
+ f"Got plaintext dtype: {plaintext_np.dtype}"
854
+ )
855
+
856
+ # Use numpy.dot to determine result shape and validate compatibility
857
+ # Create dummy arrays with same shapes to test dot product compatibility
858
+ try:
859
+ dummy_ct = np.zeros(ciphertext.semantic_shape)
860
+ dummy_pt = np.zeros(plaintext_np.shape)
861
+ dummy_result = np.dot(dummy_ct, dummy_pt)
862
+ result_shape = dummy_result.shape
863
+ except ValueError as e:
864
+ raise ValueError(
865
+ f"Shapes are not compatible for dot product: CipherText shape {ciphertext.semantic_shape} "
866
+ f"vs plaintext shape {plaintext_np.shape}: {e}"
867
+ ) from e
868
+
869
+ # Perform dot product based on input dimensions
870
+ ct_shape = ciphertext.semantic_shape
871
+ pt_shape = plaintext_np.shape
872
+ target_dtype = ciphertext.semantic_dtype
873
+
874
+ if target_dtype.is_floating:
875
+ pt_data = plaintext_np.astype(float)
876
+ # Use a small epsilon for floating point zero comparison
877
+ epsilon = 1e-15
878
+ is_zero_func = lambda x: abs(x) < epsilon
879
+ else: # integer types
880
+ pt_data = plaintext_np.astype(int)
881
+ is_zero_func = lambda x: x == 0
882
+
883
+ # Helper function to create encrypted zero when needed
884
+ def get_encrypted_zero() -> Any:
885
+ return _create_encrypted_zero(ciphertext)
886
+
887
+ if len(ct_shape) == 0 and len(pt_shape) == 0:
888
+ # Scalar * Scalar
889
+ pt_val = pt_data.item()
890
+ if is_zero_func(pt_val):
891
+ result_ciphertext = get_encrypted_zero()
892
+ else:
893
+ # Use single value (not list) for multiplication
894
+ val = float(pt_val) if target_dtype.is_floating else int(pt_val)
895
+ result_ciphertext = ciphertext.ct_data[0] * val
896
+ result_ct_data = [result_ciphertext]
897
+
898
+ elif len(ct_shape) == 1 and len(pt_shape) == 1:
899
+ # Vector * Vector -> Scalar (inner product)
900
+ if ct_shape[0] != pt_shape[0]:
901
+ raise ValueError(
902
+ f"Vector size mismatch: CipherText size {ct_shape[0]} "
903
+ f"vs plaintext size {pt_shape[0]}"
904
+ )
905
+
906
+ # Compute element-wise products, skipping zeros
907
+ non_zero_products = []
908
+ for i in range(ct_shape[0]):
909
+ pt_val = pt_data[i]
910
+ if not is_zero_func(pt_val):
911
+ # Convert to appropriate type and use single value (not list)
912
+ val = float(pt_val) if target_dtype.is_floating else int(pt_val)
913
+ product = ciphertext.ct_data[i] * val
914
+ non_zero_products.append(product)
915
+
916
+ # Handle result
917
+ if not non_zero_products:
918
+ # All plaintext values are zero
919
+ result_ciphertext = get_encrypted_zero()
920
+ else:
921
+ # Sum all non-zero products
922
+ result_ciphertext = non_zero_products[0]
923
+ for i in range(1, len(non_zero_products)):
924
+ result_ciphertext = result_ciphertext + non_zero_products[i]
925
+
926
+ result_ct_data = [result_ciphertext]
927
+
928
+ elif len(ct_shape) == 2 and len(pt_shape) == 1:
929
+ # Matrix * Vector -> Vector
930
+ if ct_shape[1] != pt_shape[0]:
931
+ raise ValueError(
932
+ f"Matrix-vector dimension mismatch: Matrix shape {ct_shape} "
933
+ f"vs vector shape {pt_shape}"
934
+ )
935
+
936
+ result_ct_data = []
937
+ for i in range(ct_shape[0]): # For each row of the matrix
938
+ # Compute dot product of row i with the vector, skipping zeros
939
+ row_products = []
940
+ for j in range(ct_shape[1]): # For each column in the row
941
+ pt_val = pt_data[j]
942
+ if not is_zero_func(pt_val):
943
+ ct_idx = i * ct_shape[1] + j
944
+ # Use single value (not list) for multiplication
945
+ val = float(pt_val) if target_dtype.is_floating else int(pt_val)
946
+ product = ciphertext.ct_data[ct_idx] * val
947
+ row_products.append(product)
948
+
949
+ # Handle row result
950
+ if not row_products:
951
+ # All plaintext values in this row are zero
952
+ row_result = get_encrypted_zero()
953
+ else:
954
+ # Sum non-zero products for this row
955
+ row_result = row_products[0]
956
+ for k in range(1, len(row_products)):
957
+ row_result = row_result + row_products[k]
958
+
959
+ result_ct_data.append(row_result)
960
+
961
+ elif len(ct_shape) == 1 and len(pt_shape) == 2:
962
+ # Vector * Matrix -> Vector
963
+ if ct_shape[0] != pt_shape[0]:
964
+ raise ValueError(
965
+ f"Vector-matrix dimension mismatch: Vector shape {ct_shape} "
966
+ f"vs matrix shape {pt_shape}"
967
+ )
968
+
969
+ result_ct_data = []
970
+ for j in range(pt_shape[1]): # For each column of the matrix
971
+ # Compute dot product of vector with column j, skipping zeros
972
+ col_products = []
973
+ for i in range(pt_shape[0]): # For each row in the column
974
+ pt_val = pt_data[i, j]
975
+ if not is_zero_func(pt_val):
976
+ # Use single value (not list) for multiplication
977
+ val = float(pt_val) if target_dtype.is_floating else int(pt_val)
978
+ product = ciphertext.ct_data[i] * val
979
+ col_products.append(product)
980
+
981
+ # Handle column result
982
+ if not col_products:
983
+ # All plaintext values in this column are zero
984
+ col_result = get_encrypted_zero()
985
+ else:
986
+ # Sum non-zero products for this column
987
+ col_result = col_products[0]
988
+ for k in range(1, len(col_products)):
989
+ col_result = col_result + col_products[k]
990
+
991
+ result_ct_data.append(col_result)
992
+
993
+ elif len(ct_shape) == 2 and len(pt_shape) == 2:
994
+ # Matrix * Matrix -> Matrix
995
+ if ct_shape[1] != pt_shape[0]:
996
+ raise ValueError(
997
+ f"Matrix dimension mismatch: First matrix shape {ct_shape} "
998
+ f"vs second matrix shape {pt_shape}"
999
+ )
1000
+
1001
+ result_ct_data = []
1002
+ for i in range(ct_shape[0]): # For each row of first matrix
1003
+ for j in range(pt_shape[1]): # For each column of second matrix
1004
+ # Compute dot product of row i with column j, skipping zeros
1005
+ products = []
1006
+ for k in range(ct_shape[1]): # Sum over common dimension
1007
+ pt_val = pt_data[k, j]
1008
+ if not is_zero_func(pt_val):
1009
+ ct_idx = i * ct_shape[1] + k
1010
+ # Use single value (not list) for multiplication
1011
+ val = (
1012
+ float(pt_val)
1013
+ if target_dtype.is_floating
1014
+ else int(pt_val)
1015
+ )
1016
+ product = ciphertext.ct_data[ct_idx] * val
1017
+ products.append(product)
1018
+
1019
+ # Handle element result
1020
+ if not products:
1021
+ # All plaintext values for this element are zero
1022
+ element_result = get_encrypted_zero()
1023
+ else:
1024
+ # Sum non-zero products for this element
1025
+ element_result = products[0]
1026
+ for p in range(1, len(products)):
1027
+ element_result = element_result + products[p]
1028
+
1029
+ result_ct_data.append(element_result)
1030
+
1031
+ else:
1032
+ # General N-D tensor dot product
1033
+ # Flatten both tensors and perform generalized dot product
1034
+ ct_flat = ciphertext.ct_data
1035
+ pt_flat = pt_data.flatten()
1036
+
1037
+ # For general case, we implement numpy.dot semantics
1038
+ # This is a simplified implementation for common cases
1039
+ if len(ct_shape) >= 2 and len(pt_shape) >= 1:
1040
+ # Treat as matrix multiplication on the last axis of ct and first axis of pt
1041
+ last_dim_ct = ct_shape[-1]
1042
+ first_dim_pt = pt_shape[0]
1043
+
1044
+ if last_dim_ct != first_dim_pt:
1045
+ raise ValueError(
1046
+ f"Tensor dimension mismatch: CipherText last dimension {last_dim_ct} "
1047
+ f"vs plaintext first dimension {first_dim_pt}"
1048
+ )
1049
+
1050
+ # Reshape for matrix multiplication
1051
+ ct_reshaped_size = int(np.prod(ct_shape[:-1]))
1052
+ pt_reshaped_size = int(np.prod(pt_shape[1:]))
1053
+
1054
+ result_ct_data = []
1055
+ for i in range(ct_reshaped_size):
1056
+ for j in range(pt_reshaped_size):
1057
+ # Compute dot product for element (i, j), skipping zeros
1058
+ products = []
1059
+ for k in range(last_dim_ct):
1060
+ pt_idx = k * pt_reshaped_size + j
1061
+ pt_val = pt_flat[pt_idx]
1062
+ if not is_zero_func(pt_val):
1063
+ ct_idx = i * last_dim_ct + k
1064
+ # Use single value (not list) for multiplication
1065
+ val = (
1066
+ float(pt_val)
1067
+ if target_dtype.is_floating
1068
+ else int(pt_val)
1069
+ )
1070
+ product = ct_flat[ct_idx] * val
1071
+ products.append(product)
1072
+
1073
+ # Handle element result
1074
+ if not products:
1075
+ # All plaintext values for this element are zero
1076
+ element_result = get_encrypted_zero()
1077
+ else:
1078
+ # Sum non-zero products
1079
+ element_result = products[0]
1080
+ for p in range(1, len(products)):
1081
+ element_result = element_result + products[p]
1082
+ result_ct_data.append(element_result)
1083
+ else:
1084
+ raise ValueError(
1085
+ f"Unsupported tensor shapes for dot product: "
1086
+ f"CipherText shape {ct_shape}, plaintext shape {pt_shape}"
1087
+ )
1088
+
1089
+ # Create result CipherText with computed shape and encoding parameters
1090
+ return [
1091
+ CipherText(
1092
+ ct_data=result_ct_data,
1093
+ semantic_dtype=ciphertext.semantic_dtype,
1094
+ semantic_shape=result_shape,
1095
+ scheme=ciphertext.scheme,
1096
+ key_size=ciphertext.key_size,
1097
+ pk_data=ciphertext.pk_data,
1098
+ max_value=ciphertext.max_value,
1099
+ fxp_bits=ciphertext.fxp_bits,
1100
+ modulus=ciphertext.modulus,
1101
+ )
1102
+ ]
1103
+
1104
+ except ValueError:
1105
+ # Re-raise ValueError directly (validation errors)
1106
+ raise
1107
+ except Exception as e:
1108
+ raise RuntimeError(f"Failed to perform dot product: {e}") from e
1109
+
1110
+
1111
+ @kernel_def("phe.gather")
1112
+ def _phe_gather(pfunc: PFunction, ciphertext: CipherText, indices: Any) -> Any:
1113
+ """Execute gather operation on CipherText.
1114
+
1115
+ Supports gathering from multidimensional CipherText using multidimensional indices.
1116
+ The operation follows numpy.take semantics:
1117
+ - result.shape = indices.shape + ciphertext.shape[:axis] + ciphertext.shape[axis+1:]
1118
+ - Gathering is performed along the specified axis of ciphertext
1119
+ """
1120
+ # Validate that first argument is a CipherText
1121
+ if not isinstance(ciphertext, CipherText):
1122
+ raise ValueError("First argument must be a CipherText instance")
1123
+
1124
+ # Get axis parameter from pfunc.attrs, default to 0
1125
+ axis = pfunc.attrs.get("axis", 0)
1126
+
1127
+ try:
1128
+ # Convert indices to numpy
1129
+ indices_np = _convert_to_numpy(indices)
1130
+
1131
+ if not np.issubdtype(indices_np.dtype, np.integer):
1132
+ raise ValueError("Indices must be of integer type")
1133
+
1134
+ # Validate that ciphertext has at least 1 dimension for indexing
1135
+ if len(ciphertext.semantic_shape) == 0:
1136
+ raise ValueError("Cannot gather from scalar CipherText")
1137
+
1138
+ # Normalize axis to positive value
1139
+ ndim = len(ciphertext.semantic_shape)
1140
+ if axis < 0:
1141
+ axis = ndim + axis
1142
+ if axis < 0 or axis >= ndim:
1143
+ raise ValueError(
1144
+ f"Axis {pfunc.attrs.get('axis', 0)} is out of bounds for array of dimension {ndim}"
1145
+ )
1146
+
1147
+ # Validate indices are within bounds for the specified axis
1148
+ axis_size = ciphertext.semantic_shape[axis]
1149
+ if np.any(indices_np < 0) or np.any(indices_np >= axis_size):
1150
+ raise ValueError(
1151
+ f"Indices are out of bounds for axis {axis} with size {axis_size}. "
1152
+ f"Got indices in range [{np.min(indices_np)}, {np.max(indices_np)}]"
1153
+ )
1154
+
1155
+ # Calculate result shape: indices.shape + ciphertext.shape[:axis] + ciphertext.shape[axis+1:]
1156
+ result_shape = (
1157
+ indices_np.shape
1158
+ + ciphertext.semantic_shape[:axis]
1159
+ + ciphertext.semantic_shape[axis + 1 :]
1160
+ )
1161
+
1162
+ # Calculate strides for multi-axis gathering
1163
+ ct_shape = ciphertext.semantic_shape
1164
+
1165
+ # Stride calculations for arbitrary axis
1166
+ # Elements before axis contribute to outer stride
1167
+ outer_stride = int(np.prod(ct_shape[:axis])) if axis > 0 else 1
1168
+ # Elements after axis contribute to inner stride
1169
+ inner_stride = int(np.prod(ct_shape[axis + 1 :])) if axis < ndim - 1 else 1
1170
+ # Total stride for one step along the specified axis
1171
+ axis_stride = inner_stride
1172
+
1173
+ # Perform gather operation
1174
+ gathered_ct_data = []
1175
+
1176
+ # Iterate through all possible combinations of indices before the gather axis
1177
+ if axis == 0:
1178
+ # Special case: gathering along axis 0 (existing behavior)
1179
+ for idx in indices_np.flatten():
1180
+ start_pos = int(idx) * axis_stride
1181
+ end_pos = start_pos + axis_stride
1182
+ slice_data = ciphertext.ct_data[start_pos:end_pos]
1183
+ gathered_ct_data.extend(slice_data)
1184
+ else:
1185
+ # General case: gathering along arbitrary axis
1186
+ for outer_idx in range(outer_stride):
1187
+ for gather_idx in indices_np.flatten():
1188
+ # Calculate position in flattened ciphertext data
1189
+ pos = (
1190
+ outer_idx * (ct_shape[axis] * inner_stride)
1191
+ + int(gather_idx) * inner_stride
1192
+ )
1193
+ slice_data = ciphertext.ct_data[pos : pos + inner_stride]
1194
+ gathered_ct_data.extend(slice_data)
1195
+
1196
+ # Validate we got the expected number of elements
1197
+ expected_size = int(np.prod(result_shape)) if result_shape else 1
1198
+ if len(gathered_ct_data) != expected_size:
1199
+ raise RuntimeError(
1200
+ f"Internal error: Expected {expected_size} elements, got {len(gathered_ct_data)}"
1201
+ )
1202
+
1203
+ # Create result CipherText
1204
+ return [
1205
+ CipherText(
1206
+ ct_data=gathered_ct_data,
1207
+ semantic_dtype=ciphertext.semantic_dtype,
1208
+ semantic_shape=result_shape,
1209
+ scheme=ciphertext.scheme,
1210
+ key_size=ciphertext.key_size,
1211
+ pk_data=ciphertext.pk_data,
1212
+ max_value=ciphertext.max_value,
1213
+ fxp_bits=ciphertext.fxp_bits,
1214
+ modulus=ciphertext.modulus,
1215
+ )
1216
+ ]
1217
+
1218
+ except ValueError:
1219
+ # Re-raise ValueError directly (validation errors)
1220
+ raise
1221
+ except Exception as e:
1222
+ raise RuntimeError(f"Failed to perform gather: {e}") from e
1223
+
1224
+
1225
+ @kernel_def("phe.scatter")
1226
+ def _phe_scatter(
1227
+ pfunc: PFunction, ciphertext: CipherText, indices: TensorLike, updated: CipherText
1228
+ ) -> Any:
1229
+ """Execute scatter operation on CipherText.
1230
+
1231
+ Supports scattering into multidimensional CipherText using multidimensional indices.
1232
+ The operation follows numpy scatter semantics:
1233
+ - Scattering is performed along the specified axis of ciphertext
1234
+ - indices.shape must equal updated.shape[:len(indices.shape)]
1235
+ - updated.shape must be indices.shape + ciphertext.shape[:axis] + ciphertext.shape[axis+1:]
1236
+ - Result shape is same as original ciphertext.shape
1237
+
1238
+ """
1239
+ # Validate that first and third arguments are CipherTexts
1240
+ if not isinstance(ciphertext, CipherText) or not isinstance(updated, CipherText):
1241
+ raise ValueError("First and third arguments must be CipherText instances")
1242
+
1243
+ # Validate that both ciphertexts use same scheme/key_size
1244
+ if ciphertext.scheme != updated.scheme or ciphertext.key_size != updated.key_size:
1245
+ raise ValueError("Both CipherTexts must use same scheme and key size")
1246
+
1247
+ if ciphertext.pk_data != updated.pk_data:
1248
+ raise ValueError("Both CipherTexts must be encrypted with same key")
1249
+
1250
+ # Get axis parameter from pfunc.attrs, default to 0
1251
+ axis = pfunc.attrs.get("axis", 0)
1252
+
1253
+ try:
1254
+ # Convert indices to numpy
1255
+ indices_np = _convert_to_numpy(indices)
1256
+
1257
+ if not np.issubdtype(indices_np.dtype, np.integer):
1258
+ raise ValueError("Indices must be of integer type")
1259
+
1260
+ # Validate that ciphertext has at least 1 dimension for indexing
1261
+ if len(ciphertext.semantic_shape) == 0:
1262
+ raise ValueError("Cannot scatter into scalar CipherText")
1263
+
1264
+ # Normalize axis to positive value
1265
+ ndim = len(ciphertext.semantic_shape)
1266
+ if axis < 0:
1267
+ axis = ndim + axis
1268
+ if axis < 0 or axis >= ndim:
1269
+ raise ValueError(
1270
+ f"Axis {pfunc.attrs.get('axis', 0)} is out of bounds for array of dimension {ndim}"
1271
+ )
1272
+
1273
+ # Validate indices are within bounds for the specified axis
1274
+ axis_size = ciphertext.semantic_shape[axis]
1275
+ if np.any(indices_np < 0) or np.any(indices_np >= axis_size):
1276
+ raise ValueError(
1277
+ f"Indices are out of bounds for axis {axis} with size {axis_size}. "
1278
+ f"Got indices in range [{np.min(indices_np)}, {np.max(indices_np)}]"
1279
+ )
1280
+
1281
+ # Validate shape compatibility
1282
+ # Expected updated shape: indices.shape + ciphertext.shape[:axis] + ciphertext.shape[axis+1:]
1283
+ expected_updated_shape = (
1284
+ indices_np.shape
1285
+ + ciphertext.semantic_shape[:axis]
1286
+ + ciphertext.semantic_shape[axis + 1 :]
1287
+ )
1288
+ if updated.semantic_shape != expected_updated_shape:
1289
+ raise ValueError(
1290
+ f"Updated CipherText shape mismatch. Expected {expected_updated_shape}, "
1291
+ f"got {updated.semantic_shape}. "
1292
+ f"Updated shape must be indices.shape + ciphertext.shape[:axis] + ciphertext.shape[axis+1:]"
1293
+ )
1294
+
1295
+ # Calculate strides for multi-axis scattering
1296
+ ct_shape = ciphertext.semantic_shape
1297
+
1298
+ # Stride calculations for arbitrary axis
1299
+ # Elements before axis contribute to outer stride
1300
+ outer_stride = int(np.prod(ct_shape[:axis])) if axis > 0 else 1
1301
+ # Elements after axis contribute to inner stride
1302
+ inner_stride = int(np.prod(ct_shape[axis + 1 :])) if axis < ndim - 1 else 1
1303
+
1304
+ # Create a copy of the original ciphertext data for scattering
1305
+ scattered_ct_data = ciphertext.ct_data.copy()
1306
+
1307
+ # Perform scatter operation
1308
+ indices_flat = indices_np.flatten()
1309
+ updated_ct_data = updated.ct_data
1310
+
1311
+ if axis == 0:
1312
+ # Special case: scattering along axis 0 (existing behavior)
1313
+ axis_stride = inner_stride
1314
+ for i, idx in enumerate(indices_flat):
1315
+ start_pos_updated = i * axis_stride
1316
+ start_pos_original = int(idx) * axis_stride
1317
+
1318
+ for j in range(axis_stride):
1319
+ if start_pos_updated + j < len(updated_ct_data):
1320
+ scattered_ct_data[start_pos_original + j] = updated_ct_data[
1321
+ start_pos_updated + j
1322
+ ]
1323
+ else:
1324
+ # General case: scattering along arbitrary axis
1325
+ for outer_idx in range(outer_stride):
1326
+ for i, scatter_idx in enumerate(indices_flat):
1327
+ # Calculate position in flattened ciphertext data
1328
+ start_pos_original = (
1329
+ outer_idx * (ct_shape[axis] * inner_stride)
1330
+ + int(scatter_idx) * inner_stride
1331
+ )
1332
+ start_pos_updated = (
1333
+ outer_idx * len(indices_flat) + i
1334
+ ) * inner_stride
1335
+
1336
+ # Update the ciphertext data
1337
+ for j in range(inner_stride):
1338
+ if start_pos_updated + j < len(updated_ct_data):
1339
+ scattered_ct_data[start_pos_original + j] = updated_ct_data[
1340
+ start_pos_updated + j
1341
+ ]
1342
+
1343
+ # Create result CipherText with same shape as original
1344
+ return [
1345
+ CipherText(
1346
+ ct_data=scattered_ct_data,
1347
+ semantic_dtype=ciphertext.semantic_dtype,
1348
+ semantic_shape=ciphertext.semantic_shape,
1349
+ scheme=ciphertext.scheme,
1350
+ key_size=ciphertext.key_size,
1351
+ pk_data=ciphertext.pk_data,
1352
+ max_value=ciphertext.max_value,
1353
+ fxp_bits=ciphertext.fxp_bits,
1354
+ modulus=ciphertext.modulus,
1355
+ )
1356
+ ]
1357
+ except ValueError:
1358
+ # Re-raise ValueError directly (validation errors)
1359
+ raise
1360
+ except Exception as e:
1361
+ raise RuntimeError(f"Failed to perform scatter: {e}") from e
1362
+
1363
+
1364
+ @kernel_def("phe.concat")
1365
+ def _phe_concat(pfunc: PFunction, c1: CipherText, c2: CipherText) -> Any:
1366
+ """Execute concat operation on multiple CipherTexts.
1367
+
1368
+ Supports concatenation along any axis of multidimensional CipherTexts.
1369
+ The axis parameter is obtained from pfunc.attrs.
1370
+ """
1371
+ # Get axis parameter from pfunc.attrs, default to 0
1372
+ axis = pfunc.attrs.get("axis", 0)
1373
+
1374
+ # Validate that all arguments are CipherText
1375
+ if not isinstance(c1, CipherText) or not isinstance(c2, CipherText):
1376
+ raise ValueError("All arguments must be CipherText instances")
1377
+
1378
+ # Validate that all ciphertexts have the same key & scheme
1379
+ if c1.scheme != c2.scheme or c1.key_size != c2.key_size:
1380
+ raise ValueError("All CipherTexts must use same scheme and key size")
1381
+ if c1.pk_data != c2.pk_data:
1382
+ raise ValueError("All CipherTexts must be encrypted with same key")
1383
+ if c1.semantic_dtype != c2.semantic_dtype:
1384
+ raise ValueError(
1385
+ f"All CipherTexts must have same semantic dtype, got {c1.semantic_dtype} vs {c2.semantic_dtype}"
1386
+ )
1387
+
1388
+ # Validate dimensions and axis
1389
+ if len(c1.semantic_shape) != len(c2.semantic_shape):
1390
+ raise ValueError(
1391
+ f"All CipherTexts must have same number of dimensions for concat, got {len(c1.semantic_shape)} vs {len(c2.semantic_shape)}"
1392
+ )
1393
+
1394
+ # Handle scalar case
1395
+ if len(c1.semantic_shape) == 0:
1396
+ raise ValueError("Cannot concatenate scalar CipherTexts")
1397
+
1398
+ # Normalize axis (handle negative axis)
1399
+ ndim = len(c1.semantic_shape)
1400
+ if axis < 0:
1401
+ axis = ndim + axis
1402
+ if axis < 0 or axis >= ndim:
1403
+ raise ValueError(
1404
+ f"axis {pfunc.attrs.get('axis', 0)} is out of bounds for array of dimension {ndim}"
1405
+ )
1406
+
1407
+ # Validate that all dimensions except the concat axis are the same
1408
+ for i in range(ndim):
1409
+ if i != axis and c1.semantic_shape[i] != c2.semantic_shape[i]:
1410
+ raise ValueError(
1411
+ f"All CipherTexts must have same shape except along concatenation axis {axis}. "
1412
+ f"Shape mismatch at dimension {i}: {c1.semantic_shape[i]} vs {c2.semantic_shape[i]}"
1413
+ )
1414
+
1415
+ try:
1416
+ # Calculate result shape
1417
+ result_shape_list = list(c1.semantic_shape)
1418
+ result_shape_list[axis] = c1.semantic_shape[axis] + c2.semantic_shape[axis]
1419
+ result_shape = tuple(result_shape_list)
1420
+
1421
+ # Calculate the number of slices before the concatenation axis
1422
+ pre_axis_size = int(np.prod(c1.semantic_shape[:axis])) if axis > 0 else 1
1423
+ # Calculate the size of data along and after the concatenation axis
1424
+ c1_post_axis_size = int(np.prod(c1.semantic_shape[axis:]))
1425
+ c2_post_axis_size = int(np.prod(c2.semantic_shape[axis:]))
1426
+
1427
+ # Initialize result data
1428
+ concatenated_ct_data = []
1429
+
1430
+ # Perform concatenation
1431
+ for pre_idx in range(pre_axis_size):
1432
+ # For each slice before the concatenation axis
1433
+
1434
+ # Add data from c1 along the concatenation axis
1435
+ c1_start = pre_idx * c1_post_axis_size
1436
+ c1_end = c1_start + c1_post_axis_size
1437
+ concatenated_ct_data.extend(c1.ct_data[c1_start:c1_end])
1438
+
1439
+ # Add data from c2 along the concatenation axis
1440
+ c2_start = pre_idx * c2_post_axis_size
1441
+ c2_end = c2_start + c2_post_axis_size
1442
+ concatenated_ct_data.extend(c2.ct_data[c2_start:c2_end])
1443
+
1444
+ # Validate we got the expected number of elements
1445
+ expected_size = int(np.prod(result_shape))
1446
+ if len(concatenated_ct_data) != expected_size:
1447
+ raise RuntimeError(
1448
+ f"Internal error: Expected {expected_size} elements, got {len(concatenated_ct_data)}"
1449
+ )
1450
+
1451
+ # Create result CipherText
1452
+ return [
1453
+ CipherText(
1454
+ ct_data=concatenated_ct_data,
1455
+ semantic_dtype=c1.semantic_dtype,
1456
+ semantic_shape=result_shape,
1457
+ scheme=c1.scheme,
1458
+ key_size=c1.key_size,
1459
+ pk_data=c1.pk_data,
1460
+ max_value=c1.max_value,
1461
+ fxp_bits=c1.fxp_bits,
1462
+ modulus=c1.modulus,
1463
+ )
1464
+ ]
1465
+
1466
+ except ValueError:
1467
+ # Re-raise ValueError directly (validation errors)
1468
+ raise
1469
+ except Exception as e:
1470
+ raise RuntimeError(f"Failed to perform concat: {e}") from e
1471
+
1472
+
1473
+ @kernel_def("phe.reshape")
1474
+ def _phe_reshape(pfunc: PFunction, ciphertext: CipherText) -> Any:
1475
+ """Execute reshape operation on CipherText.
1476
+
1477
+ Changes the shape of a CipherText without changing its encrypted data.
1478
+ The new_shape parameter is obtained from pfunc.attrs.
1479
+ """
1480
+ # Validate that argument is a CipherText
1481
+ if not isinstance(ciphertext, CipherText):
1482
+ raise ValueError("Argument must be a CipherText instance")
1483
+
1484
+ # Get new_shape parameter from pfunc.attrs
1485
+ new_shape = pfunc.attrs.get("new_shape")
1486
+ if new_shape is None:
1487
+ raise ValueError("new_shape parameter is required for reshape operation")
1488
+
1489
+ # Convert new_shape to tuple if it's a list
1490
+ if isinstance(new_shape, list):
1491
+ new_shape = tuple(new_shape)
1492
+ elif not isinstance(new_shape, tuple):
1493
+ raise ValueError("new_shape must be a tuple or list of integers")
1494
+
1495
+ try:
1496
+ # Handle -1 dimension inference
1497
+ old_size = (
1498
+ int(np.prod(ciphertext.semantic_shape)) if ciphertext.semantic_shape else 1
1499
+ )
1500
+
1501
+ # Process new_shape to infer -1 dimensions
1502
+ inferred_shape = list(new_shape)
1503
+ negative_ones = [i for i, dim in enumerate(new_shape) if dim == -1]
1504
+
1505
+ if len(negative_ones) > 1:
1506
+ raise ValueError("can only specify one unknown dimension")
1507
+ elif len(negative_ones) == 1:
1508
+ # Calculate the inferred dimension
1509
+ known_size = 1
1510
+ for dim in new_shape:
1511
+ if dim != -1:
1512
+ if dim <= 0:
1513
+ raise ValueError(
1514
+ f"negative dimensions not allowed (except -1): {dim}"
1515
+ )
1516
+ known_size *= dim
1517
+
1518
+ if old_size % known_size != 0:
1519
+ raise ValueError(
1520
+ f"cannot reshape array of size {old_size} into shape {new_shape}"
1521
+ )
1522
+
1523
+ inferred_dim = old_size // known_size
1524
+ inferred_shape[negative_ones[0]] = inferred_dim
1525
+ else:
1526
+ # No -1 dimensions, validate that all dimensions are positive
1527
+ for dim in new_shape:
1528
+ if dim <= 0:
1529
+ raise ValueError(f"negative dimensions not allowed: {dim}")
1530
+
1531
+ # Convert back to tuple
1532
+ final_shape = tuple(inferred_shape)
1533
+
1534
+ # Validate that new shape has the same number of elements
1535
+ new_size = int(np.prod(final_shape)) if final_shape else 1
1536
+
1537
+ if old_size != new_size:
1538
+ raise ValueError(
1539
+ f"Cannot reshape CipherText with {old_size} elements to shape {final_shape} "
1540
+ f"with {new_size} elements"
1541
+ )
1542
+
1543
+ # Create result CipherText with new shape and encoding parameters (ct_data remains the same)
1544
+ return [
1545
+ CipherText(
1546
+ ct_data=ciphertext.ct_data, # Same encrypted data
1547
+ semantic_dtype=ciphertext.semantic_dtype,
1548
+ semantic_shape=final_shape, # Use the final shape
1549
+ scheme=ciphertext.scheme,
1550
+ key_size=ciphertext.key_size,
1551
+ pk_data=ciphertext.pk_data,
1552
+ max_value=ciphertext.max_value,
1553
+ fxp_bits=ciphertext.fxp_bits,
1554
+ modulus=ciphertext.modulus,
1555
+ )
1556
+ ]
1557
+
1558
+ except ValueError:
1559
+ # Re-raise ValueError directly (validation errors)
1560
+ raise
1561
+ except Exception as e:
1562
+ raise RuntimeError(f"Failed to perform reshape: {e}") from e
1563
+
1564
+
1565
+ @kernel_def("phe.transpose")
1566
+ def _phe_transpose(pfunc: PFunction, ciphertext: CipherText) -> Any:
1567
+ """Execute transpose operation on CipherText.
1568
+
1569
+ Permutes the dimensions of a CipherText according to the given axes.
1570
+ The axes parameter is obtained from pfunc.attrs.
1571
+ """
1572
+ # Validate that argument is a CipherText
1573
+ if not isinstance(ciphertext, CipherText):
1574
+ raise ValueError("Argument must be a CipherText instance")
1575
+
1576
+ # Handle scalar case
1577
+ if len(ciphertext.semantic_shape) == 0:
1578
+ # Transposing a scalar returns the same scalar
1579
+ return [ciphertext]
1580
+
1581
+ # Get axes parameter from pfunc.attrs
1582
+ axes = pfunc.attrs.get("axes")
1583
+
1584
+ # If axes is None, reverse all dimensions (default transpose behavior)
1585
+ if axes is None:
1586
+ axes = tuple(reversed(range(len(ciphertext.semantic_shape))))
1587
+ elif isinstance(axes, list):
1588
+ axes = tuple(axes)
1589
+ elif not isinstance(axes, tuple):
1590
+ raise ValueError("axes must be a tuple or list of integers, or None")
1591
+
1592
+ try:
1593
+ # Validate axes
1594
+ ndim = len(ciphertext.semantic_shape)
1595
+ if len(axes) != ndim:
1596
+ raise ValueError(
1597
+ f"axes length {len(axes)} does not match tensor dimensions {ndim}"
1598
+ )
1599
+
1600
+ # Normalize negative axes and validate range
1601
+ normalized_axes = []
1602
+ for axis in axes:
1603
+ if axis < 0:
1604
+ axis = ndim + axis
1605
+ if axis < 0 or axis >= ndim:
1606
+ raise ValueError(
1607
+ f"axis {axis} is out of bounds for array of dimension {ndim}"
1608
+ )
1609
+ normalized_axes.append(axis)
1610
+ axes = tuple(normalized_axes)
1611
+
1612
+ # Check for duplicate axes
1613
+ if len(set(axes)) != len(axes):
1614
+ raise ValueError("axes cannot contain duplicate values")
1615
+
1616
+ # Calculate new shape
1617
+ old_shape = ciphertext.semantic_shape
1618
+ new_shape = tuple(old_shape[axis] for axis in axes)
1619
+
1620
+ # For multidimensional transpose, we need to rearrange the encrypted data
1621
+ # Create mapping from old flat index to new flat index
1622
+ def transpose_data(ct_data: list, old_shape: tuple, axes: tuple) -> list:
1623
+ if len(old_shape) <= 1:
1624
+ # 1D or scalar case - no actual transposition needed
1625
+ return ct_data
1626
+
1627
+ # Create numpy array to help with index calculations
1628
+ dummy_array = np.arange(len(ct_data)).reshape(old_shape)
1629
+ transposed_dummy = np.transpose(dummy_array, axes)
1630
+
1631
+ # The new data should be arranged in the order that numpy.transpose would produce
1632
+ new_ct_data = [ct_data[idx] for idx in transposed_dummy.flatten()]
1633
+
1634
+ return new_ct_data
1635
+
1636
+ # Rearrange the encrypted data according to transpose
1637
+ transposed_ct_data = transpose_data(ciphertext.ct_data, old_shape, axes)
1638
+
1639
+ # Create result CipherText with transposed shape and rearranged data
1640
+ return [
1641
+ CipherText(
1642
+ ct_data=transposed_ct_data,
1643
+ semantic_dtype=ciphertext.semantic_dtype,
1644
+ semantic_shape=new_shape,
1645
+ scheme=ciphertext.scheme,
1646
+ key_size=ciphertext.key_size,
1647
+ pk_data=ciphertext.pk_data,
1648
+ max_value=ciphertext.max_value,
1649
+ fxp_bits=ciphertext.fxp_bits,
1650
+ modulus=ciphertext.modulus,
1651
+ )
1652
+ ]
1653
+
1654
+ except ValueError:
1655
+ # Re-raise ValueError directly (validation errors)
1656
+ raise
1657
+ except Exception as e:
1658
+ raise RuntimeError(f"Failed to perform transpose: {e}") from e