mplang-nightly 0.1.dev158__py3-none-any.whl → 0.1.dev268__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (191) hide show
  1. mplang/__init__.py +21 -45
  2. mplang/py.typed +13 -0
  3. mplang/v1/__init__.py +157 -0
  4. mplang/v1/_device.py +602 -0
  5. mplang/{analysis → v1/analysis}/__init__.py +1 -1
  6. mplang/{analysis → v1/analysis}/diagram.py +5 -7
  7. mplang/v1/core/__init__.py +157 -0
  8. mplang/{core → v1/core}/cluster.py +30 -14
  9. mplang/{core → v1/core}/comm.py +5 -1
  10. mplang/{core → v1/core}/context_mgr.py +1 -1
  11. mplang/{core/dtype.py → v1/core/dtypes.py} +44 -2
  12. mplang/{core → v1/core}/expr/__init__.py +7 -7
  13. mplang/{core → v1/core}/expr/ast.py +13 -14
  14. mplang/{core → v1/core}/expr/evaluator.py +65 -24
  15. mplang/{core → v1/core}/expr/printer.py +24 -18
  16. mplang/{core → v1/core}/expr/transformer.py +3 -3
  17. mplang/{core → v1/core}/expr/utils.py +2 -2
  18. mplang/{core → v1/core}/expr/visitor.py +1 -1
  19. mplang/{core → v1/core}/expr/walk.py +1 -1
  20. mplang/{core → v1/core}/interp.py +6 -6
  21. mplang/{core → v1/core}/mpir.py +23 -16
  22. mplang/{core → v1/core}/mpobject.py +6 -6
  23. mplang/{core → v1/core}/mptype.py +13 -10
  24. mplang/{core → v1/core}/pfunc.py +4 -4
  25. mplang/{core → v1/core}/primitive.py +106 -201
  26. mplang/{core → v1/core}/table.py +36 -8
  27. mplang/{core → v1/core}/tensor.py +1 -1
  28. mplang/{core → v1/core}/tracer.py +9 -9
  29. mplang/{api.py → v1/host.py} +38 -6
  30. mplang/v1/kernels/__init__.py +41 -0
  31. mplang/{kernels → v1/kernels}/base.py +1 -1
  32. mplang/v1/kernels/basic.py +240 -0
  33. mplang/{kernels → v1/kernels}/context.py +42 -27
  34. mplang/{kernels → v1/kernels}/crypto.py +44 -37
  35. mplang/v1/kernels/fhe.py +858 -0
  36. mplang/{kernels → v1/kernels}/mock_tee.py +12 -13
  37. mplang/{kernels → v1/kernels}/phe.py +263 -57
  38. mplang/{kernels → v1/kernels}/spu.py +137 -48
  39. mplang/{kernels → v1/kernels}/sql_duckdb.py +12 -15
  40. mplang/{kernels → v1/kernels}/stablehlo.py +30 -23
  41. mplang/v1/kernels/value.py +626 -0
  42. mplang/{ops → v1/ops}/__init__.py +5 -16
  43. mplang/{ops → v1/ops}/base.py +2 -5
  44. mplang/{ops/builtin.py → v1/ops/basic.py} +34 -26
  45. mplang/v1/ops/crypto.py +262 -0
  46. mplang/v1/ops/fhe.py +272 -0
  47. mplang/{ops → v1/ops}/jax_cc.py +33 -68
  48. mplang/v1/ops/nnx_cc.py +168 -0
  49. mplang/{ops → v1/ops}/phe.py +16 -4
  50. mplang/{ops → v1/ops}/spu.py +3 -5
  51. mplang/v1/ops/sql_cc.py +303 -0
  52. mplang/{ops → v1/ops}/tee.py +9 -24
  53. mplang/{protos → v1/protos}/v1alpha1/mpir_pb2.pyi +71 -21
  54. mplang/v1/protos/v1alpha1/value_pb2.py +34 -0
  55. mplang/v1/protos/v1alpha1/value_pb2.pyi +169 -0
  56. mplang/{runtime → v1/runtime}/__init__.py +2 -2
  57. mplang/v1/runtime/channel.py +230 -0
  58. mplang/{runtime → v1/runtime}/cli.py +35 -20
  59. mplang/{runtime → v1/runtime}/client.py +19 -8
  60. mplang/{runtime → v1/runtime}/communicator.py +59 -15
  61. mplang/{runtime → v1/runtime}/data_providers.py +80 -19
  62. mplang/{runtime → v1/runtime}/driver.py +30 -12
  63. mplang/v1/runtime/link_comm.py +196 -0
  64. mplang/{runtime → v1/runtime}/server.py +58 -42
  65. mplang/{runtime → v1/runtime}/session.py +57 -71
  66. mplang/{runtime → v1/runtime}/simulation.py +55 -28
  67. mplang/v1/simp/api.py +353 -0
  68. mplang/{simp → v1/simp}/mpi.py +8 -9
  69. mplang/{simp/__init__.py → v1/simp/party.py} +19 -145
  70. mplang/{simp → v1/simp}/random.py +21 -22
  71. mplang/v1/simp/smpc.py +238 -0
  72. mplang/v1/utils/table_utils.py +185 -0
  73. mplang/v2/__init__.py +424 -0
  74. mplang/v2/backends/__init__.py +57 -0
  75. mplang/v2/backends/bfv_impl.py +705 -0
  76. mplang/v2/backends/channel.py +217 -0
  77. mplang/v2/backends/crypto_impl.py +723 -0
  78. mplang/v2/backends/field_impl.py +454 -0
  79. mplang/v2/backends/func_impl.py +107 -0
  80. mplang/v2/backends/phe_impl.py +148 -0
  81. mplang/v2/backends/simp_design.md +136 -0
  82. mplang/v2/backends/simp_driver/__init__.py +41 -0
  83. mplang/v2/backends/simp_driver/http.py +168 -0
  84. mplang/v2/backends/simp_driver/mem.py +280 -0
  85. mplang/v2/backends/simp_driver/ops.py +135 -0
  86. mplang/v2/backends/simp_driver/state.py +60 -0
  87. mplang/v2/backends/simp_driver/values.py +52 -0
  88. mplang/v2/backends/simp_worker/__init__.py +29 -0
  89. mplang/v2/backends/simp_worker/http.py +354 -0
  90. mplang/v2/backends/simp_worker/mem.py +102 -0
  91. mplang/v2/backends/simp_worker/ops.py +167 -0
  92. mplang/v2/backends/simp_worker/state.py +49 -0
  93. mplang/v2/backends/spu_impl.py +275 -0
  94. mplang/v2/backends/spu_state.py +187 -0
  95. mplang/v2/backends/store_impl.py +62 -0
  96. mplang/v2/backends/table_impl.py +838 -0
  97. mplang/v2/backends/tee_impl.py +215 -0
  98. mplang/v2/backends/tensor_impl.py +519 -0
  99. mplang/v2/cli.py +603 -0
  100. mplang/v2/cli_guide.md +122 -0
  101. mplang/v2/dialects/__init__.py +36 -0
  102. mplang/v2/dialects/bfv.py +665 -0
  103. mplang/v2/dialects/crypto.py +689 -0
  104. mplang/v2/dialects/dtypes.py +378 -0
  105. mplang/v2/dialects/field.py +210 -0
  106. mplang/v2/dialects/func.py +135 -0
  107. mplang/v2/dialects/phe.py +723 -0
  108. mplang/v2/dialects/simp.py +944 -0
  109. mplang/v2/dialects/spu.py +349 -0
  110. mplang/v2/dialects/store.py +63 -0
  111. mplang/v2/dialects/table.py +407 -0
  112. mplang/v2/dialects/tee.py +346 -0
  113. mplang/v2/dialects/tensor.py +1175 -0
  114. mplang/v2/edsl/README.md +279 -0
  115. mplang/v2/edsl/__init__.py +99 -0
  116. mplang/v2/edsl/context.py +311 -0
  117. mplang/v2/edsl/graph.py +463 -0
  118. mplang/v2/edsl/jit.py +62 -0
  119. mplang/v2/edsl/object.py +53 -0
  120. mplang/v2/edsl/primitive.py +284 -0
  121. mplang/v2/edsl/printer.py +119 -0
  122. mplang/v2/edsl/registry.py +207 -0
  123. mplang/v2/edsl/serde.py +375 -0
  124. mplang/v2/edsl/tracer.py +614 -0
  125. mplang/v2/edsl/typing.py +816 -0
  126. mplang/v2/kernels/Makefile +30 -0
  127. mplang/v2/kernels/__init__.py +23 -0
  128. mplang/v2/kernels/gf128.cpp +148 -0
  129. mplang/v2/kernels/ldpc.cpp +82 -0
  130. mplang/v2/kernels/okvs.cpp +283 -0
  131. mplang/v2/kernels/okvs_opt.cpp +291 -0
  132. mplang/v2/kernels/py_kernels.py +398 -0
  133. mplang/v2/libs/collective.py +330 -0
  134. mplang/v2/libs/device/__init__.py +51 -0
  135. mplang/v2/libs/device/api.py +813 -0
  136. mplang/v2/libs/device/cluster.py +352 -0
  137. mplang/v2/libs/ml/__init__.py +23 -0
  138. mplang/v2/libs/ml/sgb.py +1861 -0
  139. mplang/v2/libs/mpc/__init__.py +41 -0
  140. mplang/v2/libs/mpc/_utils.py +99 -0
  141. mplang/v2/libs/mpc/analytics/__init__.py +35 -0
  142. mplang/v2/libs/mpc/analytics/aggregation.py +372 -0
  143. mplang/v2/libs/mpc/analytics/groupby.md +99 -0
  144. mplang/v2/libs/mpc/analytics/groupby.py +331 -0
  145. mplang/v2/libs/mpc/analytics/permutation.py +386 -0
  146. mplang/v2/libs/mpc/common/constants.py +39 -0
  147. mplang/v2/libs/mpc/ot/__init__.py +32 -0
  148. mplang/v2/libs/mpc/ot/base.py +222 -0
  149. mplang/v2/libs/mpc/ot/extension.py +477 -0
  150. mplang/v2/libs/mpc/ot/silent.py +217 -0
  151. mplang/v2/libs/mpc/psi/__init__.py +40 -0
  152. mplang/v2/libs/mpc/psi/cuckoo.py +228 -0
  153. mplang/v2/libs/mpc/psi/okvs.py +49 -0
  154. mplang/v2/libs/mpc/psi/okvs_gct.py +79 -0
  155. mplang/v2/libs/mpc/psi/oprf.py +310 -0
  156. mplang/v2/libs/mpc/psi/rr22.py +344 -0
  157. mplang/v2/libs/mpc/psi/unbalanced.py +200 -0
  158. mplang/v2/libs/mpc/vole/__init__.py +31 -0
  159. mplang/v2/libs/mpc/vole/gilboa.py +327 -0
  160. mplang/v2/libs/mpc/vole/ldpc.py +383 -0
  161. mplang/v2/libs/mpc/vole/silver.py +336 -0
  162. mplang/v2/runtime/__init__.py +15 -0
  163. mplang/v2/runtime/dialect_state.py +41 -0
  164. mplang/v2/runtime/interpreter.py +871 -0
  165. mplang/v2/runtime/object_store.py +194 -0
  166. mplang/v2/runtime/value.py +141 -0
  167. {mplang_nightly-0.1.dev158.dist-info → mplang_nightly-0.1.dev268.dist-info}/METADATA +24 -17
  168. mplang_nightly-0.1.dev268.dist-info/RECORD +180 -0
  169. {mplang_nightly-0.1.dev158.dist-info → mplang_nightly-0.1.dev268.dist-info}/WHEEL +1 -1
  170. mplang/core/__init__.py +0 -92
  171. mplang/device.py +0 -340
  172. mplang/kernels/builtin.py +0 -207
  173. mplang/ops/crypto.py +0 -109
  174. mplang/ops/ibis_cc.py +0 -139
  175. mplang/ops/sql.py +0 -61
  176. mplang/protos/v1alpha1/mpir_pb2_grpc.py +0 -3
  177. mplang/runtime/link_comm.py +0 -131
  178. mplang/simp/smpc.py +0 -201
  179. mplang/utils/table_utils.py +0 -73
  180. mplang_nightly-0.1.dev158.dist-info/RECORD +0 -77
  181. /mplang/{core → v1/core}/mask.py +0 -0
  182. /mplang/{protos → v1/protos}/v1alpha1/mpir_pb2.py +0 -0
  183. /mplang/{runtime → v1/runtime}/exceptions.py +0 -0
  184. /mplang/{runtime → v1/runtime}/http_api.md +0 -0
  185. /mplang/{kernels → v1/simp}/__init__.py +0 -0
  186. /mplang/{utils → v1/utils}/__init__.py +0 -0
  187. /mplang/{utils → v1/utils}/crypto.py +0 -0
  188. /mplang/{utils → v1/utils}/func_utils.py +0 -0
  189. /mplang/{utils → v1/utils}/spu_utils.py +0 -0
  190. {mplang_nightly-0.1.dev158.dist-info → mplang_nightly-0.1.dev268.dist-info}/entry_points.txt +0 -0
  191. {mplang_nightly-0.1.dev158.dist-info → mplang_nightly-0.1.dev268.dist-info}/licenses/LICENSE +0 -0
@@ -15,18 +15,25 @@
15
15
 
16
16
  from jax.tree_util import PyTreeDef, tree_flatten
17
17
 
18
- from mplang.core.dtype import UINT8, UINT64
19
- from mplang.core.mpobject import MPObject # Needed for constant() triad return typing
20
- from mplang.core.pfunc import PFunction
21
- from mplang.core.table import TableLike, TableType
22
- from mplang.core.tensor import ScalarType, Shape, TensorLike, TensorType
23
- from mplang.ops.base import stateless_mod
24
- from mplang.utils import table_utils
25
-
26
- _BUILTIN_MOD = stateless_mod("builtin")
27
-
28
-
29
- @_BUILTIN_MOD.simple_op()
18
+ from mplang.v1.core import (
19
+ UINT8,
20
+ UINT64,
21
+ MPObject,
22
+ PFunction,
23
+ ScalarType,
24
+ Shape,
25
+ TableLike,
26
+ TableType,
27
+ TensorLike,
28
+ TensorType,
29
+ )
30
+ from mplang.v1.ops.base import stateless_mod
31
+ from mplang.v1.utils import table_utils
32
+
33
+ _BASIC_MOD = stateless_mod("basic")
34
+
35
+
36
+ @_BASIC_MOD.simple_op()
30
37
  def identity(x: TensorType) -> TensorType:
31
38
  """Return the input type unchanged.
32
39
 
@@ -40,7 +47,7 @@ def identity(x: TensorType) -> TensorType:
40
47
  return x
41
48
 
42
49
 
43
- @_BUILTIN_MOD.simple_op()
50
+ @_BASIC_MOD.simple_op()
44
51
  def read(*, path: str, ty: TensorType) -> TensorType:
45
52
  """Declare reading a value of type ``ty`` from ``path`` (type-only).
46
53
 
@@ -63,7 +70,7 @@ def read(*, path: str, ty: TensorType) -> TensorType:
63
70
  return ty
64
71
 
65
72
 
66
- @_BUILTIN_MOD.simple_op()
73
+ @_BASIC_MOD.simple_op()
67
74
  def write(x: TensorType, *, path: str) -> TensorType:
68
75
  """Declare writing the input value to ``path`` and return the same type.
69
76
 
@@ -77,7 +84,7 @@ def write(x: TensorType, *, path: str) -> TensorType:
77
84
  return x
78
85
 
79
86
 
80
- @_BUILTIN_MOD.op_def()
87
+ @_BASIC_MOD.op_def()
81
88
  def constant(
82
89
  data: TensorLike | ScalarType | TableLike,
83
90
  ) -> tuple[PFunction, list[MPObject], PyTreeDef]:
@@ -89,7 +96,7 @@ def constant(
89
96
 
90
97
  Returns:
91
98
  Tuple[PFunction, list[MPObject], PyTreeDef]:
92
- - PFunction: ``fn_type='builtin.constant'`` with one output whose type
99
+ - PFunction: ``fn_type='basic.constant'`` with one output whose type
93
100
  matches ``data``; payload serialized via ``data_bytes`` with
94
101
  ``data_format`` ('bytes[numpy]' or 'bytes[csv]').
95
102
  - list[MPObject]: Empty (no inputs captured).
@@ -101,8 +108,9 @@ def constant(
101
108
  out_type: TableType | TensorType
102
109
 
103
110
  if isinstance(data, TableLike):
104
- data_bytes = table_utils.dataframe_to_csv(data)
105
- data_format = "bytes[csv]"
111
+ format = "parquet"
112
+ data_bytes = table_utils.encode_table(data, format=format)
113
+ data_format = f"bytes[{format}]"
106
114
  out_type = TableType.from_tablelike(data)
107
115
  elif isinstance(data, ScalarType):
108
116
  out_type = TensorType.from_obj(data)
@@ -120,7 +128,7 @@ def constant(
120
128
  data_format = "bytes[numpy]"
121
129
 
122
130
  pfunc = PFunction(
123
- fn_type="builtin.constant",
131
+ fn_type="basic.constant",
124
132
  ins_info=(),
125
133
  outs_info=(out_type,),
126
134
  data_bytes=data_bytes,
@@ -130,7 +138,7 @@ def constant(
130
138
  return pfunc, [], treedef
131
139
 
132
140
 
133
- @_BUILTIN_MOD.simple_op()
141
+ @_BASIC_MOD.simple_op()
134
142
  def rank() -> TensorType:
135
143
  """Return the scalar UINT64 tensor type for the current party rank.
136
144
 
@@ -140,7 +148,7 @@ def rank() -> TensorType:
140
148
  return TensorType(UINT64, ())
141
149
 
142
150
 
143
- @_BUILTIN_MOD.simple_op()
151
+ @_BASIC_MOD.simple_op()
144
152
  def prand(*, shape: Shape = ()) -> TensorType:
145
153
  """Declare a private random UINT64 tensor with the given shape.
146
154
 
@@ -153,7 +161,7 @@ def prand(*, shape: Shape = ()) -> TensorType:
153
161
  return TensorType(UINT64, shape)
154
162
 
155
163
 
156
- @_BUILTIN_MOD.simple_op()
164
+ @_BASIC_MOD.simple_op()
157
165
  def debug_print(
158
166
  x: TensorType | TableType, *, prefix: str = ""
159
167
  ) -> TableType | TensorType:
@@ -169,7 +177,7 @@ def debug_print(
169
177
  return x
170
178
 
171
179
 
172
- @_BUILTIN_MOD.simple_op()
180
+ @_BASIC_MOD.simple_op()
173
181
  def pack(x: TensorType | TableType) -> TensorType:
174
182
  """Serialize a tensor/table into a byte vector (type-only).
175
183
 
@@ -189,7 +197,7 @@ def pack(x: TensorType | TableType) -> TensorType:
189
197
  return TensorType(UINT8, (-1,))
190
198
 
191
199
 
192
- @_BUILTIN_MOD.simple_op()
200
+ @_BASIC_MOD.simple_op()
193
201
  def unpack(b: TensorType, *, out_ty: TensorType | TableType) -> TensorType | TableType:
194
202
  """Deserialize a byte vector into the explicit output type.
195
203
 
@@ -215,7 +223,7 @@ def unpack(b: TensorType, *, out_ty: TensorType | TableType) -> TensorType | Tab
215
223
  return out_ty
216
224
 
217
225
 
218
- @_BUILTIN_MOD.simple_op()
226
+ @_BASIC_MOD.simple_op()
219
227
  def table_to_tensor(table: TableType, *, number_rows: int) -> TensorType:
220
228
  """Convert a homogeneous-typed table to a dense 2D tensor.
221
229
 
@@ -248,7 +256,7 @@ def table_to_tensor(table: TableType, *, number_rows: int) -> TensorType:
248
256
  return TensorType(first, shape) # type: ignore[arg-type]
249
257
 
250
258
 
251
- @_BUILTIN_MOD.simple_op()
259
+ @_BASIC_MOD.simple_op()
252
260
  def tensor_to_table(tensor: TensorType, *, column_names: list[str]) -> TableType:
253
261
  """Convert a rank-2 tensor into a table with named columns.
254
262
 
@@ -0,0 +1,262 @@
1
+ # Copyright 2025 Ant Group Co., Ltd.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """
16
+ Crypto frontend operations: operation signatures, types, and high-level semantics.
17
+
18
+ Scope and contracts:
19
+ - This module defines portable API shapes; it does not implement cryptography.
20
+ - Backends execute the operations and must meet the security semantics required
21
+ by the deployment (confidentiality, authenticity, correctness, etc.).
22
+ - The enc/dec API in this frontend uses a conventional 12-byte nonce prefix
23
+ (ciphertext = nonce || payload), and dec expects that format. Other security
24
+ properties (e.g., AEAD) are backend responsibilities.
25
+ """
26
+
27
+ from __future__ import annotations
28
+
29
+ from jax.tree_util import PyTreeDef, tree_flatten
30
+
31
+ from mplang.v1.core import UINT8, TensorType
32
+ from mplang.v1.core.mpobject import MPObject
33
+ from mplang.v1.core.pfunc import PFunction
34
+ from mplang.v1.ops.base import stateless_mod
35
+
36
+ _CRYPTO_MOD = stateless_mod("crypto")
37
+
38
+
39
+ def _get_algo_overhead(algo: str) -> int:
40
+ """Get ciphertext overhead for a given encryption algorithm.
41
+
42
+ Args:
43
+ algo: Encryption algorithm identifier
44
+
45
+ Returns:
46
+ int: Number of overhead bytes added to plaintext length
47
+ """
48
+ overhead_map = {
49
+ "aes-ctr": 16, # nonce only (legacy compatibility)
50
+ "aes-gcm": 28, # nonce(12) + tag(16) for AES-GCM
51
+ "sm4-gcm": 28, # nonce(12) + tag(16) for SM4-GCM
52
+ }
53
+
54
+ if algo not in overhead_map:
55
+ # return unknown overhead as -1
56
+ return -1
57
+ return overhead_map[algo]
58
+
59
+
60
+ @_CRYPTO_MOD.simple_op()
61
+ def keygen(*, length: int = 32) -> TensorType:
62
+ """Generate random bytes for symmetric keys or generic randomness.
63
+
64
+ API: keygen(length: int = 32) -> key: u8[length]
65
+
66
+ Notes:
67
+ - Frontend defines the type/shape; backend provides randomness.
68
+ - Raises ValueError when length <= 0.
69
+ """
70
+ if length <= 0:
71
+ raise ValueError("length must be > 0")
72
+ return TensorType(UINT8, (length,))
73
+
74
+
75
+ @_CRYPTO_MOD.op_def()
76
+ def enc(
77
+ plaintext: MPObject, key: MPObject, algo: str = "aes-ctr"
78
+ ) -> tuple[PFunction, list[MPObject], PyTreeDef]:
79
+ """Symmetric encryption with algorithm-aware output sizing.
80
+
81
+ API: enc(plaintext: u8[N], key: u8[M], *, algo: str = "aes-ctr") -> ciphertext: u8[N + overhead]
82
+
83
+ Supported algorithms and overhead:
84
+ - "aes-ctr": 16 bytes (nonce only, legacy compatibility)
85
+ - "aes-gcm": 28 bytes (nonce + 16-byte authentication tag)
86
+ - "sm4-gcm": 28 bytes (nonce + 16-byte authentication tag)
87
+
88
+ The algo parameter is stored in the PFunction attributes for backend use.
89
+ """
90
+ pt_ty = plaintext
91
+ if pt_ty.dtype != UINT8:
92
+ raise TypeError("enc expects UINT8 plaintext")
93
+ if len(pt_ty.shape) != 1:
94
+ raise TypeError("enc expects 1-D plaintext")
95
+
96
+ # Validate and get overhead for the specified algorithm
97
+ overhead = _get_algo_overhead(algo)
98
+ length = pt_ty.shape[0]
99
+ if length >= 0 and overhead >= 0:
100
+ outs_info = (TensorType(UINT8, (length + overhead,)),)
101
+ else:
102
+ # Unknown length or overhead, return dynamic length
103
+ outs_info = (TensorType(UINT8, (-1,)),)
104
+
105
+ ins_info = (TensorType.from_obj(pt_ty), TensorType.from_obj(key))
106
+ pfunc = PFunction(
107
+ fn_type="crypto.enc",
108
+ ins_info=ins_info,
109
+ outs_info=outs_info,
110
+ algo=algo,
111
+ )
112
+ _, treedef = tree_flatten(outs_info[0])
113
+ return pfunc, [plaintext, key], treedef
114
+
115
+
116
+ @_CRYPTO_MOD.op_def()
117
+ def dec(
118
+ ciphertext: MPObject, key: MPObject, algo: str = "aes-ctr"
119
+ ) -> tuple[PFunction, list[MPObject], PyTreeDef]:
120
+ """Symmetric decryption with algorithm-aware input sizing.
121
+
122
+ API: dec(ciphertext: u8[N + overhead], key: u8[M], *, algo: str = "aes-ctr") -> plaintext: u8[N]
123
+
124
+ Supported algorithms and overhead:
125
+ - "aes-ctr": 16 bytes (nonce only, legacy compatibility)
126
+ - "aes-gcm": 28 bytes (nonce + 16-byte authentication tag)
127
+ - "sm4-gcm": 28 bytes (nonce + 16-byte authentication tag)
128
+
129
+ The algo parameter is stored in the PFunction attributes for backend use.
130
+ Backend is responsible for parsing the ciphertext format according to algo.
131
+ """
132
+ ct_ty = ciphertext
133
+ if ct_ty.dtype != UINT8:
134
+ raise TypeError("dec expects UINT8 ciphertext")
135
+ if len(ct_ty.shape) != 1:
136
+ raise TypeError("dec expects 1-D ciphertext")
137
+
138
+ # Validate and get overhead for the specified algorithm
139
+ overhead = _get_algo_overhead(algo)
140
+ length = ct_ty.shape[0]
141
+
142
+ # Validate minimum ciphertext length
143
+ if length >= 0 and overhead >= 0 and length < overhead:
144
+ raise TypeError(
145
+ f"dec expects ciphertext with at least {overhead} bytes for algo='{algo}', but got {length} bytes"
146
+ )
147
+
148
+ # Compute output plaintext length
149
+ if length >= 0 and overhead >= 0:
150
+ outs_info = (TensorType(UINT8, (length - overhead,)),)
151
+ else:
152
+ # Unknown length or overhead, return dynamic length
153
+ outs_info = (TensorType(UINT8, (-1,)),)
154
+
155
+ ins_info = (TensorType.from_obj(ct_ty), TensorType.from_obj(key))
156
+ pfunc = PFunction(
157
+ fn_type="crypto.dec",
158
+ ins_info=ins_info,
159
+ outs_info=outs_info,
160
+ algo=algo,
161
+ )
162
+ _, treedef = tree_flatten(outs_info[0])
163
+ return pfunc, [ciphertext, key], treedef
164
+
165
+
166
+ @_CRYPTO_MOD.op_def()
167
+ def kem_keygen(suite: str = "x25519") -> tuple[PFunction, list[MPObject], PyTreeDef]:
168
+ """KEM-style keypair generation: returns (sk, pk) bytes.
169
+
170
+ API: kem_keygen(suite: str = "x25519") -> (sk: u8[32], pk: u8[32])
171
+
172
+ The suite parameter is stored in the PFunction attributes for backend use.
173
+ """
174
+ if suite == "x25519":
175
+ sk_ty = TensorType(UINT8, (32,))
176
+ pk_ty = TensorType(UINT8, (32,))
177
+ else:
178
+ # Unknown suite, return dynamic lengths
179
+ sk_ty = TensorType(UINT8, (-1,))
180
+ pk_ty = TensorType(UINT8, (-1,))
181
+ outs_info = (sk_ty, pk_ty)
182
+
183
+ pfunc = PFunction(
184
+ fn_type="crypto.kem_keygen",
185
+ ins_info=(),
186
+ outs_info=outs_info,
187
+ suite=suite,
188
+ )
189
+ _, treedef = tree_flatten(outs_info)
190
+ return pfunc, [], treedef
191
+
192
+
193
+ @_CRYPTO_MOD.op_def()
194
+ def kem_derive(
195
+ sk: MPObject, peer_pk: MPObject, suite: str = "x25519"
196
+ ) -> tuple[PFunction, list[MPObject], PyTreeDef]:
197
+ """KEM-style shared secret derivation: returns secret bytes.
198
+
199
+ API: kem_derive(sk: u8[32], peer_pk: u8[32], suite: str = "x25519") -> secret: u8[32]
200
+
201
+ The suite parameter is stored in the PFunction attributes for backend use.
202
+ """
203
+ # Validate input types
204
+ if sk.dtype != UINT8:
205
+ raise TypeError("kem_derive expects UINT8 secret key")
206
+ if peer_pk.dtype != UINT8:
207
+ raise TypeError("kem_derive expects UINT8 peer public key")
208
+ if len(sk.shape) != 1 or len(peer_pk.shape) != 1:
209
+ raise TypeError("kem_derive expects 1-D inputs")
210
+
211
+ if suite == "x25519":
212
+ if sk.shape[0] != 32 or peer_pk.shape[0] != 32:
213
+ raise TypeError("kem_derive expects 32-byte keys for suite 'x25519'")
214
+ secret_ty = TensorType(UINT8, (32,))
215
+ else:
216
+ # Unknown suite, return dynamic length
217
+ secret_ty = TensorType(UINT8, (-1,))
218
+ outs_info = (secret_ty,)
219
+
220
+ ins_info = (TensorType.from_obj(sk), TensorType.from_obj(peer_pk))
221
+ pfunc = PFunction(
222
+ fn_type="crypto.kem_derive",
223
+ ins_info=ins_info,
224
+ outs_info=outs_info,
225
+ suite=suite,
226
+ )
227
+ _, treedef = tree_flatten(outs_info[0])
228
+ return pfunc, [sk, peer_pk], treedef
229
+
230
+
231
+ @_CRYPTO_MOD.op_def()
232
+ def hkdf(
233
+ secret: MPObject, info: str, hash: str = "SHA-256"
234
+ ) -> tuple[PFunction, list[MPObject], PyTreeDef]:
235
+ """HKDF-style key derivation: returns a 32-byte key.
236
+
237
+ API: hkdf(secret: u8[N], info: str, hash: str = "SHA-256") -> key: u8[32]
238
+
239
+ The hash parameter is stored in the PFunction attributes for backend use.
240
+ """
241
+ # Validate input types
242
+ if secret.dtype != UINT8:
243
+ raise TypeError("hkdf expects UINT8 secret")
244
+ if len(secret.shape) != 1:
245
+ raise TypeError("hkdf expects 1-D secret")
246
+
247
+ if hash == "SHA-256" or hash == "SM3":
248
+ outs_info = (TensorType(UINT8, (32,)),)
249
+ else:
250
+ # Unknown hash, return dynamic length
251
+ outs_info = (TensorType(UINT8, (-1,)),)
252
+
253
+ ins_info = (TensorType.from_obj(secret),)
254
+ pfunc = PFunction(
255
+ fn_type="crypto.hkdf",
256
+ ins_info=ins_info,
257
+ outs_info=outs_info,
258
+ hash=hash,
259
+ info=info,
260
+ )
261
+ _, treedef = tree_flatten(outs_info[0])
262
+ return pfunc, [secret], treedef
mplang/v1/ops/fhe.py ADDED
@@ -0,0 +1,272 @@
1
+ # Copyright 2025 Ant Group Co., Ltd.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from mplang.v1.core import UINT8, TensorType
16
+ from mplang.v1.ops.base import stateless_mod
17
+
18
+ _fhe_MOD = stateless_mod("fhe")
19
+
20
+
21
+ @_fhe_MOD.simple_op()
22
+ def keygen(
23
+ *,
24
+ scheme: str = "CKKS",
25
+ poly_modulus_degree: int = 8192,
26
+ coeff_mod_bit_sizes: tuple[int, ...] | None = None,
27
+ global_scale: int | None = None,
28
+ plain_modulus: int | None = None,
29
+ ) -> tuple[TensorType, TensorType, TensorType]:
30
+ """Generate an FHE key pair for Vector backend: returns (private_context, public_context, evaluation_context).
31
+
32
+ Args:
33
+ scheme: FHE scheme to use ("CKKS" for approximate, "BFV" for exact integer)
34
+ poly_modulus_degree: Polynomial modulus degree (default: 8192)
35
+ coeff_mod_bit_sizes: Coefficient modulus bit sizes for CKKS (optional)
36
+ global_scale: Global scale for CKKS (optional)
37
+ plain_modulus: Plain modulus for BFV (optional)
38
+
39
+ Returns:
40
+ Tuple of (private_context, public_context, evaluation_context) represented as UINT8[(-1, 0)]
41
+
42
+ Contexts are represented with a sentinel TensorType UINT8[(-1, 0)] to indicate
43
+ non-structural, backend-only handles.
44
+
45
+ Note: Vector backend only supports 1D data. For multi-dimensional tensors,
46
+ use mplang.ops.fhe instead.
47
+ """
48
+ if scheme not in ("CKKS", "BFV"):
49
+ raise ValueError("Unsupported scheme. Choose either 'CKKS' or 'BFV'.")
50
+ if scheme == "CKKS":
51
+ assert plain_modulus is None, "plain_modulus is not used in CKKS scheme."
52
+ context_spec = TensorType(UINT8, (-1, 0))
53
+ return context_spec, context_spec, context_spec
54
+
55
+
56
+ @_fhe_MOD.simple_op()
57
+ def encrypt(plaintext: TensorType, context: TensorType) -> TensorType:
58
+ """Encrypt plaintext using FHE Vector backend: returns ciphertext with same semantic type.
59
+
60
+ Args:
61
+ plaintext: Data to encrypt (scalar or 1D vector only)
62
+ context: FHE context (private or public)
63
+
64
+ Returns:
65
+ Ciphertext with same semantic type as plaintext
66
+
67
+ Raises:
68
+ ValueError: If plaintext has more than 1 dimension
69
+
70
+ Note: Vector backend only supports scalars (shape=()) and 1D vectors (shape=(n,)).
71
+ For multi-dimensional data, use mplang.ops.fhe.encrypt instead.
72
+ """
73
+ _ = context
74
+ if len(plaintext.shape) > 1:
75
+ raise ValueError(
76
+ f"FHE Vector backend only supports 1D data. Got shape {plaintext.shape}. "
77
+ "Use mplang.ops.fhe for multi-dimensional tensors."
78
+ )
79
+ return plaintext
80
+
81
+
82
+ @_fhe_MOD.simple_op()
83
+ def decrypt(ciphertext: TensorType, context: TensorType) -> TensorType:
84
+ """Decrypt ciphertext using FHE Vector backend: returns plaintext with same semantic type.
85
+
86
+ Args:
87
+ ciphertext: Encrypted data to decrypt (scalar or 1D vector)
88
+ context: FHE context (must be private context with secret key)
89
+
90
+ Returns:
91
+ Plaintext with same semantic type as ciphertext
92
+
93
+ Note: Ciphertext encrypted with public context can be decrypted with
94
+ the corresponding private context.
95
+ """
96
+ _ = context
97
+ return ciphertext
98
+
99
+
100
+ @_fhe_MOD.simple_op()
101
+ def add(operand1: TensorType, operand2: TensorType) -> TensorType:
102
+ """Add two FHE operands (ciphertext + ciphertext or ciphertext + plaintext).
103
+
104
+ Args:
105
+ operand1: First operand (ciphertext or plaintext, scalar or 1D vector)
106
+ operand2: Second operand (ciphertext or plaintext, scalar or 1D vector)
107
+
108
+ Returns:
109
+ Result of homomorphic addition
110
+
111
+ Raises:
112
+ ValueError: If operands have incompatible shapes or dtypes
113
+
114
+ Note: At least one operand must be ciphertext. Both operands must have
115
+ the same shape (no broadcasting in Vector backend).
116
+ """
117
+ assert operand1.dtype == operand2.dtype, (
118
+ f"Operand dtypes must match, got {operand1.dtype} and {operand2.dtype}."
119
+ )
120
+ assert operand1.shape == operand2.shape, (
121
+ f"Operand shapes must match, got {operand1.shape} and {operand2.shape}."
122
+ )
123
+ return operand1
124
+
125
+
126
+ @_fhe_MOD.simple_op()
127
+ def sub(operand1: TensorType, operand2: TensorType) -> TensorType:
128
+ """Subtract two FHE operands (ciphertext - ciphertext or ciphertext - plaintext).
129
+
130
+ Args:
131
+ operand1: First operand (ciphertext or plaintext, scalar or 1D vector)
132
+ operand2: Second operand (ciphertext or plaintext, scalar or 1D vector)
133
+
134
+ Returns:
135
+ Result of homomorphic subtraction
136
+
137
+ Raises:
138
+ ValueError: If operands have incompatible shapes or dtypes
139
+
140
+ Note: At least one operand must be ciphertext. Both operands must have
141
+ the same shape (no broadcasting in Vector backend).
142
+ """
143
+ assert operand1.dtype == operand2.dtype, (
144
+ f"Operand dtypes must match, got {operand1.dtype} and {operand2.dtype}."
145
+ )
146
+ assert operand1.shape == operand2.shape, (
147
+ f"Operand shapes must match, got {operand1.shape} and {operand2.shape}."
148
+ )
149
+ return operand1
150
+
151
+
152
+ @_fhe_MOD.simple_op()
153
+ def mul(operand1: TensorType, operand2: TensorType) -> TensorType:
154
+ """Multiply two FHE operands (ciphertext * ciphertext or ciphertext * plaintext).
155
+
156
+ Args:
157
+ operand1: First operand (ciphertext or plaintext, scalar or 1D vector)
158
+ operand2: Second operand (ciphertext or plaintext, scalar or 1D vector)
159
+
160
+ Returns:
161
+ Result of homomorphic multiplication
162
+
163
+ Raises:
164
+ ValueError: If operands have incompatible shapes or dtypes
165
+
166
+ Note: At least one operand must be ciphertext. Both operands must have
167
+ the same shape (no broadcasting in Vector backend).
168
+ For BFV scheme, plaintext operands must be integers.
169
+ """
170
+ assert operand1.dtype == operand2.dtype, (
171
+ f"Operand dtypes must match, got {operand1.dtype} and {operand2.dtype}."
172
+ )
173
+ assert operand1.shape == operand2.shape, (
174
+ f"Operand shapes must match, got {operand1.shape} and {operand2.shape}."
175
+ )
176
+ return operand1
177
+
178
+
179
+ @_fhe_MOD.simple_op()
180
+ def dot(operand1: TensorType, operand2: TensorType) -> TensorType:
181
+ """Compute dot product of FHE operands (ciphertext · ciphertext or ciphertext · plaintext).
182
+
183
+ Args:
184
+ operand1: First operand (ciphertext or plaintext, must be 1D vector)
185
+ operand2: Second operand (ciphertext or plaintext, must be 1D vector)
186
+
187
+ Returns:
188
+ Scalar result of homomorphic dot product (shape=())
189
+
190
+ Raises:
191
+ ValueError: If operands are not 1D vectors or have different lengths
192
+
193
+ Note: Both operands must be 1D vectors (not scalars). For scalar multiplication,
194
+ use mul() instead. This operation always returns a scalar.
195
+ """
196
+ if len(operand1.shape) != 1:
197
+ raise ValueError(
198
+ f"Dot product requires 1D vectors, got shape {operand1.shape} for operand1"
199
+ )
200
+ if len(operand2.shape) != 1:
201
+ raise ValueError(
202
+ f"Dot product requires 1D vectors, got shape {operand2.shape} for operand2"
203
+ )
204
+ if operand1.shape[0] != operand2.shape[0]:
205
+ raise ValueError(
206
+ f"Dot product dimension mismatch: {operand1.shape[0]} vs {operand2.shape[0]}"
207
+ )
208
+
209
+ # Dot product of 1D vectors returns a scalar
210
+ return TensorType(operand1.dtype, ())
211
+
212
+
213
+ @_fhe_MOD.simple_op()
214
+ def polyval(ciphertext: TensorType, coeffs: TensorType) -> TensorType:
215
+ """Evaluate polynomial on encrypted data with plaintext coefficients.
216
+
217
+ Args:
218
+ ciphertext: Encrypted data (scalar or 1D vector)
219
+ coeffs: Plaintext polynomial coefficients as 1D array [c0, c1, c2, ...]
220
+ representing c0 + c1*x + c2*x^2 + ...
221
+
222
+ Returns:
223
+ Result of polynomial evaluation with same shape and dtype as ciphertext
224
+
225
+ Raises:
226
+ ValueError: If coefficients array is not 1D or has fewer than 2 elements
227
+
228
+ Note: Polynomial must have degree >= 1 (at least 2 coefficients required).
229
+ Constant polynomials (degree 0, single coefficient) are NOT supported due to
230
+ TenSEAL limitation. For constant values, use: ct * 0 + constant instead.
231
+ For BFV scheme, coefficients must be integers.
232
+
233
+ Common use case - Sigmoid approximation:
234
+ sigmoid_coeffs = [0.5, 0.15012, 0.0, -0.0018027]
235
+ result = polyval(ciphertext, sigmoid_coeffs)
236
+ """
237
+ if len(coeffs.shape) != 1:
238
+ raise ValueError(
239
+ f"Polynomial coefficients must be 1D array, got shape {coeffs.shape}"
240
+ )
241
+ _ = coeffs
242
+ return ciphertext
243
+
244
+
245
+ @_fhe_MOD.simple_op()
246
+ def negate(ciphertext: TensorType) -> TensorType:
247
+ """Negate encrypted data (unary minus).
248
+
249
+ Args:
250
+ ciphertext: Encrypted data (scalar or 1D vector)
251
+
252
+ Returns:
253
+ Negated ciphertext with same shape and dtype
254
+
255
+ Note: Equivalent to multiplying by -1.
256
+ """
257
+ return ciphertext
258
+
259
+
260
+ @_fhe_MOD.simple_op()
261
+ def square(ciphertext: TensorType) -> TensorType:
262
+ """Square encrypted data (element-wise).
263
+
264
+ Args:
265
+ ciphertext: Encrypted data (scalar or 1D vector)
266
+
267
+ Returns:
268
+ Squared ciphertext with same shape and dtype
269
+
270
+ Note: More efficient than mul(ciphertext, ciphertext) in some FHE schemes.
271
+ """
272
+ return ciphertext