mplang-nightly 0.1.dev269__py3-none-any.whl → 0.1.dev271__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mplang/__init__.py +391 -17
- mplang/{v2/backends → backends}/__init__.py +9 -7
- mplang/{v2/backends → backends}/bfv_impl.py +6 -6
- mplang/{v2/backends → backends}/crypto_impl.py +6 -6
- mplang/{v2/backends → backends}/field_impl.py +5 -5
- mplang/{v2/backends → backends}/func_impl.py +4 -4
- mplang/{v2/backends → backends}/phe_impl.py +3 -3
- mplang/{v2/backends → backends}/simp_design.md +1 -1
- mplang/{v2/backends → backends}/simp_driver/__init__.py +5 -5
- mplang/{v2/backends → backends}/simp_driver/http.py +8 -8
- mplang/{v2/backends → backends}/simp_driver/mem.py +9 -9
- mplang/{v2/backends → backends}/simp_driver/ops.py +4 -4
- mplang/{v2/backends → backends}/simp_driver/state.py +2 -2
- mplang/{v2/backends → backends}/simp_driver/values.py +2 -2
- mplang/{v2/backends → backends}/simp_worker/__init__.py +3 -3
- mplang/{v2/backends → backends}/simp_worker/http.py +10 -10
- mplang/{v2/backends → backends}/simp_worker/mem.py +1 -1
- mplang/{v2/backends → backends}/simp_worker/ops.py +5 -5
- mplang/{v2/backends → backends}/simp_worker/state.py +2 -4
- mplang/{v2/backends → backends}/spu_impl.py +8 -8
- mplang/{v2/backends → backends}/spu_state.py +4 -4
- mplang/{v2/backends → backends}/store_impl.py +3 -3
- mplang/{v2/backends → backends}/table_impl.py +8 -8
- mplang/{v2/backends → backends}/tee_impl.py +6 -6
- mplang/{v2/backends → backends}/tensor_impl.py +6 -6
- mplang/{v2/cli.py → cli.py} +9 -9
- mplang/{v2/cli_guide.md → cli_guide.md} +12 -12
- mplang/{v2/dialects → dialects}/__init__.py +5 -5
- mplang/{v2/dialects → dialects}/bfv.py +6 -6
- mplang/{v2/dialects → dialects}/crypto.py +5 -5
- mplang/{v2/dialects → dialects}/dtypes.py +2 -2
- mplang/{v2/dialects → dialects}/field.py +3 -3
- mplang/{v2/dialects → dialects}/func.py +2 -2
- mplang/{v2/dialects → dialects}/phe.py +6 -6
- mplang/{v2/dialects → dialects}/simp.py +6 -6
- mplang/{v2/dialects → dialects}/spu.py +7 -7
- mplang/{v2/dialects → dialects}/store.py +2 -2
- mplang/{v2/dialects → dialects}/table.py +3 -3
- mplang/{v2/dialects → dialects}/tee.py +6 -6
- mplang/{v2/dialects → dialects}/tensor.py +5 -5
- mplang/{v2/edsl → edsl}/__init__.py +3 -3
- mplang/{v2/edsl → edsl}/context.py +6 -6
- mplang/{v2/edsl → edsl}/graph.py +5 -5
- mplang/{v2/edsl → edsl}/jit.py +2 -2
- mplang/{v2/edsl → edsl}/object.py +1 -1
- mplang/{v2/edsl → edsl}/primitive.py +5 -5
- mplang/{v2/edsl → edsl}/printer.py +1 -1
- mplang/{v2/edsl → edsl}/serde.py +1 -1
- mplang/{v2/edsl → edsl}/tracer.py +7 -7
- mplang/{v2/edsl → edsl}/typing.py +1 -1
- mplang/{v2/kernels → kernels}/ldpc.cpp +13 -13
- mplang/{v2/kernels → kernels}/okvs.cpp +4 -4
- mplang/{v2/kernels → kernels}/okvs_opt.cpp +31 -31
- mplang/{v2/kernels → kernels}/py_kernels.py +1 -1
- mplang/{v2/libs → libs}/collective.py +5 -5
- mplang/{v2/libs → libs}/device/__init__.py +1 -1
- mplang/{v2/libs → libs}/device/api.py +12 -12
- mplang/{v2/libs → libs}/ml/__init__.py +1 -1
- mplang/{v2/libs → libs}/ml/sgb.py +4 -4
- mplang/{v2/libs → libs}/mpc/__init__.py +3 -3
- mplang/{v2/libs → libs}/mpc/_utils.py +2 -2
- mplang/{v2/libs → libs}/mpc/analytics/aggregation.py +1 -1
- mplang/{v2/libs → libs}/mpc/analytics/groupby.py +2 -2
- mplang/{v2/libs → libs}/mpc/analytics/permutation.py +3 -3
- mplang/{v2/libs → libs}/mpc/ot/base.py +3 -3
- mplang/{v2/libs → libs}/mpc/ot/extension.py +2 -2
- mplang/{v2/libs → libs}/mpc/ot/silent.py +4 -4
- mplang/{v2/libs → libs}/mpc/psi/cuckoo.py +3 -3
- mplang/{v2/libs → libs}/mpc/psi/okvs.py +1 -1
- mplang/{v2/libs → libs}/mpc/psi/okvs_gct.py +3 -3
- mplang/{v2/libs → libs}/mpc/psi/oprf.py +3 -3
- mplang/{v2/libs → libs}/mpc/psi/rr22.py +7 -7
- mplang/{v2/libs → libs}/mpc/psi/unbalanced.py +4 -4
- mplang/{v2/libs → libs}/mpc/vole/gilboa.py +3 -3
- mplang/{v2/libs → libs}/mpc/vole/ldpc.py +2 -2
- mplang/{v2/libs → libs}/mpc/vole/silver.py +6 -6
- mplang/{v2/runtime → runtime}/interpreter.py +11 -11
- mplang/{v2/runtime → runtime}/value.py +2 -2
- mplang/{v1/runtime → utils}/__init__.py +18 -15
- mplang/{v1/utils → utils}/func_utils.py +1 -1
- {mplang_nightly-0.1.dev269.dist-info → mplang_nightly-0.1.dev271.dist-info}/METADATA +2 -2
- mplang_nightly-0.1.dev271.dist-info/RECORD +102 -0
- mplang/v1/__init__.py +0 -157
- mplang/v1/_device.py +0 -602
- mplang/v1/analysis/__init__.py +0 -37
- mplang/v1/analysis/diagram.py +0 -567
- mplang/v1/core/__init__.py +0 -157
- mplang/v1/core/cluster.py +0 -343
- mplang/v1/core/comm.py +0 -281
- mplang/v1/core/context_mgr.py +0 -50
- mplang/v1/core/dtypes.py +0 -335
- mplang/v1/core/expr/__init__.py +0 -80
- mplang/v1/core/expr/ast.py +0 -542
- mplang/v1/core/expr/evaluator.py +0 -581
- mplang/v1/core/expr/printer.py +0 -285
- mplang/v1/core/expr/transformer.py +0 -141
- mplang/v1/core/expr/utils.py +0 -78
- mplang/v1/core/expr/visitor.py +0 -85
- mplang/v1/core/expr/walk.py +0 -387
- mplang/v1/core/interp.py +0 -160
- mplang/v1/core/mask.py +0 -325
- mplang/v1/core/mpir.py +0 -965
- mplang/v1/core/mpobject.py +0 -117
- mplang/v1/core/mptype.py +0 -407
- mplang/v1/core/pfunc.py +0 -130
- mplang/v1/core/primitive.py +0 -877
- mplang/v1/core/table.py +0 -218
- mplang/v1/core/tensor.py +0 -75
- mplang/v1/core/tracer.py +0 -383
- mplang/v1/host.py +0 -130
- mplang/v1/kernels/__init__.py +0 -41
- mplang/v1/kernels/base.py +0 -125
- mplang/v1/kernels/basic.py +0 -240
- mplang/v1/kernels/context.py +0 -369
- mplang/v1/kernels/crypto.py +0 -122
- mplang/v1/kernels/fhe.py +0 -858
- mplang/v1/kernels/mock_tee.py +0 -72
- mplang/v1/kernels/phe.py +0 -1864
- mplang/v1/kernels/spu.py +0 -341
- mplang/v1/kernels/sql_duckdb.py +0 -44
- mplang/v1/kernels/stablehlo.py +0 -90
- mplang/v1/kernels/value.py +0 -626
- mplang/v1/ops/__init__.py +0 -35
- mplang/v1/ops/base.py +0 -424
- mplang/v1/ops/basic.py +0 -294
- mplang/v1/ops/crypto.py +0 -262
- mplang/v1/ops/fhe.py +0 -272
- mplang/v1/ops/jax_cc.py +0 -147
- mplang/v1/ops/nnx_cc.py +0 -168
- mplang/v1/ops/phe.py +0 -216
- mplang/v1/ops/spu.py +0 -151
- mplang/v1/ops/sql_cc.py +0 -303
- mplang/v1/ops/tee.py +0 -36
- mplang/v1/protos/v1alpha1/mpir_pb2.py +0 -63
- mplang/v1/protos/v1alpha1/mpir_pb2.pyi +0 -557
- mplang/v1/protos/v1alpha1/value_pb2.py +0 -34
- mplang/v1/protos/v1alpha1/value_pb2.pyi +0 -169
- mplang/v1/runtime/channel.py +0 -230
- mplang/v1/runtime/cli.py +0 -451
- mplang/v1/runtime/client.py +0 -456
- mplang/v1/runtime/communicator.py +0 -131
- mplang/v1/runtime/data_providers.py +0 -303
- mplang/v1/runtime/driver.py +0 -324
- mplang/v1/runtime/exceptions.py +0 -27
- mplang/v1/runtime/http_api.md +0 -56
- mplang/v1/runtime/link_comm.py +0 -196
- mplang/v1/runtime/server.py +0 -501
- mplang/v1/runtime/session.py +0 -270
- mplang/v1/runtime/simulation.py +0 -324
- mplang/v1/simp/__init__.py +0 -13
- mplang/v1/simp/api.py +0 -353
- mplang/v1/simp/mpi.py +0 -131
- mplang/v1/simp/party.py +0 -225
- mplang/v1/simp/random.py +0 -120
- mplang/v1/simp/smpc.py +0 -238
- mplang/v1/utils/__init__.py +0 -13
- mplang/v1/utils/crypto.py +0 -32
- mplang/v1/utils/spu_utils.py +0 -130
- mplang/v1/utils/table_utils.py +0 -185
- mplang/v2/__init__.py +0 -424
- mplang_nightly-0.1.dev269.dist-info/RECORD +0 -180
- /mplang/{v2/backends → backends}/channel.py +0 -0
- /mplang/{v2/edsl → edsl}/README.md +0 -0
- /mplang/{v2/edsl → edsl}/registry.py +0 -0
- /mplang/{v2/kernels → kernels}/Makefile +0 -0
- /mplang/{v2/kernels → kernels}/__init__.py +0 -0
- /mplang/{v2/kernels → kernels}/gf128.cpp +0 -0
- /mplang/{v2/libs → libs}/device/cluster.py +0 -0
- /mplang/{v2/libs → libs}/mpc/analytics/__init__.py +0 -0
- /mplang/{v2/libs → libs}/mpc/analytics/groupby.md +0 -0
- /mplang/{v2/libs → libs}/mpc/common/constants.py +0 -0
- /mplang/{v2/libs → libs}/mpc/ot/__init__.py +0 -0
- /mplang/{v2/libs → libs}/mpc/psi/__init__.py +0 -0
- /mplang/{v2/libs → libs}/mpc/vole/__init__.py +0 -0
- /mplang/{v2/runtime → runtime}/__init__.py +0 -0
- /mplang/{v2/runtime → runtime}/dialect_state.py +0 -0
- /mplang/{v2/runtime → runtime}/object_store.py +0 -0
- {mplang_nightly-0.1.dev269.dist-info → mplang_nightly-0.1.dev271.dist-info}/WHEEL +0 -0
- {mplang_nightly-0.1.dev269.dist-info → mplang_nightly-0.1.dev271.dist-info}/entry_points.txt +0 -0
- {mplang_nightly-0.1.dev269.dist-info → mplang_nightly-0.1.dev271.dist-info}/licenses/LICENSE +0 -0
mplang/v1/kernels/phe.py
DELETED
|
@@ -1,1864 +0,0 @@
|
|
|
1
|
-
# Copyright 2025 Ant Group Co., Ltd.
|
|
2
|
-
#
|
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
-
# you may not use this file except in compliance with the License.
|
|
5
|
-
# You may obtain a copy of the License at
|
|
6
|
-
#
|
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
-
#
|
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
-
# See the License for the specific language governing permissions and
|
|
13
|
-
# limitations under the License.
|
|
14
|
-
|
|
15
|
-
"""PHE (Partially Homomorphic Encryption) backend implementation using lightPHE."""
|
|
16
|
-
|
|
17
|
-
from __future__ import annotations
|
|
18
|
-
|
|
19
|
-
import json
|
|
20
|
-
from typing import Any, ClassVar
|
|
21
|
-
|
|
22
|
-
import numpy as np
|
|
23
|
-
from lightphe import LightPHE
|
|
24
|
-
from lightphe.models.Ciphertext import Ciphertext
|
|
25
|
-
|
|
26
|
-
from mplang.v1.core import DType, PFunction
|
|
27
|
-
from mplang.v1.kernels.base import kernel_def
|
|
28
|
-
from mplang.v1.kernels.value import (
|
|
29
|
-
TensorValue,
|
|
30
|
-
Value,
|
|
31
|
-
ValueDecodeError,
|
|
32
|
-
ValueProtoBuilder,
|
|
33
|
-
ValueProtoReader,
|
|
34
|
-
register_value,
|
|
35
|
-
)
|
|
36
|
-
from mplang.v1.protos.v1alpha1 import value_pb2 as _value_pb2
|
|
37
|
-
|
|
38
|
-
# This controls the decimal precision used in lightPHE for float operations
|
|
39
|
-
# we force it to 0 to only support integer operations
|
|
40
|
-
# we will support negative and floating-point with our own encoding/decoding
|
|
41
|
-
PRECISION = 0
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
@register_value
|
|
45
|
-
class PublicKey(Value):
|
|
46
|
-
"""PHE Public Key Value type."""
|
|
47
|
-
|
|
48
|
-
KIND: ClassVar[str] = "mplang.phe.PublicKey"
|
|
49
|
-
WIRE_VERSION: ClassVar[int] = 1
|
|
50
|
-
|
|
51
|
-
def __init__(
|
|
52
|
-
self,
|
|
53
|
-
key_data: Any,
|
|
54
|
-
scheme: str,
|
|
55
|
-
key_size: int,
|
|
56
|
-
max_value: int = 2**100,
|
|
57
|
-
fxp_bits: int = 12,
|
|
58
|
-
modulus: int | None = None,
|
|
59
|
-
):
|
|
60
|
-
self.key_data = key_data
|
|
61
|
-
self.scheme = scheme
|
|
62
|
-
self.key_size = key_size
|
|
63
|
-
self.max_value = max_value # Maximum absolute value B for range encoding
|
|
64
|
-
self.fxp_bits = fxp_bits # Fixed-point precision bits for float encoding
|
|
65
|
-
self.modulus = modulus # Paillier modulus N for range encoding
|
|
66
|
-
|
|
67
|
-
@property
|
|
68
|
-
def dtype(self) -> Any:
|
|
69
|
-
return np.dtype("O") # Use object dtype for binary data
|
|
70
|
-
|
|
71
|
-
@property
|
|
72
|
-
def shape(self) -> tuple[int, ...]:
|
|
73
|
-
return ()
|
|
74
|
-
|
|
75
|
-
@property
|
|
76
|
-
def max_float_value(self) -> float:
|
|
77
|
-
"""Maximum float value that can be encoded."""
|
|
78
|
-
return float(self.max_value / (2**self.fxp_bits))
|
|
79
|
-
|
|
80
|
-
def to_proto(self) -> _value_pb2.ValueProto:
|
|
81
|
-
"""Serialize PublicKey to wire format."""
|
|
82
|
-
return (
|
|
83
|
-
ValueProtoBuilder(self.KIND, self.WIRE_VERSION)
|
|
84
|
-
.set_attr("scheme", self.scheme)
|
|
85
|
-
.set_attr("key_size", self.key_size)
|
|
86
|
-
.set_attr("max_value", self.max_value)
|
|
87
|
-
.set_attr("fxp_bits", self.fxp_bits)
|
|
88
|
-
.set_attr("modulus", str(self.modulus) if self.modulus is not None else "")
|
|
89
|
-
.set_payload(json.dumps(self.key_data).encode("utf-8"))
|
|
90
|
-
.build()
|
|
91
|
-
)
|
|
92
|
-
|
|
93
|
-
@classmethod
|
|
94
|
-
def from_proto(cls, proto: _value_pb2.ValueProto) -> PublicKey:
|
|
95
|
-
"""Deserialize PublicKey from wire format."""
|
|
96
|
-
reader = ValueProtoReader(proto)
|
|
97
|
-
if reader.version != cls.WIRE_VERSION:
|
|
98
|
-
raise ValueDecodeError(f"Unsupported PublicKey version {reader.version}")
|
|
99
|
-
|
|
100
|
-
# Read metadata from runtime_attrs
|
|
101
|
-
scheme = reader.get_attr("scheme")
|
|
102
|
-
key_size = reader.get_attr("key_size")
|
|
103
|
-
max_value = reader.get_attr("max_value")
|
|
104
|
-
fxp_bits = reader.get_attr("fxp_bits")
|
|
105
|
-
modulus_str = reader.get_attr("modulus")
|
|
106
|
-
modulus = None if modulus_str == "" else int(modulus_str)
|
|
107
|
-
|
|
108
|
-
# JSON deserialize the public key dict
|
|
109
|
-
key_data = json.loads(reader.payload.decode("utf-8"))
|
|
110
|
-
|
|
111
|
-
return cls(
|
|
112
|
-
key_data=key_data,
|
|
113
|
-
scheme=scheme,
|
|
114
|
-
key_size=key_size,
|
|
115
|
-
max_value=max_value,
|
|
116
|
-
fxp_bits=fxp_bits,
|
|
117
|
-
modulus=modulus,
|
|
118
|
-
)
|
|
119
|
-
|
|
120
|
-
def __repr__(self) -> str:
|
|
121
|
-
return f"PublicKey(scheme={self.scheme}, key_size={self.key_size}, max_value={self.max_value}, fxp_bits={self.fxp_bits})"
|
|
122
|
-
|
|
123
|
-
|
|
124
|
-
@register_value
|
|
125
|
-
class PrivateKey(Value):
|
|
126
|
-
"""PHE Private Key Value type."""
|
|
127
|
-
|
|
128
|
-
KIND: ClassVar[str] = "mplang.phe.PrivateKey"
|
|
129
|
-
WIRE_VERSION: ClassVar[int] = 1
|
|
130
|
-
|
|
131
|
-
def __init__(
|
|
132
|
-
self,
|
|
133
|
-
sk_data: Any,
|
|
134
|
-
pk_data: Any,
|
|
135
|
-
scheme: str,
|
|
136
|
-
key_size: int,
|
|
137
|
-
max_value: int = 2**100,
|
|
138
|
-
fxp_bits: int = 12,
|
|
139
|
-
modulus: int | None = None,
|
|
140
|
-
):
|
|
141
|
-
self.sk_data = sk_data # Store private key data
|
|
142
|
-
self.pk_data = pk_data # Store public key data as well
|
|
143
|
-
self.scheme = scheme
|
|
144
|
-
self.key_size = key_size
|
|
145
|
-
self.max_value = max_value # Maximum absolute value B for range encoding
|
|
146
|
-
self.fxp_bits = fxp_bits # Fixed-point precision bits for float encoding
|
|
147
|
-
self.modulus = modulus # Paillier modulus N for range encoding
|
|
148
|
-
|
|
149
|
-
@property
|
|
150
|
-
def dtype(self) -> Any:
|
|
151
|
-
return np.dtype("O") # Use object dtype for binary data
|
|
152
|
-
|
|
153
|
-
@property
|
|
154
|
-
def shape(self) -> tuple[int, ...]:
|
|
155
|
-
return ()
|
|
156
|
-
|
|
157
|
-
@property
|
|
158
|
-
def max_float_value(self) -> float:
|
|
159
|
-
"""Maximum float value that can be encoded."""
|
|
160
|
-
return float(self.max_value / (2**self.fxp_bits))
|
|
161
|
-
|
|
162
|
-
def to_proto(self) -> _value_pb2.ValueProto:
|
|
163
|
-
"""Serialize PrivateKey to wire format."""
|
|
164
|
-
# JSON serialize both key dicts (contain int values)
|
|
165
|
-
# Store both keys in a single dict to avoid needing length metadata
|
|
166
|
-
keys_dict = {"sk": self.sk_data, "pk": self.pk_data}
|
|
167
|
-
|
|
168
|
-
return (
|
|
169
|
-
ValueProtoBuilder(self.KIND, self.WIRE_VERSION)
|
|
170
|
-
.set_attr("scheme", self.scheme)
|
|
171
|
-
.set_attr("key_size", self.key_size)
|
|
172
|
-
.set_attr("max_value", self.max_value)
|
|
173
|
-
.set_attr("fxp_bits", self.fxp_bits)
|
|
174
|
-
.set_attr("modulus", str(self.modulus) if self.modulus is not None else "")
|
|
175
|
-
.set_payload(json.dumps(keys_dict).encode("utf-8"))
|
|
176
|
-
.build()
|
|
177
|
-
)
|
|
178
|
-
|
|
179
|
-
@classmethod
|
|
180
|
-
def from_proto(cls, proto: _value_pb2.ValueProto) -> PrivateKey:
|
|
181
|
-
"""Deserialize PrivateKey from wire format."""
|
|
182
|
-
reader = ValueProtoReader(proto)
|
|
183
|
-
if reader.version != cls.WIRE_VERSION:
|
|
184
|
-
raise ValueDecodeError(f"Unsupported PrivateKey version {reader.version}")
|
|
185
|
-
|
|
186
|
-
# Read metadata from runtime_attrs
|
|
187
|
-
scheme = reader.get_attr("scheme")
|
|
188
|
-
key_size = reader.get_attr("key_size")
|
|
189
|
-
max_value = reader.get_attr("max_value")
|
|
190
|
-
fxp_bits = reader.get_attr("fxp_bits")
|
|
191
|
-
modulus_str = reader.get_attr("modulus")
|
|
192
|
-
modulus = None if modulus_str == "" else int(modulus_str)
|
|
193
|
-
|
|
194
|
-
# JSON deserialize both key dicts
|
|
195
|
-
keys_dict = json.loads(reader.payload.decode("utf-8"))
|
|
196
|
-
sk_data = keys_dict["sk"]
|
|
197
|
-
pk_data = keys_dict["pk"]
|
|
198
|
-
|
|
199
|
-
return cls(
|
|
200
|
-
sk_data=sk_data,
|
|
201
|
-
pk_data=pk_data,
|
|
202
|
-
scheme=scheme,
|
|
203
|
-
key_size=key_size,
|
|
204
|
-
max_value=max_value,
|
|
205
|
-
fxp_bits=fxp_bits,
|
|
206
|
-
modulus=modulus,
|
|
207
|
-
)
|
|
208
|
-
|
|
209
|
-
def __repr__(self) -> str:
|
|
210
|
-
return f"PrivateKey(scheme={self.scheme}, key_size={self.key_size}, max_value={self.max_value}, fxp_bits={self.fxp_bits})"
|
|
211
|
-
|
|
212
|
-
|
|
213
|
-
@register_value
|
|
214
|
-
class CipherText(Value):
|
|
215
|
-
"""PHE CipherText Value type."""
|
|
216
|
-
|
|
217
|
-
KIND: ClassVar[str] = "mplang.phe.CipherText"
|
|
218
|
-
WIRE_VERSION: ClassVar[int] = 1
|
|
219
|
-
|
|
220
|
-
def __init__(
|
|
221
|
-
self,
|
|
222
|
-
ct_data: Any,
|
|
223
|
-
semantic_dtype: DType,
|
|
224
|
-
semantic_shape: tuple[int, ...],
|
|
225
|
-
scheme: str,
|
|
226
|
-
key_size: int,
|
|
227
|
-
pk_data: Any = None, # Store public key for operations
|
|
228
|
-
max_value: int = 2**100,
|
|
229
|
-
fxp_bits: int = 12,
|
|
230
|
-
modulus: int | None = None,
|
|
231
|
-
):
|
|
232
|
-
self.ct_data = ct_data
|
|
233
|
-
self.semantic_dtype = semantic_dtype
|
|
234
|
-
self.semantic_shape = semantic_shape
|
|
235
|
-
self.scheme = scheme
|
|
236
|
-
self.key_size = key_size
|
|
237
|
-
self.pk_data = pk_data
|
|
238
|
-
self.max_value = max_value
|
|
239
|
-
self.fxp_bits = fxp_bits
|
|
240
|
-
self.modulus = modulus
|
|
241
|
-
|
|
242
|
-
@property
|
|
243
|
-
def dtype(self) -> Any:
|
|
244
|
-
return self.semantic_dtype.to_numpy()
|
|
245
|
-
|
|
246
|
-
@property
|
|
247
|
-
def shape(self) -> tuple[int, ...]:
|
|
248
|
-
return self.semantic_shape
|
|
249
|
-
|
|
250
|
-
@property
|
|
251
|
-
def max_float_value(self) -> float:
|
|
252
|
-
"""Maximum float value that can be encoded."""
|
|
253
|
-
return float(self.max_value / (2**self.fxp_bits))
|
|
254
|
-
|
|
255
|
-
def to_proto(self) -> _value_pb2.ValueProto:
|
|
256
|
-
"""Serialize CipherText to wire format.
|
|
257
|
-
|
|
258
|
-
WARNING: This serialization is tightly coupled to lightphe.Ciphertext
|
|
259
|
-
internal attributes (value, algorithm_name, keys). Any changes to these
|
|
260
|
-
attributes in future lightphe versions will break serialization.
|
|
261
|
-
|
|
262
|
-
TODO: Check if lightphe provides official serialization methods and
|
|
263
|
-
migrate to them if available. Consider adding version compatibility checks.
|
|
264
|
-
"""
|
|
265
|
-
# JSON serialize ciphertext components
|
|
266
|
-
# ct_data is a list of lightPHE Ciphertext objects
|
|
267
|
-
# Each Ciphertext has: value, algorithm_name, keys
|
|
268
|
-
# We need to serialize the list of ciphertexts
|
|
269
|
-
if not isinstance(self.ct_data, list):
|
|
270
|
-
raise ValueError(f"ct_data should be a list, got {type(self.ct_data)}")
|
|
271
|
-
|
|
272
|
-
ct_list = []
|
|
273
|
-
for ct in self.ct_data:
|
|
274
|
-
if not isinstance(ct, Ciphertext):
|
|
275
|
-
raise TypeError(
|
|
276
|
-
f"ct_data must contain lightphe.Ciphertext objects, got {type(ct).__name__}"
|
|
277
|
-
)
|
|
278
|
-
ct_list.append({
|
|
279
|
-
"value": ct.value,
|
|
280
|
-
"algorithm_name": ct.algorithm_name,
|
|
281
|
-
"keys": ct.keys,
|
|
282
|
-
})
|
|
283
|
-
|
|
284
|
-
# Combine ct_data and pk_data into single dict
|
|
285
|
-
payload_dict = {
|
|
286
|
-
"ct_list": ct_list,
|
|
287
|
-
"pk": self.pk_data if self.pk_data is not None else None,
|
|
288
|
-
}
|
|
289
|
-
|
|
290
|
-
return (
|
|
291
|
-
ValueProtoBuilder(self.KIND, self.WIRE_VERSION)
|
|
292
|
-
.set_attr("semantic_dtype", str(self.semantic_dtype))
|
|
293
|
-
.set_attr("semantic_shape", list(self.semantic_shape))
|
|
294
|
-
.set_attr("scheme", self.scheme)
|
|
295
|
-
.set_attr("key_size", self.key_size)
|
|
296
|
-
.set_attr("max_value", self.max_value)
|
|
297
|
-
.set_attr("fxp_bits", self.fxp_bits)
|
|
298
|
-
.set_attr("modulus", str(self.modulus) if self.modulus is not None else "")
|
|
299
|
-
.set_payload(json.dumps(payload_dict).encode("utf-8"))
|
|
300
|
-
.build()
|
|
301
|
-
)
|
|
302
|
-
|
|
303
|
-
@classmethod
|
|
304
|
-
def from_proto(cls, proto: _value_pb2.ValueProto) -> CipherText:
|
|
305
|
-
"""Deserialize CipherText from wire format."""
|
|
306
|
-
reader = ValueProtoReader(proto)
|
|
307
|
-
if reader.version != cls.WIRE_VERSION:
|
|
308
|
-
raise ValueDecodeError(f"Unsupported CipherText version {reader.version}")
|
|
309
|
-
|
|
310
|
-
# Read metadata from runtime_attrs
|
|
311
|
-
semantic_dtype_str = reader.get_attr("semantic_dtype")
|
|
312
|
-
semantic_shape = reader.get_attr("semantic_shape")
|
|
313
|
-
scheme = reader.get_attr("scheme")
|
|
314
|
-
key_size = reader.get_attr("key_size")
|
|
315
|
-
max_value = reader.get_attr("max_value")
|
|
316
|
-
fxp_bits = reader.get_attr("fxp_bits")
|
|
317
|
-
modulus_str = reader.get_attr("modulus")
|
|
318
|
-
modulus = None if modulus_str == "" else int(modulus_str)
|
|
319
|
-
|
|
320
|
-
# JSON deserialize ciphertext and public key
|
|
321
|
-
payload_dict = json.loads(reader.payload.decode("utf-8"))
|
|
322
|
-
ct_list = payload_dict["ct_list"]
|
|
323
|
-
pk_data = payload_dict["pk"]
|
|
324
|
-
|
|
325
|
-
# Reconstruct ct_data: list of Ciphertext objects
|
|
326
|
-
ct_data = []
|
|
327
|
-
for ct_dict in ct_list:
|
|
328
|
-
if ct_dict["keys"] is None or ct_dict["algorithm_name"] is None:
|
|
329
|
-
raise ValueDecodeError(
|
|
330
|
-
"Invalid CipherText: missing keys or algorithm_name in serialized data"
|
|
331
|
-
)
|
|
332
|
-
ct_data.append(
|
|
333
|
-
Ciphertext(
|
|
334
|
-
algorithm_name=ct_dict["algorithm_name"],
|
|
335
|
-
keys=ct_dict["keys"],
|
|
336
|
-
value=ct_dict["value"],
|
|
337
|
-
)
|
|
338
|
-
)
|
|
339
|
-
|
|
340
|
-
# Parse dtype string back to DType
|
|
341
|
-
dtype = DType.from_any(semantic_dtype_str)
|
|
342
|
-
|
|
343
|
-
return cls(
|
|
344
|
-
ct_data=ct_data,
|
|
345
|
-
semantic_dtype=dtype,
|
|
346
|
-
semantic_shape=tuple(semantic_shape),
|
|
347
|
-
scheme=scheme,
|
|
348
|
-
key_size=key_size,
|
|
349
|
-
pk_data=pk_data,
|
|
350
|
-
max_value=max_value,
|
|
351
|
-
fxp_bits=fxp_bits,
|
|
352
|
-
modulus=modulus,
|
|
353
|
-
)
|
|
354
|
-
|
|
355
|
-
def __repr__(self) -> str:
|
|
356
|
-
return f"CipherText(dtype={self.semantic_dtype}, shape={self.semantic_shape}, scheme={self.scheme})"
|
|
357
|
-
|
|
358
|
-
|
|
359
|
-
# Range-based encoding functions for negative numbers and floats
|
|
360
|
-
def _range_encode_integer(value: int, max_value: int, modulus: int) -> int:
|
|
361
|
-
"""
|
|
362
|
-
Range encoding function for integers.
|
|
363
|
-
- Positive numbers: encode(m) = m
|
|
364
|
-
- Negative numbers: encode(m) = N + m
|
|
365
|
-
"""
|
|
366
|
-
if not (-max_value <= value <= max_value):
|
|
367
|
-
raise ValueError(
|
|
368
|
-
f"Integer value {value} out of range [-{max_value}, {max_value}]"
|
|
369
|
-
)
|
|
370
|
-
|
|
371
|
-
if value >= 0:
|
|
372
|
-
encoded = value
|
|
373
|
-
else:
|
|
374
|
-
encoded = modulus + value
|
|
375
|
-
|
|
376
|
-
return encoded
|
|
377
|
-
|
|
378
|
-
|
|
379
|
-
def _range_encode_float(
|
|
380
|
-
value: float, max_value: int, fxp_bits: int, modulus: int
|
|
381
|
-
) -> int:
|
|
382
|
-
"""
|
|
383
|
-
Range encoding function for floats.
|
|
384
|
-
1. Fixed-point conversion: scaled_int = round(value * 2^fxp_bits)
|
|
385
|
-
2. Integer encoding rules
|
|
386
|
-
"""
|
|
387
|
-
max_float = max_value / (2**fxp_bits)
|
|
388
|
-
if not (-max_float <= value <= max_float):
|
|
389
|
-
raise ValueError(
|
|
390
|
-
f"Float value {value} out of range [-{max_float}, {max_float}]"
|
|
391
|
-
)
|
|
392
|
-
|
|
393
|
-
# Fixed-point encoding: float → scaled integer
|
|
394
|
-
scaled_int = round(value * (2**fxp_bits))
|
|
395
|
-
|
|
396
|
-
# Use integer encoding rules
|
|
397
|
-
return _range_encode_integer(scaled_int, max_value, modulus)
|
|
398
|
-
|
|
399
|
-
|
|
400
|
-
def _range_encode_mixed(
|
|
401
|
-
value: Any, max_value: int, fxp_bits: int, modulus: int, semantic_dtype: DType
|
|
402
|
-
) -> int:
|
|
403
|
-
"""
|
|
404
|
-
Mixed encoding function - automatically handle integers and floats based on semantic type.
|
|
405
|
-
Use semantic_dtype to choose between integer and float encoding.
|
|
406
|
-
"""
|
|
407
|
-
if semantic_dtype.is_floating:
|
|
408
|
-
# For floating semantic types, always use float encoding
|
|
409
|
-
return _range_encode_float(float(value), max_value, fxp_bits, modulus)
|
|
410
|
-
else:
|
|
411
|
-
# For integer semantic types, use integer encoding
|
|
412
|
-
return _range_encode_integer(int(value), max_value, modulus)
|
|
413
|
-
|
|
414
|
-
|
|
415
|
-
def _range_decode_integer(encoded_value: int, max_value: int, modulus: int) -> int:
|
|
416
|
-
"""
|
|
417
|
-
Range decoding function for integers.
|
|
418
|
-
- If r <= max_value: decode(r) = r
|
|
419
|
-
- If r >= N - max_value: decode(r) = r - N
|
|
420
|
-
- If max_value < r < N - max_value: overflow error
|
|
421
|
-
"""
|
|
422
|
-
|
|
423
|
-
# Ensure handling integer
|
|
424
|
-
if isinstance(encoded_value, (list, tuple)):
|
|
425
|
-
encoded_value = encoded_value[0]
|
|
426
|
-
encoded_value = int(encoded_value) % modulus
|
|
427
|
-
|
|
428
|
-
if encoded_value <= max_value:
|
|
429
|
-
return encoded_value
|
|
430
|
-
elif encoded_value >= modulus - max_value:
|
|
431
|
-
return encoded_value - modulus
|
|
432
|
-
else:
|
|
433
|
-
raise ValueError(f"Decoded value {encoded_value} is in overflow region")
|
|
434
|
-
|
|
435
|
-
|
|
436
|
-
def _range_decode_float(
|
|
437
|
-
encoded_value: int, max_value: int, fxp_bits: int, modulus: int
|
|
438
|
-
) -> float:
|
|
439
|
-
"""
|
|
440
|
-
Range decoding function for floats.
|
|
441
|
-
1. Integer decoding: decoded_int = range_decode_integer(encoded_value)
|
|
442
|
-
2. Fixed-point conversion: value = decoded_int / 2^fxp_bits
|
|
443
|
-
"""
|
|
444
|
-
# First decode as integer
|
|
445
|
-
decoded_int = _range_decode_integer(encoded_value, max_value, modulus)
|
|
446
|
-
|
|
447
|
-
# Fixed-point decoding: scaled integer → float
|
|
448
|
-
return float(decoded_int / (2**fxp_bits))
|
|
449
|
-
|
|
450
|
-
|
|
451
|
-
def _range_decode_mixed(
|
|
452
|
-
encoded_value: int,
|
|
453
|
-
max_value: int,
|
|
454
|
-
fxp_bits: int,
|
|
455
|
-
modulus: int,
|
|
456
|
-
semantic_dtype: DType,
|
|
457
|
-
) -> Any:
|
|
458
|
-
"""
|
|
459
|
-
Mixed decoding function - decode based on semantic type.
|
|
460
|
-
Use semantic_dtype to choose between integer and float decoding.
|
|
461
|
-
"""
|
|
462
|
-
if semantic_dtype.is_floating:
|
|
463
|
-
# For floating semantic types, decode as float
|
|
464
|
-
return _range_decode_float(encoded_value, max_value, fxp_bits, modulus)
|
|
465
|
-
else:
|
|
466
|
-
# For integer semantic types, decode as integer
|
|
467
|
-
return _range_decode_integer(encoded_value, max_value, modulus)
|
|
468
|
-
|
|
469
|
-
|
|
470
|
-
@kernel_def("phe.keygen")
|
|
471
|
-
def _phe_keygen(pfunc: PFunction) -> Any:
|
|
472
|
-
scheme = pfunc.attrs.get("scheme", "paillier")
|
|
473
|
-
# use small key_size to speed up tests
|
|
474
|
-
# in production use at least 2048 bits or 3072 bits for better security
|
|
475
|
-
key_size = pfunc.attrs.get("key_size", 2048)
|
|
476
|
-
# Accept very large max_value; allow decimal string input, kept simple like other attrs
|
|
477
|
-
max_value = int(pfunc.attrs.get("max_value", 2**32))
|
|
478
|
-
fxp_bits = int(pfunc.attrs.get("fxp_bits", 12))
|
|
479
|
-
|
|
480
|
-
# Validate scheme
|
|
481
|
-
if scheme.lower() not in ["paillier"]:
|
|
482
|
-
raise ValueError(f"Unsupported PHE scheme: {scheme}")
|
|
483
|
-
|
|
484
|
-
scheme = scheme.capitalize()
|
|
485
|
-
|
|
486
|
-
try:
|
|
487
|
-
# Set higher precision for better accuracy with floats
|
|
488
|
-
phe = LightPHE(
|
|
489
|
-
algorithm_name=scheme,
|
|
490
|
-
key_size=key_size,
|
|
491
|
-
precision=PRECISION,
|
|
492
|
-
)
|
|
493
|
-
|
|
494
|
-
pk_data = phe.cs.keys["public_key"]
|
|
495
|
-
sk_data = phe.cs.keys["private_key"]
|
|
496
|
-
modulus = phe.cs.plaintext_modulo # Get Paillier modulus N
|
|
497
|
-
|
|
498
|
-
# Validate safety: N should be much larger than 3*max_value
|
|
499
|
-
if modulus <= 3 * max_value:
|
|
500
|
-
raise ValueError(
|
|
501
|
-
f"Modulus {modulus} is too small for max_value {max_value}. Require N >> 3*B"
|
|
502
|
-
)
|
|
503
|
-
|
|
504
|
-
public_key = PublicKey(
|
|
505
|
-
key_data=pk_data,
|
|
506
|
-
scheme=scheme,
|
|
507
|
-
key_size=key_size,
|
|
508
|
-
max_value=max_value,
|
|
509
|
-
fxp_bits=fxp_bits,
|
|
510
|
-
modulus=modulus,
|
|
511
|
-
)
|
|
512
|
-
private_key = PrivateKey(
|
|
513
|
-
sk_data=sk_data,
|
|
514
|
-
pk_data=pk_data,
|
|
515
|
-
scheme=scheme,
|
|
516
|
-
key_size=key_size,
|
|
517
|
-
max_value=max_value,
|
|
518
|
-
fxp_bits=fxp_bits,
|
|
519
|
-
modulus=modulus,
|
|
520
|
-
)
|
|
521
|
-
|
|
522
|
-
return [public_key, private_key]
|
|
523
|
-
|
|
524
|
-
except Exception as e:
|
|
525
|
-
raise RuntimeError(f"Failed to generate PHE keys: {e}") from e
|
|
526
|
-
|
|
527
|
-
|
|
528
|
-
@kernel_def("phe.encrypt")
|
|
529
|
-
def _phe_encrypt(
|
|
530
|
-
pfunc: PFunction, plaintext: TensorValue, public_key: PublicKey
|
|
531
|
-
) -> Any:
|
|
532
|
-
# Validate public_key type
|
|
533
|
-
if not isinstance(public_key, PublicKey):
|
|
534
|
-
raise ValueError("Second argument must be a PublicKey instance")
|
|
535
|
-
|
|
536
|
-
try:
|
|
537
|
-
# Convert plaintext to numpy to get semantic type info
|
|
538
|
-
plaintext_np = plaintext.to_numpy()
|
|
539
|
-
semantic_dtype = DType.from_numpy(plaintext_np.dtype)
|
|
540
|
-
semantic_shape = plaintext_np.shape
|
|
541
|
-
|
|
542
|
-
# Create lightPHE instance with the same scheme/key_size as the key
|
|
543
|
-
phe = LightPHE(
|
|
544
|
-
algorithm_name=public_key.scheme,
|
|
545
|
-
key_size=public_key.key_size,
|
|
546
|
-
precision=PRECISION,
|
|
547
|
-
)
|
|
548
|
-
|
|
549
|
-
# CRITICAL: Set the same modulus as the key to ensure consistency
|
|
550
|
-
if public_key.modulus is not None:
|
|
551
|
-
phe.cs.plaintext_modulo = public_key.modulus
|
|
552
|
-
phe.cs.ciphertext_modulo = public_key.modulus * public_key.modulus
|
|
553
|
-
|
|
554
|
-
# Set the public key
|
|
555
|
-
phe.cs.keys["public_key"] = public_key.key_data
|
|
556
|
-
|
|
557
|
-
# Prepare data for encryption using range encoding
|
|
558
|
-
flat_data = plaintext_np.flatten()
|
|
559
|
-
|
|
560
|
-
# Use mixed encoding for consistent handling of integers and floats
|
|
561
|
-
encoded_data_list = []
|
|
562
|
-
for val in flat_data:
|
|
563
|
-
# Use mixed encoding to handle both integers and floats uniformly
|
|
564
|
-
if public_key.modulus is None:
|
|
565
|
-
raise ValueError(
|
|
566
|
-
"Public key modulus is None, key generation may have failed"
|
|
567
|
-
)
|
|
568
|
-
encoded_val = _range_encode_mixed(
|
|
569
|
-
val,
|
|
570
|
-
public_key.max_value,
|
|
571
|
-
public_key.fxp_bits,
|
|
572
|
-
public_key.modulus,
|
|
573
|
-
semantic_dtype,
|
|
574
|
-
)
|
|
575
|
-
encoded_data_list.append(encoded_val)
|
|
576
|
-
|
|
577
|
-
# Encrypt the encoded values (note: not passing as list, just the value)
|
|
578
|
-
lightphe_ciphertext = [phe.encrypt(val) for val in encoded_data_list]
|
|
579
|
-
|
|
580
|
-
# Create CipherText object with encoding parameters
|
|
581
|
-
ciphertext = CipherText(
|
|
582
|
-
ct_data=lightphe_ciphertext,
|
|
583
|
-
semantic_dtype=semantic_dtype,
|
|
584
|
-
semantic_shape=semantic_shape,
|
|
585
|
-
scheme=public_key.scheme,
|
|
586
|
-
key_size=public_key.key_size,
|
|
587
|
-
pk_data=public_key.key_data,
|
|
588
|
-
max_value=public_key.max_value,
|
|
589
|
-
fxp_bits=public_key.fxp_bits,
|
|
590
|
-
modulus=public_key.modulus,
|
|
591
|
-
)
|
|
592
|
-
|
|
593
|
-
return [ciphertext]
|
|
594
|
-
|
|
595
|
-
except Exception as e:
|
|
596
|
-
raise RuntimeError(f"Failed to encrypt data: {e}") from e
|
|
597
|
-
|
|
598
|
-
|
|
599
|
-
@kernel_def("phe.mul")
|
|
600
|
-
def _phe_mul(pfunc: PFunction, ciphertext: CipherText, plaintext: TensorValue) -> Any:
|
|
601
|
-
# Validate that first argument is a CipherText
|
|
602
|
-
if not isinstance(ciphertext, CipherText):
|
|
603
|
-
raise ValueError("First argument must be a CipherText instance")
|
|
604
|
-
|
|
605
|
-
try:
|
|
606
|
-
# Convert plaintext to numpy
|
|
607
|
-
plaintext_np = plaintext.to_numpy()
|
|
608
|
-
|
|
609
|
-
# Check if plaintext is floating point type - multiplication not supported
|
|
610
|
-
if np.issubdtype(plaintext_np.dtype, np.floating):
|
|
611
|
-
raise ValueError(
|
|
612
|
-
f"Homomorphic multiplication with floating point plaintext is not supported. "
|
|
613
|
-
f"Got plaintext dtype: {plaintext_np.dtype}"
|
|
614
|
-
)
|
|
615
|
-
|
|
616
|
-
# Use numpy broadcasting to determine result shape and broadcast operands
|
|
617
|
-
# Create dummy arrays with the same shapes to test broadcasting
|
|
618
|
-
try:
|
|
619
|
-
dummy_ct = np.zeros(ciphertext.semantic_shape)
|
|
620
|
-
dummy_pt = np.zeros(plaintext_np.shape)
|
|
621
|
-
broadcasted_dummy = dummy_ct * dummy_pt
|
|
622
|
-
result_shape = broadcasted_dummy.shape
|
|
623
|
-
except ValueError as e:
|
|
624
|
-
raise ValueError(
|
|
625
|
-
f"Operands cannot be broadcast together: CipherText shape {ciphertext.semantic_shape} "
|
|
626
|
-
f"vs plaintext shape {plaintext_np.shape}: {e}"
|
|
627
|
-
) from e
|
|
628
|
-
|
|
629
|
-
# Broadcast plaintext to match result shape if needed
|
|
630
|
-
if plaintext_np.shape != result_shape:
|
|
631
|
-
plaintext_broadcasted = np.broadcast_to(plaintext_np, result_shape)
|
|
632
|
-
else:
|
|
633
|
-
plaintext_broadcasted = plaintext_np
|
|
634
|
-
|
|
635
|
-
# If ciphertext needs broadcasting, we need to replicate its encrypted values
|
|
636
|
-
if ciphertext.semantic_shape != result_shape:
|
|
637
|
-
# Use numpy to create a properly broadcasted index mapping
|
|
638
|
-
# Create a dummy array with same shape as ciphertext, fill with indices
|
|
639
|
-
dummy_ct = (
|
|
640
|
-
np
|
|
641
|
-
.arange(np.prod(ciphertext.semantic_shape))
|
|
642
|
-
.reshape(ciphertext.semantic_shape)
|
|
643
|
-
.astype(np.int64)
|
|
644
|
-
)
|
|
645
|
-
# Broadcast this to the result shape
|
|
646
|
-
broadcasted_indices = np.broadcast_to(dummy_ct, result_shape).flatten()
|
|
647
|
-
|
|
648
|
-
# Replicate ciphertext data according to the broadcasted indices
|
|
649
|
-
raw_ct: list[Any] = ciphertext.ct_data
|
|
650
|
-
broadcasted_ct_data = [raw_ct[int(idx)] for idx in broadcasted_indices]
|
|
651
|
-
else:
|
|
652
|
-
# No broadcasting needed for ciphertext
|
|
653
|
-
broadcasted_ct_data = ciphertext.ct_data
|
|
654
|
-
|
|
655
|
-
# Flatten the broadcasted plaintext data for element-wise multiplication
|
|
656
|
-
target_dtype = ciphertext.semantic_dtype
|
|
657
|
-
flat_data = plaintext_broadcasted.flatten()
|
|
658
|
-
|
|
659
|
-
# For multiplication, plaintext multipliers should NOT be encoded
|
|
660
|
-
# The ciphertext already contains the encoded value, multiplying by raw plaintext preserves semantics
|
|
661
|
-
raw_multipliers = []
|
|
662
|
-
for val in flat_data:
|
|
663
|
-
# Convert to appropriate numeric type but don't apply any encoding
|
|
664
|
-
if target_dtype.is_floating:
|
|
665
|
-
raw_val = float(val)
|
|
666
|
-
else:
|
|
667
|
-
raw_val = int(val)
|
|
668
|
-
raw_multipliers.append(raw_val)
|
|
669
|
-
|
|
670
|
-
# Perform homomorphic multiplication
|
|
671
|
-
# In Paillier, ciphertext * plaintext is supported
|
|
672
|
-
result_ciphertext = [
|
|
673
|
-
broadcasted_ct_data[i] * raw_multipliers[i]
|
|
674
|
-
for i in range(len(raw_multipliers))
|
|
675
|
-
]
|
|
676
|
-
|
|
677
|
-
# Create result CipherText with the broadcasted shape and encoding parameters
|
|
678
|
-
return [
|
|
679
|
-
CipherText(
|
|
680
|
-
ct_data=result_ciphertext,
|
|
681
|
-
semantic_dtype=ciphertext.semantic_dtype,
|
|
682
|
-
semantic_shape=result_shape,
|
|
683
|
-
scheme=ciphertext.scheme,
|
|
684
|
-
key_size=ciphertext.key_size,
|
|
685
|
-
pk_data=ciphertext.pk_data,
|
|
686
|
-
max_value=ciphertext.max_value,
|
|
687
|
-
fxp_bits=ciphertext.fxp_bits,
|
|
688
|
-
modulus=ciphertext.modulus,
|
|
689
|
-
)
|
|
690
|
-
]
|
|
691
|
-
|
|
692
|
-
except ValueError:
|
|
693
|
-
# Re-raise ValueError directly (validation errors)
|
|
694
|
-
raise
|
|
695
|
-
except Exception as e:
|
|
696
|
-
raise RuntimeError(f"Failed to perform multiplication: {e}") from e
|
|
697
|
-
|
|
698
|
-
|
|
699
|
-
@kernel_def("phe.add")
|
|
700
|
-
def _phe_add(pfunc: PFunction, lhs: Any, rhs: Any) -> Any:
|
|
701
|
-
try:
|
|
702
|
-
if isinstance(lhs, CipherText) and isinstance(rhs, CipherText):
|
|
703
|
-
return _phe_add_ct2ct(lhs, rhs)
|
|
704
|
-
elif isinstance(lhs, CipherText):
|
|
705
|
-
return _phe_add_ct2pt(lhs, rhs)
|
|
706
|
-
elif isinstance(rhs, CipherText):
|
|
707
|
-
return _phe_add_ct2pt(rhs, lhs)
|
|
708
|
-
else:
|
|
709
|
-
return TensorValue(lhs.to_numpy() + rhs.to_numpy())
|
|
710
|
-
except ValueError:
|
|
711
|
-
raise
|
|
712
|
-
except Exception as e: # pragma: no cover
|
|
713
|
-
raise RuntimeError(f"Failed to perform addition: {e}") from e
|
|
714
|
-
|
|
715
|
-
|
|
716
|
-
def _phe_add_ct2ct(ct1: CipherText, ct2: CipherText) -> CipherText:
|
|
717
|
-
# Validate compatibility
|
|
718
|
-
if ct1.scheme != ct2.scheme or ct1.key_size != ct2.key_size:
|
|
719
|
-
raise ValueError("CipherText operands must use same scheme and key size")
|
|
720
|
-
|
|
721
|
-
if ct1.pk_data != ct2.pk_data:
|
|
722
|
-
raise ValueError("CipherText operands must be encrypted with same key")
|
|
723
|
-
|
|
724
|
-
# Check for mixed precision issue: floating point ciphertext + integer ciphertext
|
|
725
|
-
# This would cause decode failures due to different fixed-point encoding scales
|
|
726
|
-
if ct1.semantic_dtype.is_floating != ct2.semantic_dtype.is_floating:
|
|
727
|
-
raise ValueError(
|
|
728
|
-
f"Cannot add ciphertexts with different numeric types due to fixed-point encoding. "
|
|
729
|
-
f"First CipherText dtype: {ct1.semantic_dtype}, second CipherText dtype: {ct2.semantic_dtype}. "
|
|
730
|
-
f"Both operands must have the same numeric type (both floating or both integer)."
|
|
731
|
-
)
|
|
732
|
-
|
|
733
|
-
# Use numpy broadcasting to determine result shape and broadcast operands
|
|
734
|
-
try:
|
|
735
|
-
dummy_ct1 = np.zeros(ct1.semantic_shape)
|
|
736
|
-
dummy_ct2 = np.zeros(ct2.semantic_shape)
|
|
737
|
-
broadcasted_dummy = dummy_ct1 + dummy_ct2
|
|
738
|
-
result_shape = broadcasted_dummy.shape
|
|
739
|
-
except ValueError as e:
|
|
740
|
-
raise ValueError(
|
|
741
|
-
f"CipherText operands cannot be broadcast together: shape {ct1.semantic_shape} "
|
|
742
|
-
f"vs shape {ct2.semantic_shape}: {e}"
|
|
743
|
-
) from e
|
|
744
|
-
|
|
745
|
-
# Broadcast ct1 if needed
|
|
746
|
-
if ct1.semantic_shape != result_shape:
|
|
747
|
-
dummy_ct1 = (
|
|
748
|
-
np
|
|
749
|
-
.arange(np.prod(ct1.semantic_shape))
|
|
750
|
-
.reshape(ct1.semantic_shape)
|
|
751
|
-
.astype(np.int64)
|
|
752
|
-
)
|
|
753
|
-
broadcasted_indices1 = np.broadcast_to(dummy_ct1, result_shape).flatten()
|
|
754
|
-
raw_ct1: list[Any] = ct1.ct_data
|
|
755
|
-
broadcasted_ct1_data = [raw_ct1[int(idx)] for idx in broadcasted_indices1]
|
|
756
|
-
else:
|
|
757
|
-
broadcasted_ct1_data = ct1.ct_data
|
|
758
|
-
|
|
759
|
-
# Broadcast ct2 if needed
|
|
760
|
-
if ct2.semantic_shape != result_shape:
|
|
761
|
-
dummy_ct2 = (
|
|
762
|
-
np
|
|
763
|
-
.arange(np.prod(ct2.semantic_shape))
|
|
764
|
-
.reshape(ct2.semantic_shape)
|
|
765
|
-
.astype(np.int64)
|
|
766
|
-
)
|
|
767
|
-
broadcasted_indices2 = np.broadcast_to(dummy_ct2, result_shape).flatten()
|
|
768
|
-
raw_ct2: list[Any] = ct2.ct_data
|
|
769
|
-
broadcasted_ct2_data = [raw_ct2[int(idx)] for idx in broadcasted_indices2]
|
|
770
|
-
else:
|
|
771
|
-
broadcasted_ct2_data = ct2.ct_data
|
|
772
|
-
|
|
773
|
-
# Perform homomorphic addition
|
|
774
|
-
result_ciphertext = [
|
|
775
|
-
broadcasted_ct1_data[i] + broadcasted_ct2_data[i]
|
|
776
|
-
for i in range(len(broadcasted_ct1_data))
|
|
777
|
-
]
|
|
778
|
-
|
|
779
|
-
# Create result CipherText with broadcasted shape and encoding parameters
|
|
780
|
-
return CipherText(
|
|
781
|
-
ct_data=result_ciphertext,
|
|
782
|
-
semantic_dtype=ct1.semantic_dtype,
|
|
783
|
-
semantic_shape=result_shape,
|
|
784
|
-
scheme=ct1.scheme,
|
|
785
|
-
key_size=ct1.key_size,
|
|
786
|
-
pk_data=ct1.pk_data,
|
|
787
|
-
max_value=ct1.max_value,
|
|
788
|
-
fxp_bits=ct1.fxp_bits,
|
|
789
|
-
modulus=ct1.modulus,
|
|
790
|
-
)
|
|
791
|
-
|
|
792
|
-
|
|
793
|
-
def _phe_add_ct2pt(ciphertext: CipherText, plaintext: TensorValue) -> CipherText:
|
|
794
|
-
# Convert plaintext to numpy
|
|
795
|
-
plaintext_np = plaintext.to_numpy()
|
|
796
|
-
plaintext_dtype = DType.from_numpy(plaintext_np.dtype)
|
|
797
|
-
|
|
798
|
-
# Check for mixed precision issue: floating point ciphertext + integer plaintext
|
|
799
|
-
# This would cause decode failures due to 2**fxp * f + i scaling mismatch
|
|
800
|
-
if ciphertext.semantic_dtype.is_floating and not plaintext_dtype.is_floating:
|
|
801
|
-
raise ValueError(
|
|
802
|
-
f"Cannot add integer plaintext to floating point ciphertext due to fixed-point encoding. "
|
|
803
|
-
f"CipherText dtype: {ciphertext.semantic_dtype}, plaintext dtype: {plaintext_dtype}. "
|
|
804
|
-
f"Both operands must have the same numeric type (both floating or both integer)."
|
|
805
|
-
)
|
|
806
|
-
|
|
807
|
-
# Check for mixed precision issue: integer ciphertext + floating point plaintext
|
|
808
|
-
if not ciphertext.semantic_dtype.is_floating and plaintext_dtype.is_floating:
|
|
809
|
-
raise ValueError(
|
|
810
|
-
f"Cannot add floating point plaintext to integer ciphertext due to fixed-point encoding. "
|
|
811
|
-
f"CipherText dtype: {ciphertext.semantic_dtype}, plaintext dtype: {plaintext_dtype}. "
|
|
812
|
-
f"Both operands must have the same numeric type (both floating or both integer)."
|
|
813
|
-
)
|
|
814
|
-
|
|
815
|
-
# Use numpy broadcasting to determine result shape and broadcast operands
|
|
816
|
-
try:
|
|
817
|
-
dummy_ct = np.zeros(ciphertext.semantic_shape)
|
|
818
|
-
dummy_pt = np.zeros(plaintext_np.shape)
|
|
819
|
-
broadcasted_dummy = dummy_ct + dummy_pt
|
|
820
|
-
result_shape = broadcasted_dummy.shape
|
|
821
|
-
except ValueError as e:
|
|
822
|
-
raise ValueError(
|
|
823
|
-
f"Operands cannot be broadcast together: CipherText shape {ciphertext.semantic_shape} "
|
|
824
|
-
f"vs plaintext shape {plaintext_np.shape}: {e}"
|
|
825
|
-
) from e
|
|
826
|
-
|
|
827
|
-
# Broadcast plaintext to match result shape if needed
|
|
828
|
-
if plaintext_np.shape != result_shape:
|
|
829
|
-
plaintext_broadcasted = np.broadcast_to(plaintext_np, result_shape)
|
|
830
|
-
else:
|
|
831
|
-
plaintext_broadcasted = plaintext_np
|
|
832
|
-
|
|
833
|
-
# Broadcast ciphertext if needed
|
|
834
|
-
if ciphertext.semantic_shape != result_shape:
|
|
835
|
-
dummy_ct = (
|
|
836
|
-
np
|
|
837
|
-
.arange(np.prod(ciphertext.semantic_shape))
|
|
838
|
-
.reshape(ciphertext.semantic_shape)
|
|
839
|
-
.astype(np.int64)
|
|
840
|
-
)
|
|
841
|
-
broadcasted_indices = np.broadcast_to(dummy_ct, result_shape).flatten()
|
|
842
|
-
raw_ct: list[Any] = ciphertext.ct_data
|
|
843
|
-
broadcasted_ct_data = [raw_ct[int(idx)] for idx in broadcasted_indices]
|
|
844
|
-
else:
|
|
845
|
-
broadcasted_ct_data = ciphertext.ct_data
|
|
846
|
-
|
|
847
|
-
# For ciphertext + plaintext addition, we encrypt the plaintext first
|
|
848
|
-
# and then do ciphertext + ciphertext addition
|
|
849
|
-
if ciphertext.pk_data is None:
|
|
850
|
-
raise ValueError(
|
|
851
|
-
"CipherText must contain public key data for plaintext addition"
|
|
852
|
-
)
|
|
853
|
-
|
|
854
|
-
# Create lightPHE instance to encrypt the plaintext
|
|
855
|
-
phe = LightPHE(
|
|
856
|
-
algorithm_name=ciphertext.scheme,
|
|
857
|
-
key_size=ciphertext.key_size,
|
|
858
|
-
precision=PRECISION,
|
|
859
|
-
)
|
|
860
|
-
phe.cs.keys["public_key"] = ciphertext.pk_data
|
|
861
|
-
|
|
862
|
-
# Encrypt the broadcasted plaintext using same method as original encryption
|
|
863
|
-
target_dtype = ciphertext.semantic_dtype
|
|
864
|
-
flat_data = plaintext_broadcasted.flatten()
|
|
865
|
-
|
|
866
|
-
# Use range encoding for consistency with encryption
|
|
867
|
-
encoded_data_list = []
|
|
868
|
-
for val in flat_data:
|
|
869
|
-
if ciphertext.modulus is None:
|
|
870
|
-
raise ValueError("Ciphertext modulus is None, encryption may have failed")
|
|
871
|
-
encoded_val = _range_encode_mixed(
|
|
872
|
-
val,
|
|
873
|
-
ciphertext.max_value,
|
|
874
|
-
ciphertext.fxp_bits,
|
|
875
|
-
ciphertext.modulus,
|
|
876
|
-
target_dtype,
|
|
877
|
-
)
|
|
878
|
-
encoded_data_list.append(encoded_val)
|
|
879
|
-
|
|
880
|
-
encrypted_plaintext = [phe.encrypt(val) for val in encoded_data_list]
|
|
881
|
-
|
|
882
|
-
# Perform addition
|
|
883
|
-
result_ciphertext = [
|
|
884
|
-
encrypted_plaintext[i] + broadcasted_ct_data[i]
|
|
885
|
-
for i in range(len(encrypted_plaintext))
|
|
886
|
-
]
|
|
887
|
-
|
|
888
|
-
# Create result CipherText with broadcasted shape and encoding parameters
|
|
889
|
-
return CipherText(
|
|
890
|
-
ct_data=result_ciphertext,
|
|
891
|
-
semantic_dtype=ciphertext.semantic_dtype,
|
|
892
|
-
semantic_shape=result_shape,
|
|
893
|
-
scheme=ciphertext.scheme,
|
|
894
|
-
key_size=ciphertext.key_size,
|
|
895
|
-
pk_data=ciphertext.pk_data,
|
|
896
|
-
max_value=ciphertext.max_value,
|
|
897
|
-
fxp_bits=ciphertext.fxp_bits,
|
|
898
|
-
modulus=ciphertext.modulus,
|
|
899
|
-
)
|
|
900
|
-
|
|
901
|
-
|
|
902
|
-
def _create_encrypted_zero(ciphertext: CipherText) -> Any:
|
|
903
|
-
# Create lightPHE instance with the same configuration
|
|
904
|
-
phe = LightPHE(
|
|
905
|
-
algorithm_name=ciphertext.scheme,
|
|
906
|
-
key_size=ciphertext.key_size,
|
|
907
|
-
precision=PRECISION,
|
|
908
|
-
)
|
|
909
|
-
|
|
910
|
-
# CRITICAL: Set the same modulus as the original ciphertext
|
|
911
|
-
if ciphertext.modulus is not None:
|
|
912
|
-
phe.cs.plaintext_modulo = ciphertext.modulus
|
|
913
|
-
phe.cs.ciphertext_modulo = ciphertext.modulus * ciphertext.modulus
|
|
914
|
-
|
|
915
|
-
phe.cs.keys["public_key"] = ciphertext.pk_data
|
|
916
|
-
|
|
917
|
-
# Encrypt zero value using range encoding for consistency
|
|
918
|
-
if ciphertext.modulus is None:
|
|
919
|
-
raise ValueError("Ciphertext modulus is None, encryption may have failed")
|
|
920
|
-
|
|
921
|
-
zero_encoded = _range_encode_mixed(
|
|
922
|
-
0,
|
|
923
|
-
ciphertext.max_value,
|
|
924
|
-
ciphertext.fxp_bits,
|
|
925
|
-
ciphertext.modulus,
|
|
926
|
-
ciphertext.semantic_dtype,
|
|
927
|
-
)
|
|
928
|
-
|
|
929
|
-
return phe.encrypt(zero_encoded)
|
|
930
|
-
|
|
931
|
-
|
|
932
|
-
@kernel_def("phe.decrypt")
|
|
933
|
-
def _phe_decrypt(
|
|
934
|
-
pfunc: PFunction, ciphertext: CipherText, private_key: PrivateKey
|
|
935
|
-
) -> Any:
|
|
936
|
-
# Validate argument types
|
|
937
|
-
if not isinstance(ciphertext, CipherText):
|
|
938
|
-
raise ValueError("First argument must be a CipherText instance")
|
|
939
|
-
if not isinstance(private_key, PrivateKey):
|
|
940
|
-
raise ValueError("Second argument must be a PrivateKey instance")
|
|
941
|
-
|
|
942
|
-
# Validate key compatibility
|
|
943
|
-
if (
|
|
944
|
-
ciphertext.scheme != private_key.scheme
|
|
945
|
-
or ciphertext.key_size != private_key.key_size
|
|
946
|
-
):
|
|
947
|
-
raise ValueError("CipherText and PrivateKey must use same scheme and key size")
|
|
948
|
-
|
|
949
|
-
try:
|
|
950
|
-
# Create lightPHE instance with the same scheme/key_size
|
|
951
|
-
phe = LightPHE(
|
|
952
|
-
algorithm_name=private_key.scheme,
|
|
953
|
-
key_size=private_key.key_size,
|
|
954
|
-
precision=PRECISION,
|
|
955
|
-
)
|
|
956
|
-
|
|
957
|
-
# CRITICAL FIX: Manually set the moduli to match the original encryption
|
|
958
|
-
# This ensures the decryption uses the same mathematical structure
|
|
959
|
-
if ciphertext.modulus is not None:
|
|
960
|
-
# Force the lightPHE instance to use the same modulus as during encryption
|
|
961
|
-
phe.cs.plaintext_modulo = ciphertext.modulus
|
|
962
|
-
# For Paillier: ciphertext_modulo = N^2
|
|
963
|
-
phe.cs.ciphertext_modulo = ciphertext.modulus * ciphertext.modulus
|
|
964
|
-
|
|
965
|
-
# Set both public and private keys (lightPHE needs both for proper decryption)
|
|
966
|
-
phe.cs.keys["private_key"] = private_key.sk_data
|
|
967
|
-
phe.cs.keys["public_key"] = private_key.pk_data
|
|
968
|
-
|
|
969
|
-
# Decrypt the data
|
|
970
|
-
target_dtype = ciphertext.semantic_dtype.to_numpy()
|
|
971
|
-
decrypted_raw = [phe.decrypt(ct) for ct in ciphertext.ct_data]
|
|
972
|
-
|
|
973
|
-
# Decode using range decoding
|
|
974
|
-
if ciphertext.modulus is None:
|
|
975
|
-
raise ValueError("Ciphertext modulus is None, encryption may have failed")
|
|
976
|
-
|
|
977
|
-
decoded_data = []
|
|
978
|
-
for encrypted_val in decrypted_raw:
|
|
979
|
-
# Extract numeric value from lightPHE result
|
|
980
|
-
if isinstance(encrypted_val, (int, float)):
|
|
981
|
-
raw_val = encrypted_val
|
|
982
|
-
elif hasattr(encrypted_val, "__getitem__") and len(encrypted_val) > 0:
|
|
983
|
-
raw_val = encrypted_val[0]
|
|
984
|
-
else:
|
|
985
|
-
raise ValueError(f"Cannot extract numeric value from {encrypted_val}")
|
|
986
|
-
|
|
987
|
-
# Convert to int for decoding
|
|
988
|
-
int_val = int(
|
|
989
|
-
raw_val
|
|
990
|
-
) # Use mixed decoding which returns values based on semantic type
|
|
991
|
-
decoded_val = _range_decode_mixed(
|
|
992
|
-
int_val,
|
|
993
|
-
ciphertext.max_value,
|
|
994
|
-
ciphertext.fxp_bits,
|
|
995
|
-
ciphertext.modulus,
|
|
996
|
-
ciphertext.semantic_dtype,
|
|
997
|
-
)
|
|
998
|
-
decoded_data.append(decoded_val)
|
|
999
|
-
|
|
1000
|
-
# Convert to target dtype
|
|
1001
|
-
if target_dtype.kind in "iu": # integer types
|
|
1002
|
-
# Convert floats back to integers for integer semantic types
|
|
1003
|
-
# decoded_data are numeric (ints or floats); normalize to Python int
|
|
1004
|
-
ints = [round(v) if isinstance(v, float) else v for v in decoded_data]
|
|
1005
|
-
if np.issubdtype(target_dtype, np.unsignedinteger):
|
|
1006
|
-
# Reduce modulo 2^k for unsigned to preserve ring semantics
|
|
1007
|
-
width = np.iinfo(target_dtype).bits
|
|
1008
|
-
mod = 1 << width
|
|
1009
|
-
processed_data = [v % mod for v in ints]
|
|
1010
|
-
else:
|
|
1011
|
-
# Signed integers: clamp to dtype range
|
|
1012
|
-
info = np.iinfo(target_dtype)
|
|
1013
|
-
processed_data = [max(info.min, min(info.max, v)) for v in ints]
|
|
1014
|
-
else: # float types
|
|
1015
|
-
processed_data = decoded_data
|
|
1016
|
-
|
|
1017
|
-
# Create array and reshape to target shape
|
|
1018
|
-
plaintext_np = np.array(processed_data, dtype=target_dtype).reshape(
|
|
1019
|
-
ciphertext.semantic_shape
|
|
1020
|
-
)
|
|
1021
|
-
|
|
1022
|
-
return [TensorValue(plaintext_np)]
|
|
1023
|
-
|
|
1024
|
-
except Exception as e:
|
|
1025
|
-
raise RuntimeError(f"Failed to decrypt data: {e}") from e
|
|
1026
|
-
|
|
1027
|
-
|
|
1028
|
-
@kernel_def("phe.dot")
|
|
1029
|
-
def _phe_dot(pfunc: PFunction, ciphertext: CipherText, plaintext: TensorValue) -> Any:
|
|
1030
|
-
"""Execute homomorphic dot product with zero-value optimization.
|
|
1031
|
-
|
|
1032
|
-
Supports various dot product operations:
|
|
1033
|
-
- Scalar * Scalar -> Scalar
|
|
1034
|
-
- Vector * Vector -> Scalar (inner product)
|
|
1035
|
-
- Matrix * Vector -> Vector
|
|
1036
|
-
- N-D tensor * M-D tensor -> result based on numpy.dot semantics
|
|
1037
|
-
|
|
1038
|
-
Optimization: Skip multiplication when plaintext value is 0, and handle
|
|
1039
|
-
the special case where all plaintext values are 0.
|
|
1040
|
-
|
|
1041
|
-
"""
|
|
1042
|
-
# Validate that first argument is a CipherText
|
|
1043
|
-
if not isinstance(ciphertext, CipherText):
|
|
1044
|
-
raise ValueError("First argument must be a CipherText instance")
|
|
1045
|
-
if isinstance(plaintext, CipherText):
|
|
1046
|
-
raise ValueError("Second argument must be a plaintext TensorLike")
|
|
1047
|
-
|
|
1048
|
-
try:
|
|
1049
|
-
# Convert plaintext to numpy
|
|
1050
|
-
plaintext_np = plaintext.to_numpy()
|
|
1051
|
-
|
|
1052
|
-
# Check if plaintext is floating point type - dot product not supported
|
|
1053
|
-
if np.issubdtype(plaintext_np.dtype, np.floating):
|
|
1054
|
-
raise ValueError(
|
|
1055
|
-
f"Homomorphic dot product with floating point plaintext is not supported. "
|
|
1056
|
-
f"Got plaintext dtype: {plaintext_np.dtype}"
|
|
1057
|
-
)
|
|
1058
|
-
|
|
1059
|
-
# Use numpy.dot to determine result shape and validate compatibility
|
|
1060
|
-
# Create dummy arrays with same shapes to test dot product compatibility
|
|
1061
|
-
try:
|
|
1062
|
-
dummy_ct = np.zeros(ciphertext.semantic_shape)
|
|
1063
|
-
dummy_pt = np.zeros(plaintext_np.shape)
|
|
1064
|
-
dummy_result = np.dot(dummy_ct, dummy_pt)
|
|
1065
|
-
result_shape = dummy_result.shape
|
|
1066
|
-
except ValueError as e:
|
|
1067
|
-
raise ValueError(
|
|
1068
|
-
f"Shapes are not compatible for dot product: CipherText shape {ciphertext.semantic_shape} "
|
|
1069
|
-
f"vs plaintext shape {plaintext_np.shape}: {e}"
|
|
1070
|
-
) from e
|
|
1071
|
-
|
|
1072
|
-
# Perform dot product based on input dimensions
|
|
1073
|
-
ct_shape = ciphertext.semantic_shape
|
|
1074
|
-
pt_shape = plaintext_np.shape
|
|
1075
|
-
target_dtype = ciphertext.semantic_dtype
|
|
1076
|
-
|
|
1077
|
-
if target_dtype.is_floating:
|
|
1078
|
-
pt_data = plaintext_np.astype(float)
|
|
1079
|
-
# Use a small epsilon for floating point zero comparison
|
|
1080
|
-
epsilon = 1e-15
|
|
1081
|
-
is_zero_func = lambda x: abs(x) < epsilon
|
|
1082
|
-
else: # integer types
|
|
1083
|
-
pt_data = plaintext_np.astype(int)
|
|
1084
|
-
is_zero_func = lambda x: x == 0
|
|
1085
|
-
|
|
1086
|
-
# Helper function to create encrypted zero when needed
|
|
1087
|
-
def get_encrypted_zero() -> Any:
|
|
1088
|
-
return _create_encrypted_zero(ciphertext)
|
|
1089
|
-
|
|
1090
|
-
if len(ct_shape) == 0 and len(pt_shape) == 0:
|
|
1091
|
-
# Scalar * Scalar
|
|
1092
|
-
pt_val = pt_data.item()
|
|
1093
|
-
if is_zero_func(pt_val):
|
|
1094
|
-
result_ciphertext = get_encrypted_zero()
|
|
1095
|
-
else:
|
|
1096
|
-
# Use single value (not list) for multiplication
|
|
1097
|
-
val = float(pt_val) if target_dtype.is_floating else int(pt_val)
|
|
1098
|
-
result_ciphertext = ciphertext.ct_data[0] * val
|
|
1099
|
-
result_ct_data = [result_ciphertext]
|
|
1100
|
-
|
|
1101
|
-
elif len(ct_shape) == 1 and len(pt_shape) == 1:
|
|
1102
|
-
# Vector * Vector -> Scalar (inner product)
|
|
1103
|
-
if ct_shape[0] != pt_shape[0]:
|
|
1104
|
-
raise ValueError(
|
|
1105
|
-
f"Vector size mismatch: CipherText size {ct_shape[0]} "
|
|
1106
|
-
f"vs plaintext size {pt_shape[0]}"
|
|
1107
|
-
)
|
|
1108
|
-
|
|
1109
|
-
# Compute element-wise products, skipping zeros
|
|
1110
|
-
non_zero_products = []
|
|
1111
|
-
for i in range(ct_shape[0]):
|
|
1112
|
-
pt_val = pt_data[i]
|
|
1113
|
-
if not is_zero_func(pt_val):
|
|
1114
|
-
# Convert to appropriate type and use single value (not list)
|
|
1115
|
-
val = float(pt_val) if target_dtype.is_floating else int(pt_val)
|
|
1116
|
-
product = ciphertext.ct_data[i] * val
|
|
1117
|
-
non_zero_products.append(product)
|
|
1118
|
-
|
|
1119
|
-
# Handle result
|
|
1120
|
-
if not non_zero_products:
|
|
1121
|
-
# All plaintext values are zero
|
|
1122
|
-
result_ciphertext = get_encrypted_zero()
|
|
1123
|
-
else:
|
|
1124
|
-
# Sum all non-zero products
|
|
1125
|
-
result_ciphertext = non_zero_products[0]
|
|
1126
|
-
for i in range(1, len(non_zero_products)):
|
|
1127
|
-
result_ciphertext = result_ciphertext + non_zero_products[i]
|
|
1128
|
-
|
|
1129
|
-
result_ct_data = [result_ciphertext]
|
|
1130
|
-
|
|
1131
|
-
elif len(ct_shape) == 2 and len(pt_shape) == 1:
|
|
1132
|
-
# Matrix * Vector -> Vector
|
|
1133
|
-
if ct_shape[1] != pt_shape[0]:
|
|
1134
|
-
raise ValueError(
|
|
1135
|
-
f"Matrix-vector dimension mismatch: Matrix shape {ct_shape} "
|
|
1136
|
-
f"vs vector shape {pt_shape}"
|
|
1137
|
-
)
|
|
1138
|
-
|
|
1139
|
-
result_ct_data = []
|
|
1140
|
-
for i in range(ct_shape[0]): # For each row of the matrix
|
|
1141
|
-
# Compute dot product of row i with the vector, skipping zeros
|
|
1142
|
-
row_products = []
|
|
1143
|
-
for j in range(ct_shape[1]): # For each column in the row
|
|
1144
|
-
pt_val = pt_data[j]
|
|
1145
|
-
if not is_zero_func(pt_val):
|
|
1146
|
-
ct_idx = i * ct_shape[1] + j
|
|
1147
|
-
# Use single value (not list) for multiplication
|
|
1148
|
-
val = float(pt_val) if target_dtype.is_floating else int(pt_val)
|
|
1149
|
-
product = ciphertext.ct_data[ct_idx] * val
|
|
1150
|
-
row_products.append(product)
|
|
1151
|
-
|
|
1152
|
-
# Handle row result
|
|
1153
|
-
if not row_products:
|
|
1154
|
-
# All plaintext values in this row are zero
|
|
1155
|
-
row_result = get_encrypted_zero()
|
|
1156
|
-
else:
|
|
1157
|
-
# Sum non-zero products for this row
|
|
1158
|
-
row_result = row_products[0]
|
|
1159
|
-
for k in range(1, len(row_products)):
|
|
1160
|
-
row_result = row_result + row_products[k]
|
|
1161
|
-
|
|
1162
|
-
result_ct_data.append(row_result)
|
|
1163
|
-
|
|
1164
|
-
elif len(ct_shape) == 1 and len(pt_shape) == 2:
|
|
1165
|
-
# Vector * Matrix -> Vector
|
|
1166
|
-
if ct_shape[0] != pt_shape[0]:
|
|
1167
|
-
raise ValueError(
|
|
1168
|
-
f"Vector-matrix dimension mismatch: Vector shape {ct_shape} "
|
|
1169
|
-
f"vs matrix shape {pt_shape}"
|
|
1170
|
-
)
|
|
1171
|
-
|
|
1172
|
-
result_ct_data = []
|
|
1173
|
-
for j in range(pt_shape[1]): # For each column of the matrix
|
|
1174
|
-
# Compute dot product of vector with column j, skipping zeros
|
|
1175
|
-
col_products = []
|
|
1176
|
-
for i in range(pt_shape[0]): # For each row in the column
|
|
1177
|
-
pt_val = pt_data[i, j]
|
|
1178
|
-
if not is_zero_func(pt_val):
|
|
1179
|
-
# Use single value (not list) for multiplication
|
|
1180
|
-
val = float(pt_val) if target_dtype.is_floating else int(pt_val)
|
|
1181
|
-
product = ciphertext.ct_data[i] * val
|
|
1182
|
-
col_products.append(product)
|
|
1183
|
-
|
|
1184
|
-
# Handle column result
|
|
1185
|
-
if not col_products:
|
|
1186
|
-
# All plaintext values in this column are zero
|
|
1187
|
-
col_result = get_encrypted_zero()
|
|
1188
|
-
else:
|
|
1189
|
-
# Sum non-zero products for this column
|
|
1190
|
-
col_result = col_products[0]
|
|
1191
|
-
for k in range(1, len(col_products)):
|
|
1192
|
-
col_result = col_result + col_products[k]
|
|
1193
|
-
|
|
1194
|
-
result_ct_data.append(col_result)
|
|
1195
|
-
|
|
1196
|
-
elif len(ct_shape) == 2 and len(pt_shape) == 2:
|
|
1197
|
-
# Matrix * Matrix -> Matrix
|
|
1198
|
-
if ct_shape[1] != pt_shape[0]:
|
|
1199
|
-
raise ValueError(
|
|
1200
|
-
f"Matrix dimension mismatch: First matrix shape {ct_shape} "
|
|
1201
|
-
f"vs second matrix shape {pt_shape}"
|
|
1202
|
-
)
|
|
1203
|
-
|
|
1204
|
-
result_ct_data = []
|
|
1205
|
-
for i in range(ct_shape[0]): # For each row of first matrix
|
|
1206
|
-
for j in range(pt_shape[1]): # For each column of second matrix
|
|
1207
|
-
# Compute dot product of row i with column j, skipping zeros
|
|
1208
|
-
products = []
|
|
1209
|
-
for k in range(ct_shape[1]): # Sum over common dimension
|
|
1210
|
-
pt_val = pt_data[k, j]
|
|
1211
|
-
if not is_zero_func(pt_val):
|
|
1212
|
-
ct_idx = i * ct_shape[1] + k
|
|
1213
|
-
# Use single value (not list) for multiplication
|
|
1214
|
-
val = (
|
|
1215
|
-
float(pt_val)
|
|
1216
|
-
if target_dtype.is_floating
|
|
1217
|
-
else int(pt_val)
|
|
1218
|
-
)
|
|
1219
|
-
product = ciphertext.ct_data[ct_idx] * val
|
|
1220
|
-
products.append(product)
|
|
1221
|
-
|
|
1222
|
-
# Handle element result
|
|
1223
|
-
if not products:
|
|
1224
|
-
# All plaintext values for this element are zero
|
|
1225
|
-
element_result = get_encrypted_zero()
|
|
1226
|
-
else:
|
|
1227
|
-
# Sum non-zero products for this element
|
|
1228
|
-
element_result = products[0]
|
|
1229
|
-
for p in range(1, len(products)):
|
|
1230
|
-
element_result = element_result + products[p]
|
|
1231
|
-
|
|
1232
|
-
result_ct_data.append(element_result)
|
|
1233
|
-
|
|
1234
|
-
else:
|
|
1235
|
-
# General N-D tensor dot product
|
|
1236
|
-
# Flatten both tensors and perform generalized dot product
|
|
1237
|
-
ct_flat = ciphertext.ct_data
|
|
1238
|
-
pt_flat = pt_data.flatten()
|
|
1239
|
-
|
|
1240
|
-
# For general case, we implement numpy.dot semantics
|
|
1241
|
-
# This is a simplified implementation for common cases
|
|
1242
|
-
if len(ct_shape) >= 2 and len(pt_shape) >= 1:
|
|
1243
|
-
# Treat as matrix multiplication on the last axis of ct and first axis of pt
|
|
1244
|
-
last_dim_ct = ct_shape[-1]
|
|
1245
|
-
first_dim_pt = pt_shape[0]
|
|
1246
|
-
|
|
1247
|
-
if last_dim_ct != first_dim_pt:
|
|
1248
|
-
raise ValueError(
|
|
1249
|
-
f"Tensor dimension mismatch: CipherText last dimension {last_dim_ct} "
|
|
1250
|
-
f"vs plaintext first dimension {first_dim_pt}"
|
|
1251
|
-
)
|
|
1252
|
-
|
|
1253
|
-
# Reshape for matrix multiplication
|
|
1254
|
-
ct_reshaped_size = int(np.prod(ct_shape[:-1]))
|
|
1255
|
-
pt_reshaped_size = int(np.prod(pt_shape[1:]))
|
|
1256
|
-
|
|
1257
|
-
result_ct_data = []
|
|
1258
|
-
for i in range(ct_reshaped_size):
|
|
1259
|
-
for j in range(pt_reshaped_size):
|
|
1260
|
-
# Compute dot product for element (i, j), skipping zeros
|
|
1261
|
-
products = []
|
|
1262
|
-
for k in range(last_dim_ct):
|
|
1263
|
-
pt_idx = k * pt_reshaped_size + j
|
|
1264
|
-
pt_val = pt_flat[pt_idx]
|
|
1265
|
-
if not is_zero_func(pt_val):
|
|
1266
|
-
ct_idx = i * last_dim_ct + k
|
|
1267
|
-
# Use single value (not list) for multiplication
|
|
1268
|
-
val = (
|
|
1269
|
-
float(pt_val)
|
|
1270
|
-
if target_dtype.is_floating
|
|
1271
|
-
else int(pt_val)
|
|
1272
|
-
)
|
|
1273
|
-
product = ct_flat[ct_idx] * val
|
|
1274
|
-
products.append(product)
|
|
1275
|
-
|
|
1276
|
-
# Handle element result
|
|
1277
|
-
if not products:
|
|
1278
|
-
# All plaintext values for this element are zero
|
|
1279
|
-
element_result = get_encrypted_zero()
|
|
1280
|
-
else:
|
|
1281
|
-
# Sum non-zero products
|
|
1282
|
-
element_result = products[0]
|
|
1283
|
-
for p in range(1, len(products)):
|
|
1284
|
-
element_result = element_result + products[p]
|
|
1285
|
-
result_ct_data.append(element_result)
|
|
1286
|
-
else:
|
|
1287
|
-
raise ValueError(
|
|
1288
|
-
f"Unsupported tensor shapes for dot product: "
|
|
1289
|
-
f"CipherText shape {ct_shape}, plaintext shape {pt_shape}"
|
|
1290
|
-
)
|
|
1291
|
-
|
|
1292
|
-
# Create result CipherText with computed shape and encoding parameters
|
|
1293
|
-
return [
|
|
1294
|
-
CipherText(
|
|
1295
|
-
ct_data=result_ct_data,
|
|
1296
|
-
semantic_dtype=ciphertext.semantic_dtype,
|
|
1297
|
-
semantic_shape=result_shape,
|
|
1298
|
-
scheme=ciphertext.scheme,
|
|
1299
|
-
key_size=ciphertext.key_size,
|
|
1300
|
-
pk_data=ciphertext.pk_data,
|
|
1301
|
-
max_value=ciphertext.max_value,
|
|
1302
|
-
fxp_bits=ciphertext.fxp_bits,
|
|
1303
|
-
modulus=ciphertext.modulus,
|
|
1304
|
-
)
|
|
1305
|
-
]
|
|
1306
|
-
|
|
1307
|
-
except ValueError:
|
|
1308
|
-
# Re-raise ValueError directly (validation errors)
|
|
1309
|
-
raise
|
|
1310
|
-
except Exception as e:
|
|
1311
|
-
raise RuntimeError(f"Failed to perform dot product: {e}") from e
|
|
1312
|
-
|
|
1313
|
-
|
|
1314
|
-
@kernel_def("phe.gather")
|
|
1315
|
-
def _phe_gather(pfunc: PFunction, ciphertext: CipherText, indices: TensorValue) -> Any:
|
|
1316
|
-
"""Execute gather operation on CipherText.
|
|
1317
|
-
|
|
1318
|
-
Supports gathering from multidimensional CipherText using multidimensional indices.
|
|
1319
|
-
The operation follows numpy.take semantics:
|
|
1320
|
-
- result.shape = indices.shape + ciphertext.shape[:axis] + ciphertext.shape[axis+1:]
|
|
1321
|
-
- Gathering is performed along the specified axis of ciphertext
|
|
1322
|
-
"""
|
|
1323
|
-
# Validate that first argument is a CipherText
|
|
1324
|
-
if not isinstance(ciphertext, CipherText):
|
|
1325
|
-
raise ValueError("First argument must be a CipherText instance")
|
|
1326
|
-
|
|
1327
|
-
# Get axis parameter from pfunc.attrs, default to 0
|
|
1328
|
-
axis = pfunc.attrs.get("axis", 0)
|
|
1329
|
-
|
|
1330
|
-
try:
|
|
1331
|
-
# Convert indices to numpy
|
|
1332
|
-
indices_np = indices.to_numpy()
|
|
1333
|
-
|
|
1334
|
-
if not np.issubdtype(indices_np.dtype, np.integer):
|
|
1335
|
-
raise ValueError("Indices must be of integer type")
|
|
1336
|
-
|
|
1337
|
-
# Validate that ciphertext has at least 1 dimension for indexing
|
|
1338
|
-
if len(ciphertext.semantic_shape) == 0:
|
|
1339
|
-
raise ValueError("Cannot gather from scalar CipherText")
|
|
1340
|
-
|
|
1341
|
-
# Normalize axis to positive value
|
|
1342
|
-
ndim = len(ciphertext.semantic_shape)
|
|
1343
|
-
if axis < 0:
|
|
1344
|
-
axis = ndim + axis
|
|
1345
|
-
if axis < 0 or axis >= ndim:
|
|
1346
|
-
raise ValueError(
|
|
1347
|
-
f"Axis {pfunc.attrs.get('axis', 0)} is out of bounds for array of dimension {ndim}"
|
|
1348
|
-
)
|
|
1349
|
-
|
|
1350
|
-
# Validate indices are within bounds for the specified axis
|
|
1351
|
-
axis_size = ciphertext.semantic_shape[axis]
|
|
1352
|
-
if np.any(indices_np < 0) or np.any(indices_np >= axis_size):
|
|
1353
|
-
raise ValueError(
|
|
1354
|
-
f"Indices are out of bounds for axis {axis} with size {axis_size}. "
|
|
1355
|
-
f"Got indices in range [{np.min(indices_np)}, {np.max(indices_np)}]"
|
|
1356
|
-
)
|
|
1357
|
-
|
|
1358
|
-
# Calculate result shape: indices.shape + ciphertext.shape[:axis] + ciphertext.shape[axis+1:]
|
|
1359
|
-
result_shape = (
|
|
1360
|
-
indices_np.shape
|
|
1361
|
-
+ ciphertext.semantic_shape[:axis]
|
|
1362
|
-
+ ciphertext.semantic_shape[axis + 1 :]
|
|
1363
|
-
)
|
|
1364
|
-
|
|
1365
|
-
# Calculate strides for multi-axis gathering
|
|
1366
|
-
ct_shape = ciphertext.semantic_shape
|
|
1367
|
-
|
|
1368
|
-
# Stride calculations for arbitrary axis
|
|
1369
|
-
# Elements before axis contribute to outer stride
|
|
1370
|
-
outer_stride = int(np.prod(ct_shape[:axis])) if axis > 0 else 1
|
|
1371
|
-
# Elements after axis contribute to inner stride
|
|
1372
|
-
inner_stride = int(np.prod(ct_shape[axis + 1 :])) if axis < ndim - 1 else 1
|
|
1373
|
-
# Total stride for one step along the specified axis
|
|
1374
|
-
axis_stride = inner_stride
|
|
1375
|
-
|
|
1376
|
-
# Perform gather operation
|
|
1377
|
-
gathered_ct_data = []
|
|
1378
|
-
|
|
1379
|
-
# Iterate through all possible combinations of indices before the gather axis
|
|
1380
|
-
if axis == 0:
|
|
1381
|
-
# Special case: gathering along axis 0 (existing behavior)
|
|
1382
|
-
for idx in indices_np.flatten():
|
|
1383
|
-
start_pos = int(idx) * axis_stride
|
|
1384
|
-
end_pos = start_pos + axis_stride
|
|
1385
|
-
slice_data = ciphertext.ct_data[start_pos:end_pos]
|
|
1386
|
-
gathered_ct_data.extend(slice_data)
|
|
1387
|
-
else:
|
|
1388
|
-
# General case: gathering along arbitrary axis
|
|
1389
|
-
for outer_idx in range(outer_stride):
|
|
1390
|
-
for gather_idx in indices_np.flatten():
|
|
1391
|
-
# Calculate position in flattened ciphertext data
|
|
1392
|
-
pos = (
|
|
1393
|
-
outer_idx * (ct_shape[axis] * inner_stride)
|
|
1394
|
-
+ int(gather_idx) * inner_stride
|
|
1395
|
-
)
|
|
1396
|
-
slice_data = ciphertext.ct_data[pos : pos + inner_stride]
|
|
1397
|
-
gathered_ct_data.extend(slice_data)
|
|
1398
|
-
|
|
1399
|
-
# Validate we got the expected number of elements
|
|
1400
|
-
expected_size = int(np.prod(result_shape)) if result_shape else 1
|
|
1401
|
-
if len(gathered_ct_data) != expected_size:
|
|
1402
|
-
raise RuntimeError(
|
|
1403
|
-
f"Internal error: Expected {expected_size} elements, got {len(gathered_ct_data)}"
|
|
1404
|
-
)
|
|
1405
|
-
|
|
1406
|
-
# Create result CipherText
|
|
1407
|
-
return [
|
|
1408
|
-
CipherText(
|
|
1409
|
-
ct_data=gathered_ct_data,
|
|
1410
|
-
semantic_dtype=ciphertext.semantic_dtype,
|
|
1411
|
-
semantic_shape=result_shape,
|
|
1412
|
-
scheme=ciphertext.scheme,
|
|
1413
|
-
key_size=ciphertext.key_size,
|
|
1414
|
-
pk_data=ciphertext.pk_data,
|
|
1415
|
-
max_value=ciphertext.max_value,
|
|
1416
|
-
fxp_bits=ciphertext.fxp_bits,
|
|
1417
|
-
modulus=ciphertext.modulus,
|
|
1418
|
-
)
|
|
1419
|
-
]
|
|
1420
|
-
|
|
1421
|
-
except ValueError:
|
|
1422
|
-
# Re-raise ValueError directly (validation errors)
|
|
1423
|
-
raise
|
|
1424
|
-
except Exception as e:
|
|
1425
|
-
raise RuntimeError(f"Failed to perform gather: {e}") from e
|
|
1426
|
-
|
|
1427
|
-
|
|
1428
|
-
@kernel_def("phe.scatter")
|
|
1429
|
-
def _phe_scatter(
|
|
1430
|
-
pfunc: PFunction,
|
|
1431
|
-
ciphertext: CipherText,
|
|
1432
|
-
indices: TensorValue,
|
|
1433
|
-
updated: CipherText,
|
|
1434
|
-
) -> Any:
|
|
1435
|
-
"""Execute scatter operation on CipherText.
|
|
1436
|
-
|
|
1437
|
-
Supports scattering into multidimensional CipherText using multidimensional indices.
|
|
1438
|
-
The operation follows numpy scatter semantics:
|
|
1439
|
-
- Scattering is performed along the specified axis of ciphertext
|
|
1440
|
-
- indices.shape must equal updated.shape[:len(indices.shape)]
|
|
1441
|
-
- updated.shape must be indices.shape + ciphertext.shape[:axis] + ciphertext.shape[axis+1:]
|
|
1442
|
-
- Result shape is same as original ciphertext.shape
|
|
1443
|
-
|
|
1444
|
-
"""
|
|
1445
|
-
# Validate that first and third arguments are CipherTexts
|
|
1446
|
-
if not isinstance(ciphertext, CipherText) or not isinstance(updated, CipherText):
|
|
1447
|
-
raise ValueError("First and third arguments must be CipherText instances")
|
|
1448
|
-
|
|
1449
|
-
# Validate that both ciphertexts use same scheme/key_size
|
|
1450
|
-
if ciphertext.scheme != updated.scheme or ciphertext.key_size != updated.key_size:
|
|
1451
|
-
raise ValueError("Both CipherTexts must use same scheme and key size")
|
|
1452
|
-
|
|
1453
|
-
if ciphertext.pk_data != updated.pk_data:
|
|
1454
|
-
raise ValueError("Both CipherTexts must be encrypted with same key")
|
|
1455
|
-
|
|
1456
|
-
# Get axis parameter from pfunc.attrs, default to 0
|
|
1457
|
-
axis = pfunc.attrs.get("axis", 0)
|
|
1458
|
-
|
|
1459
|
-
try:
|
|
1460
|
-
# Convert indices to numpy
|
|
1461
|
-
indices_np = indices.to_numpy()
|
|
1462
|
-
|
|
1463
|
-
if not np.issubdtype(indices_np.dtype, np.integer):
|
|
1464
|
-
raise ValueError("Indices must be of integer type")
|
|
1465
|
-
|
|
1466
|
-
# Validate that ciphertext has at least 1 dimension for indexing
|
|
1467
|
-
if len(ciphertext.semantic_shape) == 0:
|
|
1468
|
-
raise ValueError("Cannot scatter into scalar CipherText")
|
|
1469
|
-
|
|
1470
|
-
# Normalize axis to positive value
|
|
1471
|
-
ndim = len(ciphertext.semantic_shape)
|
|
1472
|
-
if axis < 0:
|
|
1473
|
-
axis = ndim + axis
|
|
1474
|
-
if axis < 0 or axis >= ndim:
|
|
1475
|
-
raise ValueError(
|
|
1476
|
-
f"Axis {pfunc.attrs.get('axis', 0)} is out of bounds for array of dimension {ndim}"
|
|
1477
|
-
)
|
|
1478
|
-
|
|
1479
|
-
# Validate indices are within bounds for the specified axis
|
|
1480
|
-
axis_size = ciphertext.semantic_shape[axis]
|
|
1481
|
-
if np.any(indices_np < 0) or np.any(indices_np >= axis_size):
|
|
1482
|
-
raise ValueError(
|
|
1483
|
-
f"Indices are out of bounds for axis {axis} with size {axis_size}. "
|
|
1484
|
-
f"Got indices in range [{np.min(indices_np)}, {np.max(indices_np)}]"
|
|
1485
|
-
)
|
|
1486
|
-
|
|
1487
|
-
# Validate shape compatibility
|
|
1488
|
-
# Expected updated shape: indices.shape + ciphertext.shape[:axis] + ciphertext.shape[axis+1:]
|
|
1489
|
-
expected_updated_shape = (
|
|
1490
|
-
indices_np.shape
|
|
1491
|
-
+ ciphertext.semantic_shape[:axis]
|
|
1492
|
-
+ ciphertext.semantic_shape[axis + 1 :]
|
|
1493
|
-
)
|
|
1494
|
-
if updated.semantic_shape != expected_updated_shape:
|
|
1495
|
-
raise ValueError(
|
|
1496
|
-
f"Updated CipherText shape mismatch. Expected {expected_updated_shape}, "
|
|
1497
|
-
f"got {updated.semantic_shape}. "
|
|
1498
|
-
f"Updated shape must be indices.shape + ciphertext.shape[:axis] + ciphertext.shape[axis+1:]"
|
|
1499
|
-
)
|
|
1500
|
-
|
|
1501
|
-
# Calculate strides for multi-axis scattering
|
|
1502
|
-
ct_shape = ciphertext.semantic_shape
|
|
1503
|
-
|
|
1504
|
-
# Stride calculations for arbitrary axis
|
|
1505
|
-
# Elements before axis contribute to outer stride
|
|
1506
|
-
outer_stride = int(np.prod(ct_shape[:axis])) if axis > 0 else 1
|
|
1507
|
-
# Elements after axis contribute to inner stride
|
|
1508
|
-
inner_stride = int(np.prod(ct_shape[axis + 1 :])) if axis < ndim - 1 else 1
|
|
1509
|
-
|
|
1510
|
-
# Create a copy of the original ciphertext data for scattering
|
|
1511
|
-
scattered_ct_data = ciphertext.ct_data.copy()
|
|
1512
|
-
|
|
1513
|
-
# Perform scatter operation
|
|
1514
|
-
indices_flat = indices_np.flatten()
|
|
1515
|
-
updated_ct_data = updated.ct_data
|
|
1516
|
-
|
|
1517
|
-
if axis == 0:
|
|
1518
|
-
# Special case: scattering along axis 0 (existing behavior)
|
|
1519
|
-
axis_stride = inner_stride
|
|
1520
|
-
for i, idx in enumerate(indices_flat):
|
|
1521
|
-
start_pos_updated = i * axis_stride
|
|
1522
|
-
start_pos_original = int(idx) * axis_stride
|
|
1523
|
-
|
|
1524
|
-
for j in range(axis_stride):
|
|
1525
|
-
if start_pos_updated + j < len(updated_ct_data):
|
|
1526
|
-
scattered_ct_data[start_pos_original + j] = updated_ct_data[
|
|
1527
|
-
start_pos_updated + j
|
|
1528
|
-
]
|
|
1529
|
-
else:
|
|
1530
|
-
# General case: scattering along arbitrary axis
|
|
1531
|
-
for outer_idx in range(outer_stride):
|
|
1532
|
-
for i, scatter_idx in enumerate(indices_flat):
|
|
1533
|
-
# Calculate position in flattened ciphertext data
|
|
1534
|
-
start_pos_original = (
|
|
1535
|
-
outer_idx * (ct_shape[axis] * inner_stride)
|
|
1536
|
-
+ int(scatter_idx) * inner_stride
|
|
1537
|
-
)
|
|
1538
|
-
start_pos_updated = (
|
|
1539
|
-
outer_idx * len(indices_flat) + i
|
|
1540
|
-
) * inner_stride
|
|
1541
|
-
|
|
1542
|
-
# Update the ciphertext data
|
|
1543
|
-
for j in range(inner_stride):
|
|
1544
|
-
if start_pos_updated + j < len(updated_ct_data):
|
|
1545
|
-
scattered_ct_data[start_pos_original + j] = updated_ct_data[
|
|
1546
|
-
start_pos_updated + j
|
|
1547
|
-
]
|
|
1548
|
-
|
|
1549
|
-
# Create result CipherText with same shape as original
|
|
1550
|
-
return [
|
|
1551
|
-
CipherText(
|
|
1552
|
-
ct_data=scattered_ct_data,
|
|
1553
|
-
semantic_dtype=ciphertext.semantic_dtype,
|
|
1554
|
-
semantic_shape=ciphertext.semantic_shape,
|
|
1555
|
-
scheme=ciphertext.scheme,
|
|
1556
|
-
key_size=ciphertext.key_size,
|
|
1557
|
-
pk_data=ciphertext.pk_data,
|
|
1558
|
-
max_value=ciphertext.max_value,
|
|
1559
|
-
fxp_bits=ciphertext.fxp_bits,
|
|
1560
|
-
modulus=ciphertext.modulus,
|
|
1561
|
-
)
|
|
1562
|
-
]
|
|
1563
|
-
except ValueError:
|
|
1564
|
-
# Re-raise ValueError directly (validation errors)
|
|
1565
|
-
raise
|
|
1566
|
-
except Exception as e:
|
|
1567
|
-
raise RuntimeError(f"Failed to perform scatter: {e}") from e
|
|
1568
|
-
|
|
1569
|
-
|
|
1570
|
-
@kernel_def("phe.concat")
|
|
1571
|
-
def _phe_concat(pfunc: PFunction, c1: CipherText, c2: CipherText) -> Any:
|
|
1572
|
-
"""Execute concat operation on multiple CipherTexts.
|
|
1573
|
-
|
|
1574
|
-
Supports concatenation along any axis of multidimensional CipherTexts.
|
|
1575
|
-
The axis parameter is obtained from pfunc.attrs.
|
|
1576
|
-
"""
|
|
1577
|
-
# Get axis parameter from pfunc.attrs, default to 0
|
|
1578
|
-
axis = pfunc.attrs.get("axis", 0)
|
|
1579
|
-
|
|
1580
|
-
# Validate that all arguments are CipherText
|
|
1581
|
-
if not isinstance(c1, CipherText) or not isinstance(c2, CipherText):
|
|
1582
|
-
raise ValueError("All arguments must be CipherText instances")
|
|
1583
|
-
|
|
1584
|
-
# Validate that all ciphertexts have the same key & scheme
|
|
1585
|
-
if c1.scheme != c2.scheme or c1.key_size != c2.key_size:
|
|
1586
|
-
raise ValueError("All CipherTexts must use same scheme and key size")
|
|
1587
|
-
if c1.pk_data != c2.pk_data:
|
|
1588
|
-
raise ValueError("All CipherTexts must be encrypted with same key")
|
|
1589
|
-
if c1.semantic_dtype != c2.semantic_dtype:
|
|
1590
|
-
raise ValueError(
|
|
1591
|
-
f"All CipherTexts must have same semantic dtype, got {c1.semantic_dtype} vs {c2.semantic_dtype}"
|
|
1592
|
-
)
|
|
1593
|
-
|
|
1594
|
-
# Validate dimensions and axis
|
|
1595
|
-
if len(c1.semantic_shape) != len(c2.semantic_shape):
|
|
1596
|
-
raise ValueError(
|
|
1597
|
-
f"All CipherTexts must have same number of dimensions for concat, got {len(c1.semantic_shape)} vs {len(c2.semantic_shape)}"
|
|
1598
|
-
)
|
|
1599
|
-
|
|
1600
|
-
# Handle scalar case
|
|
1601
|
-
if len(c1.semantic_shape) == 0:
|
|
1602
|
-
raise ValueError("Cannot concatenate scalar CipherTexts")
|
|
1603
|
-
|
|
1604
|
-
# Normalize axis (handle negative axis)
|
|
1605
|
-
ndim = len(c1.semantic_shape)
|
|
1606
|
-
if axis < 0:
|
|
1607
|
-
axis = ndim + axis
|
|
1608
|
-
if axis < 0 or axis >= ndim:
|
|
1609
|
-
raise ValueError(
|
|
1610
|
-
f"axis {pfunc.attrs.get('axis', 0)} is out of bounds for array of dimension {ndim}"
|
|
1611
|
-
)
|
|
1612
|
-
|
|
1613
|
-
# Validate that all dimensions except the concat axis are the same
|
|
1614
|
-
for i in range(ndim):
|
|
1615
|
-
if i != axis and c1.semantic_shape[i] != c2.semantic_shape[i]:
|
|
1616
|
-
raise ValueError(
|
|
1617
|
-
f"All CipherTexts must have same shape except along concatenation axis {axis}. "
|
|
1618
|
-
f"Shape mismatch at dimension {i}: {c1.semantic_shape[i]} vs {c2.semantic_shape[i]}"
|
|
1619
|
-
)
|
|
1620
|
-
|
|
1621
|
-
try:
|
|
1622
|
-
# Calculate result shape
|
|
1623
|
-
result_shape_list = list(c1.semantic_shape)
|
|
1624
|
-
result_shape_list[axis] = c1.semantic_shape[axis] + c2.semantic_shape[axis]
|
|
1625
|
-
result_shape = tuple(result_shape_list)
|
|
1626
|
-
|
|
1627
|
-
# Calculate the number of slices before the concatenation axis
|
|
1628
|
-
pre_axis_size = int(np.prod(c1.semantic_shape[:axis])) if axis > 0 else 1
|
|
1629
|
-
# Calculate the size of data along and after the concatenation axis
|
|
1630
|
-
c1_post_axis_size = int(np.prod(c1.semantic_shape[axis:]))
|
|
1631
|
-
c2_post_axis_size = int(np.prod(c2.semantic_shape[axis:]))
|
|
1632
|
-
|
|
1633
|
-
# Initialize result data
|
|
1634
|
-
concatenated_ct_data = []
|
|
1635
|
-
|
|
1636
|
-
# Perform concatenation
|
|
1637
|
-
for pre_idx in range(pre_axis_size):
|
|
1638
|
-
# For each slice before the concatenation axis
|
|
1639
|
-
|
|
1640
|
-
# Add data from c1 along the concatenation axis
|
|
1641
|
-
c1_start = pre_idx * c1_post_axis_size
|
|
1642
|
-
c1_end = c1_start + c1_post_axis_size
|
|
1643
|
-
concatenated_ct_data.extend(c1.ct_data[c1_start:c1_end])
|
|
1644
|
-
|
|
1645
|
-
# Add data from c2 along the concatenation axis
|
|
1646
|
-
c2_start = pre_idx * c2_post_axis_size
|
|
1647
|
-
c2_end = c2_start + c2_post_axis_size
|
|
1648
|
-
concatenated_ct_data.extend(c2.ct_data[c2_start:c2_end])
|
|
1649
|
-
|
|
1650
|
-
# Validate we got the expected number of elements
|
|
1651
|
-
expected_size = int(np.prod(result_shape))
|
|
1652
|
-
if len(concatenated_ct_data) != expected_size:
|
|
1653
|
-
raise RuntimeError(
|
|
1654
|
-
f"Internal error: Expected {expected_size} elements, got {len(concatenated_ct_data)}"
|
|
1655
|
-
)
|
|
1656
|
-
|
|
1657
|
-
# Create result CipherText
|
|
1658
|
-
return [
|
|
1659
|
-
CipherText(
|
|
1660
|
-
ct_data=concatenated_ct_data,
|
|
1661
|
-
semantic_dtype=c1.semantic_dtype,
|
|
1662
|
-
semantic_shape=result_shape,
|
|
1663
|
-
scheme=c1.scheme,
|
|
1664
|
-
key_size=c1.key_size,
|
|
1665
|
-
pk_data=c1.pk_data,
|
|
1666
|
-
max_value=c1.max_value,
|
|
1667
|
-
fxp_bits=c1.fxp_bits,
|
|
1668
|
-
modulus=c1.modulus,
|
|
1669
|
-
)
|
|
1670
|
-
]
|
|
1671
|
-
|
|
1672
|
-
except ValueError:
|
|
1673
|
-
# Re-raise ValueError directly (validation errors)
|
|
1674
|
-
raise
|
|
1675
|
-
except Exception as e:
|
|
1676
|
-
raise RuntimeError(f"Failed to perform concat: {e}") from e
|
|
1677
|
-
|
|
1678
|
-
|
|
1679
|
-
@kernel_def("phe.reshape")
|
|
1680
|
-
def _phe_reshape(pfunc: PFunction, ciphertext: CipherText) -> Any:
|
|
1681
|
-
"""Execute reshape operation on CipherText.
|
|
1682
|
-
|
|
1683
|
-
Changes the shape of a CipherText without changing its encrypted data.
|
|
1684
|
-
The new_shape parameter is obtained from pfunc.attrs.
|
|
1685
|
-
"""
|
|
1686
|
-
# Validate that argument is a CipherText
|
|
1687
|
-
if not isinstance(ciphertext, CipherText):
|
|
1688
|
-
raise ValueError("Argument must be a CipherText instance")
|
|
1689
|
-
|
|
1690
|
-
# Get new_shape parameter from pfunc.attrs
|
|
1691
|
-
new_shape = pfunc.attrs.get("new_shape")
|
|
1692
|
-
if new_shape is None:
|
|
1693
|
-
raise ValueError("new_shape parameter is required for reshape operation")
|
|
1694
|
-
|
|
1695
|
-
# Convert new_shape to tuple if it's a list
|
|
1696
|
-
if isinstance(new_shape, list):
|
|
1697
|
-
new_shape = tuple(new_shape)
|
|
1698
|
-
elif not isinstance(new_shape, tuple):
|
|
1699
|
-
raise ValueError("new_shape must be a tuple or list of integers")
|
|
1700
|
-
|
|
1701
|
-
try:
|
|
1702
|
-
# Handle -1 dimension inference
|
|
1703
|
-
old_size = (
|
|
1704
|
-
int(np.prod(ciphertext.semantic_shape)) if ciphertext.semantic_shape else 1
|
|
1705
|
-
)
|
|
1706
|
-
|
|
1707
|
-
# Process new_shape to infer -1 dimensions
|
|
1708
|
-
inferred_shape = list(new_shape)
|
|
1709
|
-
negative_ones = [i for i, dim in enumerate(new_shape) if dim == -1]
|
|
1710
|
-
|
|
1711
|
-
if len(negative_ones) > 1:
|
|
1712
|
-
raise ValueError("can only specify one unknown dimension")
|
|
1713
|
-
elif len(negative_ones) == 1:
|
|
1714
|
-
# Calculate the inferred dimension
|
|
1715
|
-
known_size = 1
|
|
1716
|
-
for dim in new_shape:
|
|
1717
|
-
if dim != -1:
|
|
1718
|
-
if dim <= 0:
|
|
1719
|
-
raise ValueError(
|
|
1720
|
-
f"negative dimensions not allowed (except -1): {dim}"
|
|
1721
|
-
)
|
|
1722
|
-
known_size *= dim
|
|
1723
|
-
|
|
1724
|
-
if old_size % known_size != 0:
|
|
1725
|
-
raise ValueError(
|
|
1726
|
-
f"cannot reshape array of size {old_size} into shape {new_shape}"
|
|
1727
|
-
)
|
|
1728
|
-
|
|
1729
|
-
inferred_dim = old_size // known_size
|
|
1730
|
-
inferred_shape[negative_ones[0]] = inferred_dim
|
|
1731
|
-
else:
|
|
1732
|
-
# No -1 dimensions, validate that all dimensions are positive
|
|
1733
|
-
for dim in new_shape:
|
|
1734
|
-
if dim <= 0:
|
|
1735
|
-
raise ValueError(f"negative dimensions not allowed: {dim}")
|
|
1736
|
-
|
|
1737
|
-
# Convert back to tuple
|
|
1738
|
-
final_shape = tuple(inferred_shape)
|
|
1739
|
-
|
|
1740
|
-
# Validate that new shape has the same number of elements
|
|
1741
|
-
new_size = int(np.prod(final_shape)) if final_shape else 1
|
|
1742
|
-
|
|
1743
|
-
if old_size != new_size:
|
|
1744
|
-
raise ValueError(
|
|
1745
|
-
f"Cannot reshape CipherText with {old_size} elements to shape {final_shape} "
|
|
1746
|
-
f"with {new_size} elements"
|
|
1747
|
-
)
|
|
1748
|
-
|
|
1749
|
-
# Create result CipherText with new shape and encoding parameters (ct_data remains the same)
|
|
1750
|
-
return [
|
|
1751
|
-
CipherText(
|
|
1752
|
-
ct_data=ciphertext.ct_data, # Same encrypted data
|
|
1753
|
-
semantic_dtype=ciphertext.semantic_dtype,
|
|
1754
|
-
semantic_shape=final_shape, # Use the final shape
|
|
1755
|
-
scheme=ciphertext.scheme,
|
|
1756
|
-
key_size=ciphertext.key_size,
|
|
1757
|
-
pk_data=ciphertext.pk_data,
|
|
1758
|
-
max_value=ciphertext.max_value,
|
|
1759
|
-
fxp_bits=ciphertext.fxp_bits,
|
|
1760
|
-
modulus=ciphertext.modulus,
|
|
1761
|
-
)
|
|
1762
|
-
]
|
|
1763
|
-
|
|
1764
|
-
except ValueError:
|
|
1765
|
-
# Re-raise ValueError directly (validation errors)
|
|
1766
|
-
raise
|
|
1767
|
-
except Exception as e:
|
|
1768
|
-
raise RuntimeError(f"Failed to perform reshape: {e}") from e
|
|
1769
|
-
|
|
1770
|
-
|
|
1771
|
-
@kernel_def("phe.transpose")
|
|
1772
|
-
def _phe_transpose(pfunc: PFunction, ciphertext: CipherText) -> Any:
|
|
1773
|
-
"""Execute transpose operation on CipherText.
|
|
1774
|
-
|
|
1775
|
-
Permutes the dimensions of a CipherText according to the given axes.
|
|
1776
|
-
The axes parameter is obtained from pfunc.attrs.
|
|
1777
|
-
"""
|
|
1778
|
-
# Validate that argument is a CipherText
|
|
1779
|
-
if not isinstance(ciphertext, CipherText):
|
|
1780
|
-
raise ValueError("Argument must be a CipherText instance")
|
|
1781
|
-
|
|
1782
|
-
# Handle scalar case
|
|
1783
|
-
if len(ciphertext.semantic_shape) == 0:
|
|
1784
|
-
# Transposing a scalar returns the same scalar
|
|
1785
|
-
return [ciphertext]
|
|
1786
|
-
|
|
1787
|
-
# Get axes parameter from pfunc.attrs
|
|
1788
|
-
axes = pfunc.attrs.get("axes")
|
|
1789
|
-
|
|
1790
|
-
# If axes is None, reverse all dimensions (default transpose behavior)
|
|
1791
|
-
if axes is None:
|
|
1792
|
-
axes = tuple(reversed(range(len(ciphertext.semantic_shape))))
|
|
1793
|
-
elif isinstance(axes, list):
|
|
1794
|
-
axes = tuple(axes)
|
|
1795
|
-
elif not isinstance(axes, tuple):
|
|
1796
|
-
raise ValueError("axes must be a tuple or list of integers, or None")
|
|
1797
|
-
|
|
1798
|
-
try:
|
|
1799
|
-
# Validate axes
|
|
1800
|
-
ndim = len(ciphertext.semantic_shape)
|
|
1801
|
-
if len(axes) != ndim:
|
|
1802
|
-
raise ValueError(
|
|
1803
|
-
f"axes length {len(axes)} does not match tensor dimensions {ndim}"
|
|
1804
|
-
)
|
|
1805
|
-
|
|
1806
|
-
# Normalize negative axes and validate range
|
|
1807
|
-
normalized_axes = []
|
|
1808
|
-
for axis in axes:
|
|
1809
|
-
if axis < 0:
|
|
1810
|
-
axis = ndim + axis
|
|
1811
|
-
if axis < 0 or axis >= ndim:
|
|
1812
|
-
raise ValueError(
|
|
1813
|
-
f"axis {axis} is out of bounds for array of dimension {ndim}"
|
|
1814
|
-
)
|
|
1815
|
-
normalized_axes.append(axis)
|
|
1816
|
-
axes = tuple(normalized_axes)
|
|
1817
|
-
|
|
1818
|
-
# Check for duplicate axes
|
|
1819
|
-
if len(set(axes)) != len(axes):
|
|
1820
|
-
raise ValueError("axes cannot contain duplicate values")
|
|
1821
|
-
|
|
1822
|
-
# Calculate new shape
|
|
1823
|
-
old_shape = ciphertext.semantic_shape
|
|
1824
|
-
new_shape = tuple(old_shape[axis] for axis in axes)
|
|
1825
|
-
|
|
1826
|
-
# For multidimensional transpose, we need to rearrange the encrypted data
|
|
1827
|
-
# Create mapping from old flat index to new flat index
|
|
1828
|
-
def transpose_data(ct_data: list, old_shape: tuple, axes: tuple) -> list:
|
|
1829
|
-
if len(old_shape) <= 1:
|
|
1830
|
-
# 1D or scalar case - no actual transposition needed
|
|
1831
|
-
return ct_data
|
|
1832
|
-
|
|
1833
|
-
# Create numpy array to help with index calculations
|
|
1834
|
-
dummy_array = np.arange(len(ct_data)).reshape(old_shape)
|
|
1835
|
-
transposed_dummy = np.transpose(dummy_array, axes)
|
|
1836
|
-
|
|
1837
|
-
# The new data should be arranged in the order that numpy.transpose would produce
|
|
1838
|
-
new_ct_data = [ct_data[idx] for idx in transposed_dummy.flatten()]
|
|
1839
|
-
|
|
1840
|
-
return new_ct_data
|
|
1841
|
-
|
|
1842
|
-
# Rearrange the encrypted data according to transpose
|
|
1843
|
-
transposed_ct_data = transpose_data(ciphertext.ct_data, old_shape, axes)
|
|
1844
|
-
|
|
1845
|
-
# Create result CipherText with transposed shape and rearranged data
|
|
1846
|
-
return [
|
|
1847
|
-
CipherText(
|
|
1848
|
-
ct_data=transposed_ct_data,
|
|
1849
|
-
semantic_dtype=ciphertext.semantic_dtype,
|
|
1850
|
-
semantic_shape=new_shape,
|
|
1851
|
-
scheme=ciphertext.scheme,
|
|
1852
|
-
key_size=ciphertext.key_size,
|
|
1853
|
-
pk_data=ciphertext.pk_data,
|
|
1854
|
-
max_value=ciphertext.max_value,
|
|
1855
|
-
fxp_bits=ciphertext.fxp_bits,
|
|
1856
|
-
modulus=ciphertext.modulus,
|
|
1857
|
-
)
|
|
1858
|
-
]
|
|
1859
|
-
|
|
1860
|
-
except ValueError:
|
|
1861
|
-
# Re-raise ValueError directly (validation errors)
|
|
1862
|
-
raise
|
|
1863
|
-
except Exception as e:
|
|
1864
|
-
raise RuntimeError(f"Failed to perform transpose: {e}") from e
|