mplang-nightly 0.1.dev192__py3-none-any.whl → 0.1.dev268__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (188) hide show
  1. mplang/__init__.py +21 -130
  2. mplang/py.typed +13 -0
  3. mplang/v1/__init__.py +157 -0
  4. mplang/v1/_device.py +602 -0
  5. mplang/{analysis → v1/analysis}/__init__.py +1 -1
  6. mplang/{analysis → v1/analysis}/diagram.py +4 -4
  7. mplang/{core → v1/core}/__init__.py +20 -14
  8. mplang/{core → v1/core}/cluster.py +6 -1
  9. mplang/{core → v1/core}/comm.py +1 -1
  10. mplang/{core → v1/core}/context_mgr.py +1 -1
  11. mplang/{core → v1/core}/dtypes.py +38 -0
  12. mplang/{core → v1/core}/expr/__init__.py +7 -7
  13. mplang/{core → v1/core}/expr/ast.py +11 -13
  14. mplang/{core → v1/core}/expr/evaluator.py +8 -8
  15. mplang/{core → v1/core}/expr/printer.py +6 -6
  16. mplang/{core → v1/core}/expr/transformer.py +2 -2
  17. mplang/{core → v1/core}/expr/utils.py +2 -2
  18. mplang/{core → v1/core}/expr/visitor.py +1 -1
  19. mplang/{core → v1/core}/expr/walk.py +1 -1
  20. mplang/{core → v1/core}/interp.py +6 -6
  21. mplang/{core → v1/core}/mpir.py +13 -11
  22. mplang/{core → v1/core}/mpobject.py +6 -6
  23. mplang/{core → v1/core}/mptype.py +13 -10
  24. mplang/{core → v1/core}/pfunc.py +2 -2
  25. mplang/{core → v1/core}/primitive.py +12 -12
  26. mplang/{core → v1/core}/table.py +36 -8
  27. mplang/{core → v1/core}/tensor.py +1 -1
  28. mplang/{core → v1/core}/tracer.py +9 -9
  29. mplang/{host.py → v1/host.py} +5 -5
  30. mplang/{kernels → v1/kernels}/__init__.py +1 -1
  31. mplang/{kernels → v1/kernels}/base.py +1 -1
  32. mplang/{kernels → v1/kernels}/basic.py +15 -15
  33. mplang/{kernels → v1/kernels}/context.py +19 -16
  34. mplang/{kernels → v1/kernels}/crypto.py +8 -10
  35. mplang/{kernels → v1/kernels}/fhe.py +9 -7
  36. mplang/{kernels → v1/kernels}/mock_tee.py +3 -3
  37. mplang/{kernels → v1/kernels}/phe.py +26 -18
  38. mplang/{kernels → v1/kernels}/spu.py +5 -5
  39. mplang/{kernels → v1/kernels}/sql_duckdb.py +5 -3
  40. mplang/{kernels → v1/kernels}/stablehlo.py +18 -17
  41. mplang/{kernels → v1/kernels}/value.py +2 -2
  42. mplang/{ops → v1/ops}/__init__.py +3 -3
  43. mplang/{ops → v1/ops}/base.py +1 -1
  44. mplang/{ops → v1/ops}/basic.py +6 -5
  45. mplang/v1/ops/crypto.py +262 -0
  46. mplang/{ops → v1/ops}/fhe.py +2 -2
  47. mplang/{ops → v1/ops}/jax_cc.py +26 -59
  48. mplang/v1/ops/nnx_cc.py +168 -0
  49. mplang/{ops → v1/ops}/phe.py +16 -3
  50. mplang/{ops → v1/ops}/spu.py +3 -3
  51. mplang/v1/ops/sql_cc.py +303 -0
  52. mplang/{ops → v1/ops}/tee.py +2 -2
  53. mplang/{runtime → v1/runtime}/__init__.py +2 -2
  54. mplang/v1/runtime/channel.py +230 -0
  55. mplang/{runtime → v1/runtime}/cli.py +3 -3
  56. mplang/{runtime → v1/runtime}/client.py +1 -1
  57. mplang/{runtime → v1/runtime}/communicator.py +39 -15
  58. mplang/{runtime → v1/runtime}/data_providers.py +80 -19
  59. mplang/{runtime → v1/runtime}/driver.py +4 -4
  60. mplang/v1/runtime/link_comm.py +196 -0
  61. mplang/{runtime → v1/runtime}/server.py +22 -9
  62. mplang/{runtime → v1/runtime}/session.py +24 -51
  63. mplang/{runtime → v1/runtime}/simulation.py +36 -14
  64. mplang/{simp → v1/simp}/api.py +72 -14
  65. mplang/{simp → v1/simp}/mpi.py +1 -1
  66. mplang/{simp → v1/simp}/party.py +5 -5
  67. mplang/{simp → v1/simp}/random.py +2 -2
  68. mplang/v1/simp/smpc.py +238 -0
  69. mplang/v1/utils/table_utils.py +185 -0
  70. mplang/v2/__init__.py +424 -0
  71. mplang/v2/backends/__init__.py +57 -0
  72. mplang/v2/backends/bfv_impl.py +705 -0
  73. mplang/v2/backends/channel.py +217 -0
  74. mplang/v2/backends/crypto_impl.py +723 -0
  75. mplang/v2/backends/field_impl.py +454 -0
  76. mplang/v2/backends/func_impl.py +107 -0
  77. mplang/v2/backends/phe_impl.py +148 -0
  78. mplang/v2/backends/simp_design.md +136 -0
  79. mplang/v2/backends/simp_driver/__init__.py +41 -0
  80. mplang/v2/backends/simp_driver/http.py +168 -0
  81. mplang/v2/backends/simp_driver/mem.py +280 -0
  82. mplang/v2/backends/simp_driver/ops.py +135 -0
  83. mplang/v2/backends/simp_driver/state.py +60 -0
  84. mplang/v2/backends/simp_driver/values.py +52 -0
  85. mplang/v2/backends/simp_worker/__init__.py +29 -0
  86. mplang/v2/backends/simp_worker/http.py +354 -0
  87. mplang/v2/backends/simp_worker/mem.py +102 -0
  88. mplang/v2/backends/simp_worker/ops.py +167 -0
  89. mplang/v2/backends/simp_worker/state.py +49 -0
  90. mplang/v2/backends/spu_impl.py +275 -0
  91. mplang/v2/backends/spu_state.py +187 -0
  92. mplang/v2/backends/store_impl.py +62 -0
  93. mplang/v2/backends/table_impl.py +838 -0
  94. mplang/v2/backends/tee_impl.py +215 -0
  95. mplang/v2/backends/tensor_impl.py +519 -0
  96. mplang/v2/cli.py +603 -0
  97. mplang/v2/cli_guide.md +122 -0
  98. mplang/v2/dialects/__init__.py +36 -0
  99. mplang/v2/dialects/bfv.py +665 -0
  100. mplang/v2/dialects/crypto.py +689 -0
  101. mplang/v2/dialects/dtypes.py +378 -0
  102. mplang/v2/dialects/field.py +210 -0
  103. mplang/v2/dialects/func.py +135 -0
  104. mplang/v2/dialects/phe.py +723 -0
  105. mplang/v2/dialects/simp.py +944 -0
  106. mplang/v2/dialects/spu.py +349 -0
  107. mplang/v2/dialects/store.py +63 -0
  108. mplang/v2/dialects/table.py +407 -0
  109. mplang/v2/dialects/tee.py +346 -0
  110. mplang/v2/dialects/tensor.py +1175 -0
  111. mplang/v2/edsl/README.md +279 -0
  112. mplang/v2/edsl/__init__.py +99 -0
  113. mplang/v2/edsl/context.py +311 -0
  114. mplang/v2/edsl/graph.py +463 -0
  115. mplang/v2/edsl/jit.py +62 -0
  116. mplang/v2/edsl/object.py +53 -0
  117. mplang/v2/edsl/primitive.py +284 -0
  118. mplang/v2/edsl/printer.py +119 -0
  119. mplang/v2/edsl/registry.py +207 -0
  120. mplang/v2/edsl/serde.py +375 -0
  121. mplang/v2/edsl/tracer.py +614 -0
  122. mplang/v2/edsl/typing.py +816 -0
  123. mplang/v2/kernels/Makefile +30 -0
  124. mplang/v2/kernels/__init__.py +23 -0
  125. mplang/v2/kernels/gf128.cpp +148 -0
  126. mplang/v2/kernels/ldpc.cpp +82 -0
  127. mplang/v2/kernels/okvs.cpp +283 -0
  128. mplang/v2/kernels/okvs_opt.cpp +291 -0
  129. mplang/v2/kernels/py_kernels.py +398 -0
  130. mplang/v2/libs/collective.py +330 -0
  131. mplang/v2/libs/device/__init__.py +51 -0
  132. mplang/v2/libs/device/api.py +813 -0
  133. mplang/v2/libs/device/cluster.py +352 -0
  134. mplang/v2/libs/ml/__init__.py +23 -0
  135. mplang/v2/libs/ml/sgb.py +1861 -0
  136. mplang/v2/libs/mpc/__init__.py +41 -0
  137. mplang/v2/libs/mpc/_utils.py +99 -0
  138. mplang/v2/libs/mpc/analytics/__init__.py +35 -0
  139. mplang/v2/libs/mpc/analytics/aggregation.py +372 -0
  140. mplang/v2/libs/mpc/analytics/groupby.md +99 -0
  141. mplang/v2/libs/mpc/analytics/groupby.py +331 -0
  142. mplang/v2/libs/mpc/analytics/permutation.py +386 -0
  143. mplang/v2/libs/mpc/common/constants.py +39 -0
  144. mplang/v2/libs/mpc/ot/__init__.py +32 -0
  145. mplang/v2/libs/mpc/ot/base.py +222 -0
  146. mplang/v2/libs/mpc/ot/extension.py +477 -0
  147. mplang/v2/libs/mpc/ot/silent.py +217 -0
  148. mplang/v2/libs/mpc/psi/__init__.py +40 -0
  149. mplang/v2/libs/mpc/psi/cuckoo.py +228 -0
  150. mplang/v2/libs/mpc/psi/okvs.py +49 -0
  151. mplang/v2/libs/mpc/psi/okvs_gct.py +79 -0
  152. mplang/v2/libs/mpc/psi/oprf.py +310 -0
  153. mplang/v2/libs/mpc/psi/rr22.py +344 -0
  154. mplang/v2/libs/mpc/psi/unbalanced.py +200 -0
  155. mplang/v2/libs/mpc/vole/__init__.py +31 -0
  156. mplang/v2/libs/mpc/vole/gilboa.py +327 -0
  157. mplang/v2/libs/mpc/vole/ldpc.py +383 -0
  158. mplang/v2/libs/mpc/vole/silver.py +336 -0
  159. mplang/v2/runtime/__init__.py +15 -0
  160. mplang/v2/runtime/dialect_state.py +41 -0
  161. mplang/v2/runtime/interpreter.py +871 -0
  162. mplang/v2/runtime/object_store.py +194 -0
  163. mplang/v2/runtime/value.py +141 -0
  164. {mplang_nightly-0.1.dev192.dist-info → mplang_nightly-0.1.dev268.dist-info}/METADATA +22 -16
  165. mplang_nightly-0.1.dev268.dist-info/RECORD +180 -0
  166. {mplang_nightly-0.1.dev192.dist-info → mplang_nightly-0.1.dev268.dist-info}/WHEEL +1 -1
  167. mplang/device.py +0 -327
  168. mplang/ops/crypto.py +0 -108
  169. mplang/ops/ibis_cc.py +0 -136
  170. mplang/ops/sql_cc.py +0 -62
  171. mplang/runtime/link_comm.py +0 -78
  172. mplang/simp/smpc.py +0 -201
  173. mplang/utils/table_utils.py +0 -85
  174. mplang_nightly-0.1.dev192.dist-info/RECORD +0 -83
  175. /mplang/{core → v1/core}/mask.py +0 -0
  176. /mplang/{protos → v1/protos}/v1alpha1/mpir_pb2.py +0 -0
  177. /mplang/{protos → v1/protos}/v1alpha1/mpir_pb2.pyi +0 -0
  178. /mplang/{protos → v1/protos}/v1alpha1/value_pb2.py +0 -0
  179. /mplang/{protos → v1/protos}/v1alpha1/value_pb2.pyi +0 -0
  180. /mplang/{runtime → v1/runtime}/exceptions.py +0 -0
  181. /mplang/{runtime → v1/runtime}/http_api.md +0 -0
  182. /mplang/{simp → v1/simp}/__init__.py +0 -0
  183. /mplang/{utils → v1/utils}/__init__.py +0 -0
  184. /mplang/{utils → v1/utils}/crypto.py +0 -0
  185. /mplang/{utils → v1/utils}/func_utils.py +0 -0
  186. /mplang/{utils → v1/utils}/spu_utils.py +0 -0
  187. {mplang_nightly-0.1.dev192.dist-info → mplang_nightly-0.1.dev268.dist-info}/entry_points.txt +0 -0
  188. {mplang_nightly-0.1.dev192.dist-info → mplang_nightly-0.1.dev268.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,723 @@
1
+ # Copyright 2025 Ant Group Co., Ltd.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """Crypto backend implementation using cryptography and coincurve."""
16
+
17
+ from __future__ import annotations
18
+
19
+ import base64
20
+ import hashlib
21
+ import os
22
+ from dataclasses import dataclass
23
+ from typing import Any, ClassVar
24
+
25
+ import coincurve
26
+ import numpy as np
27
+ from cryptography.hazmat.primitives.ciphers.aead import AESGCM
28
+
29
+ from mplang.v2.backends.tensor_impl import TensorValue
30
+ from mplang.v2.dialects import crypto
31
+ from mplang.v2.edsl import serde
32
+ from mplang.v2.edsl.graph import Operation
33
+ from mplang.v2.runtime.interpreter import Interpreter
34
+ from mplang.v2.runtime.value import Value, WrapValue
35
+
36
+ # =============================================================================
37
+ # BytesValue - Wrapper for raw bytes (keys, hashes, ciphertexts)
38
+ # =============================================================================
39
+
40
+
41
+ @serde.register_class
42
+ class BytesValue(WrapValue[bytes]):
43
+ """Runtime value wrapping raw bytes.
44
+
45
+ Used for cryptographic data like:
46
+ - Hash outputs (32 bytes for SHA-256)
47
+ - Symmetric keys (32 bytes for AES-256)
48
+ - Ciphertexts (variable length)
49
+ - EC point serializations
50
+ """
51
+
52
+ _serde_kind: ClassVar[str] = "crypto_impl.BytesValue"
53
+
54
+ def _convert(self, data: Any) -> bytes:
55
+ if isinstance(data, BytesValue):
56
+ return data.unwrap()
57
+ if isinstance(data, bytes):
58
+ return data
59
+ if isinstance(data, (bytearray, memoryview)):
60
+ return bytes(data)
61
+ # Handle numpy arrays
62
+ if hasattr(data, "tobytes"):
63
+ return bytes(data.tobytes()) # type: ignore[union-attr]
64
+ raise TypeError(f"Cannot convert {type(data).__name__} to bytes")
65
+
66
+ def to_json(self) -> dict[str, Any]:
67
+ return {"data": base64.b64encode(self._data).decode("ascii")}
68
+
69
+ @classmethod
70
+ def from_json(cls, data: dict[str, Any]) -> BytesValue:
71
+ return cls(base64.b64decode(data["data"]))
72
+
73
+
74
+ # =============================================================================
75
+ # ECC Point Wrapper (secp256k1)
76
+ # =============================================================================
77
+
78
+
79
+ @serde.register_class
80
+ class ECPointValue(WrapValue[bytes]):
81
+ """Wrapper for coincurve.PublicKey representing an elliptic curve point.
82
+
83
+ This wraps the external coincurve library's PublicKey type to provide
84
+ proper serialization support via the Value base class.
85
+ """
86
+
87
+ _serde_kind: ClassVar[str] = "crypto_impl.ECPointValue"
88
+
89
+ def _convert(self, data: Any) -> bytes:
90
+ if isinstance(data, ECPointValue):
91
+ return data.unwrap()
92
+ if isinstance(data, bytes):
93
+ return data
94
+ if isinstance(data, coincurve.PublicKey):
95
+ return data.format(compressed=True)
96
+ raise TypeError(f"Expected bytes or coincurve.PublicKey, got {type(data)}")
97
+
98
+ @property
99
+ def key_bytes(self) -> bytes:
100
+ return self._data
101
+
102
+ def to_json(self) -> dict[str, Any]:
103
+ return {"data": base64.b64encode(self._data).decode("ascii")}
104
+
105
+ @classmethod
106
+ def from_json(cls, data: dict[str, Any]) -> ECPointValue:
107
+ return cls(base64.b64decode(data["data"]))
108
+
109
+ @property
110
+ def coincurve_key(self) -> coincurve.PublicKey:
111
+ """Get the underlying coincurve.PublicKey object."""
112
+ return coincurve.PublicKey(self._data)
113
+
114
+ @classmethod
115
+ def from_coincurve(cls, pk: coincurve.PublicKey) -> ECPointValue:
116
+ """Create ECPointValue from a coincurve.PublicKey."""
117
+ return cls(pk)
118
+
119
+
120
+ # --- ECC Impl (Coincurve) ---
121
+
122
+ # secp256k1 order
123
+ N = 0xFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFEBAAEDCE6AF48A03BBFD25E8CD0364141
124
+
125
+
126
+ @crypto.generator_p.def_impl
127
+ def generator_impl(interpreter: Interpreter, op: Operation) -> ECPointValue:
128
+ # Compressed G
129
+ g_bytes = bytes.fromhex(
130
+ "0279BE667EF9DCBBAC55A06295CE870B07029BFCDB2DCE28D959F2815B16F81798"
131
+ )
132
+ return ECPointValue(g_bytes)
133
+
134
+
135
+ @crypto.mul_p.def_impl
136
+ def mul_impl(
137
+ interpreter: Interpreter,
138
+ op: Operation,
139
+ point: ECPointValue | None,
140
+ scalar: int | TensorValue,
141
+ ) -> ECPointValue | None:
142
+ # scalar can be:
143
+ # - int: from ec_random_scalar or ec_scalar_from_int
144
+ # - TensorValue: shouldn't happen but handle for robustness
145
+ # - numpy scalar: from inside elementwise (shouldn't reach here as mul is not in elementwise)
146
+ s_val: int
147
+ if isinstance(scalar, TensorValue):
148
+ raw = scalar.unwrap()
149
+ if hasattr(raw, "item"):
150
+ s_val = int(raw.item())
151
+ else:
152
+ s_val = int(raw)
153
+ elif isinstance(scalar, (int, np.integer)):
154
+ s_val = int(scalar)
155
+ else:
156
+ raise TypeError(
157
+ f"mul_impl scalar must be int or TensorValue, got {type(scalar).__name__}"
158
+ )
159
+
160
+ s_val = s_val % N
161
+
162
+ if s_val == 0:
163
+ return None
164
+
165
+ if point is None:
166
+ return None
167
+
168
+ # coincurve multiply expects bytes
169
+ s_bytes = s_val.to_bytes(32, "big")
170
+ result = point.coincurve_key.multiply(s_bytes)
171
+ return ECPointValue.from_coincurve(result)
172
+
173
+
174
+ @crypto.add_p.def_impl
175
+ def add_impl(
176
+ interpreter: Interpreter,
177
+ op: Operation,
178
+ p1: ECPointValue | None,
179
+ p2: ECPointValue | None,
180
+ ) -> ECPointValue | None:
181
+ if p1 is None:
182
+ return p2
183
+ if p2 is None:
184
+ return p1
185
+ result = p1.coincurve_key.combine([p2.coincurve_key])
186
+ return ECPointValue.from_coincurve(result)
187
+
188
+
189
+ @crypto.sub_p.def_impl
190
+ def sub_impl(
191
+ interpreter: Interpreter,
192
+ op: Operation,
193
+ p1: ECPointValue | None,
194
+ p2: ECPointValue | None,
195
+ ) -> ECPointValue | None:
196
+ # p1 - p2 = p1 + (-p2)
197
+ if p2 is None:
198
+ return p1
199
+
200
+ # Negate p2 by multiplying by (N-1)
201
+ neg_scalar = (N - 1).to_bytes(32, "big")
202
+ neg_p2 = p2.coincurve_key.multiply(neg_scalar)
203
+
204
+ if p1 is None:
205
+ return ECPointValue.from_coincurve(neg_p2)
206
+
207
+ result = p1.coincurve_key.combine([neg_p2])
208
+ return ECPointValue.from_coincurve(result)
209
+
210
+
211
+ @crypto.random_scalar_p.def_impl
212
+ def random_scalar_impl(interpreter: Interpreter, op: Operation) -> int:
213
+ return int.from_bytes(os.urandom(32), "big") % N
214
+
215
+
216
+ @crypto.scalar_from_int_p.def_impl
217
+ def scalar_from_int_impl(
218
+ interpreter: Interpreter, op: Operation, val: TensorValue | int
219
+ ) -> int:
220
+ """Convert a tensor/scalar value to an EC scalar (int).
221
+
222
+ val can be:
223
+ - TensorValue: wrapping a scalar numpy array
224
+ - int/bool: direct Python integer or boolean
225
+ - numpy scalar (np.integer, np.bool_): from inside elementwise operations
226
+ """
227
+ if isinstance(val, TensorValue):
228
+ raw = val.unwrap()
229
+ if hasattr(raw, "item"):
230
+ return int(raw.item())
231
+ return int(raw)
232
+ elif isinstance(val, (int, bool, np.integer, np.bool_)):
233
+ return int(val)
234
+ else:
235
+ raise TypeError(
236
+ f"scalar_from_int val must be TensorValue or int-like, "
237
+ f"got {type(val).__name__}"
238
+ )
239
+
240
+
241
+ @crypto.point_to_bytes_p.def_impl
242
+ def point_to_bytes_impl(
243
+ interpreter: Interpreter, op: Operation, point: ECPointValue | None
244
+ ) -> TensorValue:
245
+ if point is None:
246
+ # Infinity / Identity -> Zeros (65 bytes to match uncompressed format)
247
+ arr = np.zeros(65, dtype=np.uint8)
248
+ return TensorValue(arr)
249
+
250
+ # Returns 65 bytes (uncompressed)
251
+ b = point.coincurve_key.format(compressed=False)
252
+ arr = np.frombuffer(b, dtype=np.uint8).copy()
253
+ return TensorValue(arr)
254
+
255
+
256
+ @crypto.bytes_to_point_p.def_impl
257
+ def bytes_to_point_impl(
258
+ interpreter: Interpreter, op: Operation, b: TensorValue | BytesValue
259
+ ) -> ECPointValue:
260
+ if isinstance(b, TensorValue):
261
+ raw = b.unwrap().tobytes()
262
+ elif isinstance(b, BytesValue):
263
+ raw = b.unwrap()
264
+ else:
265
+ raise TypeError(
266
+ f"bytes_to_point expects TensorValue or BytesValue, got {type(b)}"
267
+ )
268
+
269
+ return ECPointValue(raw)
270
+
271
+
272
+ # --- Sym / Hash Impl ---
273
+
274
+ # Supported symmetric encryption algorithms
275
+ _SUPPORTED_ALGOS = {"aes-gcm"}
276
+
277
+
278
+ def _validate_algo(algo: str, operation: str) -> None:
279
+ """Validate that the algorithm is supported.
280
+
281
+ Args:
282
+ algo: Algorithm name to validate
283
+ operation: Operation name for error message (e.g., "encryption", "decryption")
284
+
285
+ Raises:
286
+ ValueError: If algo is not supported
287
+ """
288
+ if algo not in _SUPPORTED_ALGOS:
289
+ supported = ", ".join(sorted(_SUPPORTED_ALGOS))
290
+ raise ValueError(
291
+ f"Unsupported {operation} algorithm: {algo!r}. "
292
+ f"Supported algorithms: {supported}"
293
+ )
294
+
295
+
296
+ @crypto.hash_p.def_impl
297
+ def hash_impl(interpreter: Interpreter, op: Operation, data: Value) -> Value:
298
+ """Hash input data using SHA-256 (strict single blob)."""
299
+ # data can be BytesValue or TensorValue
300
+ if isinstance(data, BytesValue):
301
+ d = data.unwrap()
302
+ elif isinstance(data, TensorValue):
303
+ # Flatten and hash as single blob
304
+ d = data.unwrap().tobytes()
305
+ else:
306
+ raise TypeError(
307
+ f"hash expects BytesValue or TensorValue, got {type(data).__name__}"
308
+ )
309
+
310
+ h = hashlib.sha256(d).digest()
311
+ arr = np.frombuffer(h, dtype=np.uint8)
312
+ return TensorValue(arr)
313
+
314
+
315
+ @crypto.hash_batch_p.def_impl
316
+ def hash_batch_impl(interpreter: Interpreter, op: Operation, data: Value) -> Value:
317
+ """Hash data treating last dimension as bytes (explicit batching)."""
318
+ if not isinstance(data, TensorValue):
319
+ raise TypeError(f"hash_batch requires TensorValue, got {type(data)}")
320
+
321
+ arr_in = data.unwrap()
322
+
323
+ # Handle scalar / 0D / 1D case simply
324
+ if arr_in.ndim <= 1:
325
+ d = arr_in.tobytes()
326
+ h = hashlib.sha256(d).digest()
327
+ return TensorValue(np.frombuffer(h, dtype=np.uint8))
328
+
329
+ # Batch case: (B1, B2, ..., D)
330
+ batch_shape = arr_in.shape[:-1]
331
+ D = arr_in.shape[-1]
332
+
333
+ flat_in = arr_in.reshape(-1, D)
334
+ num_items = flat_in.shape[0]
335
+
336
+ hashes = []
337
+ for i in range(num_items):
338
+ row_bytes = flat_in[i].tobytes()
339
+ hashes.append(hashlib.sha256(row_bytes).digest())
340
+
341
+ flat_out = np.frombuffer(b"".join(hashes), dtype=np.uint8).reshape(num_items, 32)
342
+ arr_out = flat_out.reshape(*batch_shape, 32)
343
+
344
+ return TensorValue(arr_out)
345
+
346
+
347
+ @crypto.sym_encrypt_p.def_impl
348
+ def sym_encrypt_impl(
349
+ interpreter: Interpreter,
350
+ op: Operation,
351
+ key: SymmetricKeyValue | BytesValue,
352
+ plaintext: Any,
353
+ ) -> BytesValue:
354
+ """Encrypt plaintext using AES-GCM with the given symmetric key.
355
+
356
+ The plaintext can be any JSON-serializable value (Value subclasses,
357
+ numpy arrays, scalars, etc.). This supports both high-level API usage
358
+ (with TensorValue) and elementwise operations (with raw scalars).
359
+ """
360
+ # Read and validate algo parameter (must be provided by frontend)
361
+ algo = op.attrs["algo"]
362
+ _validate_algo(algo, "encryption")
363
+
364
+ # Get raw key bytes - strict type checking
365
+ if isinstance(key, SymmetricKeyValue):
366
+ k = key.key_bytes
367
+ elif isinstance(key, BytesValue):
368
+ k = key.unwrap()
369
+ elif isinstance(key, TensorValue):
370
+ k = key.unwrap().tobytes()
371
+ else:
372
+ raise TypeError(
373
+ f"sym_encrypt key must be SymmetricKeyValue, BytesValue, or TensorValue, "
374
+ f"got {type(key).__name__}"
375
+ )
376
+
377
+ # Serialize the plaintext using secure JSON serde
378
+ # serde.dumps handles Value subclasses, numpy arrays, scalars, etc.
379
+ pt_bytes = serde.dumps(plaintext)
380
+
381
+ # AES-GCM encryption
382
+ aesgcm = AESGCM(k)
383
+ nonce = os.urandom(12)
384
+ ct = aesgcm.encrypt(nonce, pt_bytes, None)
385
+
386
+ # Result: nonce + ct
387
+ return BytesValue(nonce + ct)
388
+
389
+
390
+ @crypto.sym_decrypt_p.def_impl
391
+ def sym_decrypt_impl(
392
+ interpreter: Interpreter,
393
+ op: Operation,
394
+ key: SymmetricKeyValue | BytesValue,
395
+ ciphertext: BytesValue,
396
+ target_type: Any = None,
397
+ ) -> Any:
398
+ """Decrypt ciphertext using AES-GCM with the given symmetric key.
399
+
400
+ Returns the original plaintext value that was encrypted. The type depends
401
+ on what was encrypted - could be a Value subclass (TensorValue, BytesValue),
402
+ a numpy array, or a scalar (int, float, etc.) when used in elementwise ops.
403
+ """
404
+ # Read and validate algo parameter (must be provided by frontend)
405
+ algo = op.attrs["algo"]
406
+ _validate_algo(algo, "decryption")
407
+
408
+ # Get raw key bytes - strict type checking
409
+ if isinstance(key, SymmetricKeyValue):
410
+ k = key.key_bytes
411
+ elif isinstance(key, BytesValue):
412
+ k = key.unwrap()
413
+ elif isinstance(key, TensorValue):
414
+ k = key.unwrap().tobytes()
415
+ else:
416
+ raise TypeError(
417
+ f"sym_decrypt key must be SymmetricKeyValue, BytesValue, or TensorValue, "
418
+ f"got {type(key).__name__}"
419
+ )
420
+
421
+ # Get ciphertext bytes - strict type checking
422
+ if not isinstance(ciphertext, BytesValue):
423
+ raise TypeError(
424
+ f"sym_decrypt ciphertext must be BytesValue, "
425
+ f"got {type(ciphertext).__name__}"
426
+ )
427
+ ct_full = ciphertext.unwrap()
428
+
429
+ # Extract nonce and decrypt
430
+ nonce = ct_full[:12]
431
+ ct = ct_full[12:]
432
+
433
+ aesgcm = AESGCM(k)
434
+ pt_bytes = aesgcm.decrypt(nonce, ct, None)
435
+
436
+ # Deserialize back using secure JSON serde
437
+ # Returns the original type that was encrypted
438
+ return serde.loads(pt_bytes)
439
+
440
+
441
+ @crypto.select_p.def_impl
442
+ def select_impl(
443
+ interpreter: Interpreter,
444
+ op: Operation,
445
+ cond: TensorValue | int,
446
+ true_val: Value,
447
+ false_val: Value,
448
+ ) -> Value:
449
+ # Handle both TensorValue and raw scalar (from elementwise)
450
+ c: int
451
+ if isinstance(cond, TensorValue):
452
+ raw = cond.unwrap()
453
+ if hasattr(raw, "item"):
454
+ c = int(raw.item())
455
+ else:
456
+ c = int(raw)
457
+ else:
458
+ c = int(cond)
459
+ return true_val if c else false_val
460
+
461
+
462
+ # ==============================================================================
463
+ # --- KEM (Key Encapsulation Mechanism) Implementations
464
+ # ==============================================================================
465
+
466
+
467
+ @serde.register_class
468
+ @dataclass
469
+ class PrivateKeyValue(Value):
470
+ """Runtime representation of a KEM private key.
471
+
472
+ This wraps the raw key bytes from a real cryptographic implementation
473
+ (e.g., X25519). The actual cryptographic operations use the `cryptography`
474
+ library which provides secure, audited implementations.
475
+ """
476
+
477
+ _serde_kind: ClassVar[str] = "crypto_impl.PrivateKeyValue"
478
+
479
+ suite: str
480
+ key_bytes: bytes
481
+
482
+ def to_json(self) -> dict[str, Any]:
483
+ return {
484
+ "suite": self.suite,
485
+ "key_bytes": base64.b64encode(self.key_bytes).decode("ascii"),
486
+ }
487
+
488
+ @classmethod
489
+ def from_json(cls, data: dict[str, Any]) -> PrivateKeyValue:
490
+ return cls(
491
+ suite=data["suite"],
492
+ key_bytes=base64.b64decode(data["key_bytes"]),
493
+ )
494
+
495
+
496
+ @serde.register_class
497
+ @dataclass
498
+ class PublicKeyValue(Value):
499
+ """Runtime representation of a KEM public key.
500
+
501
+ This wraps the raw key bytes from a real cryptographic implementation.
502
+ """
503
+
504
+ _serde_kind: ClassVar[str] = "crypto_impl.PublicKeyValue"
505
+
506
+ suite: str
507
+ key_bytes: bytes
508
+
509
+ def to_json(self) -> dict[str, Any]:
510
+ return {
511
+ "suite": self.suite,
512
+ "key_bytes": base64.b64encode(self.key_bytes).decode("ascii"),
513
+ }
514
+
515
+ @classmethod
516
+ def from_json(cls, data: dict[str, Any]) -> PublicKeyValue:
517
+ return cls(
518
+ suite=data["suite"],
519
+ key_bytes=base64.b64decode(data["key_bytes"]),
520
+ )
521
+
522
+
523
+ @serde.register_class
524
+ @dataclass
525
+ class SymmetricKeyValue(Value):
526
+ """Runtime representation of a symmetric encryption key.
527
+
528
+ This wraps the raw key bytes derived from ECDH key exchange.
529
+ The key is used with AES-256-GCM for authenticated encryption.
530
+ """
531
+
532
+ _serde_kind: ClassVar[str] = "crypto_impl.SymmetricKeyValue"
533
+
534
+ suite: str
535
+ key_bytes: bytes
536
+
537
+ def to_json(self) -> dict[str, Any]:
538
+ return {
539
+ "suite": self.suite,
540
+ "key_bytes": base64.b64encode(self.key_bytes).decode("ascii"),
541
+ }
542
+
543
+ @classmethod
544
+ def from_json(cls, data: dict[str, Any]) -> SymmetricKeyValue:
545
+ return cls(
546
+ suite=data["suite"],
547
+ key_bytes=base64.b64decode(data["key_bytes"]),
548
+ )
549
+
550
+
551
+ @crypto.kem_keygen_p.def_impl
552
+ def kem_keygen_impl(
553
+ interpreter: Interpreter, op: Operation, suite: str = "x25519"
554
+ ) -> tuple[PrivateKeyValue, PublicKeyValue]:
555
+ """Generate a KEM key pair."""
556
+ if suite == "x25519":
557
+ from cryptography.hazmat.primitives.asymmetric.x25519 import X25519PrivateKey
558
+
559
+ private_key = X25519PrivateKey.generate()
560
+ public_key = private_key.public_key()
561
+
562
+ from cryptography.hazmat.primitives.serialization import (
563
+ Encoding,
564
+ NoEncryption,
565
+ PrivateFormat,
566
+ PublicFormat,
567
+ )
568
+
569
+ sk_bytes = private_key.private_bytes(
570
+ Encoding.Raw, PrivateFormat.Raw, NoEncryption()
571
+ )
572
+ pk_bytes = public_key.public_bytes(Encoding.Raw, PublicFormat.Raw)
573
+
574
+ return (
575
+ PrivateKeyValue(suite=suite, key_bytes=sk_bytes),
576
+ PublicKeyValue(suite=suite, key_bytes=pk_bytes),
577
+ )
578
+ else:
579
+ # Fallback to random bytes for unknown suites
580
+ sk_bytes = os.urandom(32)
581
+ pk_bytes = os.urandom(32)
582
+ return (
583
+ PrivateKeyValue(suite=suite, key_bytes=sk_bytes),
584
+ PublicKeyValue(suite=suite, key_bytes=pk_bytes),
585
+ )
586
+
587
+
588
+ @crypto.kem_derive_p.def_impl
589
+ def kem_derive_impl(
590
+ interpreter: Interpreter,
591
+ op: Operation,
592
+ private_key: PrivateKeyValue,
593
+ public_key: PublicKeyValue,
594
+ ) -> SymmetricKeyValue:
595
+ """Derive a symmetric key using ECDH."""
596
+ suite = getattr(private_key, "suite", "x25519")
597
+
598
+ if suite == "x25519":
599
+ from cryptography.hazmat.primitives.asymmetric.x25519 import (
600
+ X25519PrivateKey,
601
+ X25519PublicKey,
602
+ )
603
+
604
+ sk = X25519PrivateKey.from_private_bytes(private_key.key_bytes)
605
+ pk = X25519PublicKey.from_public_bytes(public_key.key_bytes)
606
+ shared_secret = sk.exchange(pk)
607
+
608
+ return SymmetricKeyValue(suite=suite, key_bytes=shared_secret)
609
+ else:
610
+ # Fallback for unknown suites: XOR the key bytes (not cryptographically secure)
611
+ sk_bytes = private_key.key_bytes
612
+ pk_bytes = public_key.key_bytes
613
+ secret = bytes(a ^ b for a, b in zip(sk_bytes, pk_bytes, strict=True))
614
+ return SymmetricKeyValue(suite=suite, key_bytes=secret)
615
+
616
+
617
+ @crypto.hkdf_p.def_impl
618
+ def hkdf_impl(
619
+ interpreter: Interpreter,
620
+ op: Operation,
621
+ secret: SymmetricKeyValue | TensorValue,
622
+ ) -> SymmetricKeyValue:
623
+ """HKDF key derivation implementation using SHA-256.
624
+
625
+ Implements RFC 5869 HKDF with HMAC-SHA256. This is the NIST SP 800-56C
626
+ compliant way to derive symmetric keys from ECDH shared secrets.
627
+
628
+ Current implementation supports only SHA-256. Future versions will add
629
+ SHA-512, SHA3-256, and BLAKE2b support.
630
+
631
+ Security Notes:
632
+ - Uses salt=None (defaults to 32-byte all-zero salt per RFC 5869)
633
+ - ONLY SAFE for high-entropy IKM (e.g., 256-bit ECDH shared secrets)
634
+ - NOT suitable for: passwords, low-entropy secrets, or repeated key derivations
635
+ - For session keys with same ECDH pair: use unique 'info' per session
636
+
637
+ Per NIST SP 800-56C Rev. 2:
638
+ "If the IKM is already cryptographically strong (e.g., from ECDH),
639
+ a salt may not be necessary, but using one does not hurt."
640
+
641
+ Args:
642
+ interpreter: Runtime interpreter context
643
+ op: Operation node containing attributes (info, hash_algo)
644
+ secret: Input key material (IKM) as SymmetricKeyValue or TensorValue
645
+ Must be high-entropy (≥256 bits) for security with salt=None
646
+
647
+ Returns:
648
+ SymmetricKeyValue with suite="hkdf-{hash_algo}" and 32-byte key_bytes
649
+
650
+ Raises:
651
+ TypeError: If secret is not SymmetricKeyValue or TensorValue
652
+ ValueError: If info parameter is empty (required for domain separation)
653
+ NotImplementedError: If hash_algo is not "sha256"
654
+ """
655
+ from cryptography.hazmat.primitives import hashes
656
+ from cryptography.hazmat.primitives.kdf.hkdf import HKDF
657
+
658
+ # Extract operation attributes
659
+ info_str = op.attrs.get("info", "")
660
+ hash_algo = (
661
+ op.attrs.get("hash_algo", "sha256").lower().replace("-", "").replace("_", "")
662
+ )
663
+
664
+ # Validate info parameter (REQUIRED for domain separation per NIST)
665
+ if not info_str:
666
+ raise ValueError(
667
+ "HKDF requires non-empty 'info' parameter for domain separation. "
668
+ "The info string binds the derived key to a specific protocol/context. "
669
+ "Recommended format: 'namespace/component/purpose/version'"
670
+ )
671
+
672
+ info_bytes = info_str.encode("utf-8")
673
+
674
+ # Extract input key material (IKM) bytes
675
+ if isinstance(secret, SymmetricKeyValue):
676
+ ikm = secret.key_bytes
677
+ elif isinstance(secret, TensorValue):
678
+ ikm = secret.unwrap().tobytes()
679
+ else:
680
+ raise TypeError(
681
+ f"hkdf secret must be SymmetricKeyValue or TensorValue, "
682
+ f"got {type(secret).__name__}"
683
+ )
684
+
685
+ # Validate hash algorithm (currently only SHA-256 implemented)
686
+ if hash_algo != "sha256":
687
+ raise NotImplementedError(
688
+ f"HKDF with hash algorithm '{hash_algo}' is not yet implemented. "
689
+ f"Currently only 'sha256' is supported. "
690
+ f"Planned future support: sha512, sha3256, blake2b"
691
+ )
692
+
693
+ # Perform HKDF derivation using cryptography library
694
+ # Note: salt=None uses 32-byte all-zero salt (not random salt!)
695
+ # This is secure ONLY because ECDH outputs are already high-entropy (256-bit uniform)
696
+ # For low-entropy inputs or repeated derivations, a random salt would be required
697
+ hkdf = HKDF(
698
+ algorithm=hashes.SHA256(),
699
+ length=32, # Output length in bytes (AES-256 key = 32 bytes)
700
+ salt=None, # 32-byte zero salt (secure for high-entropy ECDH shared secrets)
701
+ info=info_bytes, # Context-specific binding for domain separation
702
+ )
703
+
704
+ derived_key = hkdf.derive(ikm)
705
+
706
+ # Return SymmetricKeyValue with composite suite name
707
+ # Format: "hkdf-{hash_algo}" to indicate derivation method and hash function
708
+ suite = f"hkdf-{hash_algo}"
709
+ return SymmetricKeyValue(suite=suite, key_bytes=derived_key)
710
+
711
+
712
+ @crypto.random_bytes_p.def_impl
713
+ def random_bytes_impl(interpreter: Interpreter, op: Operation) -> TensorValue:
714
+ """Generate random bytes using os.urandom."""
715
+ # Length is passed as attribute
716
+ length = op.attrs["length"]
717
+
718
+ if not isinstance(length, int):
719
+ raise TypeError(f"random_bytes length must be int, got {type(length)}")
720
+
721
+ b = os.urandom(length)
722
+ arr = np.frombuffer(b, dtype=np.uint8).copy()
723
+ return TensorValue(arr)