fips-collection 1.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- fips/FIPS204/__init__.py +10 -0
- fips/FIPS204/auxilary.py +369 -0
- fips/FIPS204/encode.py +300 -0
- fips/FIPS204/hash.py +39 -0
- fips/FIPS204/hint.py +200 -0
- fips/FIPS204/main.py +439 -0
- fips/FIPS204/ntt.py +278 -0
- fips/FIPS204/pack.py +276 -0
- fips/FIPS204/parameter.py +50 -0
- fips/FIPS204/sample.py +291 -0
- fips/__init__.py +9 -0
- fips_collection-1.0.0.dist-info/METADATA +114 -0
- fips_collection-1.0.0.dist-info/RECORD +16 -0
- fips_collection-1.0.0.dist-info/WHEEL +5 -0
- fips_collection-1.0.0.dist-info/licenses/LICENSE +21 -0
- fips_collection-1.0.0.dist-info/top_level.txt +1 -0
fips/FIPS204/__init__.py
ADDED
fips/FIPS204/auxilary.py
ADDED
|
@@ -0,0 +1,369 @@
|
|
|
1
|
+
class AUXILARY:
|
|
2
|
+
"""
|
|
3
|
+
This class provides subroutines utilized by MLDSA, including function for data-type converstions and arithmetic.
|
|
4
|
+
"""
|
|
5
|
+
|
|
6
|
+
def __init__(self, parameter: dict[str, int]):
|
|
7
|
+
self.q = parameter["q"]
|
|
8
|
+
self.N = parameter["N"]
|
|
9
|
+
self.eta = parameter["eta"]
|
|
10
|
+
|
|
11
|
+
def IntegerToBits(self, x: int, alpha: int):
|
|
12
|
+
"""
|
|
13
|
+
Algorithm 9
|
|
14
|
+
|
|
15
|
+
Computes the base-2 representation of x mod 2^alpha in ``little-endian`` order.
|
|
16
|
+
|
|
17
|
+
Args:
|
|
18
|
+
x (``int``): A ``non-negative`` integer.
|
|
19
|
+
alpha (``int``): Number of ``bits`` to represent.
|
|
20
|
+
|
|
21
|
+
Returns:
|
|
22
|
+
str (``bits``): ``Bitstring`` of length ``alpha`` in ``little-endian`` order.
|
|
23
|
+
|
|
24
|
+
Raises:
|
|
25
|
+
ValueError: If x is ``negative`` or alpha is not a positive integer.
|
|
26
|
+
ValueError: If x is ``too big`` to be represented in alpha bits.
|
|
27
|
+
TypeError: If x or alpha is ``not an integer``.
|
|
28
|
+
"""
|
|
29
|
+
|
|
30
|
+
if x < 0:
|
|
31
|
+
raise ValueError("x must be non-negative.")
|
|
32
|
+
if alpha <= 0:
|
|
33
|
+
raise ValueError("alpha must be a positive integer.")
|
|
34
|
+
if x >= 2 ** alpha:
|
|
35
|
+
raise ValueError(f"x = {x} cannot be represented in {alpha} bits.")
|
|
36
|
+
|
|
37
|
+
x_mod = x
|
|
38
|
+
bits:list[str] = []
|
|
39
|
+
for _ in range(alpha):
|
|
40
|
+
bits.append(str(x_mod % 2))
|
|
41
|
+
x_mod //= 2
|
|
42
|
+
|
|
43
|
+
return ''.join(bits)
|
|
44
|
+
|
|
45
|
+
def BitsToInteger(self, y: str, alpha: int) -> int:
|
|
46
|
+
"""
|
|
47
|
+
Algorithm 10
|
|
48
|
+
|
|
49
|
+
Computes the ``integer`` value expressed by a bit string using ``little-endian`` order.
|
|
50
|
+
|
|
51
|
+
Args:
|
|
52
|
+
y (``bitstring``): ``Bitstring`` to convert to an integer.
|
|
53
|
+
alpha (``int``): Number of ``bits`` to consider.
|
|
54
|
+
|
|
55
|
+
Returns:
|
|
56
|
+
integer (``int``): The ``integer`` value represented by the ``bitstring``.
|
|
57
|
+
|
|
58
|
+
Raises:
|
|
59
|
+
ValueError: If the length of ``y`` does not match alpha or if ``y`` contains invalid characters.
|
|
60
|
+
TypeError: If ``y`` is not a ``string`` or alpha is not an ``integer``.
|
|
61
|
+
"""
|
|
62
|
+
if alpha <= 0:
|
|
63
|
+
raise ValueError("alpha must be a positive integer.")
|
|
64
|
+
if len(y) != alpha:
|
|
65
|
+
raise ValueError(f"Bit string y must have exactly {alpha} bits.")
|
|
66
|
+
if any(bit not in "01" for bit in y):
|
|
67
|
+
raise ValueError("Bit string y must contain only '0' and '1' characters.")
|
|
68
|
+
|
|
69
|
+
x = 0
|
|
70
|
+
for i in range(1, alpha + 1):
|
|
71
|
+
bit = int(y[alpha - i])
|
|
72
|
+
x = 2 * x + bit
|
|
73
|
+
return x
|
|
74
|
+
|
|
75
|
+
def IntegerToBytes(self, x: int, alpha: int) -> bytes:
|
|
76
|
+
"""
|
|
77
|
+
Algorithm 11
|
|
78
|
+
|
|
79
|
+
Computes a base-256 representation of x mod 256^alpha in ``little-endian`` order.
|
|
80
|
+
|
|
81
|
+
Args:
|
|
82
|
+
x (``int``): A ``non-negative`` integer.
|
|
83
|
+
alpha (``int``): Number of ``bytes`` in the output.
|
|
84
|
+
|
|
85
|
+
Returns:
|
|
86
|
+
bytes (``bytes``): ``Bytestring`` of length alpha in ``little-endian`` order.
|
|
87
|
+
|
|
88
|
+
Raises:
|
|
89
|
+
ValueError: x = _ cannot be represented in _ bytes.
|
|
90
|
+
"""
|
|
91
|
+
try:
|
|
92
|
+
return x.to_bytes(alpha, byteorder = "little")
|
|
93
|
+
except OverflowError:
|
|
94
|
+
raise ValueError(f"x = {x} cannot be represented in {alpha} bytes.")
|
|
95
|
+
|
|
96
|
+
def BitsToBytes(self, y: str) -> bytes:
|
|
97
|
+
"""
|
|
98
|
+
Algorithm 12
|
|
99
|
+
|
|
100
|
+
Converts a ``bitstring`` y into a ``bytestring`` using ``little-endian`` order.
|
|
101
|
+
|
|
102
|
+
Args:
|
|
103
|
+
y (``str``): ``Bitstring`` consisting of ``0`` and ``1``.
|
|
104
|
+
|
|
105
|
+
Returns:
|
|
106
|
+
bytes (``bytes``): ``Bytestring`` of length ceil ``(len(y) / 8)``.
|
|
107
|
+
|
|
108
|
+
Raises:
|
|
109
|
+
TypeError: If y is ``not a string``.
|
|
110
|
+
ValueError: If y contains characters other than ``0`` or ``1``.
|
|
111
|
+
"""
|
|
112
|
+
if any(bit not in "01" for bit in y):
|
|
113
|
+
raise ValueError("Bit string y must contain only '0' and '1'.")
|
|
114
|
+
|
|
115
|
+
alpha = len(y)
|
|
116
|
+
byte_len = (alpha + 7) // 8 # Equivalent to ceil(alpha / 8)
|
|
117
|
+
z = [0] * byte_len
|
|
118
|
+
|
|
119
|
+
for i in range(alpha):
|
|
120
|
+
byte_index = i // 8
|
|
121
|
+
bit_index = i % 8
|
|
122
|
+
z[byte_index] |= int(y[i]) << bit_index
|
|
123
|
+
|
|
124
|
+
return bytes(z)
|
|
125
|
+
|
|
126
|
+
def BytesToBits(self, z: bytes) -> str:
|
|
127
|
+
"""
|
|
128
|
+
Algorithm 13
|
|
129
|
+
|
|
130
|
+
Converts a ``bytestring`` ``z`` into a ``bitstring`` in ``little-endian`` order.
|
|
131
|
+
|
|
132
|
+
Args:
|
|
133
|
+
z (``bytes``): A ``bytestring``.
|
|
134
|
+
|
|
135
|
+
Returns:
|
|
136
|
+
str (bit string): A ``bitstring`` of length ``8 * len(z)``, in ``little-endian`` order.
|
|
137
|
+
|
|
138
|
+
Raises:
|
|
139
|
+
TypeError: If ``z`` is not a ``bytes`` object.
|
|
140
|
+
"""
|
|
141
|
+
if not isinstance(z, (bytes, bytearray)):
|
|
142
|
+
raise TypeError("Input z must be a bytes or bytearray object.")
|
|
143
|
+
|
|
144
|
+
bits: list[str] = []
|
|
145
|
+
for byte in z:
|
|
146
|
+
for i in range(8):
|
|
147
|
+
bits.append(str((byte >> i) & 1)) # Little-endian bit order
|
|
148
|
+
|
|
149
|
+
return ''.join(bits)
|
|
150
|
+
|
|
151
|
+
def CoeffFromThreeBytes(self, b0: int, b1: int, b2:int) -> int | None:
|
|
152
|
+
"""
|
|
153
|
+
Algorithm 14
|
|
154
|
+
|
|
155
|
+
Generates an element of {``0``, ``1``, ``2``, ... , ``q - 1``} U { ``None`` }
|
|
156
|
+
|
|
157
|
+
Args:
|
|
158
|
+
b0 (``int``): first byte
|
|
159
|
+
b1 (``int``): second byte
|
|
160
|
+
b2 (``int``): third byte
|
|
161
|
+
|
|
162
|
+
Returns:
|
|
163
|
+
z (``int``): sampled coefficient or ``None`` if rejected.
|
|
164
|
+
|
|
165
|
+
Raises:
|
|
166
|
+
TypeError: if any of ``b0``, ``b1``, ``b2`` is ``not an integer``.
|
|
167
|
+
"""
|
|
168
|
+
# checks for validity of inputs.
|
|
169
|
+
for i, b in enumerate((b0, b1, b2), start = 0):
|
|
170
|
+
if not (0 <= b <= 255):
|
|
171
|
+
raise ValueError (f"b{i} must be in the range 0 - 255.")
|
|
172
|
+
|
|
173
|
+
# line 1: make a copy of b2.
|
|
174
|
+
b2_prime = b2
|
|
175
|
+
|
|
176
|
+
# line 2 to 4: making sure b2_prime is 7 bits, not 8.
|
|
177
|
+
if b2_prime > 127:
|
|
178
|
+
b2_prime = b2_prime - 128
|
|
179
|
+
|
|
180
|
+
# line 5: evaluate z for sampling.
|
|
181
|
+
z = (b2_prime << 16) + (b1 << 8) + b0
|
|
182
|
+
|
|
183
|
+
# line 6 to 8: reject the sample z if it's greater than q.
|
|
184
|
+
if z < self.q:
|
|
185
|
+
return z # accept sample
|
|
186
|
+
else:
|
|
187
|
+
return None # reject sample
|
|
188
|
+
|
|
189
|
+
def CoeffFromHalfByte(self, b: int) -> int | None:
|
|
190
|
+
"""
|
|
191
|
+
Algorithm 15
|
|
192
|
+
|
|
193
|
+
Let ``eta`` ∈ {2, 4}.
|
|
194
|
+
|
|
195
|
+
Generates an element of {``-eta``, ``-eta + 1``, ... , ``eta``} U { ``None`` }
|
|
196
|
+
|
|
197
|
+
Args:
|
|
198
|
+
b (``int``): an integer in the range ``0 - 15``.
|
|
199
|
+
|
|
200
|
+
Returns:
|
|
201
|
+
z (``int``): sampled coefficient or ``None`` if rejected.
|
|
202
|
+
|
|
203
|
+
Raises:
|
|
204
|
+
TypeError: if b is ``not an integer``.
|
|
205
|
+
ValueError: if b is ``not`` in the range ``0 - 15``.
|
|
206
|
+
"""
|
|
207
|
+
if not (0 <= b <= 15):
|
|
208
|
+
raise ValueError (f"{b} must be in the range 0 - 15.")
|
|
209
|
+
|
|
210
|
+
# line 1 and 2: rejection sampline from {-2, ... , 2 }
|
|
211
|
+
if self.eta == 2 and b < 15:
|
|
212
|
+
return 2 - (b % 5)
|
|
213
|
+
|
|
214
|
+
# line 3: rejection sampline from {-4, ... , 4 }
|
|
215
|
+
elif self.eta == 4 and b < 9:
|
|
216
|
+
return 4 - b
|
|
217
|
+
|
|
218
|
+
# line 4: sample is just rejected.
|
|
219
|
+
else:
|
|
220
|
+
return None
|
|
221
|
+
|
|
222
|
+
def CenteredModulus(self, z: int) -> int:
|
|
223
|
+
"""
|
|
224
|
+
Additional Helper Function 1
|
|
225
|
+
|
|
226
|
+
Computes the centered modulus ``z mod± q``.
|
|
227
|
+
|
|
228
|
+
Maps each integer ``x`` to the unique ``r`` in ::
|
|
229
|
+
|
|
230
|
+
[-(q-1)/2, (q-1)/2]
|
|
231
|
+
|
|
232
|
+
such that ::
|
|
233
|
+
|
|
234
|
+
x ≡ r (mod q).
|
|
235
|
+
|
|
236
|
+
Args:
|
|
237
|
+
z(``int``): An integer
|
|
238
|
+
|
|
239
|
+
Returns:
|
|
240
|
+
CenteredModulus(``int``)
|
|
241
|
+
|
|
242
|
+
Raises:
|
|
243
|
+
TypeError: If ``z`` is ``not an integer``.
|
|
244
|
+
"""
|
|
245
|
+
half_q = (self.q - 1) // 2
|
|
246
|
+
|
|
247
|
+
return (z + half_q) % self.q - half_q
|
|
248
|
+
|
|
249
|
+
def CenteredModulusList(self, z: list[int]) -> list[int]:
|
|
250
|
+
"""
|
|
251
|
+
Additional Helper Function 2
|
|
252
|
+
|
|
253
|
+
Computes the centered modulus ``z mod± q`` for a ``list`` of ``Integers``.
|
|
254
|
+
|
|
255
|
+
Maps each integer ``x`` to the unique ``r`` in ::
|
|
256
|
+
|
|
257
|
+
[-(q-1)/2, (q-1)/2]
|
|
258
|
+
|
|
259
|
+
such that ::
|
|
260
|
+
|
|
261
|
+
x ≡ r (mod q).
|
|
262
|
+
|
|
263
|
+
Args:
|
|
264
|
+
z(``list[int]``): A list of ``integers``.
|
|
265
|
+
|
|
266
|
+
Returns:
|
|
267
|
+
CenteredModulus(``list[int]``)
|
|
268
|
+
|
|
269
|
+
Raises:
|
|
270
|
+
TypeError: If ``z`` is not a ``list[int]``.
|
|
271
|
+
"""
|
|
272
|
+
half_q = (self.q - 1) // 2
|
|
273
|
+
|
|
274
|
+
return [(x + half_q) % self.q - half_q for x in z]
|
|
275
|
+
|
|
276
|
+
def CenteredModulusMatrix(self, z: list[list[int]]) -> list[list[int]]:
|
|
277
|
+
"""
|
|
278
|
+
Additional Helper Function 3
|
|
279
|
+
|
|
280
|
+
Computes the centered modulus ``z mod± q`` for a ``matrix`` of ``Integers``.
|
|
281
|
+
|
|
282
|
+
Maps each integer ``x`` to the unique ``r`` in ::
|
|
283
|
+
|
|
284
|
+
[-(q-1)/2, (q-1)/2]
|
|
285
|
+
|
|
286
|
+
such that ::
|
|
287
|
+
|
|
288
|
+
x ≡ r (mod q).
|
|
289
|
+
|
|
290
|
+
Args:
|
|
291
|
+
z(``list[list[int]]``): A matrix of ``integers``.
|
|
292
|
+
|
|
293
|
+
Returns:
|
|
294
|
+
CenteredModulus(``list[list[int]]``)
|
|
295
|
+
|
|
296
|
+
Raises:
|
|
297
|
+
TypeError: If ``z`` is not a ``list[list[int]]``.
|
|
298
|
+
"""
|
|
299
|
+
|
|
300
|
+
half_q = (self.q - 1) // 2
|
|
301
|
+
|
|
302
|
+
return [[(x + half_q) % self.q - half_q for x in z[k]] for k in range(len(z))]
|
|
303
|
+
|
|
304
|
+
def abs_for_list (self, z: list[int]) -> list[int]:
|
|
305
|
+
"""
|
|
306
|
+
Additional Helper Function 4
|
|
307
|
+
|
|
308
|
+
Computes the absolute values of a ``list`` of integers.
|
|
309
|
+
|
|
310
|
+
Args:
|
|
311
|
+
z (``list[int]``): A ``list`` of integers.
|
|
312
|
+
|
|
313
|
+
Returns:
|
|
314
|
+
|z| (``list[int]``): A ``list`` containing the ``absolute`` values of the input ``integers``.
|
|
315
|
+
"""
|
|
316
|
+
for p in range (len(z)):
|
|
317
|
+
z[p] = abs(z[p])
|
|
318
|
+
|
|
319
|
+
return z
|
|
320
|
+
|
|
321
|
+
def InfinityNorm(self, z: list[list[int]]) -> int:
|
|
322
|
+
"""
|
|
323
|
+
Additional Helper Function 5
|
|
324
|
+
|
|
325
|
+
Compute the ``L-Infinity Norm`` of a ``Matrix``.
|
|
326
|
+
|
|
327
|
+
Args:
|
|
328
|
+
z (``list[list[int]]``): A ``Matrix`` of integers.
|
|
329
|
+
|
|
330
|
+
Returns:
|
|
331
|
+
infinity_norm (``int``): The infinity norm (``maximum`` absolute value) among all lists.
|
|
332
|
+
|
|
333
|
+
Raises:
|
|
334
|
+
TypeError: If ``z`` is not a list of lists of ``integers``.
|
|
335
|
+
TypeError: If any element in the sublists is ``not an integer``.
|
|
336
|
+
TypeError: If elements of ``z[x]`` are not lists.
|
|
337
|
+
"""
|
|
338
|
+
max_value = 0
|
|
339
|
+
|
|
340
|
+
for i in range (len(z)):
|
|
341
|
+
if max_value < max(self.abs_for_list(self.CenteredModulusList(z[i]))):
|
|
342
|
+
max_value = max(self.abs_for_list(self.CenteredModulusList(z[i])))
|
|
343
|
+
|
|
344
|
+
return max_value # returns the max value among all lists.
|
|
345
|
+
|
|
346
|
+
def CalcOnes(self, h: list[list[int]]) -> int:
|
|
347
|
+
"""
|
|
348
|
+
Additional Helper Function 6
|
|
349
|
+
|
|
350
|
+
Compute the number of ``1`` s inside a ``list[list[int]]`` .
|
|
351
|
+
|
|
352
|
+
Args:
|
|
353
|
+
h (``list[list[int]]``): A ``matrix`` containing ``0`` s and ``1`` s.
|
|
354
|
+
|
|
355
|
+
Returns:
|
|
356
|
+
Count(``int``): The count of ``1`` s in the ``matrix``.
|
|
357
|
+
|
|
358
|
+
Raises:
|
|
359
|
+
TypeError: If ``h`` is not a ``matrix`` of integers.
|
|
360
|
+
TypeError: If any element in the sublists is not an integer .
|
|
361
|
+
TypeError: If elements of ``h[x]`` are not ``0`` or ``1``.
|
|
362
|
+
"""
|
|
363
|
+
count = 0
|
|
364
|
+
for i in range(len(h)):
|
|
365
|
+
for j in range(len(h[i])):
|
|
366
|
+
if h[i][j] == 1:
|
|
367
|
+
count = count + 1
|
|
368
|
+
|
|
369
|
+
return count
|
fips/FIPS204/encode.py
ADDED
|
@@ -0,0 +1,300 @@
|
|
|
1
|
+
from .pack import PACK
|
|
2
|
+
|
|
3
|
+
class ENCODE:
|
|
4
|
+
"""
|
|
5
|
+
Translate keys and signature for MLDSA into byte strings.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
def __init__(self, parameter: dict[str, int]):
|
|
9
|
+
self.pack = PACK(parameter)
|
|
10
|
+
|
|
11
|
+
self.q = parameter["q"]
|
|
12
|
+
self.d = parameter["d"]
|
|
13
|
+
self.N = parameter["N"]
|
|
14
|
+
|
|
15
|
+
self.k = parameter["k"]
|
|
16
|
+
self.l = parameter["l"]
|
|
17
|
+
self.eta = parameter["eta"]
|
|
18
|
+
self.gamma1 = parameter["gamma1"]
|
|
19
|
+
self.gamma2 = parameter["gamma2"]
|
|
20
|
+
self._lambda = parameter["_lambda"] # using _lambda because lambda is not allowed.
|
|
21
|
+
|
|
22
|
+
self.omega = parameter["omega"]
|
|
23
|
+
|
|
24
|
+
def pkEncode(self, rho: bytes, t1_vec: list[list[int]]) -> bytes:
|
|
25
|
+
"""
|
|
26
|
+
Algorithm 22
|
|
27
|
+
|
|
28
|
+
Encodes a ``public key`` for MLDSA into a ``bytestring``.
|
|
29
|
+
|
|
30
|
+
Args:
|
|
31
|
+
rho (``bytes``): The ``32-byte`` seed.
|
|
32
|
+
t1_vec (``list``): The vector of ``k`` polynomials for ``t₁``.
|
|
33
|
+
|
|
34
|
+
Returns:
|
|
35
|
+
public_key (``bytes``): The final encoded ``public key`` as a ``bytestring``.
|
|
36
|
+
|
|
37
|
+
Raises:
|
|
38
|
+
ValueError: if ``rho`` or ``t1_vec`` are of incorrect types or lengths.
|
|
39
|
+
"""
|
|
40
|
+
if not isinstance(rho, (bytes, bytearray)) or len(rho) != 32:
|
|
41
|
+
raise ValueError("rho must be a 32-byte bytestring.")
|
|
42
|
+
# if not isinstance(t1_vec, list) or len(t1_vec) != self.k:
|
|
43
|
+
# raise ValueError(f"t1 must be a list of {self.k} polynomials.")
|
|
44
|
+
|
|
45
|
+
# 1: pk ← ρ
|
|
46
|
+
# Using a bytearray for efficient concatenation
|
|
47
|
+
pk = bytearray(rho)
|
|
48
|
+
|
|
49
|
+
# Calculate the number of bits needed to store each t₁ coefficient
|
|
50
|
+
max_bit = pow(2, (self.q - 1).bit_length() - self.d) - 1 # This is 23 - 13 = 10
|
|
51
|
+
|
|
52
|
+
# 2: for i from 0 to k - 1 do
|
|
53
|
+
for poly_t1 in t1_vec:
|
|
54
|
+
# 3: pk ← pk || SimpleBitPack(...)
|
|
55
|
+
packed_poly = self.pack.SimpleBitPack(poly_t1, max_bit)
|
|
56
|
+
pk.extend(packed_poly)
|
|
57
|
+
|
|
58
|
+
# 5: return pk
|
|
59
|
+
return bytes(pk)
|
|
60
|
+
|
|
61
|
+
def pkDecode(self, pk: bytes) -> tuple[bytes, list[list[int]]]:
|
|
62
|
+
"""
|
|
63
|
+
Algorithm 23
|
|
64
|
+
|
|
65
|
+
Reverses the procedure ``pkEncode``.
|
|
66
|
+
|
|
67
|
+
Args:
|
|
68
|
+
pk (bytes): public key bytestring::
|
|
69
|
+
|
|
70
|
+
Byte string of length 32 + k * length_p, where length_p = ceil(256 * bitlen(q-1)-d / 8)
|
|
71
|
+
|
|
72
|
+
Returns:
|
|
73
|
+
rho(``bytes``) : 32-byte seed rho.
|
|
74
|
+
|
|
75
|
+
t1(``list[int]``): List of ``k`` polynomials, each with ::
|
|
76
|
+
|
|
77
|
+
256 coefficients in the range [0, 2^(bitlen(q-1)-d) - 1].
|
|
78
|
+
|
|
79
|
+
Raises:
|
|
80
|
+
ValueError: if ``pk`` is of incorrect type.
|
|
81
|
+
"""
|
|
82
|
+
if not isinstance(pk, (bytes, bytearray)):
|
|
83
|
+
raise ValueError("pk must be a bytestring.")
|
|
84
|
+
|
|
85
|
+
bitlen = (self.q - 1).bit_length() - self.d
|
|
86
|
+
# expected_len = 32 + 32 * self.k * bitlen
|
|
87
|
+
|
|
88
|
+
# if not isinstance(pk, (bytes, bytearray)) or len(pk) != expected_len:
|
|
89
|
+
# raise ValueError(f"pk must be a byte string of length {expected_len} bytes.")
|
|
90
|
+
|
|
91
|
+
rho = pk[:32] # assign the first 32 bits
|
|
92
|
+
t1: list[list[int]] = [] # initialize empty list
|
|
93
|
+
|
|
94
|
+
offset = 32
|
|
95
|
+
for i in range(self.k):
|
|
96
|
+
start = i * 32 * bitlen + offset
|
|
97
|
+
end = start + 32 * bitlen
|
|
98
|
+
segment = pk[start:end]
|
|
99
|
+
coeffs = self.pack.SimpleBitUnpack(segment, pow(2, bitlen) - 1)
|
|
100
|
+
if len(coeffs) != 256:
|
|
101
|
+
raise ValueError("Each unpacked polynomial must have 256 coefficients.")
|
|
102
|
+
t1.append(coeffs)
|
|
103
|
+
|
|
104
|
+
return rho, t1
|
|
105
|
+
|
|
106
|
+
def skEncode(self, rho: bytes, K_seed: bytes, tr: bytes, s1_vec: list[list[int]], s2_vec: list[list[int]], t0_vec: list[list[int]]) -> bytes:
|
|
107
|
+
"""
|
|
108
|
+
Algorithm 24
|
|
109
|
+
|
|
110
|
+
Encodes a ``secret key`` for MLDSA into a ``bytestring``.
|
|
111
|
+
|
|
112
|
+
Args:
|
|
113
|
+
rho (``bytes``): The 32-byte seed.
|
|
114
|
+
K_seed (``bytes``): The 32-byte ``K`` seed.
|
|
115
|
+
tr (``bytes``): The 64-byte ``tr`` value.
|
|
116
|
+
s1_vec (``list[int]``): The vector of ``l`` polynomials for ``s₁``.
|
|
117
|
+
s2_vec (``list[int]``): The vector of ``k`` polynomials for ``s₂``.
|
|
118
|
+
t0_vec (``list[int]``): The vector of ``k`` polynomials for ``t₀``.
|
|
119
|
+
|
|
120
|
+
Returns:
|
|
121
|
+
private_key(``bytes``): The final encoded private key as a ``bytestring``.
|
|
122
|
+
|
|
123
|
+
Raises:
|
|
124
|
+
ValueError: if ``rho``, ``K_seed``, ``tr``, ``s₁_vec``, ``s2_vec``, or ``t0_vec`` are of incorrect types or lengths.
|
|
125
|
+
"""
|
|
126
|
+
if not isinstance(rho, (bytes, bytearray)) or len(rho) != 32:
|
|
127
|
+
raise ValueError("rho must be a 32-byte bytestring.")
|
|
128
|
+
if not isinstance(K_seed, (bytes, bytearray)) or len(K_seed) != 32:
|
|
129
|
+
raise ValueError("K_seed must be a 32-byte bytestring.")
|
|
130
|
+
if not isinstance(tr, (bytes, bytearray)) or len(tr) != 64:
|
|
131
|
+
raise ValueError("tr must be a 64-byte bytestring.")
|
|
132
|
+
|
|
133
|
+
# 1: sk ← ρ || K || tr
|
|
134
|
+
sk = bytearray(rho + K_seed + tr)
|
|
135
|
+
|
|
136
|
+
# 2-4: Pack and append s₁
|
|
137
|
+
for poly in s1_vec:
|
|
138
|
+
sk.extend(self.pack.BitPack(poly, self.eta, self.eta))
|
|
139
|
+
|
|
140
|
+
# 5-7: Pack and append s₂
|
|
141
|
+
for poly in s2_vec:
|
|
142
|
+
sk.extend(self.pack.BitPack(poly, self.eta, self.eta))
|
|
143
|
+
|
|
144
|
+
# 8-10: Pack and append t₀
|
|
145
|
+
for poly in t0_vec:
|
|
146
|
+
sk.extend(self.pack.BitPack(poly, (1 << (self.d - 1)) - 1, (1 << (self.d - 1))))
|
|
147
|
+
|
|
148
|
+
# 11: return sk
|
|
149
|
+
return bytes(sk)
|
|
150
|
+
|
|
151
|
+
def skDecode(self, private_key: bytes):
|
|
152
|
+
"""
|
|
153
|
+
Algorithm 25
|
|
154
|
+
|
|
155
|
+
Reverses the procedure performed by skEncode.
|
|
156
|
+
|
|
157
|
+
Args:
|
|
158
|
+
private_key (``bytes``): The byte string representing the ``private key``.
|
|
159
|
+
|
|
160
|
+
Returns:
|
|
161
|
+
(Tuple[ ``bytes``, ``bytes``, ``bytes``, ``list``, ``list`` , ``list``]): A tuple containing (``rho``, ``K_seed``, ``tr``, ``s₁``, ``s2``, ``t0_vec``).
|
|
162
|
+
|
|
163
|
+
Raises:
|
|
164
|
+
ValueError: if private_key is of incorrect type.
|
|
165
|
+
"""
|
|
166
|
+
if not isinstance(private_key, (bytes, bytearray)):
|
|
167
|
+
raise ValueError("private_key must be a bytestring.")
|
|
168
|
+
|
|
169
|
+
if self.eta == 2:
|
|
170
|
+
s_bytes = 96
|
|
171
|
+
else:
|
|
172
|
+
s_bytes = 128
|
|
173
|
+
|
|
174
|
+
# find length of all the vectors
|
|
175
|
+
s1_len = s_bytes * self.l
|
|
176
|
+
s2_len = s_bytes * self.k
|
|
177
|
+
t0_len = 416 * self.k
|
|
178
|
+
|
|
179
|
+
# check length of private_key
|
|
180
|
+
if len(private_key) != 2 * 32 + 64 + s1_len + s2_len + t0_len:
|
|
181
|
+
raise ValueError("SK packed bytes is of the wrong length")
|
|
182
|
+
|
|
183
|
+
# Split bytes between seeds and vectors
|
|
184
|
+
sk_seed_bytes, sk_vec_bytes = private_key[:128], private_key[128:]
|
|
185
|
+
|
|
186
|
+
# Unpack seed bytes
|
|
187
|
+
rho, K_seed, tr = (
|
|
188
|
+
sk_seed_bytes[:32],
|
|
189
|
+
sk_seed_bytes[32:64],
|
|
190
|
+
sk_seed_bytes[64:128],
|
|
191
|
+
)
|
|
192
|
+
|
|
193
|
+
# Unpack vector bytes
|
|
194
|
+
s1_bytes = sk_vec_bytes[:s1_len]
|
|
195
|
+
s2_bytes = sk_vec_bytes[s1_len : s1_len + s2_len]
|
|
196
|
+
t0_bytes = sk_vec_bytes[-t0_len:]
|
|
197
|
+
# print(s1_len, " ", s2_len, " ", t0_len)
|
|
198
|
+
|
|
199
|
+
# Unpack bytes to vectors
|
|
200
|
+
s1 = [[0] for _ in range(int(s1_len / s_bytes))]
|
|
201
|
+
for i in range(int(s1_len / s_bytes)):
|
|
202
|
+
s1[i] = self.pack.BitUnpack(s1_bytes[i * s_bytes: (i + 1) * s_bytes], self.eta, self.eta)
|
|
203
|
+
|
|
204
|
+
s2 = [[0] for _ in range(int(s2_len / s_bytes))]
|
|
205
|
+
for i in range(int(s2_len / s_bytes)):
|
|
206
|
+
s2[i] = self.pack.BitUnpack(s2_bytes[i * s_bytes: (i + 1) * s_bytes], self.eta, self.eta)
|
|
207
|
+
|
|
208
|
+
t0_vec = [[0] for _ in range(int(t0_len / 416))]
|
|
209
|
+
for i in range(int(t0_len / 416)):
|
|
210
|
+
t0_vec[i] = self.pack.BitUnpack(t0_bytes[i * 416 : (i + 1) * 416], pow(2, self.d - 1) - 1, pow(2, self.d - 1))
|
|
211
|
+
|
|
212
|
+
return rho, K_seed, tr, s1, s2, t0_vec
|
|
213
|
+
|
|
214
|
+
def sigEncode(self, c_tilda: bytes, z: list[list[int]], h: list[list[int]]):
|
|
215
|
+
"""
|
|
216
|
+
Algorithm 26
|
|
217
|
+
|
|
218
|
+
Encodes a ``signature`` into a ``bytestring``.
|
|
219
|
+
|
|
220
|
+
Args:
|
|
221
|
+
c_tilda (``bytes``): The byte string representing ``c_tilda``.
|
|
222
|
+
z (``bytes``): The list of ``l`` polynomials representing ``z``.
|
|
223
|
+
h (``bytes``): The list representing the hint vector ``h``.
|
|
224
|
+
|
|
225
|
+
Returns:
|
|
226
|
+
signature (``bytes``): The final encoded signature as a ``bytestring``.
|
|
227
|
+
|
|
228
|
+
Raises:
|
|
229
|
+
ValueError: if ``c_tilda``, ``z``, or ``h`` are of incorrect types or lengths.
|
|
230
|
+
"""
|
|
231
|
+
if not isinstance(c_tilda, (bytes, bytearray)) or len(c_tilda) != int(self._lambda / 4):
|
|
232
|
+
raise ValueError(f"c_tilda must be a bytestring of length {int(self._lambda / 4)}.")
|
|
233
|
+
|
|
234
|
+
sigma = b""
|
|
235
|
+
sigma = sigma + c_tilda
|
|
236
|
+
for i in range(self.l):
|
|
237
|
+
sigma = sigma + self.pack.BitPack(z[i], self.gamma1 -1, self.gamma1)
|
|
238
|
+
|
|
239
|
+
sigma = sigma + self.pack.HintBitPack(h)
|
|
240
|
+
|
|
241
|
+
return sigma
|
|
242
|
+
|
|
243
|
+
def sigDecode(self, signature: bytes) -> tuple[bytes | bytearray, list[list[int]], list[list[int]] | None]:
|
|
244
|
+
"""
|
|
245
|
+
Algorithm 27
|
|
246
|
+
|
|
247
|
+
Reverses the procedue ``sigEncode``.
|
|
248
|
+
|
|
249
|
+
Args:
|
|
250
|
+
signature (``bytes``): The ``bytestring`` representing the signature.
|
|
251
|
+
|
|
252
|
+
Returns:
|
|
253
|
+
tuple (``c_tilda``, ``z``, ``h``): A tuple containing the decoded components.
|
|
254
|
+
|
|
255
|
+
Raises:
|
|
256
|
+
ValueError: if signature is of incorrect type.
|
|
257
|
+
"""
|
|
258
|
+
if not isinstance(signature, (bytes, bytearray)):
|
|
259
|
+
raise ValueError("signature must be a bytestring.")
|
|
260
|
+
|
|
261
|
+
c_tilda = signature[:int(self._lambda / 4)]
|
|
262
|
+
x_list = signature[int(self._lambda / 4) : int(self._lambda / 4) + (32 * (1 + int((self.gamma1 - 1).bit_length()))) * self.l]
|
|
263
|
+
y = signature[-(self.omega + self.k):] # last remaining elements.
|
|
264
|
+
|
|
265
|
+
z: list[list[int]] = []
|
|
266
|
+
size_x = 32 * (1 + (self.gamma1 -1).bit_length())
|
|
267
|
+
|
|
268
|
+
for i in range(self.l):
|
|
269
|
+
start = i * size_x
|
|
270
|
+
end = start + size_x
|
|
271
|
+
segment = x_list[start : end]
|
|
272
|
+
coefficients = self.pack.BitUnpack(segment, self.gamma1 -1, self.gamma1)
|
|
273
|
+
z.append(coefficients)
|
|
274
|
+
|
|
275
|
+
h = self.pack.HintBitUnpack(y)
|
|
276
|
+
|
|
277
|
+
return (c_tilda, z, h)
|
|
278
|
+
|
|
279
|
+
def w1Encode(self, w:list[list[int]]):
|
|
280
|
+
"""
|
|
281
|
+
Algorithm 28
|
|
282
|
+
|
|
283
|
+
Encodes a polynomial vector ``w₁`` into a bytestring.
|
|
284
|
+
|
|
285
|
+
Args:
|
|
286
|
+
w (``list``): The list of ``k`` polynomials representing ``w``.
|
|
287
|
+
|
|
288
|
+
Returns:
|
|
289
|
+
w₁ (``bytes``): The encoded byte string ``w₁``.
|
|
290
|
+
|
|
291
|
+
Raises:
|
|
292
|
+
ValueError: if ``w`` is of incorrect type or length.
|
|
293
|
+
"""
|
|
294
|
+
# if not isinstance(w, list) or len(w) != self.k:
|
|
295
|
+
# raise ValueError(f"w must be a list of {self.k} polynomials.")
|
|
296
|
+
|
|
297
|
+
w1 = b''
|
|
298
|
+
for i in range(self.k):
|
|
299
|
+
w1 = w1 + self.pack.SimpleBitPack(w[i], int((self.q - 1)/(2 * self.gamma2)) - 1)
|
|
300
|
+
return w1
|