mplang-nightly 0.1.dev158__py3-none-any.whl → 0.1.dev268__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (191) hide show
  1. mplang/__init__.py +21 -45
  2. mplang/py.typed +13 -0
  3. mplang/v1/__init__.py +157 -0
  4. mplang/v1/_device.py +602 -0
  5. mplang/{analysis → v1/analysis}/__init__.py +1 -1
  6. mplang/{analysis → v1/analysis}/diagram.py +5 -7
  7. mplang/v1/core/__init__.py +157 -0
  8. mplang/{core → v1/core}/cluster.py +30 -14
  9. mplang/{core → v1/core}/comm.py +5 -1
  10. mplang/{core → v1/core}/context_mgr.py +1 -1
  11. mplang/{core/dtype.py → v1/core/dtypes.py} +44 -2
  12. mplang/{core → v1/core}/expr/__init__.py +7 -7
  13. mplang/{core → v1/core}/expr/ast.py +13 -14
  14. mplang/{core → v1/core}/expr/evaluator.py +65 -24
  15. mplang/{core → v1/core}/expr/printer.py +24 -18
  16. mplang/{core → v1/core}/expr/transformer.py +3 -3
  17. mplang/{core → v1/core}/expr/utils.py +2 -2
  18. mplang/{core → v1/core}/expr/visitor.py +1 -1
  19. mplang/{core → v1/core}/expr/walk.py +1 -1
  20. mplang/{core → v1/core}/interp.py +6 -6
  21. mplang/{core → v1/core}/mpir.py +23 -16
  22. mplang/{core → v1/core}/mpobject.py +6 -6
  23. mplang/{core → v1/core}/mptype.py +13 -10
  24. mplang/{core → v1/core}/pfunc.py +4 -4
  25. mplang/{core → v1/core}/primitive.py +106 -201
  26. mplang/{core → v1/core}/table.py +36 -8
  27. mplang/{core → v1/core}/tensor.py +1 -1
  28. mplang/{core → v1/core}/tracer.py +9 -9
  29. mplang/{api.py → v1/host.py} +38 -6
  30. mplang/v1/kernels/__init__.py +41 -0
  31. mplang/{kernels → v1/kernels}/base.py +1 -1
  32. mplang/v1/kernels/basic.py +240 -0
  33. mplang/{kernels → v1/kernels}/context.py +42 -27
  34. mplang/{kernels → v1/kernels}/crypto.py +44 -37
  35. mplang/v1/kernels/fhe.py +858 -0
  36. mplang/{kernels → v1/kernels}/mock_tee.py +12 -13
  37. mplang/{kernels → v1/kernels}/phe.py +263 -57
  38. mplang/{kernels → v1/kernels}/spu.py +137 -48
  39. mplang/{kernels → v1/kernels}/sql_duckdb.py +12 -15
  40. mplang/{kernels → v1/kernels}/stablehlo.py +30 -23
  41. mplang/v1/kernels/value.py +626 -0
  42. mplang/{ops → v1/ops}/__init__.py +5 -16
  43. mplang/{ops → v1/ops}/base.py +2 -5
  44. mplang/{ops/builtin.py → v1/ops/basic.py} +34 -26
  45. mplang/v1/ops/crypto.py +262 -0
  46. mplang/v1/ops/fhe.py +272 -0
  47. mplang/{ops → v1/ops}/jax_cc.py +33 -68
  48. mplang/v1/ops/nnx_cc.py +168 -0
  49. mplang/{ops → v1/ops}/phe.py +16 -4
  50. mplang/{ops → v1/ops}/spu.py +3 -5
  51. mplang/v1/ops/sql_cc.py +303 -0
  52. mplang/{ops → v1/ops}/tee.py +9 -24
  53. mplang/{protos → v1/protos}/v1alpha1/mpir_pb2.pyi +71 -21
  54. mplang/v1/protos/v1alpha1/value_pb2.py +34 -0
  55. mplang/v1/protos/v1alpha1/value_pb2.pyi +169 -0
  56. mplang/{runtime → v1/runtime}/__init__.py +2 -2
  57. mplang/v1/runtime/channel.py +230 -0
  58. mplang/{runtime → v1/runtime}/cli.py +35 -20
  59. mplang/{runtime → v1/runtime}/client.py +19 -8
  60. mplang/{runtime → v1/runtime}/communicator.py +59 -15
  61. mplang/{runtime → v1/runtime}/data_providers.py +80 -19
  62. mplang/{runtime → v1/runtime}/driver.py +30 -12
  63. mplang/v1/runtime/link_comm.py +196 -0
  64. mplang/{runtime → v1/runtime}/server.py +58 -42
  65. mplang/{runtime → v1/runtime}/session.py +57 -71
  66. mplang/{runtime → v1/runtime}/simulation.py +55 -28
  67. mplang/v1/simp/api.py +353 -0
  68. mplang/{simp → v1/simp}/mpi.py +8 -9
  69. mplang/{simp/__init__.py → v1/simp/party.py} +19 -145
  70. mplang/{simp → v1/simp}/random.py +21 -22
  71. mplang/v1/simp/smpc.py +238 -0
  72. mplang/v1/utils/table_utils.py +185 -0
  73. mplang/v2/__init__.py +424 -0
  74. mplang/v2/backends/__init__.py +57 -0
  75. mplang/v2/backends/bfv_impl.py +705 -0
  76. mplang/v2/backends/channel.py +217 -0
  77. mplang/v2/backends/crypto_impl.py +723 -0
  78. mplang/v2/backends/field_impl.py +454 -0
  79. mplang/v2/backends/func_impl.py +107 -0
  80. mplang/v2/backends/phe_impl.py +148 -0
  81. mplang/v2/backends/simp_design.md +136 -0
  82. mplang/v2/backends/simp_driver/__init__.py +41 -0
  83. mplang/v2/backends/simp_driver/http.py +168 -0
  84. mplang/v2/backends/simp_driver/mem.py +280 -0
  85. mplang/v2/backends/simp_driver/ops.py +135 -0
  86. mplang/v2/backends/simp_driver/state.py +60 -0
  87. mplang/v2/backends/simp_driver/values.py +52 -0
  88. mplang/v2/backends/simp_worker/__init__.py +29 -0
  89. mplang/v2/backends/simp_worker/http.py +354 -0
  90. mplang/v2/backends/simp_worker/mem.py +102 -0
  91. mplang/v2/backends/simp_worker/ops.py +167 -0
  92. mplang/v2/backends/simp_worker/state.py +49 -0
  93. mplang/v2/backends/spu_impl.py +275 -0
  94. mplang/v2/backends/spu_state.py +187 -0
  95. mplang/v2/backends/store_impl.py +62 -0
  96. mplang/v2/backends/table_impl.py +838 -0
  97. mplang/v2/backends/tee_impl.py +215 -0
  98. mplang/v2/backends/tensor_impl.py +519 -0
  99. mplang/v2/cli.py +603 -0
  100. mplang/v2/cli_guide.md +122 -0
  101. mplang/v2/dialects/__init__.py +36 -0
  102. mplang/v2/dialects/bfv.py +665 -0
  103. mplang/v2/dialects/crypto.py +689 -0
  104. mplang/v2/dialects/dtypes.py +378 -0
  105. mplang/v2/dialects/field.py +210 -0
  106. mplang/v2/dialects/func.py +135 -0
  107. mplang/v2/dialects/phe.py +723 -0
  108. mplang/v2/dialects/simp.py +944 -0
  109. mplang/v2/dialects/spu.py +349 -0
  110. mplang/v2/dialects/store.py +63 -0
  111. mplang/v2/dialects/table.py +407 -0
  112. mplang/v2/dialects/tee.py +346 -0
  113. mplang/v2/dialects/tensor.py +1175 -0
  114. mplang/v2/edsl/README.md +279 -0
  115. mplang/v2/edsl/__init__.py +99 -0
  116. mplang/v2/edsl/context.py +311 -0
  117. mplang/v2/edsl/graph.py +463 -0
  118. mplang/v2/edsl/jit.py +62 -0
  119. mplang/v2/edsl/object.py +53 -0
  120. mplang/v2/edsl/primitive.py +284 -0
  121. mplang/v2/edsl/printer.py +119 -0
  122. mplang/v2/edsl/registry.py +207 -0
  123. mplang/v2/edsl/serde.py +375 -0
  124. mplang/v2/edsl/tracer.py +614 -0
  125. mplang/v2/edsl/typing.py +816 -0
  126. mplang/v2/kernels/Makefile +30 -0
  127. mplang/v2/kernels/__init__.py +23 -0
  128. mplang/v2/kernels/gf128.cpp +148 -0
  129. mplang/v2/kernels/ldpc.cpp +82 -0
  130. mplang/v2/kernels/okvs.cpp +283 -0
  131. mplang/v2/kernels/okvs_opt.cpp +291 -0
  132. mplang/v2/kernels/py_kernels.py +398 -0
  133. mplang/v2/libs/collective.py +330 -0
  134. mplang/v2/libs/device/__init__.py +51 -0
  135. mplang/v2/libs/device/api.py +813 -0
  136. mplang/v2/libs/device/cluster.py +352 -0
  137. mplang/v2/libs/ml/__init__.py +23 -0
  138. mplang/v2/libs/ml/sgb.py +1861 -0
  139. mplang/v2/libs/mpc/__init__.py +41 -0
  140. mplang/v2/libs/mpc/_utils.py +99 -0
  141. mplang/v2/libs/mpc/analytics/__init__.py +35 -0
  142. mplang/v2/libs/mpc/analytics/aggregation.py +372 -0
  143. mplang/v2/libs/mpc/analytics/groupby.md +99 -0
  144. mplang/v2/libs/mpc/analytics/groupby.py +331 -0
  145. mplang/v2/libs/mpc/analytics/permutation.py +386 -0
  146. mplang/v2/libs/mpc/common/constants.py +39 -0
  147. mplang/v2/libs/mpc/ot/__init__.py +32 -0
  148. mplang/v2/libs/mpc/ot/base.py +222 -0
  149. mplang/v2/libs/mpc/ot/extension.py +477 -0
  150. mplang/v2/libs/mpc/ot/silent.py +217 -0
  151. mplang/v2/libs/mpc/psi/__init__.py +40 -0
  152. mplang/v2/libs/mpc/psi/cuckoo.py +228 -0
  153. mplang/v2/libs/mpc/psi/okvs.py +49 -0
  154. mplang/v2/libs/mpc/psi/okvs_gct.py +79 -0
  155. mplang/v2/libs/mpc/psi/oprf.py +310 -0
  156. mplang/v2/libs/mpc/psi/rr22.py +344 -0
  157. mplang/v2/libs/mpc/psi/unbalanced.py +200 -0
  158. mplang/v2/libs/mpc/vole/__init__.py +31 -0
  159. mplang/v2/libs/mpc/vole/gilboa.py +327 -0
  160. mplang/v2/libs/mpc/vole/ldpc.py +383 -0
  161. mplang/v2/libs/mpc/vole/silver.py +336 -0
  162. mplang/v2/runtime/__init__.py +15 -0
  163. mplang/v2/runtime/dialect_state.py +41 -0
  164. mplang/v2/runtime/interpreter.py +871 -0
  165. mplang/v2/runtime/object_store.py +194 -0
  166. mplang/v2/runtime/value.py +141 -0
  167. {mplang_nightly-0.1.dev158.dist-info → mplang_nightly-0.1.dev268.dist-info}/METADATA +24 -17
  168. mplang_nightly-0.1.dev268.dist-info/RECORD +180 -0
  169. {mplang_nightly-0.1.dev158.dist-info → mplang_nightly-0.1.dev268.dist-info}/WHEEL +1 -1
  170. mplang/core/__init__.py +0 -92
  171. mplang/device.py +0 -340
  172. mplang/kernels/builtin.py +0 -207
  173. mplang/ops/crypto.py +0 -109
  174. mplang/ops/ibis_cc.py +0 -139
  175. mplang/ops/sql.py +0 -61
  176. mplang/protos/v1alpha1/mpir_pb2_grpc.py +0 -3
  177. mplang/runtime/link_comm.py +0 -131
  178. mplang/simp/smpc.py +0 -201
  179. mplang/utils/table_utils.py +0 -73
  180. mplang_nightly-0.1.dev158.dist-info/RECORD +0 -77
  181. /mplang/{core → v1/core}/mask.py +0 -0
  182. /mplang/{protos → v1/protos}/v1alpha1/mpir_pb2.py +0 -0
  183. /mplang/{runtime → v1/runtime}/exceptions.py +0 -0
  184. /mplang/{runtime → v1/runtime}/http_api.md +0 -0
  185. /mplang/{kernels → v1/simp}/__init__.py +0 -0
  186. /mplang/{utils → v1/utils}/__init__.py +0 -0
  187. /mplang/{utils → v1/utils}/crypto.py +0 -0
  188. /mplang/{utils → v1/utils}/func_utils.py +0 -0
  189. /mplang/{utils → v1/utils}/spu_utils.py +0 -0
  190. {mplang_nightly-0.1.dev158.dist-info → mplang_nightly-0.1.dev268.dist-info}/entry_points.txt +0 -0
  191. {mplang_nightly-0.1.dev158.dist-info → mplang_nightly-0.1.dev268.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,665 @@
1
+ # Copyright 2025 Ant Group Co., Ltd.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """BFV (Brakerski-Fan-Vercauteren) dialect for the EDSL.
16
+
17
+ BFV is a Fully Homomorphic Encryption (FHE) scheme that supports exact arithmetic
18
+ on integers. A key feature of BFV in this dialect is its SIMD (Single Instruction,
19
+ Multiple Data) capability, where a single ciphertext encrypts a vector of integers
20
+ (packed into "slots").
21
+
22
+ Design principles:
23
+ - **SIMD-first**: The fundamental unit of data is a packed vector (Plaintext/Ciphertext),
24
+ not a scalar.
25
+ - **Explicit Management**: Relinearization and rotation are explicit operations to
26
+ give users control over noise and performance.
27
+ - **Type Safety**: Distinguishes between Plaintext (encoded polynomial) and
28
+ Ciphertext (encrypted polynomial).
29
+
30
+ Type System Rationale:
31
+ The BFV dialect models data as `Encrypted[Vector[T]]`, where:
32
+ - `Vector[T]`: Represents the logical layout of data in SIMD slots.
33
+ - `T`: Must be an IntegerType (e.g., i64, u32). BFV does not support floating point.
34
+ If you need to encrypt floats, you must quantize them to integers first, or use CKKS.
35
+ - `Encrypted[...]`: Represents the cryptographic wrapper.
36
+
37
+ Why `Vector[T]` and not `Vector[BigInt]`?
38
+ - **Optimization**: Knowing the exact integer width (e.g., i32 vs i64) allows the
39
+ compiler to choose optimal encryption parameters (Plaintext Modulus `t`).
40
+ - **Semantics**: Preserves signed/unsigned semantics and bitwidth constraints.
41
+
42
+ Architecture:
43
+ Tensor[Integer, (N,)] (1D Vector)
44
+ ↓ encode(encoder)
45
+ Plaintext (Packed Polynomial) -> Wraps Vector[Integer, N]
46
+ ↓ encrypt(pk)
47
+ Ciphertext (Encrypted Polynomial) -> Wraps Vector[Integer, N]
48
+ ↓ add/mul (SIMD operations)
49
+ Ciphertext
50
+ ↓ decrypt(sk)
51
+ Plaintext
52
+ ↓ decode(encoder)
53
+ Tensor[Integer, (N,)]
54
+
55
+ Example:
56
+ ```python
57
+ from mplang.v2.dialects import tensor, bfv
58
+ import mplang.v2.edsl.typing as elt
59
+ import numpy as np
60
+
61
+ # 1. Setup
62
+ # poly_modulus_degree=4096 means 4096 slots
63
+ pk, sk = bfv.keygen(poly_modulus_degree=4096)
64
+ relin_keys = bfv.make_relin_keys(sk)
65
+ encoder = bfv.create_encoder(poly_modulus_degree=4096)
66
+
67
+ # 2. Data (Vectors)
68
+ v1 = tensor.constant(np.array([1, 2, 3, 4], dtype=np.int64))
69
+ v2 = tensor.constant(np.array([10, 20, 30, 40], dtype=np.int64))
70
+
71
+ # 3. Encode & Encrypt (SIMD Packing)
72
+ pt1 = bfv.encode(v1, encoder)
73
+ ct1 = bfv.encrypt(pt1, pk)
74
+
75
+ pt2 = bfv.encode(v2, encoder)
76
+ ct2 = bfv.encrypt(pt2, pk)
77
+
78
+ # 4. Computation
79
+ # Element-wise multiplication of the underlying vectors
80
+ ct_prod = bfv.mul(ct1, ct2)
81
+ # Relinearize to reduce ciphertext size after multiplication
82
+ ct_prod = bfv.relinearize(ct_prod, relin_keys)
83
+
84
+ # 5. Decrypt
85
+ pt_res = bfv.decrypt(ct_prod, sk)
86
+ res = bfv.decode(pt_res, encoder) # Returns Tensor
87
+ ```
88
+ """
89
+
90
+ from __future__ import annotations
91
+
92
+ from typing import Any, ClassVar, Literal, cast
93
+
94
+ import mplang.v2.edsl as el
95
+ import mplang.v2.edsl.typing as elt
96
+ from mplang.v2.edsl import serde
97
+
98
+ # ==============================================================================
99
+ # --- Type Definitions
100
+ # ==============================================================================
101
+
102
+ KeyKind = Literal["Public", "Private", "Relin", "Galois"]
103
+
104
+
105
+ @serde.register_class
106
+ class KeyType(elt.BaseType):
107
+ """Type for BFV keys."""
108
+
109
+ def __init__(self, kind: KeyKind, poly_modulus_degree: int = 4096):
110
+ self.scheme = "bfv"
111
+ self.kind = kind
112
+ self.poly_modulus_degree = poly_modulus_degree
113
+
114
+ def __str__(self) -> str:
115
+ return f"BFV{self.kind}Key[N={self.poly_modulus_degree}]"
116
+
117
+ def __eq__(self, other: object) -> bool:
118
+ if not isinstance(other, KeyType):
119
+ return False
120
+ return (
121
+ self.kind == other.kind
122
+ and self.poly_modulus_degree == other.poly_modulus_degree
123
+ )
124
+
125
+ def __hash__(self) -> int:
126
+ return hash(("BFVKeyType", self.kind, self.poly_modulus_degree))
127
+
128
+ # --- Serde methods ---
129
+ _serde_kind: ClassVar[str] = "bfv.KeyType"
130
+
131
+ def to_json(self) -> dict[str, Any]:
132
+ return {
133
+ "kind": self.kind,
134
+ "poly_modulus_degree": self.poly_modulus_degree,
135
+ }
136
+
137
+ @classmethod
138
+ def from_json(cls, data: dict[str, Any]) -> KeyType:
139
+ return cls(kind=data["kind"], poly_modulus_degree=data["poly_modulus_degree"])
140
+
141
+
142
+ @serde.register_class
143
+ class PlaintextType(elt.BaseType):
144
+ """Represents a BFV plaintext (a polynomial encoding a vector of integers).
145
+
146
+ In the EDSL type system, this wraps a VectorType which describes the
147
+ logical data layout (SIMD slots).
148
+ """
149
+
150
+ def __init__(self, vector_type: elt.VectorType):
151
+ self.vector_type = vector_type
152
+
153
+ @property
154
+ def slots(self) -> int:
155
+ return self.vector_type.size
156
+
157
+ def __str__(self) -> str:
158
+ return f"BFVPlaintext[{self.vector_type}]"
159
+
160
+ def __eq__(self, other: object) -> bool:
161
+ if not isinstance(other, PlaintextType):
162
+ return False
163
+ return self.vector_type == other.vector_type
164
+
165
+ def __hash__(self) -> int:
166
+ return hash(("BFVPlaintextType", self.vector_type))
167
+
168
+ # --- Serde methods ---
169
+ _serde_kind: ClassVar[str] = "bfv.PlaintextType"
170
+
171
+ def to_json(self) -> dict[str, Any]:
172
+ return {"vector_type": serde.to_json(self.vector_type)}
173
+
174
+ @classmethod
175
+ def from_json(cls, data: dict[str, Any]) -> PlaintextType:
176
+ vt = serde.from_json(data["vector_type"])
177
+ if not isinstance(vt, elt.VectorType):
178
+ raise TypeError(f"Expected VectorType, got {type(vt)}")
179
+ return cls(vector_type=vt)
180
+
181
+
182
+ @serde.register_class
183
+ class CiphertextType(elt.BaseType, elt.EncryptedTrait):
184
+ """Represents a BFV ciphertext (encrypting a Plaintext)."""
185
+
186
+ def __init__(self, vector_type: elt.VectorType):
187
+ self._scheme = "bfv"
188
+ self.vector_type = vector_type
189
+
190
+ @property
191
+ def scheme(self) -> str:
192
+ return self._scheme
193
+
194
+ @property
195
+ def poly_modulus_degree(self) -> int:
196
+ return self.vector_type.size
197
+
198
+ def __str__(self) -> str:
199
+ return f"BFVCiphertext[{self.vector_type}]"
200
+
201
+ def __eq__(self, other: object) -> bool:
202
+ if not isinstance(other, CiphertextType):
203
+ return False
204
+ return self.vector_type == other.vector_type
205
+
206
+ def __hash__(self) -> int:
207
+ return hash(("BFVCiphertextType", self.vector_type))
208
+
209
+ # --- Serde methods ---
210
+ _serde_kind: ClassVar[str] = "bfv.CiphertextType"
211
+
212
+ def to_json(self) -> dict[str, Any]:
213
+ return {"vector_type": serde.to_json(self.vector_type)}
214
+
215
+ @classmethod
216
+ def from_json(cls, data: dict[str, Any]) -> CiphertextType:
217
+ vt = serde.from_json(data["vector_type"])
218
+ if not isinstance(vt, elt.VectorType):
219
+ raise TypeError(f"Expected VectorType, got {type(vt)}")
220
+ return cls(vector_type=vt)
221
+
222
+
223
+ # Opaque types
224
+ @serde.register_class
225
+ class EncoderType(elt.BaseType):
226
+ """Type for BFV BatchEncoder."""
227
+
228
+ def __init__(self, poly_modulus_degree: int):
229
+ self.poly_modulus_degree = poly_modulus_degree
230
+
231
+ def __str__(self) -> str:
232
+ return f"BFVEncoder[N={self.poly_modulus_degree}]"
233
+
234
+ def __eq__(self, other: object) -> bool:
235
+ if not isinstance(other, EncoderType):
236
+ return False
237
+ return self.poly_modulus_degree == other.poly_modulus_degree
238
+
239
+ def __hash__(self) -> int:
240
+ return hash(("BFVEncoder", self.poly_modulus_degree))
241
+
242
+ # --- Serde methods ---
243
+ _serde_kind: ClassVar[str] = "bfv.EncoderType"
244
+
245
+ def to_json(self) -> dict[str, Any]:
246
+ return {"poly_modulus_degree": self.poly_modulus_degree}
247
+
248
+ @classmethod
249
+ def from_json(cls, data: dict[str, Any]) -> EncoderType:
250
+ return cls(poly_modulus_degree=data["poly_modulus_degree"])
251
+
252
+
253
+ # ==============================================================================
254
+ # --- Key Management
255
+ # ==============================================================================
256
+
257
+ keygen_p = el.Primitive[tuple[el.Object, el.Object]]("bfv.keygen")
258
+ make_relin_keys_p = el.Primitive[el.Object]("bfv.make_relin_keys")
259
+ make_galois_keys_p = el.Primitive[el.Object]("bfv.make_galois_keys")
260
+
261
+
262
+ @keygen_p.def_abstract_eval
263
+ def _keygen_ae(
264
+ *,
265
+ poly_modulus_degree: int = 4096,
266
+ plain_modulus: int = 1032193,
267
+ ) -> tuple[KeyType, KeyType]:
268
+ """Generate Public and Private keys."""
269
+ return (
270
+ KeyType("Public", poly_modulus_degree),
271
+ KeyType("Private", poly_modulus_degree),
272
+ )
273
+
274
+
275
+ @make_relin_keys_p.def_abstract_eval
276
+ def _make_relin_keys_ae(sk: KeyType) -> KeyType:
277
+ """Generate Relinearization keys from Secret Key."""
278
+ if not isinstance(sk, KeyType) or sk.kind != "Private":
279
+ raise TypeError(f"Expected BFV PrivateKey, got {sk}")
280
+ return KeyType("Relin", sk.poly_modulus_degree)
281
+
282
+
283
+ @make_galois_keys_p.def_abstract_eval
284
+ def _make_galois_keys_ae(sk: KeyType) -> KeyType:
285
+ """Generate Galois keys (for rotation) from Secret Key."""
286
+ if not isinstance(sk, KeyType) or sk.kind != "Private":
287
+ raise TypeError(f"Expected BFV PrivateKey, got {sk}")
288
+ return KeyType("Galois", sk.poly_modulus_degree)
289
+
290
+
291
+ # ==============================================================================
292
+ # --- Encoding / Decoding (SIMD Packing)
293
+ # ==============================================================================
294
+
295
+ create_encoder_p = el.Primitive[el.Object]("bfv.create_encoder")
296
+ encode_p = el.Primitive[el.Object]("bfv.encode")
297
+ batch_encode_p = el.Primitive[el.Object]("bfv.batch_encode")
298
+ decode_p = el.Primitive[el.Object]("bfv.decode")
299
+
300
+
301
+ @create_encoder_p.def_abstract_eval
302
+ def _create_encoder_ae(*, poly_modulus_degree: int = 4096) -> EncoderType:
303
+ return EncoderType(poly_modulus_degree)
304
+
305
+
306
+ @encode_p.def_abstract_eval
307
+ def _encode_ae(tensor: elt.TensorType, encoder: EncoderType) -> PlaintextType:
308
+ """Pack a 1D Tensor of integers into a BFV Plaintext."""
309
+ if not isinstance(encoder, EncoderType):
310
+ raise TypeError(f"Expected BFVEncoder, got {encoder}")
311
+
312
+ if not isinstance(tensor, elt.TensorType):
313
+ raise TypeError(f"Expected Tensor input, got {tensor}")
314
+
315
+ # Check 1D
316
+ if tensor.rank != 1:
317
+ raise ValueError(
318
+ f"BFV encode currently only supports 1D Tensors, got rank {tensor.rank}"
319
+ )
320
+
321
+ # Check Integer type
322
+ if not isinstance(tensor.element_type, elt.IntegerType):
323
+ raise TypeError(
324
+ f"BFV supports integer arithmetic only. Expected Tensor[Integer], got Tensor[{tensor.element_type}]"
325
+ )
326
+
327
+ # In a real implementation, we'd check if tensor size <= poly_modulus_degree
328
+ # For abstract eval, we assume N=4096 as default or infer from context if possible.
329
+ return PlaintextType(elt.Vector(tensor.element_type, encoder.poly_modulus_degree))
330
+
331
+
332
+ def _infer_batch_encode_output_types(
333
+ tensor: elt.TensorType,
334
+ encoder: elt.BaseType,
335
+ key: elt.BaseType,
336
+ ) -> tuple[PlaintextType, ...]:
337
+ if not isinstance(encoder, EncoderType):
338
+ raise TypeError(f"Expected BFVEncoder, got {encoder}")
339
+
340
+ if not isinstance(tensor, elt.TensorType):
341
+ raise TypeError(f"Expected Tensor input, got {tensor}")
342
+
343
+ if tensor.rank != 2:
344
+ raise ValueError(
345
+ f"BFV batch_encode input must be 2D Tensor, got rank {tensor.rank}"
346
+ )
347
+
348
+ N = tensor.shape[0]
349
+ if N is None:
350
+ raise ValueError(
351
+ "BFV batch_encode requires static first dimension for tensor input"
352
+ )
353
+
354
+ # Check Integer type
355
+ if not isinstance(tensor.element_type, elt.IntegerType):
356
+ raise TypeError(
357
+ f"BFV supports integer arithmetic only. Expected Tensor[Integer], got Tensor[{tensor.element_type}]"
358
+ )
359
+
360
+ return tuple(
361
+ PlaintextType(elt.Vector(tensor.element_type, encoder.poly_modulus_degree))
362
+ for _ in range(N)
363
+ )
364
+
365
+
366
+ @batch_encode_p.def_trace
367
+ def _batch_encode_trace(
368
+ tensor: el.Object,
369
+ encoder: el.Object,
370
+ key: el.Object,
371
+ ) -> tuple[el.Object, ...]:
372
+ from mplang.v2.edsl.tracer import TraceObject, Tracer
373
+
374
+ ctx = el.get_current_context()
375
+ if not isinstance(ctx, Tracer):
376
+ raise RuntimeError("batch_encode must be called within a Tracer context")
377
+
378
+ encoder_type = encoder.type
379
+ key_type = key.type
380
+
381
+ output_types = _infer_batch_encode_output_types(tensor.type, encoder_type, key_type)
382
+ inputs = [tensor, encoder, key]
383
+
384
+ input_values = [cast(TraceObject, obj)._graph_value for obj in inputs]
385
+
386
+ # 3. Add Op
387
+ result_values = ctx.graph.add_op(
388
+ batch_encode_p.name,
389
+ input_values,
390
+ output_types,
391
+ attrs={},
392
+ )
393
+
394
+ # 4. Wrap results
395
+ return tuple(
396
+ TraceObject(val, ctx)
397
+ for val, typ in zip(result_values, output_types, strict=True)
398
+ )
399
+
400
+
401
+ @decode_p.def_abstract_eval
402
+ def _decode_ae(plain: PlaintextType, encoder: EncoderType) -> elt.TensorType:
403
+ """Unpack a BFV Plaintext back into a 1D Tensor."""
404
+ if not isinstance(encoder, EncoderType):
405
+ raise TypeError(f"Expected BFVEncoder, got {encoder}")
406
+ if not isinstance(plain, PlaintextType):
407
+ raise TypeError(f"Expected BFVPlaintext, got {plain}")
408
+
409
+ # Returns a 1D tensor of i64 (default assumption for BFV)
410
+ # The shape is technically (slots,), but we might not know slots exactly here
411
+ # if it wasn't tracked perfectly.
412
+ return elt.TensorType(plain.vector_type.element_type, (plain.slots,))
413
+
414
+
415
+ # ==============================================================================
416
+ # --- Encryption / Decryption
417
+ # ==============================================================================
418
+
419
+ encrypt_p = el.Primitive[el.Object]("bfv.encrypt")
420
+ decrypt_p = el.Primitive[el.Object]("bfv.decrypt")
421
+
422
+
423
+ @encrypt_p.def_abstract_eval
424
+ def _encrypt_ae(plain: PlaintextType, pk: KeyType) -> CiphertextType:
425
+ if not isinstance(plain, PlaintextType):
426
+ raise TypeError(f"Expected BFVPlaintext, got {plain}")
427
+ if not isinstance(pk, KeyType) or pk.kind != "Public":
428
+ raise TypeError(f"Expected BFV PublicKey, got {pk}")
429
+ return CiphertextType(plain.vector_type)
430
+
431
+
432
+ @decrypt_p.def_abstract_eval
433
+ def _decrypt_ae(ct: CiphertextType, sk: KeyType) -> PlaintextType:
434
+ if not isinstance(ct, CiphertextType):
435
+ raise TypeError(f"Expected BFVCiphertext, got {ct}")
436
+ if not isinstance(sk, KeyType) or sk.kind != "Private":
437
+ raise TypeError(f"Expected BFV PrivateKey, got {sk}")
438
+ return PlaintextType(ct.vector_type)
439
+
440
+
441
+ # ==============================================================================
442
+ # --- Arithmetic Operations
443
+ # ==============================================================================
444
+
445
+ add_p = el.Primitive[el.Object]("bfv.add")
446
+ sub_p = el.Primitive[el.Object]("bfv.sub")
447
+ mul_p = el.Primitive[el.Object]("bfv.mul")
448
+ relinearize_p = el.Primitive[el.Object]("bfv.relinearize")
449
+
450
+
451
+ def _check_arithmetic_operands(lhs: Any, rhs: Any) -> None:
452
+ """Helper to validate operands for arithmetic."""
453
+ valid_types = (CiphertextType, PlaintextType)
454
+ if not (isinstance(lhs, valid_types) and isinstance(rhs, valid_types)):
455
+ raise TypeError(
456
+ f"Operands must be BFVCiphertext or BFVPlaintext, got {lhs}, {rhs}"
457
+ )
458
+ # At least one must be ciphertext
459
+ if not (isinstance(lhs, CiphertextType) or isinstance(rhs, CiphertextType)):
460
+ raise TypeError("At least one operand must be a Ciphertext")
461
+
462
+
463
+ @add_p.def_abstract_eval
464
+ def _add_ae(lhs: Any, rhs: Any) -> CiphertextType:
465
+ _check_arithmetic_operands(lhs, rhs)
466
+ # Result inherits properties from the ciphertext operand
467
+ ct = lhs if isinstance(lhs, CiphertextType) else rhs
468
+ return CiphertextType(ct.vector_type)
469
+
470
+
471
+ @sub_p.def_abstract_eval
472
+ def _sub_ae(lhs: Any, rhs: Any) -> CiphertextType:
473
+ _check_arithmetic_operands(lhs, rhs)
474
+ ct = lhs if isinstance(lhs, CiphertextType) else rhs
475
+ return CiphertextType(ct.vector_type)
476
+
477
+
478
+ @mul_p.def_abstract_eval
479
+ def _mul_ae(lhs: Any, rhs: Any) -> CiphertextType:
480
+ _check_arithmetic_operands(lhs, rhs)
481
+ ct = lhs if isinstance(lhs, CiphertextType) else rhs
482
+ # Note: Multiplication increases noise and potentially size (if CT*CT)
483
+ # But the type remains CiphertextType.
484
+ return CiphertextType(ct.vector_type)
485
+
486
+
487
+ @relinearize_p.def_abstract_eval
488
+ def _relinearize_ae(ct: CiphertextType, rk: KeyType) -> CiphertextType:
489
+ if not isinstance(ct, CiphertextType):
490
+ raise TypeError(f"Expected BFVCiphertext, got {ct}")
491
+ if not isinstance(rk, KeyType) or rk.kind != "Relin":
492
+ raise TypeError(f"Expected BFV RelinKeys, got {rk}")
493
+ return CiphertextType(ct.vector_type)
494
+
495
+
496
+ # ==============================================================================
497
+ # --- Rotation
498
+ # ==============================================================================
499
+
500
+ rotate_p = el.Primitive[el.Object]("bfv.rotate")
501
+
502
+
503
+ @rotate_p.def_abstract_eval
504
+ def _rotate_ae(ct: CiphertextType, gk: KeyType, *, steps: int) -> CiphertextType:
505
+ if not isinstance(ct, CiphertextType):
506
+ raise TypeError(f"Expected BFVCiphertext, got {ct}")
507
+ if not isinstance(gk, KeyType) or gk.kind != "Galois":
508
+ raise TypeError(f"Expected BFV GaloisKeys, got {gk}")
509
+ return CiphertextType(ct.vector_type)
510
+
511
+
512
+ rotate_columns_p = el.Primitive[el.Object]("bfv.rotate_columns")
513
+
514
+
515
+ @rotate_columns_p.def_abstract_eval
516
+ def _rotate_columns_ae(ct: CiphertextType, gk: KeyType) -> CiphertextType:
517
+ """Swap the two rows in SIMD batching (row 0 <-> row 1)."""
518
+ if not isinstance(ct, CiphertextType):
519
+ raise TypeError(f"Expected BFVCiphertext, got {ct}")
520
+ if not isinstance(gk, KeyType) or gk.kind != "Galois":
521
+ raise TypeError(f"Expected BFV GaloisKeys, got {gk}")
522
+ return CiphertextType(ct.vector_type)
523
+
524
+
525
+ # ==============================================================================
526
+ # --- User API
527
+ # ==============================================================================
528
+
529
+
530
+ def keygen(
531
+ poly_modulus_degree: int = 4096,
532
+ plain_modulus: int = 1032193,
533
+ ) -> tuple[el.Object, el.Object]:
534
+ """Generate BFV Public and Secret keys.
535
+
536
+ Args:
537
+ poly_modulus_degree: Degree of polynomial modulus (N). Determines slot count.
538
+ Must be power of 2 (e.g., 4096, 8192).
539
+ plain_modulus: Explicit plaintext modulus (integer). Default is 1032193.
540
+
541
+ Returns:
542
+ (PublicKey, SecretKey)
543
+ """
544
+ return keygen_p.bind(
545
+ poly_modulus_degree=poly_modulus_degree,
546
+ plain_modulus=plain_modulus,
547
+ )
548
+
549
+
550
+ def make_relin_keys(secret_key: el.Object) -> el.Object:
551
+ """Generate Relinearization Keys from Secret Key."""
552
+ return make_relin_keys_p.bind(secret_key)
553
+
554
+
555
+ def make_galois_keys(secret_key: el.Object) -> el.Object:
556
+ """Generate Galois Keys (for rotation) from Secret Key."""
557
+ return make_galois_keys_p.bind(secret_key)
558
+
559
+
560
+ def create_encoder(poly_modulus_degree: int = 4096) -> el.Object:
561
+ """Create a BatchEncoder for SIMD packing."""
562
+ return create_encoder_p.bind(poly_modulus_degree=poly_modulus_degree)
563
+
564
+
565
+ def encode(tensor: el.Object, encoder: el.Object) -> el.Object:
566
+ """Pack a 1D Tensor of integers into a BFV Plaintext."""
567
+ return encode_p.bind(tensor, encoder)
568
+
569
+
570
+ def batch_encode(
571
+ tensor: el.Object,
572
+ encoder: el.Object,
573
+ key: el.Object,
574
+ ) -> el.Object:
575
+ """Pack a 2D Tensor of integers into BFV Plaintexts (Batched)."""
576
+ return batch_encode_p.bind(tensor, encoder, key)
577
+
578
+
579
+ def decode(plain: el.Object, encoder: el.Object) -> el.Object:
580
+ """Unpack a BFV Plaintext back into a 1D Tensor."""
581
+ return decode_p.bind(plain, encoder)
582
+
583
+
584
+ def encrypt(plain: el.Object, public_key: el.Object) -> el.Object:
585
+ """Encrypt a Plaintext into a Ciphertext."""
586
+ return encrypt_p.bind(plain, public_key)
587
+
588
+
589
+ def decrypt(ciphertext: el.Object, secret_key: el.Object) -> el.Object:
590
+ """Decrypt a Ciphertext into a Plaintext."""
591
+ return decrypt_p.bind(ciphertext, secret_key)
592
+
593
+
594
+ def add(lhs: el.Object, rhs: el.Object) -> el.Object:
595
+ """Homomorphic Addition (SIMD)."""
596
+ return add_p.bind(lhs, rhs)
597
+
598
+
599
+ def sub(lhs: el.Object, rhs: el.Object) -> el.Object:
600
+ """Homomorphic Subtraction (SIMD)."""
601
+ return sub_p.bind(lhs, rhs)
602
+
603
+
604
+ def mul(lhs: el.Object, rhs: el.Object) -> el.Object:
605
+ """Homomorphic Multiplication (SIMD).
606
+
607
+ Note: Multiplying two ciphertexts increases the ciphertext size (e.g., size 2 -> 3).
608
+ Use `relinearize` afterwards to reduce it back to size 2 for further multiplications.
609
+ """
610
+ return mul_p.bind(lhs, rhs)
611
+
612
+
613
+ def relinearize(ciphertext: el.Object, relin_keys: el.Object) -> el.Object:
614
+ """Relinearize ciphertext (reduce size after multiplication)."""
615
+ return relinearize_p.bind(ciphertext, relin_keys)
616
+
617
+
618
+ def rotate(ciphertext: el.Object, steps: int, galois_keys: el.Object) -> el.Object:
619
+ """Cyclic rotation of the encrypted vector slots within each row.
620
+
621
+ Args:
622
+ ciphertext: The ciphertext to rotate.
623
+ steps: Number of steps to rotate. Positive = left, Negative = right.
624
+ Must be in range (-slot_count/2, slot_count/2).
625
+ galois_keys: Keys required for rotation.
626
+ """
627
+ return rotate_p.bind(ciphertext, galois_keys, steps=steps)
628
+
629
+
630
+ def rotate_columns(ciphertext: el.Object, galois_keys: el.Object) -> el.Object:
631
+ """Swap the two rows in SIMD batching (row 0 <-> row 1).
632
+
633
+ In BFV batching with slot_count = N, slots are arranged as:
634
+ - Row 0: slots 0 to N/2 - 1
635
+ - Row 1: slots N/2 to N - 1
636
+
637
+ rotate_columns swaps these two rows.
638
+
639
+ Args:
640
+ ciphertext: The ciphertext to rotate.
641
+ galois_keys: Keys required for rotation.
642
+ """
643
+ return rotate_columns_p.bind(ciphertext, galois_keys)
644
+
645
+
646
+ __all__ = [
647
+ "CiphertextType",
648
+ "EncoderType",
649
+ "KeyType",
650
+ "PlaintextType",
651
+ "add",
652
+ "create_encoder",
653
+ "decode",
654
+ "decrypt",
655
+ "encode",
656
+ "encrypt",
657
+ "keygen",
658
+ "make_galois_keys",
659
+ "make_relin_keys",
660
+ "mul",
661
+ "relinearize",
662
+ "rotate",
663
+ "rotate_columns",
664
+ "sub",
665
+ ]