mplang-nightly 0.1.dev192__py3-none-any.whl → 0.1.dev268__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (188) hide show
  1. mplang/__init__.py +21 -130
  2. mplang/py.typed +13 -0
  3. mplang/v1/__init__.py +157 -0
  4. mplang/v1/_device.py +602 -0
  5. mplang/{analysis → v1/analysis}/__init__.py +1 -1
  6. mplang/{analysis → v1/analysis}/diagram.py +4 -4
  7. mplang/{core → v1/core}/__init__.py +20 -14
  8. mplang/{core → v1/core}/cluster.py +6 -1
  9. mplang/{core → v1/core}/comm.py +1 -1
  10. mplang/{core → v1/core}/context_mgr.py +1 -1
  11. mplang/{core → v1/core}/dtypes.py +38 -0
  12. mplang/{core → v1/core}/expr/__init__.py +7 -7
  13. mplang/{core → v1/core}/expr/ast.py +11 -13
  14. mplang/{core → v1/core}/expr/evaluator.py +8 -8
  15. mplang/{core → v1/core}/expr/printer.py +6 -6
  16. mplang/{core → v1/core}/expr/transformer.py +2 -2
  17. mplang/{core → v1/core}/expr/utils.py +2 -2
  18. mplang/{core → v1/core}/expr/visitor.py +1 -1
  19. mplang/{core → v1/core}/expr/walk.py +1 -1
  20. mplang/{core → v1/core}/interp.py +6 -6
  21. mplang/{core → v1/core}/mpir.py +13 -11
  22. mplang/{core → v1/core}/mpobject.py +6 -6
  23. mplang/{core → v1/core}/mptype.py +13 -10
  24. mplang/{core → v1/core}/pfunc.py +2 -2
  25. mplang/{core → v1/core}/primitive.py +12 -12
  26. mplang/{core → v1/core}/table.py +36 -8
  27. mplang/{core → v1/core}/tensor.py +1 -1
  28. mplang/{core → v1/core}/tracer.py +9 -9
  29. mplang/{host.py → v1/host.py} +5 -5
  30. mplang/{kernels → v1/kernels}/__init__.py +1 -1
  31. mplang/{kernels → v1/kernels}/base.py +1 -1
  32. mplang/{kernels → v1/kernels}/basic.py +15 -15
  33. mplang/{kernels → v1/kernels}/context.py +19 -16
  34. mplang/{kernels → v1/kernels}/crypto.py +8 -10
  35. mplang/{kernels → v1/kernels}/fhe.py +9 -7
  36. mplang/{kernels → v1/kernels}/mock_tee.py +3 -3
  37. mplang/{kernels → v1/kernels}/phe.py +26 -18
  38. mplang/{kernels → v1/kernels}/spu.py +5 -5
  39. mplang/{kernels → v1/kernels}/sql_duckdb.py +5 -3
  40. mplang/{kernels → v1/kernels}/stablehlo.py +18 -17
  41. mplang/{kernels → v1/kernels}/value.py +2 -2
  42. mplang/{ops → v1/ops}/__init__.py +3 -3
  43. mplang/{ops → v1/ops}/base.py +1 -1
  44. mplang/{ops → v1/ops}/basic.py +6 -5
  45. mplang/v1/ops/crypto.py +262 -0
  46. mplang/{ops → v1/ops}/fhe.py +2 -2
  47. mplang/{ops → v1/ops}/jax_cc.py +26 -59
  48. mplang/v1/ops/nnx_cc.py +168 -0
  49. mplang/{ops → v1/ops}/phe.py +16 -3
  50. mplang/{ops → v1/ops}/spu.py +3 -3
  51. mplang/v1/ops/sql_cc.py +303 -0
  52. mplang/{ops → v1/ops}/tee.py +2 -2
  53. mplang/{runtime → v1/runtime}/__init__.py +2 -2
  54. mplang/v1/runtime/channel.py +230 -0
  55. mplang/{runtime → v1/runtime}/cli.py +3 -3
  56. mplang/{runtime → v1/runtime}/client.py +1 -1
  57. mplang/{runtime → v1/runtime}/communicator.py +39 -15
  58. mplang/{runtime → v1/runtime}/data_providers.py +80 -19
  59. mplang/{runtime → v1/runtime}/driver.py +4 -4
  60. mplang/v1/runtime/link_comm.py +196 -0
  61. mplang/{runtime → v1/runtime}/server.py +22 -9
  62. mplang/{runtime → v1/runtime}/session.py +24 -51
  63. mplang/{runtime → v1/runtime}/simulation.py +36 -14
  64. mplang/{simp → v1/simp}/api.py +72 -14
  65. mplang/{simp → v1/simp}/mpi.py +1 -1
  66. mplang/{simp → v1/simp}/party.py +5 -5
  67. mplang/{simp → v1/simp}/random.py +2 -2
  68. mplang/v1/simp/smpc.py +238 -0
  69. mplang/v1/utils/table_utils.py +185 -0
  70. mplang/v2/__init__.py +424 -0
  71. mplang/v2/backends/__init__.py +57 -0
  72. mplang/v2/backends/bfv_impl.py +705 -0
  73. mplang/v2/backends/channel.py +217 -0
  74. mplang/v2/backends/crypto_impl.py +723 -0
  75. mplang/v2/backends/field_impl.py +454 -0
  76. mplang/v2/backends/func_impl.py +107 -0
  77. mplang/v2/backends/phe_impl.py +148 -0
  78. mplang/v2/backends/simp_design.md +136 -0
  79. mplang/v2/backends/simp_driver/__init__.py +41 -0
  80. mplang/v2/backends/simp_driver/http.py +168 -0
  81. mplang/v2/backends/simp_driver/mem.py +280 -0
  82. mplang/v2/backends/simp_driver/ops.py +135 -0
  83. mplang/v2/backends/simp_driver/state.py +60 -0
  84. mplang/v2/backends/simp_driver/values.py +52 -0
  85. mplang/v2/backends/simp_worker/__init__.py +29 -0
  86. mplang/v2/backends/simp_worker/http.py +354 -0
  87. mplang/v2/backends/simp_worker/mem.py +102 -0
  88. mplang/v2/backends/simp_worker/ops.py +167 -0
  89. mplang/v2/backends/simp_worker/state.py +49 -0
  90. mplang/v2/backends/spu_impl.py +275 -0
  91. mplang/v2/backends/spu_state.py +187 -0
  92. mplang/v2/backends/store_impl.py +62 -0
  93. mplang/v2/backends/table_impl.py +838 -0
  94. mplang/v2/backends/tee_impl.py +215 -0
  95. mplang/v2/backends/tensor_impl.py +519 -0
  96. mplang/v2/cli.py +603 -0
  97. mplang/v2/cli_guide.md +122 -0
  98. mplang/v2/dialects/__init__.py +36 -0
  99. mplang/v2/dialects/bfv.py +665 -0
  100. mplang/v2/dialects/crypto.py +689 -0
  101. mplang/v2/dialects/dtypes.py +378 -0
  102. mplang/v2/dialects/field.py +210 -0
  103. mplang/v2/dialects/func.py +135 -0
  104. mplang/v2/dialects/phe.py +723 -0
  105. mplang/v2/dialects/simp.py +944 -0
  106. mplang/v2/dialects/spu.py +349 -0
  107. mplang/v2/dialects/store.py +63 -0
  108. mplang/v2/dialects/table.py +407 -0
  109. mplang/v2/dialects/tee.py +346 -0
  110. mplang/v2/dialects/tensor.py +1175 -0
  111. mplang/v2/edsl/README.md +279 -0
  112. mplang/v2/edsl/__init__.py +99 -0
  113. mplang/v2/edsl/context.py +311 -0
  114. mplang/v2/edsl/graph.py +463 -0
  115. mplang/v2/edsl/jit.py +62 -0
  116. mplang/v2/edsl/object.py +53 -0
  117. mplang/v2/edsl/primitive.py +284 -0
  118. mplang/v2/edsl/printer.py +119 -0
  119. mplang/v2/edsl/registry.py +207 -0
  120. mplang/v2/edsl/serde.py +375 -0
  121. mplang/v2/edsl/tracer.py +614 -0
  122. mplang/v2/edsl/typing.py +816 -0
  123. mplang/v2/kernels/Makefile +30 -0
  124. mplang/v2/kernels/__init__.py +23 -0
  125. mplang/v2/kernels/gf128.cpp +148 -0
  126. mplang/v2/kernels/ldpc.cpp +82 -0
  127. mplang/v2/kernels/okvs.cpp +283 -0
  128. mplang/v2/kernels/okvs_opt.cpp +291 -0
  129. mplang/v2/kernels/py_kernels.py +398 -0
  130. mplang/v2/libs/collective.py +330 -0
  131. mplang/v2/libs/device/__init__.py +51 -0
  132. mplang/v2/libs/device/api.py +813 -0
  133. mplang/v2/libs/device/cluster.py +352 -0
  134. mplang/v2/libs/ml/__init__.py +23 -0
  135. mplang/v2/libs/ml/sgb.py +1861 -0
  136. mplang/v2/libs/mpc/__init__.py +41 -0
  137. mplang/v2/libs/mpc/_utils.py +99 -0
  138. mplang/v2/libs/mpc/analytics/__init__.py +35 -0
  139. mplang/v2/libs/mpc/analytics/aggregation.py +372 -0
  140. mplang/v2/libs/mpc/analytics/groupby.md +99 -0
  141. mplang/v2/libs/mpc/analytics/groupby.py +331 -0
  142. mplang/v2/libs/mpc/analytics/permutation.py +386 -0
  143. mplang/v2/libs/mpc/common/constants.py +39 -0
  144. mplang/v2/libs/mpc/ot/__init__.py +32 -0
  145. mplang/v2/libs/mpc/ot/base.py +222 -0
  146. mplang/v2/libs/mpc/ot/extension.py +477 -0
  147. mplang/v2/libs/mpc/ot/silent.py +217 -0
  148. mplang/v2/libs/mpc/psi/__init__.py +40 -0
  149. mplang/v2/libs/mpc/psi/cuckoo.py +228 -0
  150. mplang/v2/libs/mpc/psi/okvs.py +49 -0
  151. mplang/v2/libs/mpc/psi/okvs_gct.py +79 -0
  152. mplang/v2/libs/mpc/psi/oprf.py +310 -0
  153. mplang/v2/libs/mpc/psi/rr22.py +344 -0
  154. mplang/v2/libs/mpc/psi/unbalanced.py +200 -0
  155. mplang/v2/libs/mpc/vole/__init__.py +31 -0
  156. mplang/v2/libs/mpc/vole/gilboa.py +327 -0
  157. mplang/v2/libs/mpc/vole/ldpc.py +383 -0
  158. mplang/v2/libs/mpc/vole/silver.py +336 -0
  159. mplang/v2/runtime/__init__.py +15 -0
  160. mplang/v2/runtime/dialect_state.py +41 -0
  161. mplang/v2/runtime/interpreter.py +871 -0
  162. mplang/v2/runtime/object_store.py +194 -0
  163. mplang/v2/runtime/value.py +141 -0
  164. {mplang_nightly-0.1.dev192.dist-info → mplang_nightly-0.1.dev268.dist-info}/METADATA +22 -16
  165. mplang_nightly-0.1.dev268.dist-info/RECORD +180 -0
  166. {mplang_nightly-0.1.dev192.dist-info → mplang_nightly-0.1.dev268.dist-info}/WHEEL +1 -1
  167. mplang/device.py +0 -327
  168. mplang/ops/crypto.py +0 -108
  169. mplang/ops/ibis_cc.py +0 -136
  170. mplang/ops/sql_cc.py +0 -62
  171. mplang/runtime/link_comm.py +0 -78
  172. mplang/simp/smpc.py +0 -201
  173. mplang/utils/table_utils.py +0 -85
  174. mplang_nightly-0.1.dev192.dist-info/RECORD +0 -83
  175. /mplang/{core → v1/core}/mask.py +0 -0
  176. /mplang/{protos → v1/protos}/v1alpha1/mpir_pb2.py +0 -0
  177. /mplang/{protos → v1/protos}/v1alpha1/mpir_pb2.pyi +0 -0
  178. /mplang/{protos → v1/protos}/v1alpha1/value_pb2.py +0 -0
  179. /mplang/{protos → v1/protos}/v1alpha1/value_pb2.pyi +0 -0
  180. /mplang/{runtime → v1/runtime}/exceptions.py +0 -0
  181. /mplang/{runtime → v1/runtime}/http_api.md +0 -0
  182. /mplang/{simp → v1/simp}/__init__.py +0 -0
  183. /mplang/{utils → v1/utils}/__init__.py +0 -0
  184. /mplang/{utils → v1/utils}/crypto.py +0 -0
  185. /mplang/{utils → v1/utils}/func_utils.py +0 -0
  186. /mplang/{utils → v1/utils}/spu_utils.py +0 -0
  187. {mplang_nightly-0.1.dev192.dist-info → mplang_nightly-0.1.dev268.dist-info}/entry_points.txt +0 -0
  188. {mplang_nightly-0.1.dev192.dist-info → mplang_nightly-0.1.dev268.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,689 @@
1
+ # Copyright 2025 Ant Group Co., Ltd.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """Crypto dialect for the EDSL.
16
+
17
+ Provides cryptographic primitives including ECC, Hashing, and Symmetric Encryption.
18
+ """
19
+
20
+ from __future__ import annotations
21
+
22
+ from typing import Any, ClassVar
23
+
24
+ import mplang.v2.edsl as el
25
+ import mplang.v2.edsl.typing as elt
26
+ from mplang.v2.edsl import serde
27
+
28
+ # ==============================================================================
29
+ # --- Type Definitions
30
+ # ==============================================================================
31
+
32
+
33
+ @serde.register_class
34
+ class PointType(elt.BaseType):
35
+ """Type for an ECC Point."""
36
+
37
+ def __init__(self, curve: str = "secp256k1"):
38
+ self.curve = curve
39
+
40
+ def __str__(self) -> str:
41
+ return f"Point[{self.curve}]"
42
+
43
+ def __eq__(self, other: object) -> bool:
44
+ if not isinstance(other, PointType):
45
+ return False
46
+ return self.curve == other.curve
47
+
48
+ def __hash__(self) -> int:
49
+ return hash(("PointType", self.curve))
50
+
51
+ # --- Serde methods ---
52
+ _serde_kind: ClassVar[str] = "crypto.PointType"
53
+
54
+ def to_json(self) -> dict[str, Any]:
55
+ return {"curve": self.curve}
56
+
57
+ @classmethod
58
+ def from_json(cls, data: dict[str, Any]) -> PointType:
59
+ return cls(curve=data["curve"])
60
+
61
+
62
+ @serde.register_class
63
+ class ScalarType(elt.BaseType):
64
+ """Type for an ECC Scalar (integer modulo curve order)."""
65
+
66
+ def __init__(self, curve: str = "secp256k1"):
67
+ self.curve = curve
68
+
69
+ def __str__(self) -> str:
70
+ return f"Scalar[{self.curve}]"
71
+
72
+ def __eq__(self, other: object) -> bool:
73
+ if not isinstance(other, ScalarType):
74
+ return False
75
+ return self.curve == other.curve
76
+
77
+ def __hash__(self) -> int:
78
+ return hash(("ScalarType", self.curve))
79
+
80
+ # --- Serde methods ---
81
+ _serde_kind: ClassVar[str] = "crypto.ScalarType"
82
+
83
+ def to_json(self) -> dict[str, Any]:
84
+ return {"curve": self.curve}
85
+
86
+ @classmethod
87
+ def from_json(cls, data: dict[str, Any]) -> ScalarType:
88
+ return cls(curve=data["curve"])
89
+
90
+
91
+ @serde.register_class
92
+ class PrivateKeyType(elt.BaseType):
93
+ """Type for a KEM private key."""
94
+
95
+ def __init__(self, suite: str = "x25519"):
96
+ self.suite = suite
97
+
98
+ def __str__(self) -> str:
99
+ return f"PrivateKey[{self.suite}]"
100
+
101
+ def __eq__(self, other: object) -> bool:
102
+ if not isinstance(other, PrivateKeyType):
103
+ return False
104
+ return self.suite == other.suite
105
+
106
+ def __hash__(self) -> int:
107
+ return hash(("PrivateKeyType", self.suite))
108
+
109
+ # --- Serde methods ---
110
+ _serde_kind: ClassVar[str] = "crypto.PrivateKeyType"
111
+
112
+ def to_json(self) -> dict[str, Any]:
113
+ return {"suite": self.suite}
114
+
115
+ @classmethod
116
+ def from_json(cls, data: dict[str, Any]) -> PrivateKeyType:
117
+ return cls(suite=data["suite"])
118
+
119
+
120
+ @serde.register_class
121
+ class PublicKeyType(elt.BaseType):
122
+ """Type for a KEM public key."""
123
+
124
+ def __init__(self, suite: str = "x25519"):
125
+ self.suite = suite
126
+
127
+ def __str__(self) -> str:
128
+ return f"PublicKey[{self.suite}]"
129
+
130
+ def __eq__(self, other: object) -> bool:
131
+ if not isinstance(other, PublicKeyType):
132
+ return False
133
+ return self.suite == other.suite
134
+
135
+ def __hash__(self) -> int:
136
+ return hash(("PublicKeyType", self.suite))
137
+
138
+ # --- Serde methods ---
139
+ _serde_kind: ClassVar[str] = "crypto.PublicKeyType"
140
+
141
+ def to_json(self) -> dict[str, Any]:
142
+ return {"suite": self.suite}
143
+
144
+ @classmethod
145
+ def from_json(cls, data: dict[str, Any]) -> PublicKeyType:
146
+ return cls(suite=data["suite"])
147
+
148
+
149
+ @serde.register_class
150
+ class SymmetricKeyType(elt.BaseType):
151
+ """Type for a symmetric encryption key (e.g., from KEM derive)."""
152
+
153
+ def __init__(self, suite: str = "x25519"):
154
+ self.suite = suite
155
+
156
+ def __str__(self) -> str:
157
+ return f"SymmetricKey[{self.suite}]"
158
+
159
+ def __eq__(self, other: object) -> bool:
160
+ if not isinstance(other, SymmetricKeyType):
161
+ return False
162
+ return self.suite == other.suite
163
+
164
+ def __hash__(self) -> int:
165
+ return hash(("SymmetricKeyType", self.suite))
166
+
167
+ # --- Serde methods ---
168
+ _serde_kind: ClassVar[str] = "crypto.SymmetricKeyType"
169
+
170
+ def to_json(self) -> dict[str, Any]:
171
+ return {"suite": self.suite}
172
+
173
+ @classmethod
174
+ def from_json(cls, data: dict[str, Any]) -> SymmetricKeyType:
175
+ return cls(suite=data["suite"])
176
+
177
+
178
+ # ==============================================================================
179
+ # --- Primitives
180
+ # ==============================================================================
181
+
182
+ # ECC
183
+ generator_p = el.Primitive[el.Object]("crypto.ec_generator")
184
+ mul_p = el.Primitive[el.Object]("crypto.ec_mul")
185
+ add_p = el.Primitive[el.Object]("crypto.ec_add")
186
+ sub_p = el.Primitive[el.Object]("crypto.ec_sub")
187
+ point_to_bytes_p = el.Primitive[el.Object]("crypto.ec_point_to_bytes")
188
+ random_scalar_p = el.Primitive[el.Object]("crypto.ec_random_scalar")
189
+ scalar_from_int_p = el.Primitive[el.Object]("crypto.ec_scalar_from_int")
190
+
191
+ # Symmetric / Hash
192
+ hash_p = el.Primitive[el.Object]("crypto.hash")
193
+ hash_batch_p = el.Primitive[el.Object]("crypto.hash_batch")
194
+ sym_encrypt_p = el.Primitive[el.Object]("crypto.sym_encrypt")
195
+ sym_decrypt_p = el.Primitive[el.Object]("crypto.sym_decrypt")
196
+ select_p = el.Primitive[el.Object]("crypto.select")
197
+
198
+ # KEM (Key Encapsulation Mechanism)
199
+ kem_keygen_p = el.Primitive[tuple[el.Object, el.Object]]("crypto.kem_keygen")
200
+ kem_derive_p = el.Primitive[el.Object]("crypto.kem_derive")
201
+
202
+ # HKDF (Key Derivation Function)
203
+ hkdf_p = el.Primitive[el.Object]("crypto.hkdf")
204
+
205
+ # Randomness
206
+ random_bytes_p = el.Primitive[el.Object]("crypto.random_bytes")
207
+
208
+
209
+ # ==============================================================================
210
+ # --- Abstract Evaluation (Type Inference)
211
+ # ==============================================================================
212
+
213
+
214
+ @generator_p.def_abstract_eval
215
+ def _generator_ae(curve: str = "secp256k1") -> PointType:
216
+ return PointType(curve)
217
+
218
+
219
+ @mul_p.def_abstract_eval
220
+ def _mul_ae(point: PointType, scalar: ScalarType) -> PointType:
221
+ return PointType(point.curve)
222
+
223
+
224
+ @add_p.def_abstract_eval
225
+ def _add_ae(p1: PointType, p2: PointType) -> PointType:
226
+ return PointType(p1.curve)
227
+
228
+
229
+ @sub_p.def_abstract_eval
230
+ def _sub_ae(p1: PointType, p2: PointType) -> PointType:
231
+ return PointType(p1.curve)
232
+
233
+
234
+ @point_to_bytes_p.def_abstract_eval
235
+ def _pt_to_bytes_ae(point: elt.BaseType) -> elt.TensorType:
236
+ if isinstance(point, elt.TensorType):
237
+ # Vectorized behavior: Tensor[Point, shape] -> Tensor[u8, shape + (65,)]
238
+ return elt.TensorType(elt.u8, (*point.shape, 65))
239
+ return elt.TensorType(elt.u8, (65,))
240
+
241
+
242
+ @random_scalar_p.def_abstract_eval
243
+ def _random_scalar_ae(curve: str = "secp256k1") -> ScalarType:
244
+ return ScalarType(curve)
245
+
246
+
247
+ @scalar_from_int_p.def_abstract_eval
248
+ def _scalar_from_int_ae(
249
+ val: elt.TensorType | elt.IntegerType, curve: str = "secp256k1"
250
+ ) -> ScalarType:
251
+ return ScalarType(curve)
252
+
253
+
254
+ @hash_p.def_abstract_eval
255
+ def _hash_ae(data: elt.BaseType) -> elt.TensorType:
256
+ # Strictly single output (blob hash)
257
+ return elt.TensorType(elt.u8, (32,))
258
+
259
+
260
+ @hash_batch_p.def_abstract_eval
261
+ def _hash_batch_ae(data: elt.BaseType) -> elt.TensorType:
262
+ # Explicit batch hashing: Input (..., D) -> Output (..., 32)
263
+ # Hashes the last dimension D bytes.
264
+ if not isinstance(data, elt.TensorType):
265
+ raise TypeError(f"hash_batch requires TensorType, got {data}")
266
+
267
+ # data.shape is tuple[int | None, ...]
268
+ shape = data.shape
269
+ if len(shape) < 2:
270
+ # Fallback/Edge case: (D,) -> (32,)
271
+ # One could argue this should be an error for *batch* primitive,
272
+ # but allowing it provides consistency for (N=1, D).
273
+ return elt.TensorType(elt.u8, (32,))
274
+
275
+ # Batch shape is everything except last dim
276
+ batch_shape = shape[:-1]
277
+ return elt.TensorType(elt.u8, (*batch_shape, 32))
278
+
279
+
280
+ @sym_encrypt_p.def_abstract_eval
281
+ def _sym_encrypt_ae(
282
+ key: elt.BaseType, plaintext: elt.BaseType, *, algo: str = "aes-gcm"
283
+ ) -> elt.TensorType:
284
+ """Abstract evaluation for symmetric encryption.
285
+
286
+ Args:
287
+ key: Symmetric encryption key
288
+ plaintext: Data to encrypt
289
+ algo: Encryption algorithm (keyword-only, validated at runtime)
290
+
291
+ Returns:
292
+ Ciphertext as dynamic-length uint8 tensor
293
+ """
294
+ # Dynamic shape for ciphertext
295
+ # algo validation is done at backend impl, not here
296
+ return elt.TensorType(elt.u8, (-1,))
297
+
298
+
299
+ @sym_decrypt_p.def_abstract_eval
300
+ def _sym_decrypt_ae(
301
+ key: elt.BaseType,
302
+ ciphertext: elt.BaseType,
303
+ *,
304
+ target_type: elt.BaseType,
305
+ algo: str = "aes-gcm",
306
+ ) -> elt.BaseType:
307
+ """Abstract evaluation for symmetric decryption.
308
+
309
+ Args:
310
+ key: Symmetric decryption key
311
+ ciphertext: Encrypted data
312
+ target_type: Expected type of decrypted plaintext (keyword-only)
313
+ algo: Decryption algorithm (keyword-only, validated at runtime)
314
+
315
+ Returns:
316
+ Decrypted plaintext with type matching target_type
317
+ """
318
+ # algo validation is done at backend impl, not here
319
+ return target_type
320
+
321
+
322
+ @select_p.def_abstract_eval
323
+ def _select_ae(
324
+ cond: elt.BaseType, true_val: elt.BaseType, false_val: elt.BaseType
325
+ ) -> elt.BaseType:
326
+ return true_val
327
+
328
+
329
+ @kem_keygen_p.def_abstract_eval
330
+ def _kem_keygen_ae(suite: str = "x25519") -> tuple[PrivateKeyType, PublicKeyType]:
331
+ return (PrivateKeyType(suite), PublicKeyType(suite))
332
+
333
+
334
+ @kem_derive_p.def_abstract_eval
335
+ def _kem_derive_ae(
336
+ private_key: PrivateKeyType, public_key: PublicKeyType
337
+ ) -> SymmetricKeyType:
338
+ suite = getattr(private_key, "suite", "x25519")
339
+ return SymmetricKeyType(suite)
340
+
341
+
342
+ @hkdf_p.def_abstract_eval
343
+ def _hkdf_ae(
344
+ secret: elt.BaseType, *, info: str, hash_algo: str = "sha256"
345
+ ) -> SymmetricKeyType:
346
+ """Abstract evaluation for HKDF key derivation.
347
+
348
+ Args:
349
+ secret: Input key material (SymmetricKeyType from kem_derive or TensorType[u8])
350
+ info: Context string for domain separation (required, non-empty, keyword-only)
351
+ hash_algo: Hash algorithm in lowercase without hyphens (e.g., "sha256", keyword-only)
352
+
353
+ Returns:
354
+ SymmetricKeyType with suite="hkdf-{hash_algo}"
355
+
356
+ Raises:
357
+ TypeError: If info or hash_algo is not a string
358
+ ValueError: If info is empty (required for domain separation per NIST)
359
+ """
360
+ # Validate info and hash_algo at trace time
361
+ if not isinstance(info, str) or not info:
362
+ raise ValueError(
363
+ "HKDF requires non-empty 'info' parameter for domain separation. "
364
+ "The info string binds the derived key to a specific protocol/context. "
365
+ "Recommended format: 'namespace/component/purpose/version'"
366
+ )
367
+ if not isinstance(hash_algo, str) or not hash_algo:
368
+ raise TypeError("hash_algo must be a non-empty string")
369
+
370
+ # Normalize: lowercase, no hyphens
371
+ hash_algo_normalized = hash_algo.lower().replace("-", "").replace("_", "")
372
+
373
+ # Return SymmetricKeyType with composite suite indicating derivation method
374
+ return SymmetricKeyType(suite=f"hkdf-{hash_algo_normalized}")
375
+
376
+
377
+ @random_bytes_p.def_abstract_eval
378
+ def _random_bytes_ae(length: int) -> elt.TensorType:
379
+ return elt.TensorType(elt.u8, (length,))
380
+
381
+
382
+ # ==============================================================================
383
+ # --- Helper Functions (Ops)
384
+ # ==============================================================================
385
+
386
+
387
+ def ec_generator(curve: str = "secp256k1") -> el.Object:
388
+ """Get the generator point G for the curve."""
389
+ return generator_p.bind(curve=curve)
390
+
391
+
392
+ def ec_mul(point: el.Object, scalar: el.Object) -> el.Object:
393
+ """Scalar multiplication: point * scalar."""
394
+ return mul_p.bind(point, scalar)
395
+
396
+
397
+ def ec_add(p1: el.Object, p2: el.Object) -> el.Object:
398
+ """Point addition: p1 + p2."""
399
+ return add_p.bind(p1, p2)
400
+
401
+
402
+ def ec_sub(p1: el.Object, p2: el.Object) -> el.Object:
403
+ """Point subtraction: p1 - p2."""
404
+ return sub_p.bind(p1, p2)
405
+
406
+
407
+ def ec_point_to_bytes(point: el.Object) -> el.Object:
408
+ """Serialize point to bytes."""
409
+ return point_to_bytes_p.bind(point)
410
+
411
+
412
+ def ec_random_scalar(curve: str = "secp256k1") -> el.Object:
413
+ """Generate a random scalar."""
414
+ return random_scalar_p.bind(curve=curve)
415
+
416
+
417
+ def ec_scalar_from_int(val: el.Object, curve: str = "secp256k1") -> el.Object:
418
+ """Convert an integer tensor to a scalar."""
419
+ return scalar_from_int_p.bind(val, curve=curve)
420
+
421
+
422
+ def hash_bytes(data: el.Object) -> el.Object:
423
+ """Hash bytes (SHA256). Returns 32-byte tensor."""
424
+ return hash_p.bind(data)
425
+
426
+
427
+ def hash_batch(data: el.Object) -> el.Object:
428
+ """Hash each row of a tensor independently.
429
+
430
+ Treats the last dimension as the data to hash.
431
+ Input: (N, D) -> Output: (N, 32)
432
+ Input: (B, N, D) -> Output: (B, N, 32)
433
+ """
434
+ return hash_batch_p.bind(data)
435
+
436
+
437
+ def sym_encrypt(
438
+ key: el.Object, plaintext: el.Object, *, algo: str = "aes-gcm"
439
+ ) -> el.Object:
440
+ """Symmetric encrypt.
441
+
442
+ Args:
443
+ key: Symmetric encryption key (SymmetricKeyType or bytes).
444
+ plaintext: Data to encrypt (any serializable object).
445
+ algo: Encryption algorithm. Currently only "aes-gcm" is supported.
446
+ Validation is performed at backend execution time.
447
+
448
+ Returns:
449
+ Ciphertext as Tensor[u8, (-1,)].
450
+
451
+ Raises:
452
+ ValueError: At runtime if algo is not supported by the backend.
453
+ """
454
+ return sym_encrypt_p.bind(key, plaintext, algo=algo)
455
+
456
+
457
+ def sym_decrypt(
458
+ key: el.Object,
459
+ ciphertext: el.Object,
460
+ target_type: elt.BaseType,
461
+ *,
462
+ algo: str = "aes-gcm",
463
+ ) -> el.Object:
464
+ """Symmetric decrypt.
465
+
466
+ Args:
467
+ key: Symmetric decryption key (SymmetricKeyType or bytes).
468
+ ciphertext: Encrypted data.
469
+ target_type: Expected type of the decrypted plaintext (for type inference).
470
+ algo: Decryption algorithm. Must match the algorithm used for encryption.
471
+ Currently only "aes-gcm" is supported.
472
+ Validation is performed at backend execution time.
473
+
474
+ Returns:
475
+ Decrypted plaintext with type matching target_type.
476
+
477
+ Raises:
478
+ ValueError: At runtime if algo is not supported by the backend.
479
+ """
480
+ return sym_decrypt_p.bind(key, ciphertext, target_type=target_type, algo=algo)
481
+
482
+
483
+ def select(cond: el.Object, true_val: el.Object, false_val: el.Object) -> el.Object:
484
+ """Select between two values based on condition."""
485
+ return select_p.bind(cond, true_val, false_val)
486
+
487
+
488
+ def kem_keygen(suite: str = "x25519") -> tuple[el.Object, el.Object]:
489
+ """Generate a KEM key pair (private_key, public_key).
490
+
491
+ Args:
492
+ suite: The KEM suite to use (e.g., "x25519", "kyber768")
493
+
494
+ Returns:
495
+ A tuple of (private_key, public_key)
496
+ """
497
+ return kem_keygen_p.bind(suite=suite)
498
+
499
+
500
+ def kem_derive(private_key: el.Object, public_key: el.Object) -> el.Object:
501
+ """Derive a symmetric key from a private key and a public key (ECDH).
502
+
503
+ Args:
504
+ private_key: The local private key
505
+ public_key: The remote party's public key
506
+
507
+ Returns:
508
+ A symmetric key suitable for use with sym_encrypt/sym_decrypt
509
+ """
510
+ return kem_derive_p.bind(private_key, public_key)
511
+
512
+
513
+ def hkdf(secret: el.Object, info: str, *, hash_algo: str = "sha256") -> el.Object:
514
+ """Derive a cryptographic key from input key material using HKDF.
515
+
516
+ HKDF (HMAC-based Key Derivation Function) is specified in RFC 5869 and
517
+ required by NIST SP 800-56C Rev.2 for deriving symmetric keys from
518
+ key agreement schemes like ECDH. Per NIST: "The shared secret output
519
+ from a key-agreement scheme SHALL NOT be used directly as a cryptographic
520
+ key. A key-derivation function (KDF) SHALL be used."
521
+
522
+ Args:
523
+ secret: Input key material (IKM). Accepts:
524
+ - SymmetricKeyValue: Typically from crypto.kem_derive (ECDH output)
525
+ - TensorType[u8, (N,)]: Raw bytes (N-byte secret)
526
+ info: Application-specific context string for domain separation.
527
+ REQUIRED and must be non-empty. Different info values produce
528
+ cryptographically independent keys even from the same secret.
529
+ Recommended format: "namespace/component/purpose/version"
530
+ Example: "mplang/device/tee/v2"
531
+ hash_algo: Hash function to use. Must be lowercase without hyphens.
532
+ Currently supported: "sha256" (default)
533
+ Future support planned: "sha512", "sha3256", "blake2b"
534
+ Default "sha256" provides 128-bit security level.
535
+
536
+ Returns:
537
+ SymmetricKeyValue with:
538
+ - suite: "hkdf-{hash_algo}" (e.g., "hkdf-sha256")
539
+ - key_bytes: 32-byte derived key suitable for AES-256-GCM
540
+
541
+ Security considerations:
542
+ - Output length: Fixed at 32 bytes (256 bits) for AES-256 keys
543
+ - Salt: Uses salt=None (acceptable for ECDH output per NIST guidance)
544
+ - Info: Provides protocol/context binding (domain separation)
545
+ - Deterministic: Same (secret, info, hash_algo) always produces same key
546
+
547
+ Raises:
548
+ ValueError:
549
+ - At abstract evaluation time if hash_algo is unsupported.
550
+ - At execution time if info is empty.
551
+ NotImplementedError:
552
+ - At execution time if hash_algo is not "sha256".
553
+
554
+ Examples:
555
+ >>> # Standard TEE session establishment
556
+ >>> sk_local, pk_local = crypto.kem_keygen("x25519")
557
+ >>> sk_remote, pk_remote = crypto.kem_keygen("x25519")
558
+ >>> # ECDH on both sides
559
+ >>> shared_local = crypto.kem_derive(sk_local, pk_remote)
560
+ >>> shared_remote = crypto.kem_derive(sk_remote, pk_local)
561
+ >>> # HKDF for domain separation
562
+ >>> sess_local = crypto.hkdf(shared_local, "mplang/device/tee/v2")
563
+ >>> sess_remote = crypto.hkdf(shared_remote, "mplang/device/tee/v2")
564
+ >>> # sess_local and sess_remote have identical key_bytes
565
+ >>> # but suite="hkdf-sha256" (not "x25519")
566
+ >>>
567
+ >>> # Derive multiple independent keys from one master secret
568
+ >>> master_secret = crypto.kem_derive(sk, pk)
569
+ >>> encryption_key = crypto.hkdf(master_secret, "app/encryption/v1")
570
+ >>> mac_key = crypto.hkdf(master_secret, "app/mac/v1")
571
+ >>> # encryption_key ≠ mac_key due to different info strings
572
+ """
573
+ return hkdf_p.bind(secret, info=info, hash_algo=hash_algo)
574
+
575
+
576
+ def random_bytes(length: int) -> el.Object:
577
+ """Generate cryptographically secure random bytes at runtime.
578
+
579
+ Args:
580
+ length: Number of bytes to generate.
581
+
582
+ Returns:
583
+ (length,) uint8 Tensor.
584
+ """
585
+ return random_bytes_p.bind(length=length)
586
+
587
+
588
+ def random_tensor(shape: tuple[int, ...], dtype: elt.ScalarType) -> el.Object:
589
+ """Generate cryptographically secure random tensor at runtime.
590
+
591
+ This is a helper function that composes `random_bytes` with `tensor.run_jax`
592
+ to produce a tensor of the specified shape and dtype.
593
+
594
+ Args:
595
+ shape: Output tensor shape (e.g., (100,) or (10, 16)).
596
+ dtype: Element type (e.g., elt.u64, elt.i32, elt.f32).
597
+
598
+ Returns:
599
+ Tensor[dtype, shape] with CSPRNG values.
600
+
601
+ Example:
602
+ >>> # Generate 100 random uint64 values
603
+ >>> x = crypto.random_tensor((100,), elt.u64)
604
+ >>> # Generate 10x16 random int32 matrix
605
+ >>> y = crypto.random_tensor((10, 16), elt.i32)
606
+ """
607
+ import math
608
+ from typing import cast
609
+
610
+ from mplang.v2.dialects import dtypes, tensor
611
+
612
+ # Get byte size from numpy dtype
613
+ np_dtype = dtypes.to_numpy(dtype)
614
+ element_bytes = np_dtype.itemsize
615
+ total_elements = math.prod(shape)
616
+ total_bytes = total_elements * element_bytes
617
+
618
+ raw = random_bytes(total_bytes)
619
+
620
+ jax_dtype = dtypes.to_jax(dtype)
621
+
622
+ def _view_reshape(b: Any) -> Any:
623
+ return b.view(jax_dtype).reshape(shape)
624
+
625
+ return cast(el.Object, tensor.run_jax(_view_reshape, raw))
626
+
627
+
628
+ def random_bits(n: int) -> el.Object:
629
+ """Generate n cryptographically secure random bits at runtime.
630
+
631
+ Each bit is stored as a uint8 with value 0 or 1 (unpacked representation).
632
+
633
+ Args:
634
+ n: Number of random bits to generate.
635
+
636
+ Returns:
637
+ (n,) uint8 Tensor with values 0 or 1.
638
+
639
+ Example:
640
+ >>> # Generate 1024 random bits for OT selection
641
+ >>> choice_bits = crypto.random_bits(1024)
642
+ """
643
+ from typing import cast
644
+
645
+ import jax.numpy as jnp
646
+
647
+ from mplang.v2.dialects import tensor
648
+
649
+ # Generate enough bytes to cover n bits
650
+ num_bytes = (n + 7) // 8
651
+ raw = random_bytes(num_bytes)
652
+
653
+ def _unpack_and_slice(b: Any, n: int = n) -> Any:
654
+ bits = jnp.unpackbits(b, bitorder="little")
655
+ return bits[:n]
656
+
657
+ return cast(el.Object, tensor.run_jax(_unpack_and_slice, raw))
658
+
659
+
660
+ # --- Bytes <-> Point Conversions ---
661
+
662
+ bytes_to_point_p = el.Primitive[el.Object]("crypto.ec_bytes_to_point")
663
+
664
+
665
+ @bytes_to_point_p.def_abstract_eval
666
+ def _bytes_to_point_ae(b: elt.TensorType) -> PointType:
667
+ return PointType("secp256k1")
668
+
669
+
670
+ def ec_bytes_to_point(b: el.Object) -> el.Object:
671
+ """
672
+ Deserialize bytes to an ECC point.
673
+
674
+ Args:
675
+ b: A (65,) uint8 Tensor representing an uncompressed point in SEC1 format.
676
+ The first byte must be 0x04, followed by 32 bytes for X and 32 bytes for Y.
677
+
678
+ Returns:
679
+ An ECC point object corresponding to the input bytes.
680
+
681
+ Raises:
682
+ ValueError: If the input is not a valid 65-byte uncompressed point representation.
683
+
684
+ Example:
685
+ >>> # Example: Deserialize a point from bytes
686
+ >>> point_bytes = jnp.array([0x04] + [0x01] * 32 + [0x02] * 32, dtype=jnp.uint8)
687
+ >>> point = crypto.ec_bytes_to_point(point_bytes)
688
+ """
689
+ return bytes_to_point_p.bind(b)