mplang-nightly 0.1.dev192__py3-none-any.whl → 0.1.dev268__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (188) hide show
  1. mplang/__init__.py +21 -130
  2. mplang/py.typed +13 -0
  3. mplang/v1/__init__.py +157 -0
  4. mplang/v1/_device.py +602 -0
  5. mplang/{analysis → v1/analysis}/__init__.py +1 -1
  6. mplang/{analysis → v1/analysis}/diagram.py +4 -4
  7. mplang/{core → v1/core}/__init__.py +20 -14
  8. mplang/{core → v1/core}/cluster.py +6 -1
  9. mplang/{core → v1/core}/comm.py +1 -1
  10. mplang/{core → v1/core}/context_mgr.py +1 -1
  11. mplang/{core → v1/core}/dtypes.py +38 -0
  12. mplang/{core → v1/core}/expr/__init__.py +7 -7
  13. mplang/{core → v1/core}/expr/ast.py +11 -13
  14. mplang/{core → v1/core}/expr/evaluator.py +8 -8
  15. mplang/{core → v1/core}/expr/printer.py +6 -6
  16. mplang/{core → v1/core}/expr/transformer.py +2 -2
  17. mplang/{core → v1/core}/expr/utils.py +2 -2
  18. mplang/{core → v1/core}/expr/visitor.py +1 -1
  19. mplang/{core → v1/core}/expr/walk.py +1 -1
  20. mplang/{core → v1/core}/interp.py +6 -6
  21. mplang/{core → v1/core}/mpir.py +13 -11
  22. mplang/{core → v1/core}/mpobject.py +6 -6
  23. mplang/{core → v1/core}/mptype.py +13 -10
  24. mplang/{core → v1/core}/pfunc.py +2 -2
  25. mplang/{core → v1/core}/primitive.py +12 -12
  26. mplang/{core → v1/core}/table.py +36 -8
  27. mplang/{core → v1/core}/tensor.py +1 -1
  28. mplang/{core → v1/core}/tracer.py +9 -9
  29. mplang/{host.py → v1/host.py} +5 -5
  30. mplang/{kernels → v1/kernels}/__init__.py +1 -1
  31. mplang/{kernels → v1/kernels}/base.py +1 -1
  32. mplang/{kernels → v1/kernels}/basic.py +15 -15
  33. mplang/{kernels → v1/kernels}/context.py +19 -16
  34. mplang/{kernels → v1/kernels}/crypto.py +8 -10
  35. mplang/{kernels → v1/kernels}/fhe.py +9 -7
  36. mplang/{kernels → v1/kernels}/mock_tee.py +3 -3
  37. mplang/{kernels → v1/kernels}/phe.py +26 -18
  38. mplang/{kernels → v1/kernels}/spu.py +5 -5
  39. mplang/{kernels → v1/kernels}/sql_duckdb.py +5 -3
  40. mplang/{kernels → v1/kernels}/stablehlo.py +18 -17
  41. mplang/{kernels → v1/kernels}/value.py +2 -2
  42. mplang/{ops → v1/ops}/__init__.py +3 -3
  43. mplang/{ops → v1/ops}/base.py +1 -1
  44. mplang/{ops → v1/ops}/basic.py +6 -5
  45. mplang/v1/ops/crypto.py +262 -0
  46. mplang/{ops → v1/ops}/fhe.py +2 -2
  47. mplang/{ops → v1/ops}/jax_cc.py +26 -59
  48. mplang/v1/ops/nnx_cc.py +168 -0
  49. mplang/{ops → v1/ops}/phe.py +16 -3
  50. mplang/{ops → v1/ops}/spu.py +3 -3
  51. mplang/v1/ops/sql_cc.py +303 -0
  52. mplang/{ops → v1/ops}/tee.py +2 -2
  53. mplang/{runtime → v1/runtime}/__init__.py +2 -2
  54. mplang/v1/runtime/channel.py +230 -0
  55. mplang/{runtime → v1/runtime}/cli.py +3 -3
  56. mplang/{runtime → v1/runtime}/client.py +1 -1
  57. mplang/{runtime → v1/runtime}/communicator.py +39 -15
  58. mplang/{runtime → v1/runtime}/data_providers.py +80 -19
  59. mplang/{runtime → v1/runtime}/driver.py +4 -4
  60. mplang/v1/runtime/link_comm.py +196 -0
  61. mplang/{runtime → v1/runtime}/server.py +22 -9
  62. mplang/{runtime → v1/runtime}/session.py +24 -51
  63. mplang/{runtime → v1/runtime}/simulation.py +36 -14
  64. mplang/{simp → v1/simp}/api.py +72 -14
  65. mplang/{simp → v1/simp}/mpi.py +1 -1
  66. mplang/{simp → v1/simp}/party.py +5 -5
  67. mplang/{simp → v1/simp}/random.py +2 -2
  68. mplang/v1/simp/smpc.py +238 -0
  69. mplang/v1/utils/table_utils.py +185 -0
  70. mplang/v2/__init__.py +424 -0
  71. mplang/v2/backends/__init__.py +57 -0
  72. mplang/v2/backends/bfv_impl.py +705 -0
  73. mplang/v2/backends/channel.py +217 -0
  74. mplang/v2/backends/crypto_impl.py +723 -0
  75. mplang/v2/backends/field_impl.py +454 -0
  76. mplang/v2/backends/func_impl.py +107 -0
  77. mplang/v2/backends/phe_impl.py +148 -0
  78. mplang/v2/backends/simp_design.md +136 -0
  79. mplang/v2/backends/simp_driver/__init__.py +41 -0
  80. mplang/v2/backends/simp_driver/http.py +168 -0
  81. mplang/v2/backends/simp_driver/mem.py +280 -0
  82. mplang/v2/backends/simp_driver/ops.py +135 -0
  83. mplang/v2/backends/simp_driver/state.py +60 -0
  84. mplang/v2/backends/simp_driver/values.py +52 -0
  85. mplang/v2/backends/simp_worker/__init__.py +29 -0
  86. mplang/v2/backends/simp_worker/http.py +354 -0
  87. mplang/v2/backends/simp_worker/mem.py +102 -0
  88. mplang/v2/backends/simp_worker/ops.py +167 -0
  89. mplang/v2/backends/simp_worker/state.py +49 -0
  90. mplang/v2/backends/spu_impl.py +275 -0
  91. mplang/v2/backends/spu_state.py +187 -0
  92. mplang/v2/backends/store_impl.py +62 -0
  93. mplang/v2/backends/table_impl.py +838 -0
  94. mplang/v2/backends/tee_impl.py +215 -0
  95. mplang/v2/backends/tensor_impl.py +519 -0
  96. mplang/v2/cli.py +603 -0
  97. mplang/v2/cli_guide.md +122 -0
  98. mplang/v2/dialects/__init__.py +36 -0
  99. mplang/v2/dialects/bfv.py +665 -0
  100. mplang/v2/dialects/crypto.py +689 -0
  101. mplang/v2/dialects/dtypes.py +378 -0
  102. mplang/v2/dialects/field.py +210 -0
  103. mplang/v2/dialects/func.py +135 -0
  104. mplang/v2/dialects/phe.py +723 -0
  105. mplang/v2/dialects/simp.py +944 -0
  106. mplang/v2/dialects/spu.py +349 -0
  107. mplang/v2/dialects/store.py +63 -0
  108. mplang/v2/dialects/table.py +407 -0
  109. mplang/v2/dialects/tee.py +346 -0
  110. mplang/v2/dialects/tensor.py +1175 -0
  111. mplang/v2/edsl/README.md +279 -0
  112. mplang/v2/edsl/__init__.py +99 -0
  113. mplang/v2/edsl/context.py +311 -0
  114. mplang/v2/edsl/graph.py +463 -0
  115. mplang/v2/edsl/jit.py +62 -0
  116. mplang/v2/edsl/object.py +53 -0
  117. mplang/v2/edsl/primitive.py +284 -0
  118. mplang/v2/edsl/printer.py +119 -0
  119. mplang/v2/edsl/registry.py +207 -0
  120. mplang/v2/edsl/serde.py +375 -0
  121. mplang/v2/edsl/tracer.py +614 -0
  122. mplang/v2/edsl/typing.py +816 -0
  123. mplang/v2/kernels/Makefile +30 -0
  124. mplang/v2/kernels/__init__.py +23 -0
  125. mplang/v2/kernels/gf128.cpp +148 -0
  126. mplang/v2/kernels/ldpc.cpp +82 -0
  127. mplang/v2/kernels/okvs.cpp +283 -0
  128. mplang/v2/kernels/okvs_opt.cpp +291 -0
  129. mplang/v2/kernels/py_kernels.py +398 -0
  130. mplang/v2/libs/collective.py +330 -0
  131. mplang/v2/libs/device/__init__.py +51 -0
  132. mplang/v2/libs/device/api.py +813 -0
  133. mplang/v2/libs/device/cluster.py +352 -0
  134. mplang/v2/libs/ml/__init__.py +23 -0
  135. mplang/v2/libs/ml/sgb.py +1861 -0
  136. mplang/v2/libs/mpc/__init__.py +41 -0
  137. mplang/v2/libs/mpc/_utils.py +99 -0
  138. mplang/v2/libs/mpc/analytics/__init__.py +35 -0
  139. mplang/v2/libs/mpc/analytics/aggregation.py +372 -0
  140. mplang/v2/libs/mpc/analytics/groupby.md +99 -0
  141. mplang/v2/libs/mpc/analytics/groupby.py +331 -0
  142. mplang/v2/libs/mpc/analytics/permutation.py +386 -0
  143. mplang/v2/libs/mpc/common/constants.py +39 -0
  144. mplang/v2/libs/mpc/ot/__init__.py +32 -0
  145. mplang/v2/libs/mpc/ot/base.py +222 -0
  146. mplang/v2/libs/mpc/ot/extension.py +477 -0
  147. mplang/v2/libs/mpc/ot/silent.py +217 -0
  148. mplang/v2/libs/mpc/psi/__init__.py +40 -0
  149. mplang/v2/libs/mpc/psi/cuckoo.py +228 -0
  150. mplang/v2/libs/mpc/psi/okvs.py +49 -0
  151. mplang/v2/libs/mpc/psi/okvs_gct.py +79 -0
  152. mplang/v2/libs/mpc/psi/oprf.py +310 -0
  153. mplang/v2/libs/mpc/psi/rr22.py +344 -0
  154. mplang/v2/libs/mpc/psi/unbalanced.py +200 -0
  155. mplang/v2/libs/mpc/vole/__init__.py +31 -0
  156. mplang/v2/libs/mpc/vole/gilboa.py +327 -0
  157. mplang/v2/libs/mpc/vole/ldpc.py +383 -0
  158. mplang/v2/libs/mpc/vole/silver.py +336 -0
  159. mplang/v2/runtime/__init__.py +15 -0
  160. mplang/v2/runtime/dialect_state.py +41 -0
  161. mplang/v2/runtime/interpreter.py +871 -0
  162. mplang/v2/runtime/object_store.py +194 -0
  163. mplang/v2/runtime/value.py +141 -0
  164. {mplang_nightly-0.1.dev192.dist-info → mplang_nightly-0.1.dev268.dist-info}/METADATA +22 -16
  165. mplang_nightly-0.1.dev268.dist-info/RECORD +180 -0
  166. {mplang_nightly-0.1.dev192.dist-info → mplang_nightly-0.1.dev268.dist-info}/WHEEL +1 -1
  167. mplang/device.py +0 -327
  168. mplang/ops/crypto.py +0 -108
  169. mplang/ops/ibis_cc.py +0 -136
  170. mplang/ops/sql_cc.py +0 -62
  171. mplang/runtime/link_comm.py +0 -78
  172. mplang/simp/smpc.py +0 -201
  173. mplang/utils/table_utils.py +0 -85
  174. mplang_nightly-0.1.dev192.dist-info/RECORD +0 -83
  175. /mplang/{core → v1/core}/mask.py +0 -0
  176. /mplang/{protos → v1/protos}/v1alpha1/mpir_pb2.py +0 -0
  177. /mplang/{protos → v1/protos}/v1alpha1/mpir_pb2.pyi +0 -0
  178. /mplang/{protos → v1/protos}/v1alpha1/value_pb2.py +0 -0
  179. /mplang/{protos → v1/protos}/v1alpha1/value_pb2.pyi +0 -0
  180. /mplang/{runtime → v1/runtime}/exceptions.py +0 -0
  181. /mplang/{runtime → v1/runtime}/http_api.md +0 -0
  182. /mplang/{simp → v1/simp}/__init__.py +0 -0
  183. /mplang/{utils → v1/utils}/__init__.py +0 -0
  184. /mplang/{utils → v1/utils}/crypto.py +0 -0
  185. /mplang/{utils → v1/utils}/func_utils.py +0 -0
  186. /mplang/{utils → v1/utils}/spu_utils.py +0 -0
  187. {mplang_nightly-0.1.dev192.dist-info → mplang_nightly-0.1.dev268.dist-info}/entry_points.txt +0 -0
  188. {mplang_nightly-0.1.dev192.dist-info → mplang_nightly-0.1.dev268.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,723 @@
1
+ # Copyright 2025 Ant Group Co., Ltd.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """PHE (Partially Homomorphic Encryption) dialect for the EDSL.
16
+
17
+ Design principles:
18
+ - Separate encoding from encryption for semantic clarity
19
+ - Element-level primitives operate on encoded integers
20
+ - Reuse `tensor.elementwise` to lift primitives across tensors
21
+ - Provide ergonomic wrappers for common workflows
22
+
23
+ Architecture:
24
+ Source Type (f64, i32, etc.)
25
+ ↓ encode(encoder)
26
+ Encoded Integer (i64)
27
+ ↓ encrypt(pk)
28
+ Ciphertext (CiphertextType)
29
+ ↓ homomorphic operations
30
+ Ciphertext (CiphertextType)
31
+ ↓ decrypt(sk)
32
+ Encoded Integer (i64)
33
+ ↓ decode(encoder)
34
+ Source Type (f64, i32, etc.)
35
+
36
+ Example:
37
+ ```python
38
+ from mplang.v2.dialects import tensor, phe
39
+ import mplang.v2.edsl.typing as elt
40
+ import numpy as np
41
+
42
+ # 1. Generate keys (cryptographic only)
43
+ pk, sk = phe.keygen()
44
+
45
+ # 2. Create encoder (encoding parameters)
46
+ encoder = phe.create_encoder(dtype=elt.f64, fxp_bits=16)
47
+
48
+ # 3. Encode data
49
+ x = tensor.constant(np.array([1.0, 2.0, 3.0]))
50
+ y = tensor.constant(np.array([4.0, 5.0, 6.0]))
51
+ x_enc = phe.encode(x, encoder) # f64 → i64
52
+ y_enc = phe.encode(y, encoder) # f64 → i64
53
+
54
+ # 4. Encrypt
55
+ ct_x = phe.encrypt(x_enc, pk) # i64 → CiphertextType
56
+ ct_y = phe.encrypt(y_enc, pk) # i64 → CiphertextType
57
+
58
+ # 5. Homomorphic operations
59
+ ct_sum = phe.add(ct_x, ct_y) # CiphertextType + CiphertextType
60
+
61
+ # 6. Decrypt and decode
62
+ sum_enc = phe.decrypt(ct_sum, sk) # CiphertextType → i64
63
+ result = phe.decode(sum_enc, encoder) # i64 → f64
64
+ ```
65
+
66
+ For convenience, auto wrappers combine encode+encrypt and decrypt+decode:
67
+ ```python
68
+ ct = phe.encrypt_auto(x, encoder, pk)
69
+ result = phe.decrypt_auto(ct, encoder, sk)
70
+ ```
71
+ """
72
+
73
+ from __future__ import annotations
74
+
75
+ from collections.abc import Callable
76
+ from typing import Any, NamedTuple
77
+
78
+ import mplang.v2.edsl as el
79
+ import mplang.v2.edsl.typing as elt
80
+ from mplang.v2.dialects import tensor
81
+
82
+ # ==============================================================================
83
+ # --- Type Definitions
84
+ # ==============================================================================
85
+
86
+
87
+ class KeyType(elt.BaseType):
88
+ """Type for PHE keys carrying scheme information."""
89
+
90
+ def __init__(self, scheme: str, is_public: bool):
91
+ self.scheme = scheme
92
+ self.is_public = is_public
93
+
94
+ def __str__(self) -> str:
95
+ kind = "P" if self.is_public else "S"
96
+ return f"{kind}Key[{self.scheme}]"
97
+
98
+ def __eq__(self, other: object) -> bool:
99
+ if not isinstance(other, KeyType):
100
+ return False
101
+ return self.scheme == other.scheme and self.is_public == other.is_public
102
+
103
+ def __hash__(self) -> int:
104
+ return hash((self.scheme, self.is_public))
105
+
106
+
107
+ class PlaintextType(elt.ScalarType):
108
+ """Represents an encoded integer ready for PHE encryption.
109
+
110
+ This type wraps the underlying integer representation (typically i64 or i128)
111
+ to distinguish it from regular integers. This ensures type safety by preventing
112
+ accidental encryption of raw integers or arithmetic between encoded and raw values.
113
+ """
114
+
115
+ def __init__(self, bitwidth: int = 64):
116
+ self.bitwidth = bitwidth
117
+
118
+ def __str__(self) -> str:
119
+ return f"PT[i{self.bitwidth}]"
120
+
121
+ def __eq__(self, other: object) -> bool:
122
+ if not isinstance(other, PlaintextType):
123
+ return False
124
+ return self.bitwidth == other.bitwidth
125
+
126
+ def __hash__(self) -> int:
127
+ return hash(("PlaintextType", self.bitwidth))
128
+
129
+
130
+ class CiphertextType(elt.ScalarType, elt.EncryptedTrait):
131
+ """Represents a single scalar value encrypted with a PHE scheme.
132
+
133
+ Inherits from ScalarType, so it can be used as a tensor element type.
134
+ """
135
+
136
+ def __init__(self, scheme: str):
137
+ self._scheme = scheme
138
+
139
+ @property
140
+ def scheme(self) -> str:
141
+ return self._scheme
142
+
143
+ def __str__(self) -> str:
144
+ return f"CT[{self._scheme}]"
145
+
146
+ def __eq__(self, other: object) -> bool:
147
+ if not isinstance(other, CiphertextType):
148
+ return False
149
+ return self._scheme == other._scheme
150
+
151
+ def __hash__(self) -> int:
152
+ return hash(("CiphertextType", self._scheme))
153
+
154
+
155
+ # Opaque types for PHE (singleton instances)
156
+ EncoderType: elt.CustomType = elt.CustomType("Encoder")
157
+
158
+ # ==============================================================================
159
+ # --- Key Management Operations
160
+ # ==============================================================================
161
+
162
+ keygen_p = el.Primitive[tuple[el.Object, el.Object]]("phe.keygen")
163
+
164
+
165
+ @keygen_p.def_abstract_eval
166
+ def _keygen_ae(
167
+ *,
168
+ scheme: str = "paillier",
169
+ key_size: int = 2048,
170
+ ) -> tuple[KeyType, KeyType]:
171
+ """Generate PHE key pair (cryptographic parameters only).
172
+
173
+ Args:
174
+ scheme: PHE scheme name (e.g., "paillier", "elgamal")
175
+ key_size: Key size in bits (default: 2048)
176
+
177
+ Returns:
178
+ Tuple of (PublicKey, PrivateKey) with scheme info
179
+ """
180
+ return (KeyType(scheme, True), KeyType(scheme, False))
181
+
182
+
183
+ # ==============================================================================
184
+ # --- Encoder Operations
185
+ # ==============================================================================
186
+
187
+ create_encoder_p = el.Primitive[el.Object]("phe.create_encoder")
188
+ encode_p = el.Primitive[el.Object]("phe.encode")
189
+ decode_p = el.Primitive[el.Object]("phe.decode")
190
+
191
+
192
+ @create_encoder_p.def_abstract_eval
193
+ def _create_encoder_ae(
194
+ *,
195
+ dtype: elt.ScalarType,
196
+ fxp_bits: int = 16,
197
+ max_value: int | None = None,
198
+ ) -> elt.CustomType:
199
+ """Create PHE encoder for type conversion and fixed-point encoding.
200
+
201
+ Args:
202
+ dtype: Source data type (f32, f64, i32, i64, etc.)
203
+ fxp_bits: Fixed-point fractional bits for float types (default: 16)
204
+ max_value: Optional maximum value for range checking
205
+
206
+ Returns:
207
+ EncoderType configured for the specified dtype
208
+ """
209
+ if not isinstance(dtype, elt.ScalarType):
210
+ raise TypeError(f"dtype must be ScalarType, got {type(dtype).__name__}")
211
+ return EncoderType
212
+
213
+
214
+ @encode_p.def_abstract_eval
215
+ def _encode_ae(value: elt.ScalarType, encoder: elt.CustomType) -> PlaintextType:
216
+ """Encode scalar value to fixed-point integer representation.
217
+
218
+ Args:
219
+ value: Source value (f32, f64, i32, etc.)
220
+ encoder: PHE encoder with encoding parameters
221
+
222
+ Returns:
223
+ Encoded integer (PlaintextType)
224
+
225
+ Raises:
226
+ TypeError: If encoder is not EncoderType
227
+ """
228
+ if encoder != EncoderType:
229
+ raise TypeError(f"Expected Encoder, got {encoder}")
230
+ if not isinstance(value, elt.ScalarType):
231
+ raise TypeError(f"Can only encode ScalarType, got {value}")
232
+ # Return sufficient integer type for encoded values
233
+ return PlaintextType(bitwidth=64)
234
+
235
+
236
+ @decode_p.def_abstract_eval
237
+ def _decode_ae(encoded: PlaintextType, encoder: elt.CustomType) -> elt.ScalarType:
238
+ """Decode fixed-point integer back to original scalar type.
239
+
240
+ Args:
241
+ encoded: Encoded integer value
242
+ encoder: PHE encoder (contains target dtype)
243
+
244
+ Returns:
245
+ Decoded value in original type (inferred from encoder's dtype)
246
+
247
+ Raises:
248
+ TypeError: If encoder is not EncoderType or encoded is not PlaintextType
249
+ """
250
+ if encoder != EncoderType:
251
+ raise TypeError(f"Expected Encoder, got {encoder}")
252
+ if not isinstance(encoded, PlaintextType):
253
+ raise TypeError(f"Can only decode PlaintextType, got {encoded}")
254
+ # In real implementation, would extract dtype from encoder attrs
255
+ # For now, return a default (this will be improved with attr introspection)
256
+ return elt.f64
257
+
258
+
259
+ # ==============================================================================
260
+ # --- Encryption/Decryption Operations (Integer only)
261
+ # ==============================================================================
262
+
263
+ encrypt_p = el.Primitive[el.Object]("phe.encrypt")
264
+ decrypt_p = el.Primitive[el.Object]("phe.decrypt")
265
+
266
+
267
+ @encrypt_p.def_abstract_eval
268
+ def _encrypt_ae(encoded: PlaintextType, pk: KeyType) -> CiphertextType:
269
+ """Encrypt encoded integer using PHE public key.
270
+
271
+ Args:
272
+ encoded: Encoded integer (from phe.encode)
273
+ pk: PHE public key
274
+
275
+ Returns:
276
+ CiphertextType - encrypted integer
277
+
278
+ Raises:
279
+ TypeError: If input is not PlaintextType or pk is not PublicKey
280
+ """
281
+ if not isinstance(pk, KeyType) or not pk.is_public:
282
+ raise TypeError(f"Expected PublicKey, got {pk}")
283
+ if not isinstance(encoded, PlaintextType):
284
+ raise TypeError(f"Can only encrypt PlaintextType, got {encoded}")
285
+ return CiphertextType(pk.scheme)
286
+
287
+
288
+ @decrypt_p.def_abstract_eval
289
+ def _decrypt_ae(ct: CiphertextType, sk: KeyType) -> PlaintextType:
290
+ """Decrypt ciphertext to encoded integer using PHE private key.
291
+
292
+ Args:
293
+ ct: Encrypted integer
294
+ sk: PHE private key
295
+
296
+ Returns:
297
+ Decrypted encoded integer
298
+
299
+ Raises:
300
+ TypeError: If ct is not CiphertextType or sk is not PrivateKey
301
+ """
302
+ if not isinstance(sk, KeyType) or sk.is_public:
303
+ raise TypeError(f"Expected PrivateKey, got {sk}")
304
+ if not isinstance(ct, CiphertextType):
305
+ raise TypeError(f"Expected CiphertextType, got {ct}")
306
+ # We assume it decrypts to i64 (standard encoded integer)
307
+ return PlaintextType(bitwidth=64)
308
+
309
+
310
+ # ==============================================================================
311
+ # --- Element-level Homomorphic Operations
312
+ # ==============================================================================
313
+
314
+ add_cc_p = el.Primitive[el.Object]("phe.add_cc")
315
+ add_cp_p = el.Primitive[el.Object]("phe.add_cp")
316
+ mul_cp_p = el.Primitive[el.Object]("phe.mul_cp")
317
+
318
+
319
+ @add_cc_p.def_abstract_eval
320
+ def _add_cc_ae(operand1: CiphertextType, operand2: CiphertextType) -> CiphertextType:
321
+ """Ciphertext + ciphertext → ciphertext."""
322
+ if not isinstance(operand1, CiphertextType) or not isinstance(
323
+ operand2, CiphertextType
324
+ ):
325
+ raise TypeError(f"Expected CiphertextType operands, got {operand1}, {operand2}")
326
+ if operand1 != operand2:
327
+ raise TypeError(f"Scheme mismatch: {operand1} vs {operand2}")
328
+ return operand1
329
+
330
+
331
+ @add_cp_p.def_abstract_eval
332
+ def _add_cp_ae(ciphertext: CiphertextType, plaintext: PlaintextType) -> CiphertextType:
333
+ """Ciphertext + plaintext → ciphertext."""
334
+ if not isinstance(ciphertext, CiphertextType):
335
+ raise TypeError(f"Expected CiphertextType ciphertext, got {ciphertext}")
336
+ if not isinstance(plaintext, PlaintextType):
337
+ raise TypeError(
338
+ f"Plaintext operand must be PlaintextType (encoded), got {plaintext}"
339
+ )
340
+ return ciphertext
341
+
342
+
343
+ @mul_cp_p.def_abstract_eval
344
+ def _mul_cp_ae(ciphertext: CiphertextType, plaintext: PlaintextType) -> CiphertextType:
345
+ """Element-level homomorphic scalar multiplication.
346
+
347
+ Args:
348
+ ciphertext: Encrypted value
349
+ plaintext: Encoded integer scalar
350
+
351
+ Returns:
352
+ Encrypted product
353
+ """
354
+ if not isinstance(ciphertext, CiphertextType):
355
+ raise TypeError(f"Expected CiphertextType ciphertext, got {ciphertext}")
356
+ if not isinstance(plaintext, PlaintextType):
357
+ raise TypeError(
358
+ f"Plaintext operand must be PlaintextType (encoded), got {plaintext}"
359
+ )
360
+ return ciphertext
361
+
362
+
363
+ # ==============================================================================
364
+ # --- User-facing API
365
+ # ==============================================================================
366
+
367
+
368
+ def keygen(
369
+ scheme: str = "paillier",
370
+ key_size: int = 2048,
371
+ ) -> tuple[el.Object, el.Object]:
372
+ """Generate PHE key pair (cryptographic parameters only).
373
+
374
+ Encoding parameters (fxp_bits, max_value) are now separate via create_encoder().
375
+
376
+ Args:
377
+ scheme: PHE scheme name (default: "paillier")
378
+ Supported: "paillier", "elgamal", "okamoto-uchiyama"
379
+ key_size: Key size in bits (default: 2048)
380
+ Larger keys = more security but slower computation
381
+
382
+ Returns:
383
+ Tuple of (public_key, private_key)
384
+
385
+ Example:
386
+ >>> # Basic usage
387
+ >>> pk, sk = phe.keygen()
388
+ >>>
389
+ >>> # Higher security
390
+ >>> pk, sk = phe.keygen(key_size=4096)
391
+ """
392
+ return keygen_p.bind(scheme=scheme, key_size=key_size)
393
+
394
+
395
+ def create_encoder(
396
+ dtype: elt.ScalarType,
397
+ fxp_bits: int = 16,
398
+ max_value: int | None = None,
399
+ ) -> el.Object:
400
+ """Create PHE encoder for value encoding/decoding.
401
+
402
+ Encoders are independent of keys and handle type conversion and
403
+ fixed-point representation for homomorphic operations.
404
+
405
+ Args:
406
+ dtype: Source data type (e.g., elt.f64, elt.i32)
407
+ Determines encoding/decoding behavior
408
+ fxp_bits: Fixed-point fractional bits for float types (default: 16)
409
+ Higher = more precision but smaller value range
410
+ Example: fxp_bits=16 means precision ≈ 1/65536
411
+ max_value: Optional maximum absolute value for overflow checking
412
+ Example: max_value=2**32 ensures |encoded_value| < 2**32
413
+
414
+ Returns:
415
+ PHE encoder configured for the specified dtype
416
+
417
+ Example:
418
+ >>> import mplang.v2.edsl.typing as elt
419
+ >>>
420
+ >>> # Float encoder with 16-bit fractional precision
421
+ >>> encoder_f64 = phe.create_encoder(dtype=elt.f64, fxp_bits=16)
422
+ >>>
423
+ >>> # Higher precision for sensitive computations
424
+ >>> encoder_hp = phe.create_encoder(dtype=elt.f64, fxp_bits=32)
425
+ >>>
426
+ >>> # Integer encoder (no fixed-point needed)
427
+ >>> encoder_i32 = phe.create_encoder(dtype=elt.i32)
428
+ """
429
+ attrs: dict[str, Any] = {
430
+ "dtype": dtype,
431
+ "fxp_bits": fxp_bits,
432
+ }
433
+ if max_value is not None:
434
+ attrs["max_value"] = max_value
435
+ return create_encoder_p.bind(**attrs)
436
+
437
+
438
+ def _has_tensor_args(*objs: el.Object) -> bool:
439
+ """Check whether any argument carries a TensorType."""
440
+ return any(isinstance(obj.type, elt.TensorType) for obj in objs)
441
+
442
+
443
+ class OperandInfo(NamedTuple):
444
+ """Classification of operand for PHE operation dispatch."""
445
+
446
+ is_tensor: bool
447
+ is_encrypted: bool
448
+ scalar_type: elt.BaseType | None
449
+
450
+
451
+ def _inspect_operand(obj: el.Object) -> OperandInfo:
452
+ """Classify operand layout/security for dispatch."""
453
+ obj_type = obj.type
454
+ if isinstance(obj_type, elt.TensorType):
455
+ elem = obj_type.element_type
456
+ if isinstance(elem, CiphertextType):
457
+ return OperandInfo(True, True, None)
458
+ if isinstance(elem, elt.ScalarType):
459
+ return OperandInfo(True, False, elem)
460
+ raise TypeError(
461
+ f"PHE operations support Tensor[ScalarType] or Tensor[CiphertextType], got Tensor[{elem}]"
462
+ )
463
+ if isinstance(obj_type, CiphertextType):
464
+ return OperandInfo(False, True, None)
465
+ if isinstance(obj_type, elt.ScalarType):
466
+ return OperandInfo(False, False, obj_type)
467
+ raise TypeError(f"PHE operations expect Scalar or Tensor operands, got {obj_type}")
468
+
469
+
470
+ BinaryFn = Callable[[el.Object, el.Object], el.Object]
471
+
472
+
473
+ def _apply_binary(fn: BinaryFn, lhs: el.Object, rhs: el.Object) -> el.Object:
474
+ """Apply scalar primitive, lifting to tensor.elementwise when needed."""
475
+ if _has_tensor_args(lhs, rhs):
476
+ return tensor.elementwise(fn, lhs, rhs)
477
+ return fn(lhs, rhs)
478
+
479
+
480
+ def _add_cp(ciphertext: el.Object, plaintext: el.Object) -> el.Object:
481
+ """Ciphertext ⊕ plaintext helper (order enforced)."""
482
+ return _apply_binary(add_cp_p.bind, ciphertext, plaintext)
483
+
484
+
485
+ def _mul_cp(ciphertext: el.Object, plaintext: el.Object) -> el.Object:
486
+ """Ciphertext ⊗ plaintext helper (order enforced)."""
487
+ return _apply_binary(mul_cp_p.bind, ciphertext, plaintext)
488
+
489
+
490
+ def encode(value: el.Object, encoder: el.Object) -> el.Object:
491
+ """Encode scalar value to fixed-point integer representation.
492
+
493
+ Args:
494
+ value: Source value (scalar or tensor)
495
+ encoder: PHE encoder (from create_encoder)
496
+
497
+ Returns:
498
+ Encoded integer with same structure as input
499
+
500
+ Example:
501
+ >>> x = tensor.constant(3.14) # f64
502
+ >>> encoder = phe.create_encoder(dtype=elt.f64, fxp_bits=16)
503
+ >>> x_enc = phe.encode(x, encoder) # i64 (encoded as 205887)
504
+ """
505
+ if _has_tensor_args(value):
506
+ return tensor.elementwise(encode_p.bind, value, encoder)
507
+ return encode_p.bind(value, encoder)
508
+
509
+
510
+ def decode(encoded: el.Object, encoder: el.Object) -> el.Object:
511
+ """Decode fixed-point integer back to original scalar type.
512
+
513
+ Args:
514
+ encoded: Encoded integer (from encode or decrypt)
515
+ encoder: PHE encoder (same as used for encoding)
516
+
517
+ Returns:
518
+ Decoded value in original type
519
+
520
+ Example:
521
+ >>> encoded = phe.encode(x, encoder)
522
+ >>> result = phe.decode(encoded, encoder) # Back to f64
523
+ """
524
+ if _has_tensor_args(encoded):
525
+ return tensor.elementwise(decode_p.bind, encoded, encoder)
526
+ return decode_p.bind(encoded, encoder)
527
+
528
+
529
+ def encrypt(encoded: el.Object, public_key: el.Object) -> el.Object:
530
+ """Encrypt encoded integer using PHE public key.
531
+
532
+ Note: Input must be encoded first via phe.encode().
533
+
534
+ Args:
535
+ encoded: Encoded integer (from phe.encode)
536
+ public_key: PHE public key
537
+
538
+ Returns:
539
+ Encrypted integer
540
+
541
+ Example:
542
+ >>> x_enc = phe.encode(x, encoder)
543
+ >>> ct = phe.encrypt(x_enc, pk) # i64 → PHECiphertext
544
+ """
545
+ if _has_tensor_args(encoded):
546
+ return tensor.elementwise(encrypt_p.bind, encoded, public_key)
547
+ return encrypt_p.bind(encoded, public_key)
548
+
549
+
550
+ def decrypt(ciphertext: el.Object, private_key: el.Object) -> el.Object:
551
+ """Decrypt ciphertext to encoded integer using PHE private key.
552
+
553
+ Note: Output is still encoded; use phe.decode() to get original type.
554
+
555
+ Args:
556
+ ciphertext: Encrypted value
557
+ private_key: PHE private key
558
+
559
+ Returns:
560
+ Decrypted encoded integer
561
+
562
+ Example:
563
+ >>> ct_sum = phe.add(ct1, ct2)
564
+ >>> sum_enc = phe.decrypt(ct_sum, sk) # PHECiphertext → i64
565
+ >>> result = phe.decode(sum_enc, encoder) # i64 → f64
566
+ """
567
+ if _has_tensor_args(ciphertext):
568
+ return tensor.elementwise(decrypt_p.bind, ciphertext, private_key)
569
+ return decrypt_p.bind(ciphertext, private_key)
570
+
571
+
572
+ def encrypt_auto(
573
+ value: el.Object, encoder: el.Object, public_key: el.Object
574
+ ) -> el.Object:
575
+ """Convenience: encode + encrypt in one step.
576
+
577
+ Args:
578
+ value: Source value (any scalar type)
579
+ encoder: PHE encoder
580
+ public_key: PHE public key
581
+
582
+ Returns:
583
+ Encrypted value
584
+
585
+ Example:
586
+ >>> ct = phe.encrypt_auto(x, encoder, pk)
587
+ >>> # Equivalent to:
588
+ >>> # ct = phe.encrypt(phe.encode(x, encoder), pk)
589
+ """
590
+ encoded = encode(value, encoder)
591
+ return encrypt(encoded, public_key)
592
+
593
+
594
+ def decrypt_auto(
595
+ ciphertext: el.Object, encoder: el.Object, private_key: el.Object
596
+ ) -> el.Object:
597
+ """Convenience: decrypt + decode in one step.
598
+
599
+ Args:
600
+ ciphertext: Encrypted value
601
+ encoder: PHE encoder (same as used for encoding)
602
+ private_key: PHE private key
603
+
604
+ Returns:
605
+ Decrypted value in original type
606
+
607
+ Example:
608
+ >>> result = phe.decrypt_auto(ct, encoder, sk)
609
+ >>> # Equivalent to:
610
+ >>> # result = phe.decode(phe.decrypt(ct, sk), encoder)
611
+ """
612
+ decoded = decrypt(ciphertext, private_key)
613
+ return decode(decoded, encoder)
614
+
615
+
616
+ def add(lhs: el.Object, rhs: el.Object) -> el.Object:
617
+ """Homomorphic addition.
618
+
619
+ Supports:
620
+ Ciphertext + Ciphertext → Ciphertext (ciphertext + ciphertext)
621
+ Ciphertext + T → Ciphertext (ciphertext + plaintext)
622
+ T + Ciphertext → Ciphertext (plaintext + ciphertext)
623
+
624
+ Args:
625
+ lhs: Left operand (encrypted or plaintext)
626
+ rhs: Right operand (encrypted or plaintext)
627
+
628
+ Returns:
629
+ Encrypted sum
630
+
631
+ Raises:
632
+ TypeError: If no operand is encrypted or types mismatch
633
+ """
634
+ lhs_info = _inspect_operand(lhs)
635
+ rhs_info = _inspect_operand(rhs)
636
+
637
+ if not (lhs_info.is_encrypted or rhs_info.is_encrypted):
638
+ raise TypeError("phe.add requires at least one ciphertext operand")
639
+
640
+ # CT + CT
641
+ if lhs_info.is_encrypted and rhs_info.is_encrypted:
642
+ return _apply_binary(add_cc_p.bind, lhs, rhs)
643
+
644
+ # CT + PT or PT + CT
645
+ if lhs_info.is_encrypted:
646
+ return _add_cp(lhs, rhs)
647
+ return _add_cp(rhs, lhs)
648
+
649
+
650
+ def mul_plain(lhs: el.Object, rhs: el.Object) -> el.Object:
651
+ """Homomorphic multiplication: ciphertext × plaintext (encoded integer).
652
+
653
+ Supports:
654
+ Ciphertext × i64 → Ciphertext (ciphertext × encoded plaintext)
655
+ i64 × Ciphertext → Ciphertext (encoded plaintext × ciphertext)
656
+
657
+ Args:
658
+ lhs: Left operand (one must be encrypted, other must be encoded integer)
659
+ rhs: Right operand
660
+
661
+ Returns:
662
+ Encrypted product
663
+
664
+ Raises:
665
+ TypeError: If both operands are encrypted or both are plaintext
666
+
667
+ Note:
668
+ - Ciphertext × ciphertext is not supported (would require FHE)
669
+ - Plaintext must be encoded integer (use phe.encode first)
670
+ - For float multiplication, may need truncation to maintain precision
671
+
672
+ Example:
673
+ >>> ct = phe.encrypt(phe.encode(x, encoder), pk)
674
+ >>> y_enc = phe.encode(y, encoder)
675
+ >>> ct_prod = phe.mul_plain(ct, y_enc)
676
+ """
677
+ lhs_info = _inspect_operand(lhs)
678
+ rhs_info = _inspect_operand(rhs)
679
+
680
+ # CT * PT
681
+ if lhs_info.is_encrypted and not rhs_info.is_encrypted:
682
+ return _mul_cp(lhs, rhs)
683
+ # PT * CT
684
+ if rhs_info.is_encrypted and not lhs_info.is_encrypted:
685
+ return _mul_cp(rhs, lhs)
686
+ # CT * CT (not supported)
687
+ if lhs_info.is_encrypted and rhs_info.is_encrypted:
688
+ raise TypeError(
689
+ "phe.mul_plain supports ciphertext * plaintext only, not CT * CT. "
690
+ "Ciphertext * ciphertext requires FHE."
691
+ )
692
+ # PT * PT (invalid)
693
+ raise TypeError("phe.mul_plain requires at least one ciphertext operand")
694
+
695
+
696
+ __all__ = [
697
+ "CiphertextType",
698
+ # Types
699
+ "EncoderType",
700
+ "KeyType",
701
+ "PlaintextType",
702
+ # User API
703
+ "add",
704
+ # Primitives
705
+ "add_cc_p",
706
+ "add_cp_p",
707
+ "create_encoder",
708
+ "create_encoder_p",
709
+ "decode",
710
+ "decode_p",
711
+ "decrypt",
712
+ "decrypt_auto",
713
+ "decrypt_p",
714
+ "encode",
715
+ "encode_p",
716
+ "encrypt",
717
+ "encrypt_auto",
718
+ "encrypt_p",
719
+ "keygen",
720
+ "keygen_p",
721
+ "mul_cp_p",
722
+ "mul_plain",
723
+ ]