transcrypto 1.1.2__py3-none-any.whl → 1.3.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
transcrypto/dsa.py CHANGED
@@ -12,30 +12,49 @@ In the future we will design a proper DSA+Hash implementation.
12
12
 
13
13
  from __future__ import annotations
14
14
 
15
+ import concurrent.futures
15
16
  import dataclasses
16
17
  import logging
18
+ import multiprocessing
19
+ import os
17
20
  # import pdb
18
21
  from typing import Self
19
22
 
20
- from . import base
21
- from . import modmath
23
+ import gmpy2 # type:ignore
24
+
25
+ from . import base, modmath
22
26
 
23
27
  __author__ = 'balparda@github.com'
24
28
  __version__: str = base.__version__ # version comes from base!
25
29
  __version_tuple__: tuple[int, ...] = base.__version_tuple__
26
30
 
27
31
 
28
- _PRIME_MULTIPLE_SEARCH = 30
29
32
  _MAX_KEY_GENERATION_FAILURES = 15
30
33
 
34
+ # fixed prefixes: do NOT ever change! will break all encryption and signature schemes
35
+ _DSA_SIGNATURE_HASH_PREFIX = b'transcrypto.DSA.Signature.1.0\x00'
36
+
31
37
 
32
- def NBitRandomDSAPrimes(p_bits: int, q_bits: int, /) -> tuple[int, int, int]:
38
+ def NBitRandomDSAPrimes(
39
+ p_bits: int, q_bits: int, /, *, serial: bool = True) -> tuple[int, int, int]:
33
40
  """Generates 2 random DSA primes p & q with `x_bits` size and (p-1)%q==0.
34
41
 
42
+ Uses an aggressive small-prime wheel sieve:
43
+ Before any Miller-Rabin we reject p = m·q + 1 if it is divisible by a small prime.
44
+ We precompute forbidden residues for m:
45
+ • For each small prime r (all primes up to, say, 100 000), we compute
46
+ m_forbidden ≡ -q⁻¹ (mod r) (because (m·q + 1) % r == 0 ⇔ m ≡ -q⁻¹ (mod r))
47
+ • When we iterate m, we skip values that hit any forbidden residue class.
48
+
49
+ Method will decide if executes on one thread or many.
50
+
35
51
  Args:
36
52
  p_bits (int): Number of guaranteed bits in `p` prime representation,
37
53
  p_bits ≥ q_bits + 11
38
54
  q_bits (int): Number of guaranteed bits in `q` prime representation, ≥ 11
55
+ serial (bool, optional): True (default) will force one thread; False will allow parallelism;
56
+ we have temporarily disabled parallelism with a default of True because it is not making
57
+ things faster...
39
58
 
40
59
  Returns:
41
60
  random primes tuple (p, q, m), with p-1 a random multiple m of q, such
@@ -50,36 +69,89 @@ def NBitRandomDSAPrimes(p_bits: int, q_bits: int, /) -> tuple[int, int, int]:
50
69
  if p_bits < q_bits + 11:
51
70
  raise base.InputError(f'invalid p_bits length: {p_bits=}')
52
71
  # make q
53
- q = modmath.NBitRandomPrime(q_bits)
72
+ q: int = modmath.NBitRandomPrimes(q_bits).pop()
73
+ # get number of CPUs and decide if we do parallel or not
74
+ n_workers: int = min(4, os.cpu_count() or 1)
75
+ pr: int | None = None
76
+ m: int | None = None
77
+ if serial or n_workers <= 1 or p_bits < 200:
78
+ # do one worker
79
+ while pr is None or m is None or pr.bit_length() != p_bits:
80
+ pr, m = _PrimePSearchShard(q, p_bits)
81
+ return (pr, q, m)
82
+ # parallel: keep a small pool of bounded shards; stop on first hit
83
+ multiprocessing.set_start_method('fork', force=True)
84
+ with concurrent.futures.ProcessPoolExecutor(max_workers=n_workers) as pool:
85
+ workers: set[concurrent.futures.Future[tuple[int | None, int | None]]] = {
86
+ pool.submit(_PrimePSearchShard, q, p_bits) for _ in range(n_workers)}
87
+ while workers:
88
+ done: set[concurrent.futures.Future[tuple[int | None, int | None]]] = concurrent.futures.wait(
89
+ workers, return_when=concurrent.futures.FIRST_COMPLETED)[0]
90
+ for worker in done:
91
+ workers.remove(worker)
92
+ pr, m = worker.result()
93
+ if pr is not None and m is not None and pr.bit_length() == p_bits:
94
+ return (pr, q, m)
95
+ # no hit in that shard: keep the pool full with a fresh shard
96
+ workers.add(pool.submit(_PrimePSearchShard, q, p_bits)) # pragma: no cover
97
+ # can never reach this point, but leave this here; remove line from coverage
98
+ raise base.Error(f'could not find prime with {p_bits=}/{q_bits=} bits') # pragma: no cover
99
+
100
+
101
+ def _PrimePSearchShard(q: int, p_bits: int) -> tuple[int | None, int | None]:
102
+ """Search for a `p_bits` random prime, starting from a random point, for ~6× expected prime gap.
103
+
104
+ Args:
105
+ q (int): Prime `q` for DSA
106
+ p_bits (int): Number of guaranteed bits in prime `p` representation
107
+
108
+ Returns:
109
+ tuple[int | None, int | None]: either the prime `p` and multiple `m` or None if no prime found
110
+ """
111
+ q_bits: int = q.bit_length()
112
+ shard_len: int = max(2000, 6 * int(0.693 * p_bits)) # ~6× expected prime gap ~2^k (≈ 0.693*k)
54
113
  # find range of multiples to use
55
- min_p, max_p = 2 ** (p_bits - 1), 2 ** p_bits - 1
56
- min_m, max_m = min_p // q + 2, max_p // q - 2
114
+ min_p: int = 2 ** (p_bits - 1)
115
+ max_p: int = 2 ** p_bits - 1
116
+ min_m: int = min_p // q + 2
117
+ max_m: int = max_p // q - 2
57
118
  assert max_m - min_m > 1000 # make sure we'll have options!
58
- # start searching from a random multiple
59
- failures: int = 0
60
- while True:
61
- # try searching starting here
62
- m: int = base.RandInt(min_m, max_m)
63
- for _ in range(_PRIME_MULTIPLE_SEARCH):
64
- p: int = q * m + 1
65
- if p >= max_p:
66
- break
67
- if modmath.IsPrime(p):
68
- return (p, q, m) # found a suitable prime set!
69
- m += 1 # next multiple
70
- # after _PRIME_MULTIPLE_SEARCH we declare this range failed
71
- failures += 1
72
- if failures >= _MAX_KEY_GENERATION_FAILURES:
73
- raise base.CryptoError(f'failed primes generation {failures} times')
74
- logging.warning(f'failed primes search: {failures}')
119
+ # make list of small primes to use for sieving
120
+ approx_q_root: int = 1 << (q_bits // 2)
121
+ pr: int
122
+ forbidden: dict[int, int] = { # (modulus: forbidden residue)
123
+ pr: ((-modmath.ModInv(q % pr, pr)) % pr)
124
+ for pr in modmath.FIRST_5K_PRIMES_SORTED[1:min(5000, approx_q_root)]} # skip pr==2
125
+
126
+ def _PassesSieve(m: int) -> bool:
127
+ for r, f in forbidden.items():
128
+ if m % r == f:
129
+ return False
130
+ return True
131
+
132
+ # try searching starting here
133
+ m: int = base.RandInt(min_m, max_m)
134
+ if m % 2:
135
+ m += 1 # make even
136
+ count: int = 0
137
+ pr = 0
138
+ while count < shard_len:
139
+ pr = q * m + 1
140
+ if pr > max_p:
141
+ break
142
+ # first do a quick sieve test
143
+ if _PassesSieve(m):
144
+ if modmath.IsPrime(pr): # passed sieve, do full test
145
+ return (pr, m) # found a suitable prime set!
146
+ count += 1
147
+ m += 2 # next even number
148
+ return (None, None)
75
149
 
76
150
 
77
151
  @dataclasses.dataclass(kw_only=True, slots=True, frozen=True, repr=False)
78
152
  class DSASharedPublicKey(base.CryptoKey):
79
153
  """DSA shared public key. This key can be shared by a group.
80
154
 
81
- BEWARE: This is raw DSA, no ECDSA/EdDSA padding, no hash, no validation!
82
- These are pedagogical/raw primitives; do not use for new protocols.
83
155
  No measures are taken here to prevent timing attacks.
84
156
 
85
157
  Attributes:
@@ -116,10 +188,43 @@ class DSASharedPublicKey(base.CryptoKey):
116
188
  string representation of DSASharedPublicKey
117
189
  """
118
190
  return ('DSASharedPublicKey('
191
+ f'bits=[{self.prime_modulus.bit_length()}, {self.prime_seed.bit_length()}], '
119
192
  f'prime_modulus={base.IntToEncoded(self.prime_modulus)}, '
120
193
  f'prime_seed={base.IntToEncoded(self.prime_seed)}, '
121
194
  f'group_base={base.IntToEncoded(self.group_base)})')
122
195
 
196
+ @property
197
+ def modulus_size(self) -> tuple[int, int]:
198
+ """Modulus size in bytes. The number of bytes used in Sign/Verify."""
199
+ return ((self.prime_modulus.bit_length() + 7) // 8,
200
+ (self.prime_seed.bit_length() + 7) // 8)
201
+
202
+ def _DomainSeparatedHash(
203
+ self, message: bytes, associated_data: bytes | None, salt: bytes, /) -> int:
204
+ """Compute the domain-separated hash for signing and verifying.
205
+
206
+ Args:
207
+ message (bytes): message to sign/verify
208
+ associated_data (bytes | None): optional associated data
209
+ salt (bytes): salt to use in the hash
210
+
211
+ Returns:
212
+ int: integer representation of the hash output;
213
+ Hash512("prefix" || len(aad) || aad || message || salt)
214
+
215
+ Raises:
216
+ CryptoError: hash output is out of range
217
+ """
218
+ aad: bytes = b'' if associated_data is None else associated_data
219
+ la: bytes = base.IntToFixedBytes(len(aad), 8)
220
+ assert len(salt) == 64, 'should never happen: salt should be exactly 64 bytes'
221
+ y: int = base.BytesToInt(
222
+ base.Hash512(_DSA_SIGNATURE_HASH_PREFIX + la + aad + message + salt))
223
+ if not 1 < y < self.prime_seed - 1:
224
+ # will only reasonably happen if prime seed is small
225
+ raise base.CryptoError(f'hash output {y} is out of range/invalid {self.prime_seed}')
226
+ return y
227
+
123
228
  @classmethod
124
229
  def NewShared(cls, p_bits: int, q_bits: int, /) -> Self:
125
230
  """Make a new shared public key of `bit_length` bits.
@@ -141,16 +246,14 @@ class DSASharedPublicKey(base.CryptoKey):
141
246
  g: int = 0
142
247
  while g < 2:
143
248
  h: int = base.RandBits(p_bits - 1)
144
- g = modmath.ModExp(h, m, p)
249
+ g = int(gmpy2.powmod(h, m, p)) # type:ignore # pylint:disable=no-member
145
250
  return cls(prime_modulus=p, prime_seed=q, group_base=g)
146
251
 
147
252
 
148
253
  @dataclasses.dataclass(kw_only=True, slots=True, frozen=True, repr=False)
149
- class DSAPublicKey(DSASharedPublicKey):
254
+ class DSAPublicKey(DSASharedPublicKey, base.Verifier):
150
255
  """DSA public key. This is an individual public key.
151
256
 
152
- BEWARE: This is raw DSA, no ECDSA/EdDSA padding, no hash, no validation!
153
- These are pedagogical/raw primitives; do not use for new protocols.
154
257
  No measures are taken here to prevent timing attacks.
155
258
 
156
259
  Attributes:
@@ -176,11 +279,12 @@ class DSAPublicKey(DSASharedPublicKey):
176
279
  Returns:
177
280
  string representation of DSAPublicKey
178
281
  """
179
- return (f'DSAPublicKey({super(DSAPublicKey, self).__str__()}, ' # pylint: disable=super-with-arguments
282
+ return ('DSAPublicKey('
283
+ f'{super(DSAPublicKey, self).__str__()}, ' # pylint: disable=super-with-arguments
180
284
  f'individual_base={base.IntToEncoded(self.individual_base)})')
181
285
 
182
286
  def _MakeEphemeralKey(self) -> tuple[int, int]:
183
- """Make an ephemeral key adequate to be used with El-Gamal.
287
+ """Make an ephemeral key adequate to be used with DSA.
184
288
 
185
289
  Returns:
186
290
  (key, key_inverse), where 3 ≤ k < p_seed and (k*i) % p_seed == 1
@@ -192,9 +296,11 @@ class DSAPublicKey(DSASharedPublicKey):
192
296
  ephemeral_key = base.RandBits(bit_length - 1)
193
297
  return (ephemeral_key, modmath.ModInv(ephemeral_key, self.prime_seed))
194
298
 
195
- def VerifySignature(self, message: int, signature: tuple[int, int], /) -> bool:
299
+ def RawVerify(self, message: int, signature: tuple[int, int], /) -> bool:
196
300
  """Verify a signature. True if OK; False if failed verification.
197
301
 
302
+ BEWARE: This is raw DSA, no ECDSA/EdDSA padding, no hash, no validation!
303
+ These are pedagogical/raw primitives; do not use for new protocols.
198
304
  We explicitly disallow `message` to be zero.
199
305
 
200
306
  Args:
@@ -215,12 +321,48 @@ class DSAPublicKey(DSASharedPublicKey):
215
321
  raise base.InputError(f'invalid signature: {signature=}')
216
322
  # verify
217
323
  inv: int = modmath.ModInv(signature[1], self.prime_seed)
218
- a: int = modmath.ModExp(
219
- self.group_base, (message * inv) % self.prime_seed, self.prime_modulus)
220
- b: int = modmath.ModExp(
221
- self.individual_base, (signature[0] * inv) % self.prime_seed, self.prime_modulus)
324
+ a: int = int(gmpy2.powmod( # type:ignore # pylint:disable=no-member
325
+ self.group_base, (message * inv) % self.prime_seed, self.prime_modulus))
326
+ b: int = int(gmpy2.powmod( # type:ignore # pylint:disable=no-member
327
+ self.individual_base, (signature[0] * inv) % self.prime_seed, self.prime_modulus))
222
328
  return ((a * b) % self.prime_modulus) % self.prime_seed == signature[0]
223
329
 
330
+ def Verify(
331
+ self, message: bytes, signature: bytes, /, *, associated_data: bytes | None = None) -> bool:
332
+ """Verify a `signature` for `message`. True if OK; False if failed verification.
333
+
334
+ • Let k = ceil(log2(n))/8 be the modulus size in bytes.
335
+ • Split signature in 3 parts: the first 64 bytes is salt, the rest is s1 and s2
336
+ • y_check = DSA(s1, s2)
337
+ • return y_check == Hash512("prefix" || len(aad) || aad || message || salt)
338
+ • return False for any malformed signature
339
+
340
+ Args:
341
+ message (bytes): Data that was signed
342
+ signature (bytes): Signature data to verify
343
+ associated_data (bytes, optional): Optional AAD (must match what was used during signing)
344
+
345
+ Returns:
346
+ True if signature is valid, False otherwise
347
+
348
+ Raises:
349
+ InputError: invalid inputs
350
+ CryptoError: internal crypto failures, authentication failure, key mismatch, etc
351
+ """
352
+ k: int = self.modulus_size[1] # use prime_seed size
353
+ if k <= 64:
354
+ raise base.InputError(f'modulus/seed too small for signing operations: {k} bytes')
355
+ if len(signature) != (64 + k + k):
356
+ logging.info(f'invalid signature length: {len(signature)} ; expected {64 + k + k}')
357
+ return False
358
+ try:
359
+ return self.RawVerify(
360
+ self._DomainSeparatedHash(message, associated_data, signature[:64]),
361
+ (base.BytesToInt(signature[64:64 + k]), base.BytesToInt(signature[64 + k:])))
362
+ except base.InputError as err:
363
+ logging.info(err)
364
+ return False
365
+
224
366
  @classmethod
225
367
  def Copy(cls, other: DSAPublicKey, /) -> Self:
226
368
  """Initialize a public key by taking the public parts of a public/private key."""
@@ -232,11 +374,9 @@ class DSAPublicKey(DSASharedPublicKey):
232
374
 
233
375
 
234
376
  @dataclasses.dataclass(kw_only=True, slots=True, frozen=True, repr=False)
235
- class DSAPrivateKey(DSAPublicKey):
377
+ class DSAPrivateKey(DSAPublicKey, base.Signer): # pylint: disable=too-many-ancestors
236
378
  """DSA private key.
237
379
 
238
- BEWARE: This is raw DSA, no ECDSA/EdDSA padding, no hash, no validation!
239
- These are pedagogical/raw primitives; do not use for new protocols.
240
380
  No measures are taken here to prevent timing attacks.
241
381
 
242
382
  Attributes:
@@ -256,8 +396,7 @@ class DSAPrivateKey(DSAPublicKey):
256
396
  if (not 2 < self.decrypt_exp < self.prime_seed or
257
397
  self.decrypt_exp in (self.group_base, self.individual_base)):
258
398
  raise base.InputError(f'invalid decrypt_exp: {self}')
259
- if modmath.ModExp(
260
- self.group_base, self.decrypt_exp, self.prime_modulus) != self.individual_base:
399
+ if gmpy2.powmod(self.group_base, self.decrypt_exp, self.prime_modulus) != self.individual_base: # type:ignore # pylint:disable=no-member
261
400
  raise base.CryptoError(f'inconsistent g**d % p == i: {self}')
262
401
 
263
402
  def __str__(self) -> str:
@@ -266,12 +405,15 @@ class DSAPrivateKey(DSAPublicKey):
266
405
  Returns:
267
406
  string representation of DSAPrivateKey without leaking secrets
268
407
  """
269
- return (f'DSAPrivateKey({super(DSAPrivateKey, self).__str__()}, ' # pylint: disable=super-with-arguments
408
+ return ('DSAPrivateKey('
409
+ f'{super(DSAPrivateKey, self).__str__()}, ' # pylint: disable=super-with-arguments
270
410
  f'decrypt_exp={base.ObfuscateSecret(self.decrypt_exp)})')
271
411
 
272
- def Sign(self, message: int, /) -> tuple[int, int]:
412
+ def RawSign(self, message: int, /) -> tuple[int, int]:
273
413
  """Sign `message` with this private key.
274
414
 
415
+ BEWARE: This is raw DSA, no ECDSA/EdDSA padding, no hash, no validation!
416
+ These are pedagogical/raw primitives; do not use for new protocols.
275
417
  We explicitly disallow `message` to be zero.
276
418
 
277
419
  Args:
@@ -287,13 +429,48 @@ class DSAPrivateKey(DSAPublicKey):
287
429
  if not 0 < message < self.prime_seed:
288
430
  raise base.InputError(f'invalid message: {message=}')
289
431
  # sign
290
- a, b = 0, 0
432
+ a: int = 0
433
+ b: int = 0
291
434
  while a < 2 or b < 2:
292
435
  ephemeral_key, ephemeral_inv = self._MakeEphemeralKey()
293
- a = modmath.ModExp(self.group_base, ephemeral_key, self.prime_modulus) % self.prime_seed
436
+ a = int(gmpy2.powmod(self.group_base, ephemeral_key, self.prime_modulus) % self.prime_seed) # type:ignore # pylint:disable=no-member
294
437
  b = (ephemeral_inv * ((message + a * self.decrypt_exp) % self.prime_seed)) % self.prime_seed
295
438
  return (a, b)
296
439
 
440
+ def Sign(self, message: bytes, /, *, associated_data: bytes | None = None) -> bytes:
441
+ """Sign `message` and return the `signature`.
442
+
443
+ • Let k = ceil(log2(n))/8 be the modulus size in bytes.
444
+ • Pick random salt of 64 bytes
445
+ • s1, s2 = DSA(Hash512("prefix" || len(aad) || aad || message || salt))
446
+ • return salt || Padded(s1, k) || Padded(s2, k)
447
+
448
+ This is basically Full-Domain Hash DSA with a 512-bit hash and per-signature salt,
449
+ which is EUF-CMA secure in the ROM. Our domain-separation prefix and explicit AAD
450
+ length prefix are both correct and remove composition/ambiguity pitfalls.
451
+ There are no Bleichenbacher-style issue because we do not expose any padding semantics.
452
+
453
+ Args:
454
+ message (bytes): Data to sign.
455
+ associated_data (bytes, optional): Optional AAD for AEAD modes; must be
456
+ provided again on decrypt
457
+
458
+ Returns:
459
+ bytes: Signature; salt || Padded(s, k) - see above
460
+
461
+ Raises:
462
+ InputError: invalid inputs
463
+ CryptoError: internal crypto failures
464
+ """
465
+ k: int = self.modulus_size[1] # use prime_seed size
466
+ if k <= 64:
467
+ raise base.InputError(f'modulus/seed too small for signing operations: {k} bytes')
468
+ salt: bytes = base.RandBytes(64)
469
+ s_int: tuple[int, int] = self.RawSign(self._DomainSeparatedHash(message, associated_data, salt))
470
+ s_bytes: bytes = base.IntToFixedBytes(s_int[0], k) + base.IntToFixedBytes(s_int[1], k)
471
+ assert len(s_bytes) == 2 * k, 'should never happen: s_bytes should be exactly 2k bytes'
472
+ return salt + s_bytes
473
+
297
474
  @classmethod
298
475
  def New(cls, shared_key: DSASharedPublicKey, /) -> Self:
299
476
  """Make a new private key based on an existing shared public key.
@@ -326,8 +503,8 @@ class DSAPrivateKey(DSAPublicKey):
326
503
  prime_modulus=shared_key.prime_modulus,
327
504
  prime_seed=shared_key.prime_seed,
328
505
  group_base=shared_key.group_base,
329
- individual_base=modmath.ModExp(
330
- shared_key.group_base, decrypt_exp, shared_key.prime_modulus),
506
+ individual_base=int(gmpy2.powmod( # type:ignore # pylint:disable=no-member
507
+ shared_key.group_base, decrypt_exp, shared_key.prime_modulus)),
331
508
  decrypt_exp=decrypt_exp)
332
509
  except base.InputError as err:
333
510
  failures += 1