mplang-nightly 0.1.dev192__py3-none-any.whl → 0.1.dev268__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mplang/__init__.py +21 -130
- mplang/py.typed +13 -0
- mplang/v1/__init__.py +157 -0
- mplang/v1/_device.py +602 -0
- mplang/{analysis → v1/analysis}/__init__.py +1 -1
- mplang/{analysis → v1/analysis}/diagram.py +4 -4
- mplang/{core → v1/core}/__init__.py +20 -14
- mplang/{core → v1/core}/cluster.py +6 -1
- mplang/{core → v1/core}/comm.py +1 -1
- mplang/{core → v1/core}/context_mgr.py +1 -1
- mplang/{core → v1/core}/dtypes.py +38 -0
- mplang/{core → v1/core}/expr/__init__.py +7 -7
- mplang/{core → v1/core}/expr/ast.py +11 -13
- mplang/{core → v1/core}/expr/evaluator.py +8 -8
- mplang/{core → v1/core}/expr/printer.py +6 -6
- mplang/{core → v1/core}/expr/transformer.py +2 -2
- mplang/{core → v1/core}/expr/utils.py +2 -2
- mplang/{core → v1/core}/expr/visitor.py +1 -1
- mplang/{core → v1/core}/expr/walk.py +1 -1
- mplang/{core → v1/core}/interp.py +6 -6
- mplang/{core → v1/core}/mpir.py +13 -11
- mplang/{core → v1/core}/mpobject.py +6 -6
- mplang/{core → v1/core}/mptype.py +13 -10
- mplang/{core → v1/core}/pfunc.py +2 -2
- mplang/{core → v1/core}/primitive.py +12 -12
- mplang/{core → v1/core}/table.py +36 -8
- mplang/{core → v1/core}/tensor.py +1 -1
- mplang/{core → v1/core}/tracer.py +9 -9
- mplang/{host.py → v1/host.py} +5 -5
- mplang/{kernels → v1/kernels}/__init__.py +1 -1
- mplang/{kernels → v1/kernels}/base.py +1 -1
- mplang/{kernels → v1/kernels}/basic.py +15 -15
- mplang/{kernels → v1/kernels}/context.py +19 -16
- mplang/{kernels → v1/kernels}/crypto.py +8 -10
- mplang/{kernels → v1/kernels}/fhe.py +9 -7
- mplang/{kernels → v1/kernels}/mock_tee.py +3 -3
- mplang/{kernels → v1/kernels}/phe.py +26 -18
- mplang/{kernels → v1/kernels}/spu.py +5 -5
- mplang/{kernels → v1/kernels}/sql_duckdb.py +5 -3
- mplang/{kernels → v1/kernels}/stablehlo.py +18 -17
- mplang/{kernels → v1/kernels}/value.py +2 -2
- mplang/{ops → v1/ops}/__init__.py +3 -3
- mplang/{ops → v1/ops}/base.py +1 -1
- mplang/{ops → v1/ops}/basic.py +6 -5
- mplang/v1/ops/crypto.py +262 -0
- mplang/{ops → v1/ops}/fhe.py +2 -2
- mplang/{ops → v1/ops}/jax_cc.py +26 -59
- mplang/v1/ops/nnx_cc.py +168 -0
- mplang/{ops → v1/ops}/phe.py +16 -3
- mplang/{ops → v1/ops}/spu.py +3 -3
- mplang/v1/ops/sql_cc.py +303 -0
- mplang/{ops → v1/ops}/tee.py +2 -2
- mplang/{runtime → v1/runtime}/__init__.py +2 -2
- mplang/v1/runtime/channel.py +230 -0
- mplang/{runtime → v1/runtime}/cli.py +3 -3
- mplang/{runtime → v1/runtime}/client.py +1 -1
- mplang/{runtime → v1/runtime}/communicator.py +39 -15
- mplang/{runtime → v1/runtime}/data_providers.py +80 -19
- mplang/{runtime → v1/runtime}/driver.py +4 -4
- mplang/v1/runtime/link_comm.py +196 -0
- mplang/{runtime → v1/runtime}/server.py +22 -9
- mplang/{runtime → v1/runtime}/session.py +24 -51
- mplang/{runtime → v1/runtime}/simulation.py +36 -14
- mplang/{simp → v1/simp}/api.py +72 -14
- mplang/{simp → v1/simp}/mpi.py +1 -1
- mplang/{simp → v1/simp}/party.py +5 -5
- mplang/{simp → v1/simp}/random.py +2 -2
- mplang/v1/simp/smpc.py +238 -0
- mplang/v1/utils/table_utils.py +185 -0
- mplang/v2/__init__.py +424 -0
- mplang/v2/backends/__init__.py +57 -0
- mplang/v2/backends/bfv_impl.py +705 -0
- mplang/v2/backends/channel.py +217 -0
- mplang/v2/backends/crypto_impl.py +723 -0
- mplang/v2/backends/field_impl.py +454 -0
- mplang/v2/backends/func_impl.py +107 -0
- mplang/v2/backends/phe_impl.py +148 -0
- mplang/v2/backends/simp_design.md +136 -0
- mplang/v2/backends/simp_driver/__init__.py +41 -0
- mplang/v2/backends/simp_driver/http.py +168 -0
- mplang/v2/backends/simp_driver/mem.py +280 -0
- mplang/v2/backends/simp_driver/ops.py +135 -0
- mplang/v2/backends/simp_driver/state.py +60 -0
- mplang/v2/backends/simp_driver/values.py +52 -0
- mplang/v2/backends/simp_worker/__init__.py +29 -0
- mplang/v2/backends/simp_worker/http.py +354 -0
- mplang/v2/backends/simp_worker/mem.py +102 -0
- mplang/v2/backends/simp_worker/ops.py +167 -0
- mplang/v2/backends/simp_worker/state.py +49 -0
- mplang/v2/backends/spu_impl.py +275 -0
- mplang/v2/backends/spu_state.py +187 -0
- mplang/v2/backends/store_impl.py +62 -0
- mplang/v2/backends/table_impl.py +838 -0
- mplang/v2/backends/tee_impl.py +215 -0
- mplang/v2/backends/tensor_impl.py +519 -0
- mplang/v2/cli.py +603 -0
- mplang/v2/cli_guide.md +122 -0
- mplang/v2/dialects/__init__.py +36 -0
- mplang/v2/dialects/bfv.py +665 -0
- mplang/v2/dialects/crypto.py +689 -0
- mplang/v2/dialects/dtypes.py +378 -0
- mplang/v2/dialects/field.py +210 -0
- mplang/v2/dialects/func.py +135 -0
- mplang/v2/dialects/phe.py +723 -0
- mplang/v2/dialects/simp.py +944 -0
- mplang/v2/dialects/spu.py +349 -0
- mplang/v2/dialects/store.py +63 -0
- mplang/v2/dialects/table.py +407 -0
- mplang/v2/dialects/tee.py +346 -0
- mplang/v2/dialects/tensor.py +1175 -0
- mplang/v2/edsl/README.md +279 -0
- mplang/v2/edsl/__init__.py +99 -0
- mplang/v2/edsl/context.py +311 -0
- mplang/v2/edsl/graph.py +463 -0
- mplang/v2/edsl/jit.py +62 -0
- mplang/v2/edsl/object.py +53 -0
- mplang/v2/edsl/primitive.py +284 -0
- mplang/v2/edsl/printer.py +119 -0
- mplang/v2/edsl/registry.py +207 -0
- mplang/v2/edsl/serde.py +375 -0
- mplang/v2/edsl/tracer.py +614 -0
- mplang/v2/edsl/typing.py +816 -0
- mplang/v2/kernels/Makefile +30 -0
- mplang/v2/kernels/__init__.py +23 -0
- mplang/v2/kernels/gf128.cpp +148 -0
- mplang/v2/kernels/ldpc.cpp +82 -0
- mplang/v2/kernels/okvs.cpp +283 -0
- mplang/v2/kernels/okvs_opt.cpp +291 -0
- mplang/v2/kernels/py_kernels.py +398 -0
- mplang/v2/libs/collective.py +330 -0
- mplang/v2/libs/device/__init__.py +51 -0
- mplang/v2/libs/device/api.py +813 -0
- mplang/v2/libs/device/cluster.py +352 -0
- mplang/v2/libs/ml/__init__.py +23 -0
- mplang/v2/libs/ml/sgb.py +1861 -0
- mplang/v2/libs/mpc/__init__.py +41 -0
- mplang/v2/libs/mpc/_utils.py +99 -0
- mplang/v2/libs/mpc/analytics/__init__.py +35 -0
- mplang/v2/libs/mpc/analytics/aggregation.py +372 -0
- mplang/v2/libs/mpc/analytics/groupby.md +99 -0
- mplang/v2/libs/mpc/analytics/groupby.py +331 -0
- mplang/v2/libs/mpc/analytics/permutation.py +386 -0
- mplang/v2/libs/mpc/common/constants.py +39 -0
- mplang/v2/libs/mpc/ot/__init__.py +32 -0
- mplang/v2/libs/mpc/ot/base.py +222 -0
- mplang/v2/libs/mpc/ot/extension.py +477 -0
- mplang/v2/libs/mpc/ot/silent.py +217 -0
- mplang/v2/libs/mpc/psi/__init__.py +40 -0
- mplang/v2/libs/mpc/psi/cuckoo.py +228 -0
- mplang/v2/libs/mpc/psi/okvs.py +49 -0
- mplang/v2/libs/mpc/psi/okvs_gct.py +79 -0
- mplang/v2/libs/mpc/psi/oprf.py +310 -0
- mplang/v2/libs/mpc/psi/rr22.py +344 -0
- mplang/v2/libs/mpc/psi/unbalanced.py +200 -0
- mplang/v2/libs/mpc/vole/__init__.py +31 -0
- mplang/v2/libs/mpc/vole/gilboa.py +327 -0
- mplang/v2/libs/mpc/vole/ldpc.py +383 -0
- mplang/v2/libs/mpc/vole/silver.py +336 -0
- mplang/v2/runtime/__init__.py +15 -0
- mplang/v2/runtime/dialect_state.py +41 -0
- mplang/v2/runtime/interpreter.py +871 -0
- mplang/v2/runtime/object_store.py +194 -0
- mplang/v2/runtime/value.py +141 -0
- {mplang_nightly-0.1.dev192.dist-info → mplang_nightly-0.1.dev268.dist-info}/METADATA +22 -16
- mplang_nightly-0.1.dev268.dist-info/RECORD +180 -0
- {mplang_nightly-0.1.dev192.dist-info → mplang_nightly-0.1.dev268.dist-info}/WHEEL +1 -1
- mplang/device.py +0 -327
- mplang/ops/crypto.py +0 -108
- mplang/ops/ibis_cc.py +0 -136
- mplang/ops/sql_cc.py +0 -62
- mplang/runtime/link_comm.py +0 -78
- mplang/simp/smpc.py +0 -201
- mplang/utils/table_utils.py +0 -85
- mplang_nightly-0.1.dev192.dist-info/RECORD +0 -83
- /mplang/{core → v1/core}/mask.py +0 -0
- /mplang/{protos → v1/protos}/v1alpha1/mpir_pb2.py +0 -0
- /mplang/{protos → v1/protos}/v1alpha1/mpir_pb2.pyi +0 -0
- /mplang/{protos → v1/protos}/v1alpha1/value_pb2.py +0 -0
- /mplang/{protos → v1/protos}/v1alpha1/value_pb2.pyi +0 -0
- /mplang/{runtime → v1/runtime}/exceptions.py +0 -0
- /mplang/{runtime → v1/runtime}/http_api.md +0 -0
- /mplang/{simp → v1/simp}/__init__.py +0 -0
- /mplang/{utils → v1/utils}/__init__.py +0 -0
- /mplang/{utils → v1/utils}/crypto.py +0 -0
- /mplang/{utils → v1/utils}/func_utils.py +0 -0
- /mplang/{utils → v1/utils}/spu_utils.py +0 -0
- {mplang_nightly-0.1.dev192.dist-info → mplang_nightly-0.1.dev268.dist-info}/entry_points.txt +0 -0
- {mplang_nightly-0.1.dev192.dist-info → mplang_nightly-0.1.dev268.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,723 @@
|
|
|
1
|
+
# Copyright 2025 Ant Group Co., Ltd.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
"""Crypto backend implementation using cryptography and coincurve."""
|
|
16
|
+
|
|
17
|
+
from __future__ import annotations
|
|
18
|
+
|
|
19
|
+
import base64
|
|
20
|
+
import hashlib
|
|
21
|
+
import os
|
|
22
|
+
from dataclasses import dataclass
|
|
23
|
+
from typing import Any, ClassVar
|
|
24
|
+
|
|
25
|
+
import coincurve
|
|
26
|
+
import numpy as np
|
|
27
|
+
from cryptography.hazmat.primitives.ciphers.aead import AESGCM
|
|
28
|
+
|
|
29
|
+
from mplang.v2.backends.tensor_impl import TensorValue
|
|
30
|
+
from mplang.v2.dialects import crypto
|
|
31
|
+
from mplang.v2.edsl import serde
|
|
32
|
+
from mplang.v2.edsl.graph import Operation
|
|
33
|
+
from mplang.v2.runtime.interpreter import Interpreter
|
|
34
|
+
from mplang.v2.runtime.value import Value, WrapValue
|
|
35
|
+
|
|
36
|
+
# =============================================================================
|
|
37
|
+
# BytesValue - Wrapper for raw bytes (keys, hashes, ciphertexts)
|
|
38
|
+
# =============================================================================
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
@serde.register_class
|
|
42
|
+
class BytesValue(WrapValue[bytes]):
|
|
43
|
+
"""Runtime value wrapping raw bytes.
|
|
44
|
+
|
|
45
|
+
Used for cryptographic data like:
|
|
46
|
+
- Hash outputs (32 bytes for SHA-256)
|
|
47
|
+
- Symmetric keys (32 bytes for AES-256)
|
|
48
|
+
- Ciphertexts (variable length)
|
|
49
|
+
- EC point serializations
|
|
50
|
+
"""
|
|
51
|
+
|
|
52
|
+
_serde_kind: ClassVar[str] = "crypto_impl.BytesValue"
|
|
53
|
+
|
|
54
|
+
def _convert(self, data: Any) -> bytes:
|
|
55
|
+
if isinstance(data, BytesValue):
|
|
56
|
+
return data.unwrap()
|
|
57
|
+
if isinstance(data, bytes):
|
|
58
|
+
return data
|
|
59
|
+
if isinstance(data, (bytearray, memoryview)):
|
|
60
|
+
return bytes(data)
|
|
61
|
+
# Handle numpy arrays
|
|
62
|
+
if hasattr(data, "tobytes"):
|
|
63
|
+
return bytes(data.tobytes()) # type: ignore[union-attr]
|
|
64
|
+
raise TypeError(f"Cannot convert {type(data).__name__} to bytes")
|
|
65
|
+
|
|
66
|
+
def to_json(self) -> dict[str, Any]:
|
|
67
|
+
return {"data": base64.b64encode(self._data).decode("ascii")}
|
|
68
|
+
|
|
69
|
+
@classmethod
|
|
70
|
+
def from_json(cls, data: dict[str, Any]) -> BytesValue:
|
|
71
|
+
return cls(base64.b64decode(data["data"]))
|
|
72
|
+
|
|
73
|
+
|
|
74
|
+
# =============================================================================
|
|
75
|
+
# ECC Point Wrapper (secp256k1)
|
|
76
|
+
# =============================================================================
|
|
77
|
+
|
|
78
|
+
|
|
79
|
+
@serde.register_class
|
|
80
|
+
class ECPointValue(WrapValue[bytes]):
|
|
81
|
+
"""Wrapper for coincurve.PublicKey representing an elliptic curve point.
|
|
82
|
+
|
|
83
|
+
This wraps the external coincurve library's PublicKey type to provide
|
|
84
|
+
proper serialization support via the Value base class.
|
|
85
|
+
"""
|
|
86
|
+
|
|
87
|
+
_serde_kind: ClassVar[str] = "crypto_impl.ECPointValue"
|
|
88
|
+
|
|
89
|
+
def _convert(self, data: Any) -> bytes:
|
|
90
|
+
if isinstance(data, ECPointValue):
|
|
91
|
+
return data.unwrap()
|
|
92
|
+
if isinstance(data, bytes):
|
|
93
|
+
return data
|
|
94
|
+
if isinstance(data, coincurve.PublicKey):
|
|
95
|
+
return data.format(compressed=True)
|
|
96
|
+
raise TypeError(f"Expected bytes or coincurve.PublicKey, got {type(data)}")
|
|
97
|
+
|
|
98
|
+
@property
|
|
99
|
+
def key_bytes(self) -> bytes:
|
|
100
|
+
return self._data
|
|
101
|
+
|
|
102
|
+
def to_json(self) -> dict[str, Any]:
|
|
103
|
+
return {"data": base64.b64encode(self._data).decode("ascii")}
|
|
104
|
+
|
|
105
|
+
@classmethod
|
|
106
|
+
def from_json(cls, data: dict[str, Any]) -> ECPointValue:
|
|
107
|
+
return cls(base64.b64decode(data["data"]))
|
|
108
|
+
|
|
109
|
+
@property
|
|
110
|
+
def coincurve_key(self) -> coincurve.PublicKey:
|
|
111
|
+
"""Get the underlying coincurve.PublicKey object."""
|
|
112
|
+
return coincurve.PublicKey(self._data)
|
|
113
|
+
|
|
114
|
+
@classmethod
|
|
115
|
+
def from_coincurve(cls, pk: coincurve.PublicKey) -> ECPointValue:
|
|
116
|
+
"""Create ECPointValue from a coincurve.PublicKey."""
|
|
117
|
+
return cls(pk)
|
|
118
|
+
|
|
119
|
+
|
|
120
|
+
# --- ECC Impl (Coincurve) ---
|
|
121
|
+
|
|
122
|
+
# secp256k1 order
|
|
123
|
+
N = 0xFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFEBAAEDCE6AF48A03BBFD25E8CD0364141
|
|
124
|
+
|
|
125
|
+
|
|
126
|
+
@crypto.generator_p.def_impl
|
|
127
|
+
def generator_impl(interpreter: Interpreter, op: Operation) -> ECPointValue:
|
|
128
|
+
# Compressed G
|
|
129
|
+
g_bytes = bytes.fromhex(
|
|
130
|
+
"0279BE667EF9DCBBAC55A06295CE870B07029BFCDB2DCE28D959F2815B16F81798"
|
|
131
|
+
)
|
|
132
|
+
return ECPointValue(g_bytes)
|
|
133
|
+
|
|
134
|
+
|
|
135
|
+
@crypto.mul_p.def_impl
|
|
136
|
+
def mul_impl(
|
|
137
|
+
interpreter: Interpreter,
|
|
138
|
+
op: Operation,
|
|
139
|
+
point: ECPointValue | None,
|
|
140
|
+
scalar: int | TensorValue,
|
|
141
|
+
) -> ECPointValue | None:
|
|
142
|
+
# scalar can be:
|
|
143
|
+
# - int: from ec_random_scalar or ec_scalar_from_int
|
|
144
|
+
# - TensorValue: shouldn't happen but handle for robustness
|
|
145
|
+
# - numpy scalar: from inside elementwise (shouldn't reach here as mul is not in elementwise)
|
|
146
|
+
s_val: int
|
|
147
|
+
if isinstance(scalar, TensorValue):
|
|
148
|
+
raw = scalar.unwrap()
|
|
149
|
+
if hasattr(raw, "item"):
|
|
150
|
+
s_val = int(raw.item())
|
|
151
|
+
else:
|
|
152
|
+
s_val = int(raw)
|
|
153
|
+
elif isinstance(scalar, (int, np.integer)):
|
|
154
|
+
s_val = int(scalar)
|
|
155
|
+
else:
|
|
156
|
+
raise TypeError(
|
|
157
|
+
f"mul_impl scalar must be int or TensorValue, got {type(scalar).__name__}"
|
|
158
|
+
)
|
|
159
|
+
|
|
160
|
+
s_val = s_val % N
|
|
161
|
+
|
|
162
|
+
if s_val == 0:
|
|
163
|
+
return None
|
|
164
|
+
|
|
165
|
+
if point is None:
|
|
166
|
+
return None
|
|
167
|
+
|
|
168
|
+
# coincurve multiply expects bytes
|
|
169
|
+
s_bytes = s_val.to_bytes(32, "big")
|
|
170
|
+
result = point.coincurve_key.multiply(s_bytes)
|
|
171
|
+
return ECPointValue.from_coincurve(result)
|
|
172
|
+
|
|
173
|
+
|
|
174
|
+
@crypto.add_p.def_impl
|
|
175
|
+
def add_impl(
|
|
176
|
+
interpreter: Interpreter,
|
|
177
|
+
op: Operation,
|
|
178
|
+
p1: ECPointValue | None,
|
|
179
|
+
p2: ECPointValue | None,
|
|
180
|
+
) -> ECPointValue | None:
|
|
181
|
+
if p1 is None:
|
|
182
|
+
return p2
|
|
183
|
+
if p2 is None:
|
|
184
|
+
return p1
|
|
185
|
+
result = p1.coincurve_key.combine([p2.coincurve_key])
|
|
186
|
+
return ECPointValue.from_coincurve(result)
|
|
187
|
+
|
|
188
|
+
|
|
189
|
+
@crypto.sub_p.def_impl
|
|
190
|
+
def sub_impl(
|
|
191
|
+
interpreter: Interpreter,
|
|
192
|
+
op: Operation,
|
|
193
|
+
p1: ECPointValue | None,
|
|
194
|
+
p2: ECPointValue | None,
|
|
195
|
+
) -> ECPointValue | None:
|
|
196
|
+
# p1 - p2 = p1 + (-p2)
|
|
197
|
+
if p2 is None:
|
|
198
|
+
return p1
|
|
199
|
+
|
|
200
|
+
# Negate p2 by multiplying by (N-1)
|
|
201
|
+
neg_scalar = (N - 1).to_bytes(32, "big")
|
|
202
|
+
neg_p2 = p2.coincurve_key.multiply(neg_scalar)
|
|
203
|
+
|
|
204
|
+
if p1 is None:
|
|
205
|
+
return ECPointValue.from_coincurve(neg_p2)
|
|
206
|
+
|
|
207
|
+
result = p1.coincurve_key.combine([neg_p2])
|
|
208
|
+
return ECPointValue.from_coincurve(result)
|
|
209
|
+
|
|
210
|
+
|
|
211
|
+
@crypto.random_scalar_p.def_impl
|
|
212
|
+
def random_scalar_impl(interpreter: Interpreter, op: Operation) -> int:
|
|
213
|
+
return int.from_bytes(os.urandom(32), "big") % N
|
|
214
|
+
|
|
215
|
+
|
|
216
|
+
@crypto.scalar_from_int_p.def_impl
|
|
217
|
+
def scalar_from_int_impl(
|
|
218
|
+
interpreter: Interpreter, op: Operation, val: TensorValue | int
|
|
219
|
+
) -> int:
|
|
220
|
+
"""Convert a tensor/scalar value to an EC scalar (int).
|
|
221
|
+
|
|
222
|
+
val can be:
|
|
223
|
+
- TensorValue: wrapping a scalar numpy array
|
|
224
|
+
- int/bool: direct Python integer or boolean
|
|
225
|
+
- numpy scalar (np.integer, np.bool_): from inside elementwise operations
|
|
226
|
+
"""
|
|
227
|
+
if isinstance(val, TensorValue):
|
|
228
|
+
raw = val.unwrap()
|
|
229
|
+
if hasattr(raw, "item"):
|
|
230
|
+
return int(raw.item())
|
|
231
|
+
return int(raw)
|
|
232
|
+
elif isinstance(val, (int, bool, np.integer, np.bool_)):
|
|
233
|
+
return int(val)
|
|
234
|
+
else:
|
|
235
|
+
raise TypeError(
|
|
236
|
+
f"scalar_from_int val must be TensorValue or int-like, "
|
|
237
|
+
f"got {type(val).__name__}"
|
|
238
|
+
)
|
|
239
|
+
|
|
240
|
+
|
|
241
|
+
@crypto.point_to_bytes_p.def_impl
|
|
242
|
+
def point_to_bytes_impl(
|
|
243
|
+
interpreter: Interpreter, op: Operation, point: ECPointValue | None
|
|
244
|
+
) -> TensorValue:
|
|
245
|
+
if point is None:
|
|
246
|
+
# Infinity / Identity -> Zeros (65 bytes to match uncompressed format)
|
|
247
|
+
arr = np.zeros(65, dtype=np.uint8)
|
|
248
|
+
return TensorValue(arr)
|
|
249
|
+
|
|
250
|
+
# Returns 65 bytes (uncompressed)
|
|
251
|
+
b = point.coincurve_key.format(compressed=False)
|
|
252
|
+
arr = np.frombuffer(b, dtype=np.uint8).copy()
|
|
253
|
+
return TensorValue(arr)
|
|
254
|
+
|
|
255
|
+
|
|
256
|
+
@crypto.bytes_to_point_p.def_impl
|
|
257
|
+
def bytes_to_point_impl(
|
|
258
|
+
interpreter: Interpreter, op: Operation, b: TensorValue | BytesValue
|
|
259
|
+
) -> ECPointValue:
|
|
260
|
+
if isinstance(b, TensorValue):
|
|
261
|
+
raw = b.unwrap().tobytes()
|
|
262
|
+
elif isinstance(b, BytesValue):
|
|
263
|
+
raw = b.unwrap()
|
|
264
|
+
else:
|
|
265
|
+
raise TypeError(
|
|
266
|
+
f"bytes_to_point expects TensorValue or BytesValue, got {type(b)}"
|
|
267
|
+
)
|
|
268
|
+
|
|
269
|
+
return ECPointValue(raw)
|
|
270
|
+
|
|
271
|
+
|
|
272
|
+
# --- Sym / Hash Impl ---
|
|
273
|
+
|
|
274
|
+
# Supported symmetric encryption algorithms
|
|
275
|
+
_SUPPORTED_ALGOS = {"aes-gcm"}
|
|
276
|
+
|
|
277
|
+
|
|
278
|
+
def _validate_algo(algo: str, operation: str) -> None:
|
|
279
|
+
"""Validate that the algorithm is supported.
|
|
280
|
+
|
|
281
|
+
Args:
|
|
282
|
+
algo: Algorithm name to validate
|
|
283
|
+
operation: Operation name for error message (e.g., "encryption", "decryption")
|
|
284
|
+
|
|
285
|
+
Raises:
|
|
286
|
+
ValueError: If algo is not supported
|
|
287
|
+
"""
|
|
288
|
+
if algo not in _SUPPORTED_ALGOS:
|
|
289
|
+
supported = ", ".join(sorted(_SUPPORTED_ALGOS))
|
|
290
|
+
raise ValueError(
|
|
291
|
+
f"Unsupported {operation} algorithm: {algo!r}. "
|
|
292
|
+
f"Supported algorithms: {supported}"
|
|
293
|
+
)
|
|
294
|
+
|
|
295
|
+
|
|
296
|
+
@crypto.hash_p.def_impl
|
|
297
|
+
def hash_impl(interpreter: Interpreter, op: Operation, data: Value) -> Value:
|
|
298
|
+
"""Hash input data using SHA-256 (strict single blob)."""
|
|
299
|
+
# data can be BytesValue or TensorValue
|
|
300
|
+
if isinstance(data, BytesValue):
|
|
301
|
+
d = data.unwrap()
|
|
302
|
+
elif isinstance(data, TensorValue):
|
|
303
|
+
# Flatten and hash as single blob
|
|
304
|
+
d = data.unwrap().tobytes()
|
|
305
|
+
else:
|
|
306
|
+
raise TypeError(
|
|
307
|
+
f"hash expects BytesValue or TensorValue, got {type(data).__name__}"
|
|
308
|
+
)
|
|
309
|
+
|
|
310
|
+
h = hashlib.sha256(d).digest()
|
|
311
|
+
arr = np.frombuffer(h, dtype=np.uint8)
|
|
312
|
+
return TensorValue(arr)
|
|
313
|
+
|
|
314
|
+
|
|
315
|
+
@crypto.hash_batch_p.def_impl
|
|
316
|
+
def hash_batch_impl(interpreter: Interpreter, op: Operation, data: Value) -> Value:
|
|
317
|
+
"""Hash data treating last dimension as bytes (explicit batching)."""
|
|
318
|
+
if not isinstance(data, TensorValue):
|
|
319
|
+
raise TypeError(f"hash_batch requires TensorValue, got {type(data)}")
|
|
320
|
+
|
|
321
|
+
arr_in = data.unwrap()
|
|
322
|
+
|
|
323
|
+
# Handle scalar / 0D / 1D case simply
|
|
324
|
+
if arr_in.ndim <= 1:
|
|
325
|
+
d = arr_in.tobytes()
|
|
326
|
+
h = hashlib.sha256(d).digest()
|
|
327
|
+
return TensorValue(np.frombuffer(h, dtype=np.uint8))
|
|
328
|
+
|
|
329
|
+
# Batch case: (B1, B2, ..., D)
|
|
330
|
+
batch_shape = arr_in.shape[:-1]
|
|
331
|
+
D = arr_in.shape[-1]
|
|
332
|
+
|
|
333
|
+
flat_in = arr_in.reshape(-1, D)
|
|
334
|
+
num_items = flat_in.shape[0]
|
|
335
|
+
|
|
336
|
+
hashes = []
|
|
337
|
+
for i in range(num_items):
|
|
338
|
+
row_bytes = flat_in[i].tobytes()
|
|
339
|
+
hashes.append(hashlib.sha256(row_bytes).digest())
|
|
340
|
+
|
|
341
|
+
flat_out = np.frombuffer(b"".join(hashes), dtype=np.uint8).reshape(num_items, 32)
|
|
342
|
+
arr_out = flat_out.reshape(*batch_shape, 32)
|
|
343
|
+
|
|
344
|
+
return TensorValue(arr_out)
|
|
345
|
+
|
|
346
|
+
|
|
347
|
+
@crypto.sym_encrypt_p.def_impl
|
|
348
|
+
def sym_encrypt_impl(
|
|
349
|
+
interpreter: Interpreter,
|
|
350
|
+
op: Operation,
|
|
351
|
+
key: SymmetricKeyValue | BytesValue,
|
|
352
|
+
plaintext: Any,
|
|
353
|
+
) -> BytesValue:
|
|
354
|
+
"""Encrypt plaintext using AES-GCM with the given symmetric key.
|
|
355
|
+
|
|
356
|
+
The plaintext can be any JSON-serializable value (Value subclasses,
|
|
357
|
+
numpy arrays, scalars, etc.). This supports both high-level API usage
|
|
358
|
+
(with TensorValue) and elementwise operations (with raw scalars).
|
|
359
|
+
"""
|
|
360
|
+
# Read and validate algo parameter (must be provided by frontend)
|
|
361
|
+
algo = op.attrs["algo"]
|
|
362
|
+
_validate_algo(algo, "encryption")
|
|
363
|
+
|
|
364
|
+
# Get raw key bytes - strict type checking
|
|
365
|
+
if isinstance(key, SymmetricKeyValue):
|
|
366
|
+
k = key.key_bytes
|
|
367
|
+
elif isinstance(key, BytesValue):
|
|
368
|
+
k = key.unwrap()
|
|
369
|
+
elif isinstance(key, TensorValue):
|
|
370
|
+
k = key.unwrap().tobytes()
|
|
371
|
+
else:
|
|
372
|
+
raise TypeError(
|
|
373
|
+
f"sym_encrypt key must be SymmetricKeyValue, BytesValue, or TensorValue, "
|
|
374
|
+
f"got {type(key).__name__}"
|
|
375
|
+
)
|
|
376
|
+
|
|
377
|
+
# Serialize the plaintext using secure JSON serde
|
|
378
|
+
# serde.dumps handles Value subclasses, numpy arrays, scalars, etc.
|
|
379
|
+
pt_bytes = serde.dumps(plaintext)
|
|
380
|
+
|
|
381
|
+
# AES-GCM encryption
|
|
382
|
+
aesgcm = AESGCM(k)
|
|
383
|
+
nonce = os.urandom(12)
|
|
384
|
+
ct = aesgcm.encrypt(nonce, pt_bytes, None)
|
|
385
|
+
|
|
386
|
+
# Result: nonce + ct
|
|
387
|
+
return BytesValue(nonce + ct)
|
|
388
|
+
|
|
389
|
+
|
|
390
|
+
@crypto.sym_decrypt_p.def_impl
|
|
391
|
+
def sym_decrypt_impl(
|
|
392
|
+
interpreter: Interpreter,
|
|
393
|
+
op: Operation,
|
|
394
|
+
key: SymmetricKeyValue | BytesValue,
|
|
395
|
+
ciphertext: BytesValue,
|
|
396
|
+
target_type: Any = None,
|
|
397
|
+
) -> Any:
|
|
398
|
+
"""Decrypt ciphertext using AES-GCM with the given symmetric key.
|
|
399
|
+
|
|
400
|
+
Returns the original plaintext value that was encrypted. The type depends
|
|
401
|
+
on what was encrypted - could be a Value subclass (TensorValue, BytesValue),
|
|
402
|
+
a numpy array, or a scalar (int, float, etc.) when used in elementwise ops.
|
|
403
|
+
"""
|
|
404
|
+
# Read and validate algo parameter (must be provided by frontend)
|
|
405
|
+
algo = op.attrs["algo"]
|
|
406
|
+
_validate_algo(algo, "decryption")
|
|
407
|
+
|
|
408
|
+
# Get raw key bytes - strict type checking
|
|
409
|
+
if isinstance(key, SymmetricKeyValue):
|
|
410
|
+
k = key.key_bytes
|
|
411
|
+
elif isinstance(key, BytesValue):
|
|
412
|
+
k = key.unwrap()
|
|
413
|
+
elif isinstance(key, TensorValue):
|
|
414
|
+
k = key.unwrap().tobytes()
|
|
415
|
+
else:
|
|
416
|
+
raise TypeError(
|
|
417
|
+
f"sym_decrypt key must be SymmetricKeyValue, BytesValue, or TensorValue, "
|
|
418
|
+
f"got {type(key).__name__}"
|
|
419
|
+
)
|
|
420
|
+
|
|
421
|
+
# Get ciphertext bytes - strict type checking
|
|
422
|
+
if not isinstance(ciphertext, BytesValue):
|
|
423
|
+
raise TypeError(
|
|
424
|
+
f"sym_decrypt ciphertext must be BytesValue, "
|
|
425
|
+
f"got {type(ciphertext).__name__}"
|
|
426
|
+
)
|
|
427
|
+
ct_full = ciphertext.unwrap()
|
|
428
|
+
|
|
429
|
+
# Extract nonce and decrypt
|
|
430
|
+
nonce = ct_full[:12]
|
|
431
|
+
ct = ct_full[12:]
|
|
432
|
+
|
|
433
|
+
aesgcm = AESGCM(k)
|
|
434
|
+
pt_bytes = aesgcm.decrypt(nonce, ct, None)
|
|
435
|
+
|
|
436
|
+
# Deserialize back using secure JSON serde
|
|
437
|
+
# Returns the original type that was encrypted
|
|
438
|
+
return serde.loads(pt_bytes)
|
|
439
|
+
|
|
440
|
+
|
|
441
|
+
@crypto.select_p.def_impl
|
|
442
|
+
def select_impl(
|
|
443
|
+
interpreter: Interpreter,
|
|
444
|
+
op: Operation,
|
|
445
|
+
cond: TensorValue | int,
|
|
446
|
+
true_val: Value,
|
|
447
|
+
false_val: Value,
|
|
448
|
+
) -> Value:
|
|
449
|
+
# Handle both TensorValue and raw scalar (from elementwise)
|
|
450
|
+
c: int
|
|
451
|
+
if isinstance(cond, TensorValue):
|
|
452
|
+
raw = cond.unwrap()
|
|
453
|
+
if hasattr(raw, "item"):
|
|
454
|
+
c = int(raw.item())
|
|
455
|
+
else:
|
|
456
|
+
c = int(raw)
|
|
457
|
+
else:
|
|
458
|
+
c = int(cond)
|
|
459
|
+
return true_val if c else false_val
|
|
460
|
+
|
|
461
|
+
|
|
462
|
+
# ==============================================================================
|
|
463
|
+
# --- KEM (Key Encapsulation Mechanism) Implementations
|
|
464
|
+
# ==============================================================================
|
|
465
|
+
|
|
466
|
+
|
|
467
|
+
@serde.register_class
|
|
468
|
+
@dataclass
|
|
469
|
+
class PrivateKeyValue(Value):
|
|
470
|
+
"""Runtime representation of a KEM private key.
|
|
471
|
+
|
|
472
|
+
This wraps the raw key bytes from a real cryptographic implementation
|
|
473
|
+
(e.g., X25519). The actual cryptographic operations use the `cryptography`
|
|
474
|
+
library which provides secure, audited implementations.
|
|
475
|
+
"""
|
|
476
|
+
|
|
477
|
+
_serde_kind: ClassVar[str] = "crypto_impl.PrivateKeyValue"
|
|
478
|
+
|
|
479
|
+
suite: str
|
|
480
|
+
key_bytes: bytes
|
|
481
|
+
|
|
482
|
+
def to_json(self) -> dict[str, Any]:
|
|
483
|
+
return {
|
|
484
|
+
"suite": self.suite,
|
|
485
|
+
"key_bytes": base64.b64encode(self.key_bytes).decode("ascii"),
|
|
486
|
+
}
|
|
487
|
+
|
|
488
|
+
@classmethod
|
|
489
|
+
def from_json(cls, data: dict[str, Any]) -> PrivateKeyValue:
|
|
490
|
+
return cls(
|
|
491
|
+
suite=data["suite"],
|
|
492
|
+
key_bytes=base64.b64decode(data["key_bytes"]),
|
|
493
|
+
)
|
|
494
|
+
|
|
495
|
+
|
|
496
|
+
@serde.register_class
|
|
497
|
+
@dataclass
|
|
498
|
+
class PublicKeyValue(Value):
|
|
499
|
+
"""Runtime representation of a KEM public key.
|
|
500
|
+
|
|
501
|
+
This wraps the raw key bytes from a real cryptographic implementation.
|
|
502
|
+
"""
|
|
503
|
+
|
|
504
|
+
_serde_kind: ClassVar[str] = "crypto_impl.PublicKeyValue"
|
|
505
|
+
|
|
506
|
+
suite: str
|
|
507
|
+
key_bytes: bytes
|
|
508
|
+
|
|
509
|
+
def to_json(self) -> dict[str, Any]:
|
|
510
|
+
return {
|
|
511
|
+
"suite": self.suite,
|
|
512
|
+
"key_bytes": base64.b64encode(self.key_bytes).decode("ascii"),
|
|
513
|
+
}
|
|
514
|
+
|
|
515
|
+
@classmethod
|
|
516
|
+
def from_json(cls, data: dict[str, Any]) -> PublicKeyValue:
|
|
517
|
+
return cls(
|
|
518
|
+
suite=data["suite"],
|
|
519
|
+
key_bytes=base64.b64decode(data["key_bytes"]),
|
|
520
|
+
)
|
|
521
|
+
|
|
522
|
+
|
|
523
|
+
@serde.register_class
|
|
524
|
+
@dataclass
|
|
525
|
+
class SymmetricKeyValue(Value):
|
|
526
|
+
"""Runtime representation of a symmetric encryption key.
|
|
527
|
+
|
|
528
|
+
This wraps the raw key bytes derived from ECDH key exchange.
|
|
529
|
+
The key is used with AES-256-GCM for authenticated encryption.
|
|
530
|
+
"""
|
|
531
|
+
|
|
532
|
+
_serde_kind: ClassVar[str] = "crypto_impl.SymmetricKeyValue"
|
|
533
|
+
|
|
534
|
+
suite: str
|
|
535
|
+
key_bytes: bytes
|
|
536
|
+
|
|
537
|
+
def to_json(self) -> dict[str, Any]:
|
|
538
|
+
return {
|
|
539
|
+
"suite": self.suite,
|
|
540
|
+
"key_bytes": base64.b64encode(self.key_bytes).decode("ascii"),
|
|
541
|
+
}
|
|
542
|
+
|
|
543
|
+
@classmethod
|
|
544
|
+
def from_json(cls, data: dict[str, Any]) -> SymmetricKeyValue:
|
|
545
|
+
return cls(
|
|
546
|
+
suite=data["suite"],
|
|
547
|
+
key_bytes=base64.b64decode(data["key_bytes"]),
|
|
548
|
+
)
|
|
549
|
+
|
|
550
|
+
|
|
551
|
+
@crypto.kem_keygen_p.def_impl
|
|
552
|
+
def kem_keygen_impl(
|
|
553
|
+
interpreter: Interpreter, op: Operation, suite: str = "x25519"
|
|
554
|
+
) -> tuple[PrivateKeyValue, PublicKeyValue]:
|
|
555
|
+
"""Generate a KEM key pair."""
|
|
556
|
+
if suite == "x25519":
|
|
557
|
+
from cryptography.hazmat.primitives.asymmetric.x25519 import X25519PrivateKey
|
|
558
|
+
|
|
559
|
+
private_key = X25519PrivateKey.generate()
|
|
560
|
+
public_key = private_key.public_key()
|
|
561
|
+
|
|
562
|
+
from cryptography.hazmat.primitives.serialization import (
|
|
563
|
+
Encoding,
|
|
564
|
+
NoEncryption,
|
|
565
|
+
PrivateFormat,
|
|
566
|
+
PublicFormat,
|
|
567
|
+
)
|
|
568
|
+
|
|
569
|
+
sk_bytes = private_key.private_bytes(
|
|
570
|
+
Encoding.Raw, PrivateFormat.Raw, NoEncryption()
|
|
571
|
+
)
|
|
572
|
+
pk_bytes = public_key.public_bytes(Encoding.Raw, PublicFormat.Raw)
|
|
573
|
+
|
|
574
|
+
return (
|
|
575
|
+
PrivateKeyValue(suite=suite, key_bytes=sk_bytes),
|
|
576
|
+
PublicKeyValue(suite=suite, key_bytes=pk_bytes),
|
|
577
|
+
)
|
|
578
|
+
else:
|
|
579
|
+
# Fallback to random bytes for unknown suites
|
|
580
|
+
sk_bytes = os.urandom(32)
|
|
581
|
+
pk_bytes = os.urandom(32)
|
|
582
|
+
return (
|
|
583
|
+
PrivateKeyValue(suite=suite, key_bytes=sk_bytes),
|
|
584
|
+
PublicKeyValue(suite=suite, key_bytes=pk_bytes),
|
|
585
|
+
)
|
|
586
|
+
|
|
587
|
+
|
|
588
|
+
@crypto.kem_derive_p.def_impl
|
|
589
|
+
def kem_derive_impl(
|
|
590
|
+
interpreter: Interpreter,
|
|
591
|
+
op: Operation,
|
|
592
|
+
private_key: PrivateKeyValue,
|
|
593
|
+
public_key: PublicKeyValue,
|
|
594
|
+
) -> SymmetricKeyValue:
|
|
595
|
+
"""Derive a symmetric key using ECDH."""
|
|
596
|
+
suite = getattr(private_key, "suite", "x25519")
|
|
597
|
+
|
|
598
|
+
if suite == "x25519":
|
|
599
|
+
from cryptography.hazmat.primitives.asymmetric.x25519 import (
|
|
600
|
+
X25519PrivateKey,
|
|
601
|
+
X25519PublicKey,
|
|
602
|
+
)
|
|
603
|
+
|
|
604
|
+
sk = X25519PrivateKey.from_private_bytes(private_key.key_bytes)
|
|
605
|
+
pk = X25519PublicKey.from_public_bytes(public_key.key_bytes)
|
|
606
|
+
shared_secret = sk.exchange(pk)
|
|
607
|
+
|
|
608
|
+
return SymmetricKeyValue(suite=suite, key_bytes=shared_secret)
|
|
609
|
+
else:
|
|
610
|
+
# Fallback for unknown suites: XOR the key bytes (not cryptographically secure)
|
|
611
|
+
sk_bytes = private_key.key_bytes
|
|
612
|
+
pk_bytes = public_key.key_bytes
|
|
613
|
+
secret = bytes(a ^ b for a, b in zip(sk_bytes, pk_bytes, strict=True))
|
|
614
|
+
return SymmetricKeyValue(suite=suite, key_bytes=secret)
|
|
615
|
+
|
|
616
|
+
|
|
617
|
+
@crypto.hkdf_p.def_impl
|
|
618
|
+
def hkdf_impl(
|
|
619
|
+
interpreter: Interpreter,
|
|
620
|
+
op: Operation,
|
|
621
|
+
secret: SymmetricKeyValue | TensorValue,
|
|
622
|
+
) -> SymmetricKeyValue:
|
|
623
|
+
"""HKDF key derivation implementation using SHA-256.
|
|
624
|
+
|
|
625
|
+
Implements RFC 5869 HKDF with HMAC-SHA256. This is the NIST SP 800-56C
|
|
626
|
+
compliant way to derive symmetric keys from ECDH shared secrets.
|
|
627
|
+
|
|
628
|
+
Current implementation supports only SHA-256. Future versions will add
|
|
629
|
+
SHA-512, SHA3-256, and BLAKE2b support.
|
|
630
|
+
|
|
631
|
+
Security Notes:
|
|
632
|
+
- Uses salt=None (defaults to 32-byte all-zero salt per RFC 5869)
|
|
633
|
+
- ONLY SAFE for high-entropy IKM (e.g., 256-bit ECDH shared secrets)
|
|
634
|
+
- NOT suitable for: passwords, low-entropy secrets, or repeated key derivations
|
|
635
|
+
- For session keys with same ECDH pair: use unique 'info' per session
|
|
636
|
+
|
|
637
|
+
Per NIST SP 800-56C Rev. 2:
|
|
638
|
+
"If the IKM is already cryptographically strong (e.g., from ECDH),
|
|
639
|
+
a salt may not be necessary, but using one does not hurt."
|
|
640
|
+
|
|
641
|
+
Args:
|
|
642
|
+
interpreter: Runtime interpreter context
|
|
643
|
+
op: Operation node containing attributes (info, hash_algo)
|
|
644
|
+
secret: Input key material (IKM) as SymmetricKeyValue or TensorValue
|
|
645
|
+
Must be high-entropy (≥256 bits) for security with salt=None
|
|
646
|
+
|
|
647
|
+
Returns:
|
|
648
|
+
SymmetricKeyValue with suite="hkdf-{hash_algo}" and 32-byte key_bytes
|
|
649
|
+
|
|
650
|
+
Raises:
|
|
651
|
+
TypeError: If secret is not SymmetricKeyValue or TensorValue
|
|
652
|
+
ValueError: If info parameter is empty (required for domain separation)
|
|
653
|
+
NotImplementedError: If hash_algo is not "sha256"
|
|
654
|
+
"""
|
|
655
|
+
from cryptography.hazmat.primitives import hashes
|
|
656
|
+
from cryptography.hazmat.primitives.kdf.hkdf import HKDF
|
|
657
|
+
|
|
658
|
+
# Extract operation attributes
|
|
659
|
+
info_str = op.attrs.get("info", "")
|
|
660
|
+
hash_algo = (
|
|
661
|
+
op.attrs.get("hash_algo", "sha256").lower().replace("-", "").replace("_", "")
|
|
662
|
+
)
|
|
663
|
+
|
|
664
|
+
# Validate info parameter (REQUIRED for domain separation per NIST)
|
|
665
|
+
if not info_str:
|
|
666
|
+
raise ValueError(
|
|
667
|
+
"HKDF requires non-empty 'info' parameter for domain separation. "
|
|
668
|
+
"The info string binds the derived key to a specific protocol/context. "
|
|
669
|
+
"Recommended format: 'namespace/component/purpose/version'"
|
|
670
|
+
)
|
|
671
|
+
|
|
672
|
+
info_bytes = info_str.encode("utf-8")
|
|
673
|
+
|
|
674
|
+
# Extract input key material (IKM) bytes
|
|
675
|
+
if isinstance(secret, SymmetricKeyValue):
|
|
676
|
+
ikm = secret.key_bytes
|
|
677
|
+
elif isinstance(secret, TensorValue):
|
|
678
|
+
ikm = secret.unwrap().tobytes()
|
|
679
|
+
else:
|
|
680
|
+
raise TypeError(
|
|
681
|
+
f"hkdf secret must be SymmetricKeyValue or TensorValue, "
|
|
682
|
+
f"got {type(secret).__name__}"
|
|
683
|
+
)
|
|
684
|
+
|
|
685
|
+
# Validate hash algorithm (currently only SHA-256 implemented)
|
|
686
|
+
if hash_algo != "sha256":
|
|
687
|
+
raise NotImplementedError(
|
|
688
|
+
f"HKDF with hash algorithm '{hash_algo}' is not yet implemented. "
|
|
689
|
+
f"Currently only 'sha256' is supported. "
|
|
690
|
+
f"Planned future support: sha512, sha3256, blake2b"
|
|
691
|
+
)
|
|
692
|
+
|
|
693
|
+
# Perform HKDF derivation using cryptography library
|
|
694
|
+
# Note: salt=None uses 32-byte all-zero salt (not random salt!)
|
|
695
|
+
# This is secure ONLY because ECDH outputs are already high-entropy (256-bit uniform)
|
|
696
|
+
# For low-entropy inputs or repeated derivations, a random salt would be required
|
|
697
|
+
hkdf = HKDF(
|
|
698
|
+
algorithm=hashes.SHA256(),
|
|
699
|
+
length=32, # Output length in bytes (AES-256 key = 32 bytes)
|
|
700
|
+
salt=None, # 32-byte zero salt (secure for high-entropy ECDH shared secrets)
|
|
701
|
+
info=info_bytes, # Context-specific binding for domain separation
|
|
702
|
+
)
|
|
703
|
+
|
|
704
|
+
derived_key = hkdf.derive(ikm)
|
|
705
|
+
|
|
706
|
+
# Return SymmetricKeyValue with composite suite name
|
|
707
|
+
# Format: "hkdf-{hash_algo}" to indicate derivation method and hash function
|
|
708
|
+
suite = f"hkdf-{hash_algo}"
|
|
709
|
+
return SymmetricKeyValue(suite=suite, key_bytes=derived_key)
|
|
710
|
+
|
|
711
|
+
|
|
712
|
+
@crypto.random_bytes_p.def_impl
|
|
713
|
+
def random_bytes_impl(interpreter: Interpreter, op: Operation) -> TensorValue:
|
|
714
|
+
"""Generate random bytes using os.urandom."""
|
|
715
|
+
# Length is passed as attribute
|
|
716
|
+
length = op.attrs["length"]
|
|
717
|
+
|
|
718
|
+
if not isinstance(length, int):
|
|
719
|
+
raise TypeError(f"random_bytes length must be int, got {type(length)}")
|
|
720
|
+
|
|
721
|
+
b = os.urandom(length)
|
|
722
|
+
arr = np.frombuffer(b, dtype=np.uint8).copy()
|
|
723
|
+
return TensorValue(arr)
|