mplang-nightly 0.1.dev192__py3-none-any.whl → 0.1.dev268__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mplang/__init__.py +21 -130
- mplang/py.typed +13 -0
- mplang/v1/__init__.py +157 -0
- mplang/v1/_device.py +602 -0
- mplang/{analysis → v1/analysis}/__init__.py +1 -1
- mplang/{analysis → v1/analysis}/diagram.py +4 -4
- mplang/{core → v1/core}/__init__.py +20 -14
- mplang/{core → v1/core}/cluster.py +6 -1
- mplang/{core → v1/core}/comm.py +1 -1
- mplang/{core → v1/core}/context_mgr.py +1 -1
- mplang/{core → v1/core}/dtypes.py +38 -0
- mplang/{core → v1/core}/expr/__init__.py +7 -7
- mplang/{core → v1/core}/expr/ast.py +11 -13
- mplang/{core → v1/core}/expr/evaluator.py +8 -8
- mplang/{core → v1/core}/expr/printer.py +6 -6
- mplang/{core → v1/core}/expr/transformer.py +2 -2
- mplang/{core → v1/core}/expr/utils.py +2 -2
- mplang/{core → v1/core}/expr/visitor.py +1 -1
- mplang/{core → v1/core}/expr/walk.py +1 -1
- mplang/{core → v1/core}/interp.py +6 -6
- mplang/{core → v1/core}/mpir.py +13 -11
- mplang/{core → v1/core}/mpobject.py +6 -6
- mplang/{core → v1/core}/mptype.py +13 -10
- mplang/{core → v1/core}/pfunc.py +2 -2
- mplang/{core → v1/core}/primitive.py +12 -12
- mplang/{core → v1/core}/table.py +36 -8
- mplang/{core → v1/core}/tensor.py +1 -1
- mplang/{core → v1/core}/tracer.py +9 -9
- mplang/{host.py → v1/host.py} +5 -5
- mplang/{kernels → v1/kernels}/__init__.py +1 -1
- mplang/{kernels → v1/kernels}/base.py +1 -1
- mplang/{kernels → v1/kernels}/basic.py +15 -15
- mplang/{kernels → v1/kernels}/context.py +19 -16
- mplang/{kernels → v1/kernels}/crypto.py +8 -10
- mplang/{kernels → v1/kernels}/fhe.py +9 -7
- mplang/{kernels → v1/kernels}/mock_tee.py +3 -3
- mplang/{kernels → v1/kernels}/phe.py +26 -18
- mplang/{kernels → v1/kernels}/spu.py +5 -5
- mplang/{kernels → v1/kernels}/sql_duckdb.py +5 -3
- mplang/{kernels → v1/kernels}/stablehlo.py +18 -17
- mplang/{kernels → v1/kernels}/value.py +2 -2
- mplang/{ops → v1/ops}/__init__.py +3 -3
- mplang/{ops → v1/ops}/base.py +1 -1
- mplang/{ops → v1/ops}/basic.py +6 -5
- mplang/v1/ops/crypto.py +262 -0
- mplang/{ops → v1/ops}/fhe.py +2 -2
- mplang/{ops → v1/ops}/jax_cc.py +26 -59
- mplang/v1/ops/nnx_cc.py +168 -0
- mplang/{ops → v1/ops}/phe.py +16 -3
- mplang/{ops → v1/ops}/spu.py +3 -3
- mplang/v1/ops/sql_cc.py +303 -0
- mplang/{ops → v1/ops}/tee.py +2 -2
- mplang/{runtime → v1/runtime}/__init__.py +2 -2
- mplang/v1/runtime/channel.py +230 -0
- mplang/{runtime → v1/runtime}/cli.py +3 -3
- mplang/{runtime → v1/runtime}/client.py +1 -1
- mplang/{runtime → v1/runtime}/communicator.py +39 -15
- mplang/{runtime → v1/runtime}/data_providers.py +80 -19
- mplang/{runtime → v1/runtime}/driver.py +4 -4
- mplang/v1/runtime/link_comm.py +196 -0
- mplang/{runtime → v1/runtime}/server.py +22 -9
- mplang/{runtime → v1/runtime}/session.py +24 -51
- mplang/{runtime → v1/runtime}/simulation.py +36 -14
- mplang/{simp → v1/simp}/api.py +72 -14
- mplang/{simp → v1/simp}/mpi.py +1 -1
- mplang/{simp → v1/simp}/party.py +5 -5
- mplang/{simp → v1/simp}/random.py +2 -2
- mplang/v1/simp/smpc.py +238 -0
- mplang/v1/utils/table_utils.py +185 -0
- mplang/v2/__init__.py +424 -0
- mplang/v2/backends/__init__.py +57 -0
- mplang/v2/backends/bfv_impl.py +705 -0
- mplang/v2/backends/channel.py +217 -0
- mplang/v2/backends/crypto_impl.py +723 -0
- mplang/v2/backends/field_impl.py +454 -0
- mplang/v2/backends/func_impl.py +107 -0
- mplang/v2/backends/phe_impl.py +148 -0
- mplang/v2/backends/simp_design.md +136 -0
- mplang/v2/backends/simp_driver/__init__.py +41 -0
- mplang/v2/backends/simp_driver/http.py +168 -0
- mplang/v2/backends/simp_driver/mem.py +280 -0
- mplang/v2/backends/simp_driver/ops.py +135 -0
- mplang/v2/backends/simp_driver/state.py +60 -0
- mplang/v2/backends/simp_driver/values.py +52 -0
- mplang/v2/backends/simp_worker/__init__.py +29 -0
- mplang/v2/backends/simp_worker/http.py +354 -0
- mplang/v2/backends/simp_worker/mem.py +102 -0
- mplang/v2/backends/simp_worker/ops.py +167 -0
- mplang/v2/backends/simp_worker/state.py +49 -0
- mplang/v2/backends/spu_impl.py +275 -0
- mplang/v2/backends/spu_state.py +187 -0
- mplang/v2/backends/store_impl.py +62 -0
- mplang/v2/backends/table_impl.py +838 -0
- mplang/v2/backends/tee_impl.py +215 -0
- mplang/v2/backends/tensor_impl.py +519 -0
- mplang/v2/cli.py +603 -0
- mplang/v2/cli_guide.md +122 -0
- mplang/v2/dialects/__init__.py +36 -0
- mplang/v2/dialects/bfv.py +665 -0
- mplang/v2/dialects/crypto.py +689 -0
- mplang/v2/dialects/dtypes.py +378 -0
- mplang/v2/dialects/field.py +210 -0
- mplang/v2/dialects/func.py +135 -0
- mplang/v2/dialects/phe.py +723 -0
- mplang/v2/dialects/simp.py +944 -0
- mplang/v2/dialects/spu.py +349 -0
- mplang/v2/dialects/store.py +63 -0
- mplang/v2/dialects/table.py +407 -0
- mplang/v2/dialects/tee.py +346 -0
- mplang/v2/dialects/tensor.py +1175 -0
- mplang/v2/edsl/README.md +279 -0
- mplang/v2/edsl/__init__.py +99 -0
- mplang/v2/edsl/context.py +311 -0
- mplang/v2/edsl/graph.py +463 -0
- mplang/v2/edsl/jit.py +62 -0
- mplang/v2/edsl/object.py +53 -0
- mplang/v2/edsl/primitive.py +284 -0
- mplang/v2/edsl/printer.py +119 -0
- mplang/v2/edsl/registry.py +207 -0
- mplang/v2/edsl/serde.py +375 -0
- mplang/v2/edsl/tracer.py +614 -0
- mplang/v2/edsl/typing.py +816 -0
- mplang/v2/kernels/Makefile +30 -0
- mplang/v2/kernels/__init__.py +23 -0
- mplang/v2/kernels/gf128.cpp +148 -0
- mplang/v2/kernels/ldpc.cpp +82 -0
- mplang/v2/kernels/okvs.cpp +283 -0
- mplang/v2/kernels/okvs_opt.cpp +291 -0
- mplang/v2/kernels/py_kernels.py +398 -0
- mplang/v2/libs/collective.py +330 -0
- mplang/v2/libs/device/__init__.py +51 -0
- mplang/v2/libs/device/api.py +813 -0
- mplang/v2/libs/device/cluster.py +352 -0
- mplang/v2/libs/ml/__init__.py +23 -0
- mplang/v2/libs/ml/sgb.py +1861 -0
- mplang/v2/libs/mpc/__init__.py +41 -0
- mplang/v2/libs/mpc/_utils.py +99 -0
- mplang/v2/libs/mpc/analytics/__init__.py +35 -0
- mplang/v2/libs/mpc/analytics/aggregation.py +372 -0
- mplang/v2/libs/mpc/analytics/groupby.md +99 -0
- mplang/v2/libs/mpc/analytics/groupby.py +331 -0
- mplang/v2/libs/mpc/analytics/permutation.py +386 -0
- mplang/v2/libs/mpc/common/constants.py +39 -0
- mplang/v2/libs/mpc/ot/__init__.py +32 -0
- mplang/v2/libs/mpc/ot/base.py +222 -0
- mplang/v2/libs/mpc/ot/extension.py +477 -0
- mplang/v2/libs/mpc/ot/silent.py +217 -0
- mplang/v2/libs/mpc/psi/__init__.py +40 -0
- mplang/v2/libs/mpc/psi/cuckoo.py +228 -0
- mplang/v2/libs/mpc/psi/okvs.py +49 -0
- mplang/v2/libs/mpc/psi/okvs_gct.py +79 -0
- mplang/v2/libs/mpc/psi/oprf.py +310 -0
- mplang/v2/libs/mpc/psi/rr22.py +344 -0
- mplang/v2/libs/mpc/psi/unbalanced.py +200 -0
- mplang/v2/libs/mpc/vole/__init__.py +31 -0
- mplang/v2/libs/mpc/vole/gilboa.py +327 -0
- mplang/v2/libs/mpc/vole/ldpc.py +383 -0
- mplang/v2/libs/mpc/vole/silver.py +336 -0
- mplang/v2/runtime/__init__.py +15 -0
- mplang/v2/runtime/dialect_state.py +41 -0
- mplang/v2/runtime/interpreter.py +871 -0
- mplang/v2/runtime/object_store.py +194 -0
- mplang/v2/runtime/value.py +141 -0
- {mplang_nightly-0.1.dev192.dist-info → mplang_nightly-0.1.dev268.dist-info}/METADATA +22 -16
- mplang_nightly-0.1.dev268.dist-info/RECORD +180 -0
- {mplang_nightly-0.1.dev192.dist-info → mplang_nightly-0.1.dev268.dist-info}/WHEEL +1 -1
- mplang/device.py +0 -327
- mplang/ops/crypto.py +0 -108
- mplang/ops/ibis_cc.py +0 -136
- mplang/ops/sql_cc.py +0 -62
- mplang/runtime/link_comm.py +0 -78
- mplang/simp/smpc.py +0 -201
- mplang/utils/table_utils.py +0 -85
- mplang_nightly-0.1.dev192.dist-info/RECORD +0 -83
- /mplang/{core → v1/core}/mask.py +0 -0
- /mplang/{protos → v1/protos}/v1alpha1/mpir_pb2.py +0 -0
- /mplang/{protos → v1/protos}/v1alpha1/mpir_pb2.pyi +0 -0
- /mplang/{protos → v1/protos}/v1alpha1/value_pb2.py +0 -0
- /mplang/{protos → v1/protos}/v1alpha1/value_pb2.pyi +0 -0
- /mplang/{runtime → v1/runtime}/exceptions.py +0 -0
- /mplang/{runtime → v1/runtime}/http_api.md +0 -0
- /mplang/{simp → v1/simp}/__init__.py +0 -0
- /mplang/{utils → v1/utils}/__init__.py +0 -0
- /mplang/{utils → v1/utils}/crypto.py +0 -0
- /mplang/{utils → v1/utils}/func_utils.py +0 -0
- /mplang/{utils → v1/utils}/spu_utils.py +0 -0
- {mplang_nightly-0.1.dev192.dist-info → mplang_nightly-0.1.dev268.dist-info}/entry_points.txt +0 -0
- {mplang_nightly-0.1.dev192.dist-info → mplang_nightly-0.1.dev268.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,689 @@
|
|
|
1
|
+
# Copyright 2025 Ant Group Co., Ltd.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
"""Crypto dialect for the EDSL.
|
|
16
|
+
|
|
17
|
+
Provides cryptographic primitives including ECC, Hashing, and Symmetric Encryption.
|
|
18
|
+
"""
|
|
19
|
+
|
|
20
|
+
from __future__ import annotations
|
|
21
|
+
|
|
22
|
+
from typing import Any, ClassVar
|
|
23
|
+
|
|
24
|
+
import mplang.v2.edsl as el
|
|
25
|
+
import mplang.v2.edsl.typing as elt
|
|
26
|
+
from mplang.v2.edsl import serde
|
|
27
|
+
|
|
28
|
+
# ==============================================================================
|
|
29
|
+
# --- Type Definitions
|
|
30
|
+
# ==============================================================================
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
@serde.register_class
|
|
34
|
+
class PointType(elt.BaseType):
|
|
35
|
+
"""Type for an ECC Point."""
|
|
36
|
+
|
|
37
|
+
def __init__(self, curve: str = "secp256k1"):
|
|
38
|
+
self.curve = curve
|
|
39
|
+
|
|
40
|
+
def __str__(self) -> str:
|
|
41
|
+
return f"Point[{self.curve}]"
|
|
42
|
+
|
|
43
|
+
def __eq__(self, other: object) -> bool:
|
|
44
|
+
if not isinstance(other, PointType):
|
|
45
|
+
return False
|
|
46
|
+
return self.curve == other.curve
|
|
47
|
+
|
|
48
|
+
def __hash__(self) -> int:
|
|
49
|
+
return hash(("PointType", self.curve))
|
|
50
|
+
|
|
51
|
+
# --- Serde methods ---
|
|
52
|
+
_serde_kind: ClassVar[str] = "crypto.PointType"
|
|
53
|
+
|
|
54
|
+
def to_json(self) -> dict[str, Any]:
|
|
55
|
+
return {"curve": self.curve}
|
|
56
|
+
|
|
57
|
+
@classmethod
|
|
58
|
+
def from_json(cls, data: dict[str, Any]) -> PointType:
|
|
59
|
+
return cls(curve=data["curve"])
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
@serde.register_class
|
|
63
|
+
class ScalarType(elt.BaseType):
|
|
64
|
+
"""Type for an ECC Scalar (integer modulo curve order)."""
|
|
65
|
+
|
|
66
|
+
def __init__(self, curve: str = "secp256k1"):
|
|
67
|
+
self.curve = curve
|
|
68
|
+
|
|
69
|
+
def __str__(self) -> str:
|
|
70
|
+
return f"Scalar[{self.curve}]"
|
|
71
|
+
|
|
72
|
+
def __eq__(self, other: object) -> bool:
|
|
73
|
+
if not isinstance(other, ScalarType):
|
|
74
|
+
return False
|
|
75
|
+
return self.curve == other.curve
|
|
76
|
+
|
|
77
|
+
def __hash__(self) -> int:
|
|
78
|
+
return hash(("ScalarType", self.curve))
|
|
79
|
+
|
|
80
|
+
# --- Serde methods ---
|
|
81
|
+
_serde_kind: ClassVar[str] = "crypto.ScalarType"
|
|
82
|
+
|
|
83
|
+
def to_json(self) -> dict[str, Any]:
|
|
84
|
+
return {"curve": self.curve}
|
|
85
|
+
|
|
86
|
+
@classmethod
|
|
87
|
+
def from_json(cls, data: dict[str, Any]) -> ScalarType:
|
|
88
|
+
return cls(curve=data["curve"])
|
|
89
|
+
|
|
90
|
+
|
|
91
|
+
@serde.register_class
|
|
92
|
+
class PrivateKeyType(elt.BaseType):
|
|
93
|
+
"""Type for a KEM private key."""
|
|
94
|
+
|
|
95
|
+
def __init__(self, suite: str = "x25519"):
|
|
96
|
+
self.suite = suite
|
|
97
|
+
|
|
98
|
+
def __str__(self) -> str:
|
|
99
|
+
return f"PrivateKey[{self.suite}]"
|
|
100
|
+
|
|
101
|
+
def __eq__(self, other: object) -> bool:
|
|
102
|
+
if not isinstance(other, PrivateKeyType):
|
|
103
|
+
return False
|
|
104
|
+
return self.suite == other.suite
|
|
105
|
+
|
|
106
|
+
def __hash__(self) -> int:
|
|
107
|
+
return hash(("PrivateKeyType", self.suite))
|
|
108
|
+
|
|
109
|
+
# --- Serde methods ---
|
|
110
|
+
_serde_kind: ClassVar[str] = "crypto.PrivateKeyType"
|
|
111
|
+
|
|
112
|
+
def to_json(self) -> dict[str, Any]:
|
|
113
|
+
return {"suite": self.suite}
|
|
114
|
+
|
|
115
|
+
@classmethod
|
|
116
|
+
def from_json(cls, data: dict[str, Any]) -> PrivateKeyType:
|
|
117
|
+
return cls(suite=data["suite"])
|
|
118
|
+
|
|
119
|
+
|
|
120
|
+
@serde.register_class
|
|
121
|
+
class PublicKeyType(elt.BaseType):
|
|
122
|
+
"""Type for a KEM public key."""
|
|
123
|
+
|
|
124
|
+
def __init__(self, suite: str = "x25519"):
|
|
125
|
+
self.suite = suite
|
|
126
|
+
|
|
127
|
+
def __str__(self) -> str:
|
|
128
|
+
return f"PublicKey[{self.suite}]"
|
|
129
|
+
|
|
130
|
+
def __eq__(self, other: object) -> bool:
|
|
131
|
+
if not isinstance(other, PublicKeyType):
|
|
132
|
+
return False
|
|
133
|
+
return self.suite == other.suite
|
|
134
|
+
|
|
135
|
+
def __hash__(self) -> int:
|
|
136
|
+
return hash(("PublicKeyType", self.suite))
|
|
137
|
+
|
|
138
|
+
# --- Serde methods ---
|
|
139
|
+
_serde_kind: ClassVar[str] = "crypto.PublicKeyType"
|
|
140
|
+
|
|
141
|
+
def to_json(self) -> dict[str, Any]:
|
|
142
|
+
return {"suite": self.suite}
|
|
143
|
+
|
|
144
|
+
@classmethod
|
|
145
|
+
def from_json(cls, data: dict[str, Any]) -> PublicKeyType:
|
|
146
|
+
return cls(suite=data["suite"])
|
|
147
|
+
|
|
148
|
+
|
|
149
|
+
@serde.register_class
|
|
150
|
+
class SymmetricKeyType(elt.BaseType):
|
|
151
|
+
"""Type for a symmetric encryption key (e.g., from KEM derive)."""
|
|
152
|
+
|
|
153
|
+
def __init__(self, suite: str = "x25519"):
|
|
154
|
+
self.suite = suite
|
|
155
|
+
|
|
156
|
+
def __str__(self) -> str:
|
|
157
|
+
return f"SymmetricKey[{self.suite}]"
|
|
158
|
+
|
|
159
|
+
def __eq__(self, other: object) -> bool:
|
|
160
|
+
if not isinstance(other, SymmetricKeyType):
|
|
161
|
+
return False
|
|
162
|
+
return self.suite == other.suite
|
|
163
|
+
|
|
164
|
+
def __hash__(self) -> int:
|
|
165
|
+
return hash(("SymmetricKeyType", self.suite))
|
|
166
|
+
|
|
167
|
+
# --- Serde methods ---
|
|
168
|
+
_serde_kind: ClassVar[str] = "crypto.SymmetricKeyType"
|
|
169
|
+
|
|
170
|
+
def to_json(self) -> dict[str, Any]:
|
|
171
|
+
return {"suite": self.suite}
|
|
172
|
+
|
|
173
|
+
@classmethod
|
|
174
|
+
def from_json(cls, data: dict[str, Any]) -> SymmetricKeyType:
|
|
175
|
+
return cls(suite=data["suite"])
|
|
176
|
+
|
|
177
|
+
|
|
178
|
+
# ==============================================================================
|
|
179
|
+
# --- Primitives
|
|
180
|
+
# ==============================================================================
|
|
181
|
+
|
|
182
|
+
# ECC
|
|
183
|
+
generator_p = el.Primitive[el.Object]("crypto.ec_generator")
|
|
184
|
+
mul_p = el.Primitive[el.Object]("crypto.ec_mul")
|
|
185
|
+
add_p = el.Primitive[el.Object]("crypto.ec_add")
|
|
186
|
+
sub_p = el.Primitive[el.Object]("crypto.ec_sub")
|
|
187
|
+
point_to_bytes_p = el.Primitive[el.Object]("crypto.ec_point_to_bytes")
|
|
188
|
+
random_scalar_p = el.Primitive[el.Object]("crypto.ec_random_scalar")
|
|
189
|
+
scalar_from_int_p = el.Primitive[el.Object]("crypto.ec_scalar_from_int")
|
|
190
|
+
|
|
191
|
+
# Symmetric / Hash
|
|
192
|
+
hash_p = el.Primitive[el.Object]("crypto.hash")
|
|
193
|
+
hash_batch_p = el.Primitive[el.Object]("crypto.hash_batch")
|
|
194
|
+
sym_encrypt_p = el.Primitive[el.Object]("crypto.sym_encrypt")
|
|
195
|
+
sym_decrypt_p = el.Primitive[el.Object]("crypto.sym_decrypt")
|
|
196
|
+
select_p = el.Primitive[el.Object]("crypto.select")
|
|
197
|
+
|
|
198
|
+
# KEM (Key Encapsulation Mechanism)
|
|
199
|
+
kem_keygen_p = el.Primitive[tuple[el.Object, el.Object]]("crypto.kem_keygen")
|
|
200
|
+
kem_derive_p = el.Primitive[el.Object]("crypto.kem_derive")
|
|
201
|
+
|
|
202
|
+
# HKDF (Key Derivation Function)
|
|
203
|
+
hkdf_p = el.Primitive[el.Object]("crypto.hkdf")
|
|
204
|
+
|
|
205
|
+
# Randomness
|
|
206
|
+
random_bytes_p = el.Primitive[el.Object]("crypto.random_bytes")
|
|
207
|
+
|
|
208
|
+
|
|
209
|
+
# ==============================================================================
|
|
210
|
+
# --- Abstract Evaluation (Type Inference)
|
|
211
|
+
# ==============================================================================
|
|
212
|
+
|
|
213
|
+
|
|
214
|
+
@generator_p.def_abstract_eval
|
|
215
|
+
def _generator_ae(curve: str = "secp256k1") -> PointType:
|
|
216
|
+
return PointType(curve)
|
|
217
|
+
|
|
218
|
+
|
|
219
|
+
@mul_p.def_abstract_eval
|
|
220
|
+
def _mul_ae(point: PointType, scalar: ScalarType) -> PointType:
|
|
221
|
+
return PointType(point.curve)
|
|
222
|
+
|
|
223
|
+
|
|
224
|
+
@add_p.def_abstract_eval
|
|
225
|
+
def _add_ae(p1: PointType, p2: PointType) -> PointType:
|
|
226
|
+
return PointType(p1.curve)
|
|
227
|
+
|
|
228
|
+
|
|
229
|
+
@sub_p.def_abstract_eval
|
|
230
|
+
def _sub_ae(p1: PointType, p2: PointType) -> PointType:
|
|
231
|
+
return PointType(p1.curve)
|
|
232
|
+
|
|
233
|
+
|
|
234
|
+
@point_to_bytes_p.def_abstract_eval
|
|
235
|
+
def _pt_to_bytes_ae(point: elt.BaseType) -> elt.TensorType:
|
|
236
|
+
if isinstance(point, elt.TensorType):
|
|
237
|
+
# Vectorized behavior: Tensor[Point, shape] -> Tensor[u8, shape + (65,)]
|
|
238
|
+
return elt.TensorType(elt.u8, (*point.shape, 65))
|
|
239
|
+
return elt.TensorType(elt.u8, (65,))
|
|
240
|
+
|
|
241
|
+
|
|
242
|
+
@random_scalar_p.def_abstract_eval
|
|
243
|
+
def _random_scalar_ae(curve: str = "secp256k1") -> ScalarType:
|
|
244
|
+
return ScalarType(curve)
|
|
245
|
+
|
|
246
|
+
|
|
247
|
+
@scalar_from_int_p.def_abstract_eval
|
|
248
|
+
def _scalar_from_int_ae(
|
|
249
|
+
val: elt.TensorType | elt.IntegerType, curve: str = "secp256k1"
|
|
250
|
+
) -> ScalarType:
|
|
251
|
+
return ScalarType(curve)
|
|
252
|
+
|
|
253
|
+
|
|
254
|
+
@hash_p.def_abstract_eval
|
|
255
|
+
def _hash_ae(data: elt.BaseType) -> elt.TensorType:
|
|
256
|
+
# Strictly single output (blob hash)
|
|
257
|
+
return elt.TensorType(elt.u8, (32,))
|
|
258
|
+
|
|
259
|
+
|
|
260
|
+
@hash_batch_p.def_abstract_eval
|
|
261
|
+
def _hash_batch_ae(data: elt.BaseType) -> elt.TensorType:
|
|
262
|
+
# Explicit batch hashing: Input (..., D) -> Output (..., 32)
|
|
263
|
+
# Hashes the last dimension D bytes.
|
|
264
|
+
if not isinstance(data, elt.TensorType):
|
|
265
|
+
raise TypeError(f"hash_batch requires TensorType, got {data}")
|
|
266
|
+
|
|
267
|
+
# data.shape is tuple[int | None, ...]
|
|
268
|
+
shape = data.shape
|
|
269
|
+
if len(shape) < 2:
|
|
270
|
+
# Fallback/Edge case: (D,) -> (32,)
|
|
271
|
+
# One could argue this should be an error for *batch* primitive,
|
|
272
|
+
# but allowing it provides consistency for (N=1, D).
|
|
273
|
+
return elt.TensorType(elt.u8, (32,))
|
|
274
|
+
|
|
275
|
+
# Batch shape is everything except last dim
|
|
276
|
+
batch_shape = shape[:-1]
|
|
277
|
+
return elt.TensorType(elt.u8, (*batch_shape, 32))
|
|
278
|
+
|
|
279
|
+
|
|
280
|
+
@sym_encrypt_p.def_abstract_eval
|
|
281
|
+
def _sym_encrypt_ae(
|
|
282
|
+
key: elt.BaseType, plaintext: elt.BaseType, *, algo: str = "aes-gcm"
|
|
283
|
+
) -> elt.TensorType:
|
|
284
|
+
"""Abstract evaluation for symmetric encryption.
|
|
285
|
+
|
|
286
|
+
Args:
|
|
287
|
+
key: Symmetric encryption key
|
|
288
|
+
plaintext: Data to encrypt
|
|
289
|
+
algo: Encryption algorithm (keyword-only, validated at runtime)
|
|
290
|
+
|
|
291
|
+
Returns:
|
|
292
|
+
Ciphertext as dynamic-length uint8 tensor
|
|
293
|
+
"""
|
|
294
|
+
# Dynamic shape for ciphertext
|
|
295
|
+
# algo validation is done at backend impl, not here
|
|
296
|
+
return elt.TensorType(elt.u8, (-1,))
|
|
297
|
+
|
|
298
|
+
|
|
299
|
+
@sym_decrypt_p.def_abstract_eval
|
|
300
|
+
def _sym_decrypt_ae(
|
|
301
|
+
key: elt.BaseType,
|
|
302
|
+
ciphertext: elt.BaseType,
|
|
303
|
+
*,
|
|
304
|
+
target_type: elt.BaseType,
|
|
305
|
+
algo: str = "aes-gcm",
|
|
306
|
+
) -> elt.BaseType:
|
|
307
|
+
"""Abstract evaluation for symmetric decryption.
|
|
308
|
+
|
|
309
|
+
Args:
|
|
310
|
+
key: Symmetric decryption key
|
|
311
|
+
ciphertext: Encrypted data
|
|
312
|
+
target_type: Expected type of decrypted plaintext (keyword-only)
|
|
313
|
+
algo: Decryption algorithm (keyword-only, validated at runtime)
|
|
314
|
+
|
|
315
|
+
Returns:
|
|
316
|
+
Decrypted plaintext with type matching target_type
|
|
317
|
+
"""
|
|
318
|
+
# algo validation is done at backend impl, not here
|
|
319
|
+
return target_type
|
|
320
|
+
|
|
321
|
+
|
|
322
|
+
@select_p.def_abstract_eval
|
|
323
|
+
def _select_ae(
|
|
324
|
+
cond: elt.BaseType, true_val: elt.BaseType, false_val: elt.BaseType
|
|
325
|
+
) -> elt.BaseType:
|
|
326
|
+
return true_val
|
|
327
|
+
|
|
328
|
+
|
|
329
|
+
@kem_keygen_p.def_abstract_eval
|
|
330
|
+
def _kem_keygen_ae(suite: str = "x25519") -> tuple[PrivateKeyType, PublicKeyType]:
|
|
331
|
+
return (PrivateKeyType(suite), PublicKeyType(suite))
|
|
332
|
+
|
|
333
|
+
|
|
334
|
+
@kem_derive_p.def_abstract_eval
|
|
335
|
+
def _kem_derive_ae(
|
|
336
|
+
private_key: PrivateKeyType, public_key: PublicKeyType
|
|
337
|
+
) -> SymmetricKeyType:
|
|
338
|
+
suite = getattr(private_key, "suite", "x25519")
|
|
339
|
+
return SymmetricKeyType(suite)
|
|
340
|
+
|
|
341
|
+
|
|
342
|
+
@hkdf_p.def_abstract_eval
|
|
343
|
+
def _hkdf_ae(
|
|
344
|
+
secret: elt.BaseType, *, info: str, hash_algo: str = "sha256"
|
|
345
|
+
) -> SymmetricKeyType:
|
|
346
|
+
"""Abstract evaluation for HKDF key derivation.
|
|
347
|
+
|
|
348
|
+
Args:
|
|
349
|
+
secret: Input key material (SymmetricKeyType from kem_derive or TensorType[u8])
|
|
350
|
+
info: Context string for domain separation (required, non-empty, keyword-only)
|
|
351
|
+
hash_algo: Hash algorithm in lowercase without hyphens (e.g., "sha256", keyword-only)
|
|
352
|
+
|
|
353
|
+
Returns:
|
|
354
|
+
SymmetricKeyType with suite="hkdf-{hash_algo}"
|
|
355
|
+
|
|
356
|
+
Raises:
|
|
357
|
+
TypeError: If info or hash_algo is not a string
|
|
358
|
+
ValueError: If info is empty (required for domain separation per NIST)
|
|
359
|
+
"""
|
|
360
|
+
# Validate info and hash_algo at trace time
|
|
361
|
+
if not isinstance(info, str) or not info:
|
|
362
|
+
raise ValueError(
|
|
363
|
+
"HKDF requires non-empty 'info' parameter for domain separation. "
|
|
364
|
+
"The info string binds the derived key to a specific protocol/context. "
|
|
365
|
+
"Recommended format: 'namespace/component/purpose/version'"
|
|
366
|
+
)
|
|
367
|
+
if not isinstance(hash_algo, str) or not hash_algo:
|
|
368
|
+
raise TypeError("hash_algo must be a non-empty string")
|
|
369
|
+
|
|
370
|
+
# Normalize: lowercase, no hyphens
|
|
371
|
+
hash_algo_normalized = hash_algo.lower().replace("-", "").replace("_", "")
|
|
372
|
+
|
|
373
|
+
# Return SymmetricKeyType with composite suite indicating derivation method
|
|
374
|
+
return SymmetricKeyType(suite=f"hkdf-{hash_algo_normalized}")
|
|
375
|
+
|
|
376
|
+
|
|
377
|
+
@random_bytes_p.def_abstract_eval
|
|
378
|
+
def _random_bytes_ae(length: int) -> elt.TensorType:
|
|
379
|
+
return elt.TensorType(elt.u8, (length,))
|
|
380
|
+
|
|
381
|
+
|
|
382
|
+
# ==============================================================================
|
|
383
|
+
# --- Helper Functions (Ops)
|
|
384
|
+
# ==============================================================================
|
|
385
|
+
|
|
386
|
+
|
|
387
|
+
def ec_generator(curve: str = "secp256k1") -> el.Object:
|
|
388
|
+
"""Get the generator point G for the curve."""
|
|
389
|
+
return generator_p.bind(curve=curve)
|
|
390
|
+
|
|
391
|
+
|
|
392
|
+
def ec_mul(point: el.Object, scalar: el.Object) -> el.Object:
|
|
393
|
+
"""Scalar multiplication: point * scalar."""
|
|
394
|
+
return mul_p.bind(point, scalar)
|
|
395
|
+
|
|
396
|
+
|
|
397
|
+
def ec_add(p1: el.Object, p2: el.Object) -> el.Object:
|
|
398
|
+
"""Point addition: p1 + p2."""
|
|
399
|
+
return add_p.bind(p1, p2)
|
|
400
|
+
|
|
401
|
+
|
|
402
|
+
def ec_sub(p1: el.Object, p2: el.Object) -> el.Object:
|
|
403
|
+
"""Point subtraction: p1 - p2."""
|
|
404
|
+
return sub_p.bind(p1, p2)
|
|
405
|
+
|
|
406
|
+
|
|
407
|
+
def ec_point_to_bytes(point: el.Object) -> el.Object:
|
|
408
|
+
"""Serialize point to bytes."""
|
|
409
|
+
return point_to_bytes_p.bind(point)
|
|
410
|
+
|
|
411
|
+
|
|
412
|
+
def ec_random_scalar(curve: str = "secp256k1") -> el.Object:
|
|
413
|
+
"""Generate a random scalar."""
|
|
414
|
+
return random_scalar_p.bind(curve=curve)
|
|
415
|
+
|
|
416
|
+
|
|
417
|
+
def ec_scalar_from_int(val: el.Object, curve: str = "secp256k1") -> el.Object:
|
|
418
|
+
"""Convert an integer tensor to a scalar."""
|
|
419
|
+
return scalar_from_int_p.bind(val, curve=curve)
|
|
420
|
+
|
|
421
|
+
|
|
422
|
+
def hash_bytes(data: el.Object) -> el.Object:
|
|
423
|
+
"""Hash bytes (SHA256). Returns 32-byte tensor."""
|
|
424
|
+
return hash_p.bind(data)
|
|
425
|
+
|
|
426
|
+
|
|
427
|
+
def hash_batch(data: el.Object) -> el.Object:
|
|
428
|
+
"""Hash each row of a tensor independently.
|
|
429
|
+
|
|
430
|
+
Treats the last dimension as the data to hash.
|
|
431
|
+
Input: (N, D) -> Output: (N, 32)
|
|
432
|
+
Input: (B, N, D) -> Output: (B, N, 32)
|
|
433
|
+
"""
|
|
434
|
+
return hash_batch_p.bind(data)
|
|
435
|
+
|
|
436
|
+
|
|
437
|
+
def sym_encrypt(
|
|
438
|
+
key: el.Object, plaintext: el.Object, *, algo: str = "aes-gcm"
|
|
439
|
+
) -> el.Object:
|
|
440
|
+
"""Symmetric encrypt.
|
|
441
|
+
|
|
442
|
+
Args:
|
|
443
|
+
key: Symmetric encryption key (SymmetricKeyType or bytes).
|
|
444
|
+
plaintext: Data to encrypt (any serializable object).
|
|
445
|
+
algo: Encryption algorithm. Currently only "aes-gcm" is supported.
|
|
446
|
+
Validation is performed at backend execution time.
|
|
447
|
+
|
|
448
|
+
Returns:
|
|
449
|
+
Ciphertext as Tensor[u8, (-1,)].
|
|
450
|
+
|
|
451
|
+
Raises:
|
|
452
|
+
ValueError: At runtime if algo is not supported by the backend.
|
|
453
|
+
"""
|
|
454
|
+
return sym_encrypt_p.bind(key, plaintext, algo=algo)
|
|
455
|
+
|
|
456
|
+
|
|
457
|
+
def sym_decrypt(
|
|
458
|
+
key: el.Object,
|
|
459
|
+
ciphertext: el.Object,
|
|
460
|
+
target_type: elt.BaseType,
|
|
461
|
+
*,
|
|
462
|
+
algo: str = "aes-gcm",
|
|
463
|
+
) -> el.Object:
|
|
464
|
+
"""Symmetric decrypt.
|
|
465
|
+
|
|
466
|
+
Args:
|
|
467
|
+
key: Symmetric decryption key (SymmetricKeyType or bytes).
|
|
468
|
+
ciphertext: Encrypted data.
|
|
469
|
+
target_type: Expected type of the decrypted plaintext (for type inference).
|
|
470
|
+
algo: Decryption algorithm. Must match the algorithm used for encryption.
|
|
471
|
+
Currently only "aes-gcm" is supported.
|
|
472
|
+
Validation is performed at backend execution time.
|
|
473
|
+
|
|
474
|
+
Returns:
|
|
475
|
+
Decrypted plaintext with type matching target_type.
|
|
476
|
+
|
|
477
|
+
Raises:
|
|
478
|
+
ValueError: At runtime if algo is not supported by the backend.
|
|
479
|
+
"""
|
|
480
|
+
return sym_decrypt_p.bind(key, ciphertext, target_type=target_type, algo=algo)
|
|
481
|
+
|
|
482
|
+
|
|
483
|
+
def select(cond: el.Object, true_val: el.Object, false_val: el.Object) -> el.Object:
|
|
484
|
+
"""Select between two values based on condition."""
|
|
485
|
+
return select_p.bind(cond, true_val, false_val)
|
|
486
|
+
|
|
487
|
+
|
|
488
|
+
def kem_keygen(suite: str = "x25519") -> tuple[el.Object, el.Object]:
|
|
489
|
+
"""Generate a KEM key pair (private_key, public_key).
|
|
490
|
+
|
|
491
|
+
Args:
|
|
492
|
+
suite: The KEM suite to use (e.g., "x25519", "kyber768")
|
|
493
|
+
|
|
494
|
+
Returns:
|
|
495
|
+
A tuple of (private_key, public_key)
|
|
496
|
+
"""
|
|
497
|
+
return kem_keygen_p.bind(suite=suite)
|
|
498
|
+
|
|
499
|
+
|
|
500
|
+
def kem_derive(private_key: el.Object, public_key: el.Object) -> el.Object:
|
|
501
|
+
"""Derive a symmetric key from a private key and a public key (ECDH).
|
|
502
|
+
|
|
503
|
+
Args:
|
|
504
|
+
private_key: The local private key
|
|
505
|
+
public_key: The remote party's public key
|
|
506
|
+
|
|
507
|
+
Returns:
|
|
508
|
+
A symmetric key suitable for use with sym_encrypt/sym_decrypt
|
|
509
|
+
"""
|
|
510
|
+
return kem_derive_p.bind(private_key, public_key)
|
|
511
|
+
|
|
512
|
+
|
|
513
|
+
def hkdf(secret: el.Object, info: str, *, hash_algo: str = "sha256") -> el.Object:
|
|
514
|
+
"""Derive a cryptographic key from input key material using HKDF.
|
|
515
|
+
|
|
516
|
+
HKDF (HMAC-based Key Derivation Function) is specified in RFC 5869 and
|
|
517
|
+
required by NIST SP 800-56C Rev.2 for deriving symmetric keys from
|
|
518
|
+
key agreement schemes like ECDH. Per NIST: "The shared secret output
|
|
519
|
+
from a key-agreement scheme SHALL NOT be used directly as a cryptographic
|
|
520
|
+
key. A key-derivation function (KDF) SHALL be used."
|
|
521
|
+
|
|
522
|
+
Args:
|
|
523
|
+
secret: Input key material (IKM). Accepts:
|
|
524
|
+
- SymmetricKeyValue: Typically from crypto.kem_derive (ECDH output)
|
|
525
|
+
- TensorType[u8, (N,)]: Raw bytes (N-byte secret)
|
|
526
|
+
info: Application-specific context string for domain separation.
|
|
527
|
+
REQUIRED and must be non-empty. Different info values produce
|
|
528
|
+
cryptographically independent keys even from the same secret.
|
|
529
|
+
Recommended format: "namespace/component/purpose/version"
|
|
530
|
+
Example: "mplang/device/tee/v2"
|
|
531
|
+
hash_algo: Hash function to use. Must be lowercase without hyphens.
|
|
532
|
+
Currently supported: "sha256" (default)
|
|
533
|
+
Future support planned: "sha512", "sha3256", "blake2b"
|
|
534
|
+
Default "sha256" provides 128-bit security level.
|
|
535
|
+
|
|
536
|
+
Returns:
|
|
537
|
+
SymmetricKeyValue with:
|
|
538
|
+
- suite: "hkdf-{hash_algo}" (e.g., "hkdf-sha256")
|
|
539
|
+
- key_bytes: 32-byte derived key suitable for AES-256-GCM
|
|
540
|
+
|
|
541
|
+
Security considerations:
|
|
542
|
+
- Output length: Fixed at 32 bytes (256 bits) for AES-256 keys
|
|
543
|
+
- Salt: Uses salt=None (acceptable for ECDH output per NIST guidance)
|
|
544
|
+
- Info: Provides protocol/context binding (domain separation)
|
|
545
|
+
- Deterministic: Same (secret, info, hash_algo) always produces same key
|
|
546
|
+
|
|
547
|
+
Raises:
|
|
548
|
+
ValueError:
|
|
549
|
+
- At abstract evaluation time if hash_algo is unsupported.
|
|
550
|
+
- At execution time if info is empty.
|
|
551
|
+
NotImplementedError:
|
|
552
|
+
- At execution time if hash_algo is not "sha256".
|
|
553
|
+
|
|
554
|
+
Examples:
|
|
555
|
+
>>> # Standard TEE session establishment
|
|
556
|
+
>>> sk_local, pk_local = crypto.kem_keygen("x25519")
|
|
557
|
+
>>> sk_remote, pk_remote = crypto.kem_keygen("x25519")
|
|
558
|
+
>>> # ECDH on both sides
|
|
559
|
+
>>> shared_local = crypto.kem_derive(sk_local, pk_remote)
|
|
560
|
+
>>> shared_remote = crypto.kem_derive(sk_remote, pk_local)
|
|
561
|
+
>>> # HKDF for domain separation
|
|
562
|
+
>>> sess_local = crypto.hkdf(shared_local, "mplang/device/tee/v2")
|
|
563
|
+
>>> sess_remote = crypto.hkdf(shared_remote, "mplang/device/tee/v2")
|
|
564
|
+
>>> # sess_local and sess_remote have identical key_bytes
|
|
565
|
+
>>> # but suite="hkdf-sha256" (not "x25519")
|
|
566
|
+
>>>
|
|
567
|
+
>>> # Derive multiple independent keys from one master secret
|
|
568
|
+
>>> master_secret = crypto.kem_derive(sk, pk)
|
|
569
|
+
>>> encryption_key = crypto.hkdf(master_secret, "app/encryption/v1")
|
|
570
|
+
>>> mac_key = crypto.hkdf(master_secret, "app/mac/v1")
|
|
571
|
+
>>> # encryption_key ≠ mac_key due to different info strings
|
|
572
|
+
"""
|
|
573
|
+
return hkdf_p.bind(secret, info=info, hash_algo=hash_algo)
|
|
574
|
+
|
|
575
|
+
|
|
576
|
+
def random_bytes(length: int) -> el.Object:
|
|
577
|
+
"""Generate cryptographically secure random bytes at runtime.
|
|
578
|
+
|
|
579
|
+
Args:
|
|
580
|
+
length: Number of bytes to generate.
|
|
581
|
+
|
|
582
|
+
Returns:
|
|
583
|
+
(length,) uint8 Tensor.
|
|
584
|
+
"""
|
|
585
|
+
return random_bytes_p.bind(length=length)
|
|
586
|
+
|
|
587
|
+
|
|
588
|
+
def random_tensor(shape: tuple[int, ...], dtype: elt.ScalarType) -> el.Object:
|
|
589
|
+
"""Generate cryptographically secure random tensor at runtime.
|
|
590
|
+
|
|
591
|
+
This is a helper function that composes `random_bytes` with `tensor.run_jax`
|
|
592
|
+
to produce a tensor of the specified shape and dtype.
|
|
593
|
+
|
|
594
|
+
Args:
|
|
595
|
+
shape: Output tensor shape (e.g., (100,) or (10, 16)).
|
|
596
|
+
dtype: Element type (e.g., elt.u64, elt.i32, elt.f32).
|
|
597
|
+
|
|
598
|
+
Returns:
|
|
599
|
+
Tensor[dtype, shape] with CSPRNG values.
|
|
600
|
+
|
|
601
|
+
Example:
|
|
602
|
+
>>> # Generate 100 random uint64 values
|
|
603
|
+
>>> x = crypto.random_tensor((100,), elt.u64)
|
|
604
|
+
>>> # Generate 10x16 random int32 matrix
|
|
605
|
+
>>> y = crypto.random_tensor((10, 16), elt.i32)
|
|
606
|
+
"""
|
|
607
|
+
import math
|
|
608
|
+
from typing import cast
|
|
609
|
+
|
|
610
|
+
from mplang.v2.dialects import dtypes, tensor
|
|
611
|
+
|
|
612
|
+
# Get byte size from numpy dtype
|
|
613
|
+
np_dtype = dtypes.to_numpy(dtype)
|
|
614
|
+
element_bytes = np_dtype.itemsize
|
|
615
|
+
total_elements = math.prod(shape)
|
|
616
|
+
total_bytes = total_elements * element_bytes
|
|
617
|
+
|
|
618
|
+
raw = random_bytes(total_bytes)
|
|
619
|
+
|
|
620
|
+
jax_dtype = dtypes.to_jax(dtype)
|
|
621
|
+
|
|
622
|
+
def _view_reshape(b: Any) -> Any:
|
|
623
|
+
return b.view(jax_dtype).reshape(shape)
|
|
624
|
+
|
|
625
|
+
return cast(el.Object, tensor.run_jax(_view_reshape, raw))
|
|
626
|
+
|
|
627
|
+
|
|
628
|
+
def random_bits(n: int) -> el.Object:
|
|
629
|
+
"""Generate n cryptographically secure random bits at runtime.
|
|
630
|
+
|
|
631
|
+
Each bit is stored as a uint8 with value 0 or 1 (unpacked representation).
|
|
632
|
+
|
|
633
|
+
Args:
|
|
634
|
+
n: Number of random bits to generate.
|
|
635
|
+
|
|
636
|
+
Returns:
|
|
637
|
+
(n,) uint8 Tensor with values 0 or 1.
|
|
638
|
+
|
|
639
|
+
Example:
|
|
640
|
+
>>> # Generate 1024 random bits for OT selection
|
|
641
|
+
>>> choice_bits = crypto.random_bits(1024)
|
|
642
|
+
"""
|
|
643
|
+
from typing import cast
|
|
644
|
+
|
|
645
|
+
import jax.numpy as jnp
|
|
646
|
+
|
|
647
|
+
from mplang.v2.dialects import tensor
|
|
648
|
+
|
|
649
|
+
# Generate enough bytes to cover n bits
|
|
650
|
+
num_bytes = (n + 7) // 8
|
|
651
|
+
raw = random_bytes(num_bytes)
|
|
652
|
+
|
|
653
|
+
def _unpack_and_slice(b: Any, n: int = n) -> Any:
|
|
654
|
+
bits = jnp.unpackbits(b, bitorder="little")
|
|
655
|
+
return bits[:n]
|
|
656
|
+
|
|
657
|
+
return cast(el.Object, tensor.run_jax(_unpack_and_slice, raw))
|
|
658
|
+
|
|
659
|
+
|
|
660
|
+
# --- Bytes <-> Point Conversions ---
|
|
661
|
+
|
|
662
|
+
bytes_to_point_p = el.Primitive[el.Object]("crypto.ec_bytes_to_point")
|
|
663
|
+
|
|
664
|
+
|
|
665
|
+
@bytes_to_point_p.def_abstract_eval
|
|
666
|
+
def _bytes_to_point_ae(b: elt.TensorType) -> PointType:
|
|
667
|
+
return PointType("secp256k1")
|
|
668
|
+
|
|
669
|
+
|
|
670
|
+
def ec_bytes_to_point(b: el.Object) -> el.Object:
|
|
671
|
+
"""
|
|
672
|
+
Deserialize bytes to an ECC point.
|
|
673
|
+
|
|
674
|
+
Args:
|
|
675
|
+
b: A (65,) uint8 Tensor representing an uncompressed point in SEC1 format.
|
|
676
|
+
The first byte must be 0x04, followed by 32 bytes for X and 32 bytes for Y.
|
|
677
|
+
|
|
678
|
+
Returns:
|
|
679
|
+
An ECC point object corresponding to the input bytes.
|
|
680
|
+
|
|
681
|
+
Raises:
|
|
682
|
+
ValueError: If the input is not a valid 65-byte uncompressed point representation.
|
|
683
|
+
|
|
684
|
+
Example:
|
|
685
|
+
>>> # Example: Deserialize a point from bytes
|
|
686
|
+
>>> point_bytes = jnp.array([0x04] + [0x01] * 32 + [0x02] * 32, dtype=jnp.uint8)
|
|
687
|
+
>>> point = crypto.ec_bytes_to_point(point_bytes)
|
|
688
|
+
"""
|
|
689
|
+
return bytes_to_point_p.bind(b)
|