mplang-nightly 0.1.dev158__py3-none-any.whl → 0.1.dev268__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (191) hide show
  1. mplang/__init__.py +21 -45
  2. mplang/py.typed +13 -0
  3. mplang/v1/__init__.py +157 -0
  4. mplang/v1/_device.py +602 -0
  5. mplang/{analysis → v1/analysis}/__init__.py +1 -1
  6. mplang/{analysis → v1/analysis}/diagram.py +5 -7
  7. mplang/v1/core/__init__.py +157 -0
  8. mplang/{core → v1/core}/cluster.py +30 -14
  9. mplang/{core → v1/core}/comm.py +5 -1
  10. mplang/{core → v1/core}/context_mgr.py +1 -1
  11. mplang/{core/dtype.py → v1/core/dtypes.py} +44 -2
  12. mplang/{core → v1/core}/expr/__init__.py +7 -7
  13. mplang/{core → v1/core}/expr/ast.py +13 -14
  14. mplang/{core → v1/core}/expr/evaluator.py +65 -24
  15. mplang/{core → v1/core}/expr/printer.py +24 -18
  16. mplang/{core → v1/core}/expr/transformer.py +3 -3
  17. mplang/{core → v1/core}/expr/utils.py +2 -2
  18. mplang/{core → v1/core}/expr/visitor.py +1 -1
  19. mplang/{core → v1/core}/expr/walk.py +1 -1
  20. mplang/{core → v1/core}/interp.py +6 -6
  21. mplang/{core → v1/core}/mpir.py +23 -16
  22. mplang/{core → v1/core}/mpobject.py +6 -6
  23. mplang/{core → v1/core}/mptype.py +13 -10
  24. mplang/{core → v1/core}/pfunc.py +4 -4
  25. mplang/{core → v1/core}/primitive.py +106 -201
  26. mplang/{core → v1/core}/table.py +36 -8
  27. mplang/{core → v1/core}/tensor.py +1 -1
  28. mplang/{core → v1/core}/tracer.py +9 -9
  29. mplang/{api.py → v1/host.py} +38 -6
  30. mplang/v1/kernels/__init__.py +41 -0
  31. mplang/{kernels → v1/kernels}/base.py +1 -1
  32. mplang/v1/kernels/basic.py +240 -0
  33. mplang/{kernels → v1/kernels}/context.py +42 -27
  34. mplang/{kernels → v1/kernels}/crypto.py +44 -37
  35. mplang/v1/kernels/fhe.py +858 -0
  36. mplang/{kernels → v1/kernels}/mock_tee.py +12 -13
  37. mplang/{kernels → v1/kernels}/phe.py +263 -57
  38. mplang/{kernels → v1/kernels}/spu.py +137 -48
  39. mplang/{kernels → v1/kernels}/sql_duckdb.py +12 -15
  40. mplang/{kernels → v1/kernels}/stablehlo.py +30 -23
  41. mplang/v1/kernels/value.py +626 -0
  42. mplang/{ops → v1/ops}/__init__.py +5 -16
  43. mplang/{ops → v1/ops}/base.py +2 -5
  44. mplang/{ops/builtin.py → v1/ops/basic.py} +34 -26
  45. mplang/v1/ops/crypto.py +262 -0
  46. mplang/v1/ops/fhe.py +272 -0
  47. mplang/{ops → v1/ops}/jax_cc.py +33 -68
  48. mplang/v1/ops/nnx_cc.py +168 -0
  49. mplang/{ops → v1/ops}/phe.py +16 -4
  50. mplang/{ops → v1/ops}/spu.py +3 -5
  51. mplang/v1/ops/sql_cc.py +303 -0
  52. mplang/{ops → v1/ops}/tee.py +9 -24
  53. mplang/{protos → v1/protos}/v1alpha1/mpir_pb2.pyi +71 -21
  54. mplang/v1/protos/v1alpha1/value_pb2.py +34 -0
  55. mplang/v1/protos/v1alpha1/value_pb2.pyi +169 -0
  56. mplang/{runtime → v1/runtime}/__init__.py +2 -2
  57. mplang/v1/runtime/channel.py +230 -0
  58. mplang/{runtime → v1/runtime}/cli.py +35 -20
  59. mplang/{runtime → v1/runtime}/client.py +19 -8
  60. mplang/{runtime → v1/runtime}/communicator.py +59 -15
  61. mplang/{runtime → v1/runtime}/data_providers.py +80 -19
  62. mplang/{runtime → v1/runtime}/driver.py +30 -12
  63. mplang/v1/runtime/link_comm.py +196 -0
  64. mplang/{runtime → v1/runtime}/server.py +58 -42
  65. mplang/{runtime → v1/runtime}/session.py +57 -71
  66. mplang/{runtime → v1/runtime}/simulation.py +55 -28
  67. mplang/v1/simp/api.py +353 -0
  68. mplang/{simp → v1/simp}/mpi.py +8 -9
  69. mplang/{simp/__init__.py → v1/simp/party.py} +19 -145
  70. mplang/{simp → v1/simp}/random.py +21 -22
  71. mplang/v1/simp/smpc.py +238 -0
  72. mplang/v1/utils/table_utils.py +185 -0
  73. mplang/v2/__init__.py +424 -0
  74. mplang/v2/backends/__init__.py +57 -0
  75. mplang/v2/backends/bfv_impl.py +705 -0
  76. mplang/v2/backends/channel.py +217 -0
  77. mplang/v2/backends/crypto_impl.py +723 -0
  78. mplang/v2/backends/field_impl.py +454 -0
  79. mplang/v2/backends/func_impl.py +107 -0
  80. mplang/v2/backends/phe_impl.py +148 -0
  81. mplang/v2/backends/simp_design.md +136 -0
  82. mplang/v2/backends/simp_driver/__init__.py +41 -0
  83. mplang/v2/backends/simp_driver/http.py +168 -0
  84. mplang/v2/backends/simp_driver/mem.py +280 -0
  85. mplang/v2/backends/simp_driver/ops.py +135 -0
  86. mplang/v2/backends/simp_driver/state.py +60 -0
  87. mplang/v2/backends/simp_driver/values.py +52 -0
  88. mplang/v2/backends/simp_worker/__init__.py +29 -0
  89. mplang/v2/backends/simp_worker/http.py +354 -0
  90. mplang/v2/backends/simp_worker/mem.py +102 -0
  91. mplang/v2/backends/simp_worker/ops.py +167 -0
  92. mplang/v2/backends/simp_worker/state.py +49 -0
  93. mplang/v2/backends/spu_impl.py +275 -0
  94. mplang/v2/backends/spu_state.py +187 -0
  95. mplang/v2/backends/store_impl.py +62 -0
  96. mplang/v2/backends/table_impl.py +838 -0
  97. mplang/v2/backends/tee_impl.py +215 -0
  98. mplang/v2/backends/tensor_impl.py +519 -0
  99. mplang/v2/cli.py +603 -0
  100. mplang/v2/cli_guide.md +122 -0
  101. mplang/v2/dialects/__init__.py +36 -0
  102. mplang/v2/dialects/bfv.py +665 -0
  103. mplang/v2/dialects/crypto.py +689 -0
  104. mplang/v2/dialects/dtypes.py +378 -0
  105. mplang/v2/dialects/field.py +210 -0
  106. mplang/v2/dialects/func.py +135 -0
  107. mplang/v2/dialects/phe.py +723 -0
  108. mplang/v2/dialects/simp.py +944 -0
  109. mplang/v2/dialects/spu.py +349 -0
  110. mplang/v2/dialects/store.py +63 -0
  111. mplang/v2/dialects/table.py +407 -0
  112. mplang/v2/dialects/tee.py +346 -0
  113. mplang/v2/dialects/tensor.py +1175 -0
  114. mplang/v2/edsl/README.md +279 -0
  115. mplang/v2/edsl/__init__.py +99 -0
  116. mplang/v2/edsl/context.py +311 -0
  117. mplang/v2/edsl/graph.py +463 -0
  118. mplang/v2/edsl/jit.py +62 -0
  119. mplang/v2/edsl/object.py +53 -0
  120. mplang/v2/edsl/primitive.py +284 -0
  121. mplang/v2/edsl/printer.py +119 -0
  122. mplang/v2/edsl/registry.py +207 -0
  123. mplang/v2/edsl/serde.py +375 -0
  124. mplang/v2/edsl/tracer.py +614 -0
  125. mplang/v2/edsl/typing.py +816 -0
  126. mplang/v2/kernels/Makefile +30 -0
  127. mplang/v2/kernels/__init__.py +23 -0
  128. mplang/v2/kernels/gf128.cpp +148 -0
  129. mplang/v2/kernels/ldpc.cpp +82 -0
  130. mplang/v2/kernels/okvs.cpp +283 -0
  131. mplang/v2/kernels/okvs_opt.cpp +291 -0
  132. mplang/v2/kernels/py_kernels.py +398 -0
  133. mplang/v2/libs/collective.py +330 -0
  134. mplang/v2/libs/device/__init__.py +51 -0
  135. mplang/v2/libs/device/api.py +813 -0
  136. mplang/v2/libs/device/cluster.py +352 -0
  137. mplang/v2/libs/ml/__init__.py +23 -0
  138. mplang/v2/libs/ml/sgb.py +1861 -0
  139. mplang/v2/libs/mpc/__init__.py +41 -0
  140. mplang/v2/libs/mpc/_utils.py +99 -0
  141. mplang/v2/libs/mpc/analytics/__init__.py +35 -0
  142. mplang/v2/libs/mpc/analytics/aggregation.py +372 -0
  143. mplang/v2/libs/mpc/analytics/groupby.md +99 -0
  144. mplang/v2/libs/mpc/analytics/groupby.py +331 -0
  145. mplang/v2/libs/mpc/analytics/permutation.py +386 -0
  146. mplang/v2/libs/mpc/common/constants.py +39 -0
  147. mplang/v2/libs/mpc/ot/__init__.py +32 -0
  148. mplang/v2/libs/mpc/ot/base.py +222 -0
  149. mplang/v2/libs/mpc/ot/extension.py +477 -0
  150. mplang/v2/libs/mpc/ot/silent.py +217 -0
  151. mplang/v2/libs/mpc/psi/__init__.py +40 -0
  152. mplang/v2/libs/mpc/psi/cuckoo.py +228 -0
  153. mplang/v2/libs/mpc/psi/okvs.py +49 -0
  154. mplang/v2/libs/mpc/psi/okvs_gct.py +79 -0
  155. mplang/v2/libs/mpc/psi/oprf.py +310 -0
  156. mplang/v2/libs/mpc/psi/rr22.py +344 -0
  157. mplang/v2/libs/mpc/psi/unbalanced.py +200 -0
  158. mplang/v2/libs/mpc/vole/__init__.py +31 -0
  159. mplang/v2/libs/mpc/vole/gilboa.py +327 -0
  160. mplang/v2/libs/mpc/vole/ldpc.py +383 -0
  161. mplang/v2/libs/mpc/vole/silver.py +336 -0
  162. mplang/v2/runtime/__init__.py +15 -0
  163. mplang/v2/runtime/dialect_state.py +41 -0
  164. mplang/v2/runtime/interpreter.py +871 -0
  165. mplang/v2/runtime/object_store.py +194 -0
  166. mplang/v2/runtime/value.py +141 -0
  167. {mplang_nightly-0.1.dev158.dist-info → mplang_nightly-0.1.dev268.dist-info}/METADATA +24 -17
  168. mplang_nightly-0.1.dev268.dist-info/RECORD +180 -0
  169. {mplang_nightly-0.1.dev158.dist-info → mplang_nightly-0.1.dev268.dist-info}/WHEEL +1 -1
  170. mplang/core/__init__.py +0 -92
  171. mplang/device.py +0 -340
  172. mplang/kernels/builtin.py +0 -207
  173. mplang/ops/crypto.py +0 -109
  174. mplang/ops/ibis_cc.py +0 -139
  175. mplang/ops/sql.py +0 -61
  176. mplang/protos/v1alpha1/mpir_pb2_grpc.py +0 -3
  177. mplang/runtime/link_comm.py +0 -131
  178. mplang/simp/smpc.py +0 -201
  179. mplang/utils/table_utils.py +0 -73
  180. mplang_nightly-0.1.dev158.dist-info/RECORD +0 -77
  181. /mplang/{core → v1/core}/mask.py +0 -0
  182. /mplang/{protos → v1/protos}/v1alpha1/mpir_pb2.py +0 -0
  183. /mplang/{runtime → v1/runtime}/exceptions.py +0 -0
  184. /mplang/{runtime → v1/runtime}/http_api.md +0 -0
  185. /mplang/{kernels → v1/simp}/__init__.py +0 -0
  186. /mplang/{utils → v1/utils}/__init__.py +0 -0
  187. /mplang/{utils → v1/utils}/crypto.py +0 -0
  188. /mplang/{utils → v1/utils}/func_utils.py +0 -0
  189. /mplang/{utils → v1/utils}/spu_utils.py +0 -0
  190. {mplang_nightly-0.1.dev158.dist-info → mplang_nightly-0.1.dev268.dist-info}/entry_points.txt +0 -0
  191. {mplang_nightly-0.1.dev158.dist-info → mplang_nightly-0.1.dev268.dist-info}/licenses/LICENSE +0 -0
@@ -20,8 +20,9 @@ import warnings
20
20
  import numpy as np
21
21
  from numpy.typing import NDArray
22
22
 
23
- from mplang.core.pfunc import PFunction
24
- from mplang.kernels.base import cur_kctx, kernel_def
23
+ from mplang.v1.core import PFunction
24
+ from mplang.v1.kernels.base import cur_kctx, kernel_def
25
+ from mplang.v1.kernels.value import TensorValue
25
26
 
26
27
  __all__: list[str] = []
27
28
 
@@ -46,28 +47,26 @@ def _quote_from_pk(pk: np.ndarray) -> NDArray[np.uint8]:
46
47
 
47
48
 
48
49
  @kernel_def("mock_tee.quote_gen")
49
- def _tee_quote_gen(pfunc: PFunction, pk: object) -> NDArray[np.uint8]:
50
+ def _tee_quote_gen(pfunc: PFunction, pk: TensorValue) -> TensorValue:
50
51
  warnings.warn(
51
52
  "Insecure mock TEE kernel 'mock_tee.quote_gen' in use. NOT secure; for local testing only.",
52
53
  stacklevel=3,
53
54
  )
54
- pk = np.asarray(pk, dtype=np.uint8)
55
+ pk_arr = pk.to_numpy().astype(np.uint8, copy=False)
55
56
  # rng access ensures deterministic seeding per rank even if unused now
56
57
  _rng()
57
- return _quote_from_pk(pk)
58
+ quote = _quote_from_pk(pk_arr)
59
+ return TensorValue(np.array(quote, copy=True))
58
60
 
59
61
 
60
62
  @kernel_def("mock_tee.attest")
61
- def _tee_attest(pfunc: PFunction, quote: object) -> NDArray[np.uint8]:
63
+ def _tee_attest(pfunc: PFunction, quote: TensorValue) -> TensorValue:
62
64
  warnings.warn(
63
65
  "Insecure mock TEE kernel 'mock_tee.attest' in use. NOT secure; for local testing only.",
64
66
  stacklevel=3,
65
67
  )
66
- quote = np.asarray(quote, dtype=np.uint8)
67
- platform = pfunc.attrs.get("platform")
68
- if platform is None:
69
- raise ValueError("missing required 'platform' attribute in PFunction")
70
-
71
- if quote.size != 33:
68
+ quote_arr = quote.to_numpy().astype(np.uint8, copy=False)
69
+ if quote_arr.size != 33:
72
70
  raise ValueError("mock quote must be 33 bytes (1 header + 32 pk)")
73
- return quote[1:33].astype(np.uint8)
71
+ attest = quote_arr[1:33].astype(np.uint8, copy=True)
72
+ return TensorValue(attest)
@@ -14,15 +14,26 @@
14
14
 
15
15
  """PHE (Partially Homomorphic Encryption) backend implementation using lightPHE."""
16
16
 
17
- from typing import Any
17
+ from __future__ import annotations
18
+
19
+ import json
20
+ from typing import Any, ClassVar
18
21
 
19
22
  import numpy as np
20
23
  from lightphe import LightPHE
21
-
22
- from mplang.core.dtype import DType
23
- from mplang.core.mptype import TensorLike
24
- from mplang.core.pfunc import PFunction
25
- from mplang.kernels.base import kernel_def
24
+ from lightphe.models.Ciphertext import Ciphertext
25
+
26
+ from mplang.v1.core import DType, PFunction
27
+ from mplang.v1.kernels.base import kernel_def
28
+ from mplang.v1.kernels.value import (
29
+ TensorValue,
30
+ Value,
31
+ ValueDecodeError,
32
+ ValueProtoBuilder,
33
+ ValueProtoReader,
34
+ register_value,
35
+ )
36
+ from mplang.v1.protos.v1alpha1 import value_pb2 as _value_pb2
26
37
 
27
38
  # This controls the decimal precision used in lightPHE for float operations
28
39
  # we force it to 0 to only support integer operations
@@ -30,8 +41,12 @@ from mplang.kernels.base import kernel_def
30
41
  PRECISION = 0
31
42
 
32
43
 
33
- class PublicKey:
34
- """PHE Public Key that implements TensorLike protocol."""
44
+ @register_value
45
+ class PublicKey(Value):
46
+ """PHE Public Key Value type."""
47
+
48
+ KIND: ClassVar[str] = "mplang.phe.PublicKey"
49
+ WIRE_VERSION: ClassVar[int] = 1
35
50
 
36
51
  def __init__(
37
52
  self,
@@ -62,12 +77,56 @@ class PublicKey:
62
77
  """Maximum float value that can be encoded."""
63
78
  return float(self.max_value / (2**self.fxp_bits))
64
79
 
80
+ def to_proto(self) -> _value_pb2.ValueProto:
81
+ """Serialize PublicKey to wire format."""
82
+ return (
83
+ ValueProtoBuilder(self.KIND, self.WIRE_VERSION)
84
+ .set_attr("scheme", self.scheme)
85
+ .set_attr("key_size", self.key_size)
86
+ .set_attr("max_value", self.max_value)
87
+ .set_attr("fxp_bits", self.fxp_bits)
88
+ .set_attr("modulus", str(self.modulus) if self.modulus is not None else "")
89
+ .set_payload(json.dumps(self.key_data).encode("utf-8"))
90
+ .build()
91
+ )
92
+
93
+ @classmethod
94
+ def from_proto(cls, proto: _value_pb2.ValueProto) -> PublicKey:
95
+ """Deserialize PublicKey from wire format."""
96
+ reader = ValueProtoReader(proto)
97
+ if reader.version != cls.WIRE_VERSION:
98
+ raise ValueDecodeError(f"Unsupported PublicKey version {reader.version}")
99
+
100
+ # Read metadata from runtime_attrs
101
+ scheme = reader.get_attr("scheme")
102
+ key_size = reader.get_attr("key_size")
103
+ max_value = reader.get_attr("max_value")
104
+ fxp_bits = reader.get_attr("fxp_bits")
105
+ modulus_str = reader.get_attr("modulus")
106
+ modulus = None if modulus_str == "" else int(modulus_str)
107
+
108
+ # JSON deserialize the public key dict
109
+ key_data = json.loads(reader.payload.decode("utf-8"))
110
+
111
+ return cls(
112
+ key_data=key_data,
113
+ scheme=scheme,
114
+ key_size=key_size,
115
+ max_value=max_value,
116
+ fxp_bits=fxp_bits,
117
+ modulus=modulus,
118
+ )
119
+
65
120
  def __repr__(self) -> str:
66
121
  return f"PublicKey(scheme={self.scheme}, key_size={self.key_size}, max_value={self.max_value}, fxp_bits={self.fxp_bits})"
67
122
 
68
123
 
69
- class PrivateKey:
70
- """PHE Private Key that implements TensorLike protocol."""
124
+ @register_value
125
+ class PrivateKey(Value):
126
+ """PHE Private Key Value type."""
127
+
128
+ KIND: ClassVar[str] = "mplang.phe.PrivateKey"
129
+ WIRE_VERSION: ClassVar[int] = 1
71
130
 
72
131
  def __init__(
73
132
  self,
@@ -100,12 +159,63 @@ class PrivateKey:
100
159
  """Maximum float value that can be encoded."""
101
160
  return float(self.max_value / (2**self.fxp_bits))
102
161
 
162
+ def to_proto(self) -> _value_pb2.ValueProto:
163
+ """Serialize PrivateKey to wire format."""
164
+ # JSON serialize both key dicts (contain int values)
165
+ # Store both keys in a single dict to avoid needing length metadata
166
+ keys_dict = {"sk": self.sk_data, "pk": self.pk_data}
167
+
168
+ return (
169
+ ValueProtoBuilder(self.KIND, self.WIRE_VERSION)
170
+ .set_attr("scheme", self.scheme)
171
+ .set_attr("key_size", self.key_size)
172
+ .set_attr("max_value", self.max_value)
173
+ .set_attr("fxp_bits", self.fxp_bits)
174
+ .set_attr("modulus", str(self.modulus) if self.modulus is not None else "")
175
+ .set_payload(json.dumps(keys_dict).encode("utf-8"))
176
+ .build()
177
+ )
178
+
179
+ @classmethod
180
+ def from_proto(cls, proto: _value_pb2.ValueProto) -> PrivateKey:
181
+ """Deserialize PrivateKey from wire format."""
182
+ reader = ValueProtoReader(proto)
183
+ if reader.version != cls.WIRE_VERSION:
184
+ raise ValueDecodeError(f"Unsupported PrivateKey version {reader.version}")
185
+
186
+ # Read metadata from runtime_attrs
187
+ scheme = reader.get_attr("scheme")
188
+ key_size = reader.get_attr("key_size")
189
+ max_value = reader.get_attr("max_value")
190
+ fxp_bits = reader.get_attr("fxp_bits")
191
+ modulus_str = reader.get_attr("modulus")
192
+ modulus = None if modulus_str == "" else int(modulus_str)
193
+
194
+ # JSON deserialize both key dicts
195
+ keys_dict = json.loads(reader.payload.decode("utf-8"))
196
+ sk_data = keys_dict["sk"]
197
+ pk_data = keys_dict["pk"]
198
+
199
+ return cls(
200
+ sk_data=sk_data,
201
+ pk_data=pk_data,
202
+ scheme=scheme,
203
+ key_size=key_size,
204
+ max_value=max_value,
205
+ fxp_bits=fxp_bits,
206
+ modulus=modulus,
207
+ )
208
+
103
209
  def __repr__(self) -> str:
104
210
  return f"PrivateKey(scheme={self.scheme}, key_size={self.key_size}, max_value={self.max_value}, fxp_bits={self.fxp_bits})"
105
211
 
106
212
 
107
- class CipherText:
108
- """PHE CipherText that implements TensorLike protocol."""
213
+ @register_value
214
+ class CipherText(Value):
215
+ """PHE CipherText Value type."""
216
+
217
+ KIND: ClassVar[str] = "mplang.phe.CipherText"
218
+ WIRE_VERSION: ClassVar[int] = 1
109
219
 
110
220
  def __init__(
111
221
  self,
@@ -142,6 +252,106 @@ class CipherText:
142
252
  """Maximum float value that can be encoded."""
143
253
  return float(self.max_value / (2**self.fxp_bits))
144
254
 
255
+ def to_proto(self) -> _value_pb2.ValueProto:
256
+ """Serialize CipherText to wire format.
257
+
258
+ WARNING: This serialization is tightly coupled to lightphe.Ciphertext
259
+ internal attributes (value, algorithm_name, keys). Any changes to these
260
+ attributes in future lightphe versions will break serialization.
261
+
262
+ TODO: Check if lightphe provides official serialization methods and
263
+ migrate to them if available. Consider adding version compatibility checks.
264
+ """
265
+ # JSON serialize ciphertext components
266
+ # ct_data is a list of lightPHE Ciphertext objects
267
+ # Each Ciphertext has: value, algorithm_name, keys
268
+ # We need to serialize the list of ciphertexts
269
+ if not isinstance(self.ct_data, list):
270
+ raise ValueError(f"ct_data should be a list, got {type(self.ct_data)}")
271
+
272
+ ct_list = []
273
+ for ct in self.ct_data:
274
+ if not isinstance(ct, Ciphertext):
275
+ raise TypeError(
276
+ f"ct_data must contain lightphe.Ciphertext objects, got {type(ct).__name__}"
277
+ )
278
+ ct_list.append({
279
+ "value": ct.value,
280
+ "algorithm_name": ct.algorithm_name,
281
+ "keys": ct.keys,
282
+ })
283
+
284
+ # Combine ct_data and pk_data into single dict
285
+ payload_dict = {
286
+ "ct_list": ct_list,
287
+ "pk": self.pk_data if self.pk_data is not None else None,
288
+ }
289
+
290
+ return (
291
+ ValueProtoBuilder(self.KIND, self.WIRE_VERSION)
292
+ .set_attr("semantic_dtype", str(self.semantic_dtype))
293
+ .set_attr("semantic_shape", list(self.semantic_shape))
294
+ .set_attr("scheme", self.scheme)
295
+ .set_attr("key_size", self.key_size)
296
+ .set_attr("max_value", self.max_value)
297
+ .set_attr("fxp_bits", self.fxp_bits)
298
+ .set_attr("modulus", str(self.modulus) if self.modulus is not None else "")
299
+ .set_payload(json.dumps(payload_dict).encode("utf-8"))
300
+ .build()
301
+ )
302
+
303
+ @classmethod
304
+ def from_proto(cls, proto: _value_pb2.ValueProto) -> CipherText:
305
+ """Deserialize CipherText from wire format."""
306
+ reader = ValueProtoReader(proto)
307
+ if reader.version != cls.WIRE_VERSION:
308
+ raise ValueDecodeError(f"Unsupported CipherText version {reader.version}")
309
+
310
+ # Read metadata from runtime_attrs
311
+ semantic_dtype_str = reader.get_attr("semantic_dtype")
312
+ semantic_shape = reader.get_attr("semantic_shape")
313
+ scheme = reader.get_attr("scheme")
314
+ key_size = reader.get_attr("key_size")
315
+ max_value = reader.get_attr("max_value")
316
+ fxp_bits = reader.get_attr("fxp_bits")
317
+ modulus_str = reader.get_attr("modulus")
318
+ modulus = None if modulus_str == "" else int(modulus_str)
319
+
320
+ # JSON deserialize ciphertext and public key
321
+ payload_dict = json.loads(reader.payload.decode("utf-8"))
322
+ ct_list = payload_dict["ct_list"]
323
+ pk_data = payload_dict["pk"]
324
+
325
+ # Reconstruct ct_data: list of Ciphertext objects
326
+ ct_data = []
327
+ for ct_dict in ct_list:
328
+ if ct_dict["keys"] is None or ct_dict["algorithm_name"] is None:
329
+ raise ValueDecodeError(
330
+ "Invalid CipherText: missing keys or algorithm_name in serialized data"
331
+ )
332
+ ct_data.append(
333
+ Ciphertext(
334
+ algorithm_name=ct_dict["algorithm_name"],
335
+ keys=ct_dict["keys"],
336
+ value=ct_dict["value"],
337
+ )
338
+ )
339
+
340
+ # Parse dtype string back to DType
341
+ dtype = DType.from_any(semantic_dtype_str)
342
+
343
+ return cls(
344
+ ct_data=ct_data,
345
+ semantic_dtype=dtype,
346
+ semantic_shape=tuple(semantic_shape),
347
+ scheme=scheme,
348
+ key_size=key_size,
349
+ pk_data=pk_data,
350
+ max_value=max_value,
351
+ fxp_bits=fxp_bits,
352
+ modulus=modulus,
353
+ )
354
+
145
355
  def __repr__(self) -> str:
146
356
  return f"CipherText(dtype={self.semantic_dtype}, shape={self.semantic_shape}, scheme={self.scheme})"
147
357
 
@@ -257,33 +467,15 @@ def _range_decode_mixed(
257
467
  return _range_decode_integer(encoded_value, max_value, modulus)
258
468
 
259
469
 
260
- def _convert_to_numpy(obj: TensorLike) -> np.ndarray:
261
- """Convert a TensorLike object to numpy array."""
262
- if isinstance(obj, np.ndarray):
263
- return obj
264
-
265
- # Try to use .numpy() method if available
266
- if hasattr(obj, "numpy"):
267
- numpy_method = getattr(obj, "numpy", None)
268
- if callable(numpy_method):
269
- try:
270
- return np.asarray(numpy_method())
271
- except Exception:
272
- pass
273
-
274
- return np.asarray(obj)
275
-
276
-
277
470
  @kernel_def("phe.keygen")
278
471
  def _phe_keygen(pfunc: PFunction) -> Any:
279
472
  scheme = pfunc.attrs.get("scheme", "paillier")
280
473
  # use small key_size to speed up tests
281
474
  # in production use at least 2048 bits or 3072 bits for better security
282
475
  key_size = pfunc.attrs.get("key_size", 2048)
283
- max_value = pfunc.attrs.get(
284
- "max_value", 2**32
285
- ) # Use larger range to avoid overflow
286
- fxp_bits = pfunc.attrs.get("fxp_bits", 12)
476
+ # Accept very large max_value; allow decimal string input, kept simple like other attrs
477
+ max_value = int(pfunc.attrs.get("max_value", 2**32))
478
+ fxp_bits = int(pfunc.attrs.get("fxp_bits", 12))
287
479
 
288
480
  # Validate scheme
289
481
  if scheme.lower() not in ["paillier"]:
@@ -334,14 +526,16 @@ def _phe_keygen(pfunc: PFunction) -> Any:
334
526
 
335
527
 
336
528
  @kernel_def("phe.encrypt")
337
- def _phe_encrypt(pfunc: PFunction, plaintext: Any, public_key: PublicKey) -> Any:
529
+ def _phe_encrypt(
530
+ pfunc: PFunction, plaintext: TensorValue, public_key: PublicKey
531
+ ) -> Any:
338
532
  # Validate public_key type
339
533
  if not isinstance(public_key, PublicKey):
340
534
  raise ValueError("Second argument must be a PublicKey instance")
341
535
 
342
536
  try:
343
537
  # Convert plaintext to numpy to get semantic type info
344
- plaintext_np = _convert_to_numpy(plaintext)
538
+ plaintext_np = plaintext.to_numpy()
345
539
  semantic_dtype = DType.from_numpy(plaintext_np.dtype)
346
540
  semantic_shape = plaintext_np.shape
347
541
 
@@ -403,14 +597,14 @@ def _phe_encrypt(pfunc: PFunction, plaintext: Any, public_key: PublicKey) -> Any
403
597
 
404
598
 
405
599
  @kernel_def("phe.mul")
406
- def _phe_mul(pfunc: PFunction, ciphertext: CipherText, plaintext: Any) -> Any:
600
+ def _phe_mul(pfunc: PFunction, ciphertext: CipherText, plaintext: TensorValue) -> Any:
407
601
  # Validate that first argument is a CipherText
408
602
  if not isinstance(ciphertext, CipherText):
409
603
  raise ValueError("First argument must be a CipherText instance")
410
604
 
411
605
  try:
412
606
  # Convert plaintext to numpy
413
- plaintext_np = _convert_to_numpy(plaintext)
607
+ plaintext_np = plaintext.to_numpy()
414
608
 
415
609
  # Check if plaintext is floating point type - multiplication not supported
416
610
  if np.issubdtype(plaintext_np.dtype, np.floating):
@@ -443,7 +637,8 @@ def _phe_mul(pfunc: PFunction, ciphertext: CipherText, plaintext: Any) -> Any:
443
637
  # Use numpy to create a properly broadcasted index mapping
444
638
  # Create a dummy array with same shape as ciphertext, fill with indices
445
639
  dummy_ct = (
446
- np.arange(np.prod(ciphertext.semantic_shape))
640
+ np
641
+ .arange(np.prod(ciphertext.semantic_shape))
447
642
  .reshape(ciphertext.semantic_shape)
448
643
  .astype(np.int64)
449
644
  )
@@ -511,7 +706,7 @@ def _phe_add(pfunc: PFunction, lhs: Any, rhs: Any) -> Any:
511
706
  elif isinstance(rhs, CipherText):
512
707
  return _phe_add_ct2pt(rhs, lhs)
513
708
  else:
514
- return _convert_to_numpy(lhs) + _convert_to_numpy(rhs)
709
+ return TensorValue(lhs.to_numpy() + rhs.to_numpy())
515
710
  except ValueError:
516
711
  raise
517
712
  except Exception as e: # pragma: no cover
@@ -550,7 +745,8 @@ def _phe_add_ct2ct(ct1: CipherText, ct2: CipherText) -> CipherText:
550
745
  # Broadcast ct1 if needed
551
746
  if ct1.semantic_shape != result_shape:
552
747
  dummy_ct1 = (
553
- np.arange(np.prod(ct1.semantic_shape))
748
+ np
749
+ .arange(np.prod(ct1.semantic_shape))
554
750
  .reshape(ct1.semantic_shape)
555
751
  .astype(np.int64)
556
752
  )
@@ -563,7 +759,8 @@ def _phe_add_ct2ct(ct1: CipherText, ct2: CipherText) -> CipherText:
563
759
  # Broadcast ct2 if needed
564
760
  if ct2.semantic_shape != result_shape:
565
761
  dummy_ct2 = (
566
- np.arange(np.prod(ct2.semantic_shape))
762
+ np
763
+ .arange(np.prod(ct2.semantic_shape))
567
764
  .reshape(ct2.semantic_shape)
568
765
  .astype(np.int64)
569
766
  )
@@ -593,9 +790,9 @@ def _phe_add_ct2ct(ct1: CipherText, ct2: CipherText) -> CipherText:
593
790
  )
594
791
 
595
792
 
596
- def _phe_add_ct2pt(ciphertext: CipherText, plaintext: TensorLike) -> CipherText:
793
+ def _phe_add_ct2pt(ciphertext: CipherText, plaintext: TensorValue) -> CipherText:
597
794
  # Convert plaintext to numpy
598
- plaintext_np = _convert_to_numpy(plaintext)
795
+ plaintext_np = plaintext.to_numpy()
599
796
  plaintext_dtype = DType.from_numpy(plaintext_np.dtype)
600
797
 
601
798
  # Check for mixed precision issue: floating point ciphertext + integer plaintext
@@ -636,7 +833,8 @@ def _phe_add_ct2pt(ciphertext: CipherText, plaintext: TensorLike) -> CipherText:
636
833
  # Broadcast ciphertext if needed
637
834
  if ciphertext.semantic_shape != result_shape:
638
835
  dummy_ct = (
639
- np.arange(np.prod(ciphertext.semantic_shape))
836
+ np
837
+ .arange(np.prod(ciphertext.semantic_shape))
640
838
  .reshape(ciphertext.semantic_shape)
641
839
  .astype(np.int64)
642
840
  )
@@ -802,12 +1000,17 @@ def _phe_decrypt(
802
1000
  # Convert to target dtype
803
1001
  if target_dtype.kind in "iu": # integer types
804
1002
  # Convert floats back to integers for integer semantic types
805
- processed_data = [round(val) for val in decoded_data]
806
- # Handle overflow for smaller integer types
807
- info = np.iinfo(target_dtype)
808
- processed_data = [
809
- max(info.min, min(info.max, val)) for val in processed_data
810
- ]
1003
+ # decoded_data are numeric (ints or floats); normalize to Python int
1004
+ ints = [round(v) if isinstance(v, float) else v for v in decoded_data]
1005
+ if np.issubdtype(target_dtype, np.unsignedinteger):
1006
+ # Reduce modulo 2^k for unsigned to preserve ring semantics
1007
+ width = np.iinfo(target_dtype).bits
1008
+ mod = 1 << width
1009
+ processed_data = [v % mod for v in ints]
1010
+ else:
1011
+ # Signed integers: clamp to dtype range
1012
+ info = np.iinfo(target_dtype)
1013
+ processed_data = [max(info.min, min(info.max, v)) for v in ints]
811
1014
  else: # float types
812
1015
  processed_data = decoded_data
813
1016
 
@@ -816,14 +1019,14 @@ def _phe_decrypt(
816
1019
  ciphertext.semantic_shape
817
1020
  )
818
1021
 
819
- return [plaintext_np]
1022
+ return [TensorValue(plaintext_np)]
820
1023
 
821
1024
  except Exception as e:
822
1025
  raise RuntimeError(f"Failed to decrypt data: {e}") from e
823
1026
 
824
1027
 
825
1028
  @kernel_def("phe.dot")
826
- def _phe_dot(pfunc: PFunction, ciphertext: CipherText, plaintext: TensorLike) -> Any:
1029
+ def _phe_dot(pfunc: PFunction, ciphertext: CipherText, plaintext: TensorValue) -> Any:
827
1030
  """Execute homomorphic dot product with zero-value optimization.
828
1031
 
829
1032
  Supports various dot product operations:
@@ -844,7 +1047,7 @@ def _phe_dot(pfunc: PFunction, ciphertext: CipherText, plaintext: TensorLike) ->
844
1047
 
845
1048
  try:
846
1049
  # Convert plaintext to numpy
847
- plaintext_np = _convert_to_numpy(plaintext)
1050
+ plaintext_np = plaintext.to_numpy()
848
1051
 
849
1052
  # Check if plaintext is floating point type - dot product not supported
850
1053
  if np.issubdtype(plaintext_np.dtype, np.floating):
@@ -1109,7 +1312,7 @@ def _phe_dot(pfunc: PFunction, ciphertext: CipherText, plaintext: TensorLike) ->
1109
1312
 
1110
1313
 
1111
1314
  @kernel_def("phe.gather")
1112
- def _phe_gather(pfunc: PFunction, ciphertext: CipherText, indices: Any) -> Any:
1315
+ def _phe_gather(pfunc: PFunction, ciphertext: CipherText, indices: TensorValue) -> Any:
1113
1316
  """Execute gather operation on CipherText.
1114
1317
 
1115
1318
  Supports gathering from multidimensional CipherText using multidimensional indices.
@@ -1126,7 +1329,7 @@ def _phe_gather(pfunc: PFunction, ciphertext: CipherText, indices: Any) -> Any:
1126
1329
 
1127
1330
  try:
1128
1331
  # Convert indices to numpy
1129
- indices_np = _convert_to_numpy(indices)
1332
+ indices_np = indices.to_numpy()
1130
1333
 
1131
1334
  if not np.issubdtype(indices_np.dtype, np.integer):
1132
1335
  raise ValueError("Indices must be of integer type")
@@ -1224,7 +1427,10 @@ def _phe_gather(pfunc: PFunction, ciphertext: CipherText, indices: Any) -> Any:
1224
1427
 
1225
1428
  @kernel_def("phe.scatter")
1226
1429
  def _phe_scatter(
1227
- pfunc: PFunction, ciphertext: CipherText, indices: TensorLike, updated: CipherText
1430
+ pfunc: PFunction,
1431
+ ciphertext: CipherText,
1432
+ indices: TensorValue,
1433
+ updated: CipherText,
1228
1434
  ) -> Any:
1229
1435
  """Execute scatter operation on CipherText.
1230
1436
 
@@ -1252,7 +1458,7 @@ def _phe_scatter(
1252
1458
 
1253
1459
  try:
1254
1460
  # Convert indices to numpy
1255
- indices_np = _convert_to_numpy(indices)
1461
+ indices_np = indices.to_numpy()
1256
1462
 
1257
1463
  if not np.issubdtype(indices_np.dtype, np.integer):
1258
1464
  raise ValueError("Indices must be of integer type")