mplang-nightly 0.1.dev158__py3-none-any.whl → 0.1.dev268__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mplang/__init__.py +21 -45
- mplang/py.typed +13 -0
- mplang/v1/__init__.py +157 -0
- mplang/v1/_device.py +602 -0
- mplang/{analysis → v1/analysis}/__init__.py +1 -1
- mplang/{analysis → v1/analysis}/diagram.py +5 -7
- mplang/v1/core/__init__.py +157 -0
- mplang/{core → v1/core}/cluster.py +30 -14
- mplang/{core → v1/core}/comm.py +5 -1
- mplang/{core → v1/core}/context_mgr.py +1 -1
- mplang/{core/dtype.py → v1/core/dtypes.py} +44 -2
- mplang/{core → v1/core}/expr/__init__.py +7 -7
- mplang/{core → v1/core}/expr/ast.py +13 -14
- mplang/{core → v1/core}/expr/evaluator.py +65 -24
- mplang/{core → v1/core}/expr/printer.py +24 -18
- mplang/{core → v1/core}/expr/transformer.py +3 -3
- mplang/{core → v1/core}/expr/utils.py +2 -2
- mplang/{core → v1/core}/expr/visitor.py +1 -1
- mplang/{core → v1/core}/expr/walk.py +1 -1
- mplang/{core → v1/core}/interp.py +6 -6
- mplang/{core → v1/core}/mpir.py +23 -16
- mplang/{core → v1/core}/mpobject.py +6 -6
- mplang/{core → v1/core}/mptype.py +13 -10
- mplang/{core → v1/core}/pfunc.py +4 -4
- mplang/{core → v1/core}/primitive.py +106 -201
- mplang/{core → v1/core}/table.py +36 -8
- mplang/{core → v1/core}/tensor.py +1 -1
- mplang/{core → v1/core}/tracer.py +9 -9
- mplang/{api.py → v1/host.py} +38 -6
- mplang/v1/kernels/__init__.py +41 -0
- mplang/{kernels → v1/kernels}/base.py +1 -1
- mplang/v1/kernels/basic.py +240 -0
- mplang/{kernels → v1/kernels}/context.py +42 -27
- mplang/{kernels → v1/kernels}/crypto.py +44 -37
- mplang/v1/kernels/fhe.py +858 -0
- mplang/{kernels → v1/kernels}/mock_tee.py +12 -13
- mplang/{kernels → v1/kernels}/phe.py +263 -57
- mplang/{kernels → v1/kernels}/spu.py +137 -48
- mplang/{kernels → v1/kernels}/sql_duckdb.py +12 -15
- mplang/{kernels → v1/kernels}/stablehlo.py +30 -23
- mplang/v1/kernels/value.py +626 -0
- mplang/{ops → v1/ops}/__init__.py +5 -16
- mplang/{ops → v1/ops}/base.py +2 -5
- mplang/{ops/builtin.py → v1/ops/basic.py} +34 -26
- mplang/v1/ops/crypto.py +262 -0
- mplang/v1/ops/fhe.py +272 -0
- mplang/{ops → v1/ops}/jax_cc.py +33 -68
- mplang/v1/ops/nnx_cc.py +168 -0
- mplang/{ops → v1/ops}/phe.py +16 -4
- mplang/{ops → v1/ops}/spu.py +3 -5
- mplang/v1/ops/sql_cc.py +303 -0
- mplang/{ops → v1/ops}/tee.py +9 -24
- mplang/{protos → v1/protos}/v1alpha1/mpir_pb2.pyi +71 -21
- mplang/v1/protos/v1alpha1/value_pb2.py +34 -0
- mplang/v1/protos/v1alpha1/value_pb2.pyi +169 -0
- mplang/{runtime → v1/runtime}/__init__.py +2 -2
- mplang/v1/runtime/channel.py +230 -0
- mplang/{runtime → v1/runtime}/cli.py +35 -20
- mplang/{runtime → v1/runtime}/client.py +19 -8
- mplang/{runtime → v1/runtime}/communicator.py +59 -15
- mplang/{runtime → v1/runtime}/data_providers.py +80 -19
- mplang/{runtime → v1/runtime}/driver.py +30 -12
- mplang/v1/runtime/link_comm.py +196 -0
- mplang/{runtime → v1/runtime}/server.py +58 -42
- mplang/{runtime → v1/runtime}/session.py +57 -71
- mplang/{runtime → v1/runtime}/simulation.py +55 -28
- mplang/v1/simp/api.py +353 -0
- mplang/{simp → v1/simp}/mpi.py +8 -9
- mplang/{simp/__init__.py → v1/simp/party.py} +19 -145
- mplang/{simp → v1/simp}/random.py +21 -22
- mplang/v1/simp/smpc.py +238 -0
- mplang/v1/utils/table_utils.py +185 -0
- mplang/v2/__init__.py +424 -0
- mplang/v2/backends/__init__.py +57 -0
- mplang/v2/backends/bfv_impl.py +705 -0
- mplang/v2/backends/channel.py +217 -0
- mplang/v2/backends/crypto_impl.py +723 -0
- mplang/v2/backends/field_impl.py +454 -0
- mplang/v2/backends/func_impl.py +107 -0
- mplang/v2/backends/phe_impl.py +148 -0
- mplang/v2/backends/simp_design.md +136 -0
- mplang/v2/backends/simp_driver/__init__.py +41 -0
- mplang/v2/backends/simp_driver/http.py +168 -0
- mplang/v2/backends/simp_driver/mem.py +280 -0
- mplang/v2/backends/simp_driver/ops.py +135 -0
- mplang/v2/backends/simp_driver/state.py +60 -0
- mplang/v2/backends/simp_driver/values.py +52 -0
- mplang/v2/backends/simp_worker/__init__.py +29 -0
- mplang/v2/backends/simp_worker/http.py +354 -0
- mplang/v2/backends/simp_worker/mem.py +102 -0
- mplang/v2/backends/simp_worker/ops.py +167 -0
- mplang/v2/backends/simp_worker/state.py +49 -0
- mplang/v2/backends/spu_impl.py +275 -0
- mplang/v2/backends/spu_state.py +187 -0
- mplang/v2/backends/store_impl.py +62 -0
- mplang/v2/backends/table_impl.py +838 -0
- mplang/v2/backends/tee_impl.py +215 -0
- mplang/v2/backends/tensor_impl.py +519 -0
- mplang/v2/cli.py +603 -0
- mplang/v2/cli_guide.md +122 -0
- mplang/v2/dialects/__init__.py +36 -0
- mplang/v2/dialects/bfv.py +665 -0
- mplang/v2/dialects/crypto.py +689 -0
- mplang/v2/dialects/dtypes.py +378 -0
- mplang/v2/dialects/field.py +210 -0
- mplang/v2/dialects/func.py +135 -0
- mplang/v2/dialects/phe.py +723 -0
- mplang/v2/dialects/simp.py +944 -0
- mplang/v2/dialects/spu.py +349 -0
- mplang/v2/dialects/store.py +63 -0
- mplang/v2/dialects/table.py +407 -0
- mplang/v2/dialects/tee.py +346 -0
- mplang/v2/dialects/tensor.py +1175 -0
- mplang/v2/edsl/README.md +279 -0
- mplang/v2/edsl/__init__.py +99 -0
- mplang/v2/edsl/context.py +311 -0
- mplang/v2/edsl/graph.py +463 -0
- mplang/v2/edsl/jit.py +62 -0
- mplang/v2/edsl/object.py +53 -0
- mplang/v2/edsl/primitive.py +284 -0
- mplang/v2/edsl/printer.py +119 -0
- mplang/v2/edsl/registry.py +207 -0
- mplang/v2/edsl/serde.py +375 -0
- mplang/v2/edsl/tracer.py +614 -0
- mplang/v2/edsl/typing.py +816 -0
- mplang/v2/kernels/Makefile +30 -0
- mplang/v2/kernels/__init__.py +23 -0
- mplang/v2/kernels/gf128.cpp +148 -0
- mplang/v2/kernels/ldpc.cpp +82 -0
- mplang/v2/kernels/okvs.cpp +283 -0
- mplang/v2/kernels/okvs_opt.cpp +291 -0
- mplang/v2/kernels/py_kernels.py +398 -0
- mplang/v2/libs/collective.py +330 -0
- mplang/v2/libs/device/__init__.py +51 -0
- mplang/v2/libs/device/api.py +813 -0
- mplang/v2/libs/device/cluster.py +352 -0
- mplang/v2/libs/ml/__init__.py +23 -0
- mplang/v2/libs/ml/sgb.py +1861 -0
- mplang/v2/libs/mpc/__init__.py +41 -0
- mplang/v2/libs/mpc/_utils.py +99 -0
- mplang/v2/libs/mpc/analytics/__init__.py +35 -0
- mplang/v2/libs/mpc/analytics/aggregation.py +372 -0
- mplang/v2/libs/mpc/analytics/groupby.md +99 -0
- mplang/v2/libs/mpc/analytics/groupby.py +331 -0
- mplang/v2/libs/mpc/analytics/permutation.py +386 -0
- mplang/v2/libs/mpc/common/constants.py +39 -0
- mplang/v2/libs/mpc/ot/__init__.py +32 -0
- mplang/v2/libs/mpc/ot/base.py +222 -0
- mplang/v2/libs/mpc/ot/extension.py +477 -0
- mplang/v2/libs/mpc/ot/silent.py +217 -0
- mplang/v2/libs/mpc/psi/__init__.py +40 -0
- mplang/v2/libs/mpc/psi/cuckoo.py +228 -0
- mplang/v2/libs/mpc/psi/okvs.py +49 -0
- mplang/v2/libs/mpc/psi/okvs_gct.py +79 -0
- mplang/v2/libs/mpc/psi/oprf.py +310 -0
- mplang/v2/libs/mpc/psi/rr22.py +344 -0
- mplang/v2/libs/mpc/psi/unbalanced.py +200 -0
- mplang/v2/libs/mpc/vole/__init__.py +31 -0
- mplang/v2/libs/mpc/vole/gilboa.py +327 -0
- mplang/v2/libs/mpc/vole/ldpc.py +383 -0
- mplang/v2/libs/mpc/vole/silver.py +336 -0
- mplang/v2/runtime/__init__.py +15 -0
- mplang/v2/runtime/dialect_state.py +41 -0
- mplang/v2/runtime/interpreter.py +871 -0
- mplang/v2/runtime/object_store.py +194 -0
- mplang/v2/runtime/value.py +141 -0
- {mplang_nightly-0.1.dev158.dist-info → mplang_nightly-0.1.dev268.dist-info}/METADATA +24 -17
- mplang_nightly-0.1.dev268.dist-info/RECORD +180 -0
- {mplang_nightly-0.1.dev158.dist-info → mplang_nightly-0.1.dev268.dist-info}/WHEEL +1 -1
- mplang/core/__init__.py +0 -92
- mplang/device.py +0 -340
- mplang/kernels/builtin.py +0 -207
- mplang/ops/crypto.py +0 -109
- mplang/ops/ibis_cc.py +0 -139
- mplang/ops/sql.py +0 -61
- mplang/protos/v1alpha1/mpir_pb2_grpc.py +0 -3
- mplang/runtime/link_comm.py +0 -131
- mplang/simp/smpc.py +0 -201
- mplang/utils/table_utils.py +0 -73
- mplang_nightly-0.1.dev158.dist-info/RECORD +0 -77
- /mplang/{core → v1/core}/mask.py +0 -0
- /mplang/{protos → v1/protos}/v1alpha1/mpir_pb2.py +0 -0
- /mplang/{runtime → v1/runtime}/exceptions.py +0 -0
- /mplang/{runtime → v1/runtime}/http_api.md +0 -0
- /mplang/{kernels → v1/simp}/__init__.py +0 -0
- /mplang/{utils → v1/utils}/__init__.py +0 -0
- /mplang/{utils → v1/utils}/crypto.py +0 -0
- /mplang/{utils → v1/utils}/func_utils.py +0 -0
- /mplang/{utils → v1/utils}/spu_utils.py +0 -0
- {mplang_nightly-0.1.dev158.dist-info → mplang_nightly-0.1.dev268.dist-info}/entry_points.txt +0 -0
- {mplang_nightly-0.1.dev158.dist-info → mplang_nightly-0.1.dev268.dist-info}/licenses/LICENSE +0 -0
|
@@ -20,8 +20,9 @@ import warnings
|
|
|
20
20
|
import numpy as np
|
|
21
21
|
from numpy.typing import NDArray
|
|
22
22
|
|
|
23
|
-
from mplang.core
|
|
24
|
-
from mplang.kernels.base import cur_kctx, kernel_def
|
|
23
|
+
from mplang.v1.core import PFunction
|
|
24
|
+
from mplang.v1.kernels.base import cur_kctx, kernel_def
|
|
25
|
+
from mplang.v1.kernels.value import TensorValue
|
|
25
26
|
|
|
26
27
|
__all__: list[str] = []
|
|
27
28
|
|
|
@@ -46,28 +47,26 @@ def _quote_from_pk(pk: np.ndarray) -> NDArray[np.uint8]:
|
|
|
46
47
|
|
|
47
48
|
|
|
48
49
|
@kernel_def("mock_tee.quote_gen")
|
|
49
|
-
def _tee_quote_gen(pfunc: PFunction, pk:
|
|
50
|
+
def _tee_quote_gen(pfunc: PFunction, pk: TensorValue) -> TensorValue:
|
|
50
51
|
warnings.warn(
|
|
51
52
|
"Insecure mock TEE kernel 'mock_tee.quote_gen' in use. NOT secure; for local testing only.",
|
|
52
53
|
stacklevel=3,
|
|
53
54
|
)
|
|
54
|
-
|
|
55
|
+
pk_arr = pk.to_numpy().astype(np.uint8, copy=False)
|
|
55
56
|
# rng access ensures deterministic seeding per rank even if unused now
|
|
56
57
|
_rng()
|
|
57
|
-
|
|
58
|
+
quote = _quote_from_pk(pk_arr)
|
|
59
|
+
return TensorValue(np.array(quote, copy=True))
|
|
58
60
|
|
|
59
61
|
|
|
60
62
|
@kernel_def("mock_tee.attest")
|
|
61
|
-
def _tee_attest(pfunc: PFunction, quote:
|
|
63
|
+
def _tee_attest(pfunc: PFunction, quote: TensorValue) -> TensorValue:
|
|
62
64
|
warnings.warn(
|
|
63
65
|
"Insecure mock TEE kernel 'mock_tee.attest' in use. NOT secure; for local testing only.",
|
|
64
66
|
stacklevel=3,
|
|
65
67
|
)
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
if platform is None:
|
|
69
|
-
raise ValueError("missing required 'platform' attribute in PFunction")
|
|
70
|
-
|
|
71
|
-
if quote.size != 33:
|
|
68
|
+
quote_arr = quote.to_numpy().astype(np.uint8, copy=False)
|
|
69
|
+
if quote_arr.size != 33:
|
|
72
70
|
raise ValueError("mock quote must be 33 bytes (1 header + 32 pk)")
|
|
73
|
-
|
|
71
|
+
attest = quote_arr[1:33].astype(np.uint8, copy=True)
|
|
72
|
+
return TensorValue(attest)
|
|
@@ -14,15 +14,26 @@
|
|
|
14
14
|
|
|
15
15
|
"""PHE (Partially Homomorphic Encryption) backend implementation using lightPHE."""
|
|
16
16
|
|
|
17
|
-
from
|
|
17
|
+
from __future__ import annotations
|
|
18
|
+
|
|
19
|
+
import json
|
|
20
|
+
from typing import Any, ClassVar
|
|
18
21
|
|
|
19
22
|
import numpy as np
|
|
20
23
|
from lightphe import LightPHE
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
from mplang.core
|
|
24
|
-
from mplang.
|
|
25
|
-
from mplang.kernels.
|
|
24
|
+
from lightphe.models.Ciphertext import Ciphertext
|
|
25
|
+
|
|
26
|
+
from mplang.v1.core import DType, PFunction
|
|
27
|
+
from mplang.v1.kernels.base import kernel_def
|
|
28
|
+
from mplang.v1.kernels.value import (
|
|
29
|
+
TensorValue,
|
|
30
|
+
Value,
|
|
31
|
+
ValueDecodeError,
|
|
32
|
+
ValueProtoBuilder,
|
|
33
|
+
ValueProtoReader,
|
|
34
|
+
register_value,
|
|
35
|
+
)
|
|
36
|
+
from mplang.v1.protos.v1alpha1 import value_pb2 as _value_pb2
|
|
26
37
|
|
|
27
38
|
# This controls the decimal precision used in lightPHE for float operations
|
|
28
39
|
# we force it to 0 to only support integer operations
|
|
@@ -30,8 +41,12 @@ from mplang.kernels.base import kernel_def
|
|
|
30
41
|
PRECISION = 0
|
|
31
42
|
|
|
32
43
|
|
|
33
|
-
|
|
34
|
-
|
|
44
|
+
@register_value
|
|
45
|
+
class PublicKey(Value):
|
|
46
|
+
"""PHE Public Key Value type."""
|
|
47
|
+
|
|
48
|
+
KIND: ClassVar[str] = "mplang.phe.PublicKey"
|
|
49
|
+
WIRE_VERSION: ClassVar[int] = 1
|
|
35
50
|
|
|
36
51
|
def __init__(
|
|
37
52
|
self,
|
|
@@ -62,12 +77,56 @@ class PublicKey:
|
|
|
62
77
|
"""Maximum float value that can be encoded."""
|
|
63
78
|
return float(self.max_value / (2**self.fxp_bits))
|
|
64
79
|
|
|
80
|
+
def to_proto(self) -> _value_pb2.ValueProto:
|
|
81
|
+
"""Serialize PublicKey to wire format."""
|
|
82
|
+
return (
|
|
83
|
+
ValueProtoBuilder(self.KIND, self.WIRE_VERSION)
|
|
84
|
+
.set_attr("scheme", self.scheme)
|
|
85
|
+
.set_attr("key_size", self.key_size)
|
|
86
|
+
.set_attr("max_value", self.max_value)
|
|
87
|
+
.set_attr("fxp_bits", self.fxp_bits)
|
|
88
|
+
.set_attr("modulus", str(self.modulus) if self.modulus is not None else "")
|
|
89
|
+
.set_payload(json.dumps(self.key_data).encode("utf-8"))
|
|
90
|
+
.build()
|
|
91
|
+
)
|
|
92
|
+
|
|
93
|
+
@classmethod
|
|
94
|
+
def from_proto(cls, proto: _value_pb2.ValueProto) -> PublicKey:
|
|
95
|
+
"""Deserialize PublicKey from wire format."""
|
|
96
|
+
reader = ValueProtoReader(proto)
|
|
97
|
+
if reader.version != cls.WIRE_VERSION:
|
|
98
|
+
raise ValueDecodeError(f"Unsupported PublicKey version {reader.version}")
|
|
99
|
+
|
|
100
|
+
# Read metadata from runtime_attrs
|
|
101
|
+
scheme = reader.get_attr("scheme")
|
|
102
|
+
key_size = reader.get_attr("key_size")
|
|
103
|
+
max_value = reader.get_attr("max_value")
|
|
104
|
+
fxp_bits = reader.get_attr("fxp_bits")
|
|
105
|
+
modulus_str = reader.get_attr("modulus")
|
|
106
|
+
modulus = None if modulus_str == "" else int(modulus_str)
|
|
107
|
+
|
|
108
|
+
# JSON deserialize the public key dict
|
|
109
|
+
key_data = json.loads(reader.payload.decode("utf-8"))
|
|
110
|
+
|
|
111
|
+
return cls(
|
|
112
|
+
key_data=key_data,
|
|
113
|
+
scheme=scheme,
|
|
114
|
+
key_size=key_size,
|
|
115
|
+
max_value=max_value,
|
|
116
|
+
fxp_bits=fxp_bits,
|
|
117
|
+
modulus=modulus,
|
|
118
|
+
)
|
|
119
|
+
|
|
65
120
|
def __repr__(self) -> str:
|
|
66
121
|
return f"PublicKey(scheme={self.scheme}, key_size={self.key_size}, max_value={self.max_value}, fxp_bits={self.fxp_bits})"
|
|
67
122
|
|
|
68
123
|
|
|
69
|
-
|
|
70
|
-
|
|
124
|
+
@register_value
|
|
125
|
+
class PrivateKey(Value):
|
|
126
|
+
"""PHE Private Key Value type."""
|
|
127
|
+
|
|
128
|
+
KIND: ClassVar[str] = "mplang.phe.PrivateKey"
|
|
129
|
+
WIRE_VERSION: ClassVar[int] = 1
|
|
71
130
|
|
|
72
131
|
def __init__(
|
|
73
132
|
self,
|
|
@@ -100,12 +159,63 @@ class PrivateKey:
|
|
|
100
159
|
"""Maximum float value that can be encoded."""
|
|
101
160
|
return float(self.max_value / (2**self.fxp_bits))
|
|
102
161
|
|
|
162
|
+
def to_proto(self) -> _value_pb2.ValueProto:
|
|
163
|
+
"""Serialize PrivateKey to wire format."""
|
|
164
|
+
# JSON serialize both key dicts (contain int values)
|
|
165
|
+
# Store both keys in a single dict to avoid needing length metadata
|
|
166
|
+
keys_dict = {"sk": self.sk_data, "pk": self.pk_data}
|
|
167
|
+
|
|
168
|
+
return (
|
|
169
|
+
ValueProtoBuilder(self.KIND, self.WIRE_VERSION)
|
|
170
|
+
.set_attr("scheme", self.scheme)
|
|
171
|
+
.set_attr("key_size", self.key_size)
|
|
172
|
+
.set_attr("max_value", self.max_value)
|
|
173
|
+
.set_attr("fxp_bits", self.fxp_bits)
|
|
174
|
+
.set_attr("modulus", str(self.modulus) if self.modulus is not None else "")
|
|
175
|
+
.set_payload(json.dumps(keys_dict).encode("utf-8"))
|
|
176
|
+
.build()
|
|
177
|
+
)
|
|
178
|
+
|
|
179
|
+
@classmethod
|
|
180
|
+
def from_proto(cls, proto: _value_pb2.ValueProto) -> PrivateKey:
|
|
181
|
+
"""Deserialize PrivateKey from wire format."""
|
|
182
|
+
reader = ValueProtoReader(proto)
|
|
183
|
+
if reader.version != cls.WIRE_VERSION:
|
|
184
|
+
raise ValueDecodeError(f"Unsupported PrivateKey version {reader.version}")
|
|
185
|
+
|
|
186
|
+
# Read metadata from runtime_attrs
|
|
187
|
+
scheme = reader.get_attr("scheme")
|
|
188
|
+
key_size = reader.get_attr("key_size")
|
|
189
|
+
max_value = reader.get_attr("max_value")
|
|
190
|
+
fxp_bits = reader.get_attr("fxp_bits")
|
|
191
|
+
modulus_str = reader.get_attr("modulus")
|
|
192
|
+
modulus = None if modulus_str == "" else int(modulus_str)
|
|
193
|
+
|
|
194
|
+
# JSON deserialize both key dicts
|
|
195
|
+
keys_dict = json.loads(reader.payload.decode("utf-8"))
|
|
196
|
+
sk_data = keys_dict["sk"]
|
|
197
|
+
pk_data = keys_dict["pk"]
|
|
198
|
+
|
|
199
|
+
return cls(
|
|
200
|
+
sk_data=sk_data,
|
|
201
|
+
pk_data=pk_data,
|
|
202
|
+
scheme=scheme,
|
|
203
|
+
key_size=key_size,
|
|
204
|
+
max_value=max_value,
|
|
205
|
+
fxp_bits=fxp_bits,
|
|
206
|
+
modulus=modulus,
|
|
207
|
+
)
|
|
208
|
+
|
|
103
209
|
def __repr__(self) -> str:
|
|
104
210
|
return f"PrivateKey(scheme={self.scheme}, key_size={self.key_size}, max_value={self.max_value}, fxp_bits={self.fxp_bits})"
|
|
105
211
|
|
|
106
212
|
|
|
107
|
-
|
|
108
|
-
|
|
213
|
+
@register_value
|
|
214
|
+
class CipherText(Value):
|
|
215
|
+
"""PHE CipherText Value type."""
|
|
216
|
+
|
|
217
|
+
KIND: ClassVar[str] = "mplang.phe.CipherText"
|
|
218
|
+
WIRE_VERSION: ClassVar[int] = 1
|
|
109
219
|
|
|
110
220
|
def __init__(
|
|
111
221
|
self,
|
|
@@ -142,6 +252,106 @@ class CipherText:
|
|
|
142
252
|
"""Maximum float value that can be encoded."""
|
|
143
253
|
return float(self.max_value / (2**self.fxp_bits))
|
|
144
254
|
|
|
255
|
+
def to_proto(self) -> _value_pb2.ValueProto:
|
|
256
|
+
"""Serialize CipherText to wire format.
|
|
257
|
+
|
|
258
|
+
WARNING: This serialization is tightly coupled to lightphe.Ciphertext
|
|
259
|
+
internal attributes (value, algorithm_name, keys). Any changes to these
|
|
260
|
+
attributes in future lightphe versions will break serialization.
|
|
261
|
+
|
|
262
|
+
TODO: Check if lightphe provides official serialization methods and
|
|
263
|
+
migrate to them if available. Consider adding version compatibility checks.
|
|
264
|
+
"""
|
|
265
|
+
# JSON serialize ciphertext components
|
|
266
|
+
# ct_data is a list of lightPHE Ciphertext objects
|
|
267
|
+
# Each Ciphertext has: value, algorithm_name, keys
|
|
268
|
+
# We need to serialize the list of ciphertexts
|
|
269
|
+
if not isinstance(self.ct_data, list):
|
|
270
|
+
raise ValueError(f"ct_data should be a list, got {type(self.ct_data)}")
|
|
271
|
+
|
|
272
|
+
ct_list = []
|
|
273
|
+
for ct in self.ct_data:
|
|
274
|
+
if not isinstance(ct, Ciphertext):
|
|
275
|
+
raise TypeError(
|
|
276
|
+
f"ct_data must contain lightphe.Ciphertext objects, got {type(ct).__name__}"
|
|
277
|
+
)
|
|
278
|
+
ct_list.append({
|
|
279
|
+
"value": ct.value,
|
|
280
|
+
"algorithm_name": ct.algorithm_name,
|
|
281
|
+
"keys": ct.keys,
|
|
282
|
+
})
|
|
283
|
+
|
|
284
|
+
# Combine ct_data and pk_data into single dict
|
|
285
|
+
payload_dict = {
|
|
286
|
+
"ct_list": ct_list,
|
|
287
|
+
"pk": self.pk_data if self.pk_data is not None else None,
|
|
288
|
+
}
|
|
289
|
+
|
|
290
|
+
return (
|
|
291
|
+
ValueProtoBuilder(self.KIND, self.WIRE_VERSION)
|
|
292
|
+
.set_attr("semantic_dtype", str(self.semantic_dtype))
|
|
293
|
+
.set_attr("semantic_shape", list(self.semantic_shape))
|
|
294
|
+
.set_attr("scheme", self.scheme)
|
|
295
|
+
.set_attr("key_size", self.key_size)
|
|
296
|
+
.set_attr("max_value", self.max_value)
|
|
297
|
+
.set_attr("fxp_bits", self.fxp_bits)
|
|
298
|
+
.set_attr("modulus", str(self.modulus) if self.modulus is not None else "")
|
|
299
|
+
.set_payload(json.dumps(payload_dict).encode("utf-8"))
|
|
300
|
+
.build()
|
|
301
|
+
)
|
|
302
|
+
|
|
303
|
+
@classmethod
|
|
304
|
+
def from_proto(cls, proto: _value_pb2.ValueProto) -> CipherText:
|
|
305
|
+
"""Deserialize CipherText from wire format."""
|
|
306
|
+
reader = ValueProtoReader(proto)
|
|
307
|
+
if reader.version != cls.WIRE_VERSION:
|
|
308
|
+
raise ValueDecodeError(f"Unsupported CipherText version {reader.version}")
|
|
309
|
+
|
|
310
|
+
# Read metadata from runtime_attrs
|
|
311
|
+
semantic_dtype_str = reader.get_attr("semantic_dtype")
|
|
312
|
+
semantic_shape = reader.get_attr("semantic_shape")
|
|
313
|
+
scheme = reader.get_attr("scheme")
|
|
314
|
+
key_size = reader.get_attr("key_size")
|
|
315
|
+
max_value = reader.get_attr("max_value")
|
|
316
|
+
fxp_bits = reader.get_attr("fxp_bits")
|
|
317
|
+
modulus_str = reader.get_attr("modulus")
|
|
318
|
+
modulus = None if modulus_str == "" else int(modulus_str)
|
|
319
|
+
|
|
320
|
+
# JSON deserialize ciphertext and public key
|
|
321
|
+
payload_dict = json.loads(reader.payload.decode("utf-8"))
|
|
322
|
+
ct_list = payload_dict["ct_list"]
|
|
323
|
+
pk_data = payload_dict["pk"]
|
|
324
|
+
|
|
325
|
+
# Reconstruct ct_data: list of Ciphertext objects
|
|
326
|
+
ct_data = []
|
|
327
|
+
for ct_dict in ct_list:
|
|
328
|
+
if ct_dict["keys"] is None or ct_dict["algorithm_name"] is None:
|
|
329
|
+
raise ValueDecodeError(
|
|
330
|
+
"Invalid CipherText: missing keys or algorithm_name in serialized data"
|
|
331
|
+
)
|
|
332
|
+
ct_data.append(
|
|
333
|
+
Ciphertext(
|
|
334
|
+
algorithm_name=ct_dict["algorithm_name"],
|
|
335
|
+
keys=ct_dict["keys"],
|
|
336
|
+
value=ct_dict["value"],
|
|
337
|
+
)
|
|
338
|
+
)
|
|
339
|
+
|
|
340
|
+
# Parse dtype string back to DType
|
|
341
|
+
dtype = DType.from_any(semantic_dtype_str)
|
|
342
|
+
|
|
343
|
+
return cls(
|
|
344
|
+
ct_data=ct_data,
|
|
345
|
+
semantic_dtype=dtype,
|
|
346
|
+
semantic_shape=tuple(semantic_shape),
|
|
347
|
+
scheme=scheme,
|
|
348
|
+
key_size=key_size,
|
|
349
|
+
pk_data=pk_data,
|
|
350
|
+
max_value=max_value,
|
|
351
|
+
fxp_bits=fxp_bits,
|
|
352
|
+
modulus=modulus,
|
|
353
|
+
)
|
|
354
|
+
|
|
145
355
|
def __repr__(self) -> str:
|
|
146
356
|
return f"CipherText(dtype={self.semantic_dtype}, shape={self.semantic_shape}, scheme={self.scheme})"
|
|
147
357
|
|
|
@@ -257,33 +467,15 @@ def _range_decode_mixed(
|
|
|
257
467
|
return _range_decode_integer(encoded_value, max_value, modulus)
|
|
258
468
|
|
|
259
469
|
|
|
260
|
-
def _convert_to_numpy(obj: TensorLike) -> np.ndarray:
|
|
261
|
-
"""Convert a TensorLike object to numpy array."""
|
|
262
|
-
if isinstance(obj, np.ndarray):
|
|
263
|
-
return obj
|
|
264
|
-
|
|
265
|
-
# Try to use .numpy() method if available
|
|
266
|
-
if hasattr(obj, "numpy"):
|
|
267
|
-
numpy_method = getattr(obj, "numpy", None)
|
|
268
|
-
if callable(numpy_method):
|
|
269
|
-
try:
|
|
270
|
-
return np.asarray(numpy_method())
|
|
271
|
-
except Exception:
|
|
272
|
-
pass
|
|
273
|
-
|
|
274
|
-
return np.asarray(obj)
|
|
275
|
-
|
|
276
|
-
|
|
277
470
|
@kernel_def("phe.keygen")
|
|
278
471
|
def _phe_keygen(pfunc: PFunction) -> Any:
|
|
279
472
|
scheme = pfunc.attrs.get("scheme", "paillier")
|
|
280
473
|
# use small key_size to speed up tests
|
|
281
474
|
# in production use at least 2048 bits or 3072 bits for better security
|
|
282
475
|
key_size = pfunc.attrs.get("key_size", 2048)
|
|
283
|
-
max_value
|
|
284
|
-
|
|
285
|
-
|
|
286
|
-
fxp_bits = pfunc.attrs.get("fxp_bits", 12)
|
|
476
|
+
# Accept very large max_value; allow decimal string input, kept simple like other attrs
|
|
477
|
+
max_value = int(pfunc.attrs.get("max_value", 2**32))
|
|
478
|
+
fxp_bits = int(pfunc.attrs.get("fxp_bits", 12))
|
|
287
479
|
|
|
288
480
|
# Validate scheme
|
|
289
481
|
if scheme.lower() not in ["paillier"]:
|
|
@@ -334,14 +526,16 @@ def _phe_keygen(pfunc: PFunction) -> Any:
|
|
|
334
526
|
|
|
335
527
|
|
|
336
528
|
@kernel_def("phe.encrypt")
|
|
337
|
-
def _phe_encrypt(
|
|
529
|
+
def _phe_encrypt(
|
|
530
|
+
pfunc: PFunction, plaintext: TensorValue, public_key: PublicKey
|
|
531
|
+
) -> Any:
|
|
338
532
|
# Validate public_key type
|
|
339
533
|
if not isinstance(public_key, PublicKey):
|
|
340
534
|
raise ValueError("Second argument must be a PublicKey instance")
|
|
341
535
|
|
|
342
536
|
try:
|
|
343
537
|
# Convert plaintext to numpy to get semantic type info
|
|
344
|
-
plaintext_np =
|
|
538
|
+
plaintext_np = plaintext.to_numpy()
|
|
345
539
|
semantic_dtype = DType.from_numpy(plaintext_np.dtype)
|
|
346
540
|
semantic_shape = plaintext_np.shape
|
|
347
541
|
|
|
@@ -403,14 +597,14 @@ def _phe_encrypt(pfunc: PFunction, plaintext: Any, public_key: PublicKey) -> Any
|
|
|
403
597
|
|
|
404
598
|
|
|
405
599
|
@kernel_def("phe.mul")
|
|
406
|
-
def _phe_mul(pfunc: PFunction, ciphertext: CipherText, plaintext:
|
|
600
|
+
def _phe_mul(pfunc: PFunction, ciphertext: CipherText, plaintext: TensorValue) -> Any:
|
|
407
601
|
# Validate that first argument is a CipherText
|
|
408
602
|
if not isinstance(ciphertext, CipherText):
|
|
409
603
|
raise ValueError("First argument must be a CipherText instance")
|
|
410
604
|
|
|
411
605
|
try:
|
|
412
606
|
# Convert plaintext to numpy
|
|
413
|
-
plaintext_np =
|
|
607
|
+
plaintext_np = plaintext.to_numpy()
|
|
414
608
|
|
|
415
609
|
# Check if plaintext is floating point type - multiplication not supported
|
|
416
610
|
if np.issubdtype(plaintext_np.dtype, np.floating):
|
|
@@ -443,7 +637,8 @@ def _phe_mul(pfunc: PFunction, ciphertext: CipherText, plaintext: Any) -> Any:
|
|
|
443
637
|
# Use numpy to create a properly broadcasted index mapping
|
|
444
638
|
# Create a dummy array with same shape as ciphertext, fill with indices
|
|
445
639
|
dummy_ct = (
|
|
446
|
-
np
|
|
640
|
+
np
|
|
641
|
+
.arange(np.prod(ciphertext.semantic_shape))
|
|
447
642
|
.reshape(ciphertext.semantic_shape)
|
|
448
643
|
.astype(np.int64)
|
|
449
644
|
)
|
|
@@ -511,7 +706,7 @@ def _phe_add(pfunc: PFunction, lhs: Any, rhs: Any) -> Any:
|
|
|
511
706
|
elif isinstance(rhs, CipherText):
|
|
512
707
|
return _phe_add_ct2pt(rhs, lhs)
|
|
513
708
|
else:
|
|
514
|
-
return
|
|
709
|
+
return TensorValue(lhs.to_numpy() + rhs.to_numpy())
|
|
515
710
|
except ValueError:
|
|
516
711
|
raise
|
|
517
712
|
except Exception as e: # pragma: no cover
|
|
@@ -550,7 +745,8 @@ def _phe_add_ct2ct(ct1: CipherText, ct2: CipherText) -> CipherText:
|
|
|
550
745
|
# Broadcast ct1 if needed
|
|
551
746
|
if ct1.semantic_shape != result_shape:
|
|
552
747
|
dummy_ct1 = (
|
|
553
|
-
np
|
|
748
|
+
np
|
|
749
|
+
.arange(np.prod(ct1.semantic_shape))
|
|
554
750
|
.reshape(ct1.semantic_shape)
|
|
555
751
|
.astype(np.int64)
|
|
556
752
|
)
|
|
@@ -563,7 +759,8 @@ def _phe_add_ct2ct(ct1: CipherText, ct2: CipherText) -> CipherText:
|
|
|
563
759
|
# Broadcast ct2 if needed
|
|
564
760
|
if ct2.semantic_shape != result_shape:
|
|
565
761
|
dummy_ct2 = (
|
|
566
|
-
np
|
|
762
|
+
np
|
|
763
|
+
.arange(np.prod(ct2.semantic_shape))
|
|
567
764
|
.reshape(ct2.semantic_shape)
|
|
568
765
|
.astype(np.int64)
|
|
569
766
|
)
|
|
@@ -593,9 +790,9 @@ def _phe_add_ct2ct(ct1: CipherText, ct2: CipherText) -> CipherText:
|
|
|
593
790
|
)
|
|
594
791
|
|
|
595
792
|
|
|
596
|
-
def _phe_add_ct2pt(ciphertext: CipherText, plaintext:
|
|
793
|
+
def _phe_add_ct2pt(ciphertext: CipherText, plaintext: TensorValue) -> CipherText:
|
|
597
794
|
# Convert plaintext to numpy
|
|
598
|
-
plaintext_np =
|
|
795
|
+
plaintext_np = plaintext.to_numpy()
|
|
599
796
|
plaintext_dtype = DType.from_numpy(plaintext_np.dtype)
|
|
600
797
|
|
|
601
798
|
# Check for mixed precision issue: floating point ciphertext + integer plaintext
|
|
@@ -636,7 +833,8 @@ def _phe_add_ct2pt(ciphertext: CipherText, plaintext: TensorLike) -> CipherText:
|
|
|
636
833
|
# Broadcast ciphertext if needed
|
|
637
834
|
if ciphertext.semantic_shape != result_shape:
|
|
638
835
|
dummy_ct = (
|
|
639
|
-
np
|
|
836
|
+
np
|
|
837
|
+
.arange(np.prod(ciphertext.semantic_shape))
|
|
640
838
|
.reshape(ciphertext.semantic_shape)
|
|
641
839
|
.astype(np.int64)
|
|
642
840
|
)
|
|
@@ -802,12 +1000,17 @@ def _phe_decrypt(
|
|
|
802
1000
|
# Convert to target dtype
|
|
803
1001
|
if target_dtype.kind in "iu": # integer types
|
|
804
1002
|
# Convert floats back to integers for integer semantic types
|
|
805
|
-
|
|
806
|
-
|
|
807
|
-
|
|
808
|
-
|
|
809
|
-
|
|
810
|
-
|
|
1003
|
+
# decoded_data are numeric (ints or floats); normalize to Python int
|
|
1004
|
+
ints = [round(v) if isinstance(v, float) else v for v in decoded_data]
|
|
1005
|
+
if np.issubdtype(target_dtype, np.unsignedinteger):
|
|
1006
|
+
# Reduce modulo 2^k for unsigned to preserve ring semantics
|
|
1007
|
+
width = np.iinfo(target_dtype).bits
|
|
1008
|
+
mod = 1 << width
|
|
1009
|
+
processed_data = [v % mod for v in ints]
|
|
1010
|
+
else:
|
|
1011
|
+
# Signed integers: clamp to dtype range
|
|
1012
|
+
info = np.iinfo(target_dtype)
|
|
1013
|
+
processed_data = [max(info.min, min(info.max, v)) for v in ints]
|
|
811
1014
|
else: # float types
|
|
812
1015
|
processed_data = decoded_data
|
|
813
1016
|
|
|
@@ -816,14 +1019,14 @@ def _phe_decrypt(
|
|
|
816
1019
|
ciphertext.semantic_shape
|
|
817
1020
|
)
|
|
818
1021
|
|
|
819
|
-
return [plaintext_np]
|
|
1022
|
+
return [TensorValue(plaintext_np)]
|
|
820
1023
|
|
|
821
1024
|
except Exception as e:
|
|
822
1025
|
raise RuntimeError(f"Failed to decrypt data: {e}") from e
|
|
823
1026
|
|
|
824
1027
|
|
|
825
1028
|
@kernel_def("phe.dot")
|
|
826
|
-
def _phe_dot(pfunc: PFunction, ciphertext: CipherText, plaintext:
|
|
1029
|
+
def _phe_dot(pfunc: PFunction, ciphertext: CipherText, plaintext: TensorValue) -> Any:
|
|
827
1030
|
"""Execute homomorphic dot product with zero-value optimization.
|
|
828
1031
|
|
|
829
1032
|
Supports various dot product operations:
|
|
@@ -844,7 +1047,7 @@ def _phe_dot(pfunc: PFunction, ciphertext: CipherText, plaintext: TensorLike) ->
|
|
|
844
1047
|
|
|
845
1048
|
try:
|
|
846
1049
|
# Convert plaintext to numpy
|
|
847
|
-
plaintext_np =
|
|
1050
|
+
plaintext_np = plaintext.to_numpy()
|
|
848
1051
|
|
|
849
1052
|
# Check if plaintext is floating point type - dot product not supported
|
|
850
1053
|
if np.issubdtype(plaintext_np.dtype, np.floating):
|
|
@@ -1109,7 +1312,7 @@ def _phe_dot(pfunc: PFunction, ciphertext: CipherText, plaintext: TensorLike) ->
|
|
|
1109
1312
|
|
|
1110
1313
|
|
|
1111
1314
|
@kernel_def("phe.gather")
|
|
1112
|
-
def _phe_gather(pfunc: PFunction, ciphertext: CipherText, indices:
|
|
1315
|
+
def _phe_gather(pfunc: PFunction, ciphertext: CipherText, indices: TensorValue) -> Any:
|
|
1113
1316
|
"""Execute gather operation on CipherText.
|
|
1114
1317
|
|
|
1115
1318
|
Supports gathering from multidimensional CipherText using multidimensional indices.
|
|
@@ -1126,7 +1329,7 @@ def _phe_gather(pfunc: PFunction, ciphertext: CipherText, indices: Any) -> Any:
|
|
|
1126
1329
|
|
|
1127
1330
|
try:
|
|
1128
1331
|
# Convert indices to numpy
|
|
1129
|
-
indices_np =
|
|
1332
|
+
indices_np = indices.to_numpy()
|
|
1130
1333
|
|
|
1131
1334
|
if not np.issubdtype(indices_np.dtype, np.integer):
|
|
1132
1335
|
raise ValueError("Indices must be of integer type")
|
|
@@ -1224,7 +1427,10 @@ def _phe_gather(pfunc: PFunction, ciphertext: CipherText, indices: Any) -> Any:
|
|
|
1224
1427
|
|
|
1225
1428
|
@kernel_def("phe.scatter")
|
|
1226
1429
|
def _phe_scatter(
|
|
1227
|
-
pfunc: PFunction,
|
|
1430
|
+
pfunc: PFunction,
|
|
1431
|
+
ciphertext: CipherText,
|
|
1432
|
+
indices: TensorValue,
|
|
1433
|
+
updated: CipherText,
|
|
1228
1434
|
) -> Any:
|
|
1229
1435
|
"""Execute scatter operation on CipherText.
|
|
1230
1436
|
|
|
@@ -1252,7 +1458,7 @@ def _phe_scatter(
|
|
|
1252
1458
|
|
|
1253
1459
|
try:
|
|
1254
1460
|
# Convert indices to numpy
|
|
1255
|
-
indices_np =
|
|
1461
|
+
indices_np = indices.to_numpy()
|
|
1256
1462
|
|
|
1257
1463
|
if not np.issubdtype(indices_np.dtype, np.integer):
|
|
1258
1464
|
raise ValueError("Indices must be of integer type")
|