transcrypto 1.2.0__py3-none-any.whl → 1.4.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
transcrypto/dsa.py CHANGED
@@ -12,26 +12,31 @@ In the future we will design a proper DSA+Hash implementation.
12
12
 
13
13
  from __future__ import annotations
14
14
 
15
+ import concurrent.futures
15
16
  import dataclasses
16
17
  import logging
18
+ import multiprocessing
19
+ import os
17
20
  # import pdb
18
21
  from typing import Self
19
22
 
20
- from . import base, modmath
23
+ import gmpy2 # type:ignore
24
+
25
+ from . import base, constants, modmath
21
26
 
22
27
  __author__ = 'balparda@github.com'
23
28
  __version__: str = base.__version__ # version comes from base!
24
29
  __version_tuple__: tuple[int, ...] = base.__version_tuple__
25
30
 
26
31
 
27
- _PRIME_MULTIPLE_SEARCH = 4096 # how many multiples of q to try before restarting
28
32
  _MAX_KEY_GENERATION_FAILURES = 15
29
33
 
30
34
  # fixed prefixes: do NOT ever change! will break all encryption and signature schemes
31
35
  _DSA_SIGNATURE_HASH_PREFIX = b'transcrypto.DSA.Signature.1.0\x00'
32
36
 
33
37
 
34
- def NBitRandomDSAPrimes(p_bits: int, q_bits: int, /) -> tuple[int, int, int]:
38
+ def NBitRandomDSAPrimes(
39
+ p_bits: int, q_bits: int, /, *, serial: bool = True) -> tuple[int, int, int]:
35
40
  """Generates 2 random DSA primes p & q with `x_bits` size and (p-1)%q==0.
36
41
 
37
42
  Uses an aggressive small-prime wheel sieve:
@@ -41,10 +46,29 @@ def NBitRandomDSAPrimes(p_bits: int, q_bits: int, /) -> tuple[int, int, int]:
41
46
  m_forbidden ≡ -q⁻¹ (mod r) (because (m·q + 1) % r == 0 ⇔ m ≡ -q⁻¹ (mod r))
42
47
  • When we iterate m, we skip values that hit any forbidden residue class.
43
48
 
49
+ Method will decide if executes on one thread or many.
50
+
51
+ $ poetry run profiler -s -n 100 -b 1000,11000,1000 -c 98 dsa # single-thread, Mac M2 Max, 2025
52
+ 1000 → 101.069 ms ± 19.714 ms [81.354 ms … 120.783 ms]98%CI@100
53
+ 2000 → 471.038 ms ± 98.810 ms [372.229 ms … 569.848 ms]98%CI@100
54
+ 3000 → 1.45 s ± 253.462 ms [1.20 s … 1.70 s]98%CI@100
55
+ 4000 → 3.09 s ± 592.267 ms [2.50 s … 3.69 s]98%CI@100
56
+ 5000 → 5.52 s ± 1.22 s [4.30 s … 6.74 s]98%CI@100
57
+ 6000 → 8.33 s ± 2.02 s [6.31 s … 10.35 s]98%CI@100
58
+ 7000 → 15.76 s ± 3.55 s [12.21 s … 19.31 s]98%CI@100
59
+ 8000 → 25.66 s ± 6.66 s [18.99 s … 32.32 s]98%CI@100
60
+ 9000 → 35.02 s ± 8.68 s [26.34 s … 43.70 s]98%CI@100
61
+ 10000 → 1.01 min ± 13.64 s [47.13 s … 1.24 min]98%CI@100
62
+
63
+ Rule of thumb: double the bits requires ~10x execution time
64
+
44
65
  Args:
45
66
  p_bits (int): Number of guaranteed bits in `p` prime representation,
46
67
  p_bits ≥ q_bits + 11
47
68
  q_bits (int): Number of guaranteed bits in `q` prime representation, ≥ 11
69
+ serial (bool, optional): True (default) will force one thread; False will allow parallelism;
70
+ we have temporarily disabled parallelism with a default of True because it is not making
71
+ things faster...
48
72
 
49
73
  Returns:
50
74
  random primes tuple (p, q, m), with p-1 a random multiple m of q, such
@@ -59,14 +83,59 @@ def NBitRandomDSAPrimes(p_bits: int, q_bits: int, /) -> tuple[int, int, int]:
59
83
  if p_bits < q_bits + 11:
60
84
  raise base.InputError(f'invalid p_bits length: {p_bits=}')
61
85
  # make q
62
- q: int = modmath.NBitRandomPrime(q_bits)
86
+ q: int = modmath.NBitRandomPrimes(q_bits).pop()
87
+ # get number of CPUs and decide if we do parallel or not
88
+ n_workers: int = min(4, os.cpu_count() or 1)
89
+ pr: int | None = None
90
+ m: int | None = None
91
+ if serial or n_workers <= 1 or p_bits < 200:
92
+ # do one worker
93
+ while pr is None or m is None or pr.bit_length() != p_bits:
94
+ pr, m = _PrimePSearchShard(q, p_bits)
95
+ return (pr, q, m)
96
+ # parallel: keep a small pool of bounded shards; stop on first hit
97
+ multiprocessing.set_start_method('fork', force=True)
98
+ with concurrent.futures.ProcessPoolExecutor(max_workers=n_workers) as pool:
99
+ workers: set[concurrent.futures.Future[tuple[int | None, int | None]]] = {
100
+ pool.submit(_PrimePSearchShard, q, p_bits) for _ in range(n_workers)}
101
+ while workers:
102
+ done: set[concurrent.futures.Future[tuple[int | None, int | None]]] = concurrent.futures.wait(
103
+ workers, return_when=concurrent.futures.FIRST_COMPLETED)[0]
104
+ for worker in done:
105
+ workers.remove(worker)
106
+ pr, m = worker.result()
107
+ if pr is not None and m is not None and pr.bit_length() == p_bits:
108
+ return (pr, q, m)
109
+ # no hit in that shard: keep the pool full with a fresh shard
110
+ workers.add(pool.submit(_PrimePSearchShard, q, p_bits)) # pragma: no cover
111
+ # can never reach this point, but leave this here; remove line from coverage
112
+ raise base.Error(f'could not find prime with {p_bits=}/{q_bits=} bits') # pragma: no cover
113
+
114
+
115
+ def _PrimePSearchShard(q: int, p_bits: int) -> tuple[int | None, int | None]:
116
+ """Search for a `p_bits` random prime, starting from a random point, for ~6× expected prime gap.
117
+
118
+ Args:
119
+ q (int): Prime `q` for DSA
120
+ p_bits (int): Number of guaranteed bits in prime `p` representation
121
+
122
+ Returns:
123
+ tuple[int | None, int | None]: either the prime `p` and multiple `m` or None if no prime found
124
+ """
125
+ q_bits: int = q.bit_length()
126
+ shard_len: int = max(2000, 6 * int(0.693 * p_bits)) # ~6× expected prime gap ~2^k (≈ 0.693*k)
127
+ # find range of multiples to use
128
+ min_p: int = 2 ** (p_bits - 1)
129
+ max_p: int = 2 ** p_bits - 1
130
+ min_m: int = min_p // q + 2
131
+ max_m: int = max_p // q - 2
132
+ assert max_m - min_m > 1000 # make sure we'll have options!
63
133
  # make list of small primes to use for sieving
64
134
  approx_q_root: int = 1 << (q_bits // 2)
65
- forbidden: dict[int, int] = {} # (modulus: forbidden residue)
66
- for r in modmath.PrimeGenerator(3):
67
- forbidden[r] = (-modmath.ModInv(q % r, r)) % r
68
- if r > 100000 or r > approx_q_root:
69
- break
135
+ pr: int
136
+ forbidden: dict[int, int] = { # (modulus: forbidden residue)
137
+ pr: ((-modmath.ModInv(q % pr, pr)) % pr)
138
+ for pr in constants.FIRST_5K_PRIMES_SORTED[1:min(1000, approx_q_root)]} # skip pr==2
70
139
 
71
140
  def _PassesSieve(m: int) -> bool:
72
141
  for r, f in forbidden.items():
@@ -74,35 +143,23 @@ def NBitRandomDSAPrimes(p_bits: int, q_bits: int, /) -> tuple[int, int, int]:
74
143
  return False
75
144
  return True
76
145
 
77
- # find range of multiples to use
78
- min_p, max_p = 2 ** (p_bits - 1), 2 ** p_bits - 1
79
- min_m, max_m = min_p // q + 2, max_p // q - 2
80
- assert max_m - min_m > 1000 # make sure we'll have options!
81
- # start searching from a random multiple
82
- failures: int = 0
83
- window: int = max(_PRIME_MULTIPLE_SEARCH, 2 * p_bits)
84
- while True:
85
- # try searching starting here
86
- m: int = base.RandInt(min_m, max_m)
87
- if m % 2:
88
- m += 1 # make even
89
- for _ in range(window):
90
- p: int = q * m + 1
91
- if p >= max_p:
92
- break
93
- # first do a quick sieve test
94
- if not _PassesSieve(m):
95
- m += 2
96
- continue
97
- # passed sieve, do full test
98
- if modmath.IsPrime(p):
99
- return (p, q, m) # found a suitable prime set!
100
- m += 2 # next multiple
101
- # after _PRIME_MULTIPLE_SEARCH we declare this range failed
102
- failures += 1
103
- if failures >= _MAX_KEY_GENERATION_FAILURES:
104
- raise base.CryptoError(f'failed primes generation {failures} times')
105
- logging.warning(f'failed primes search: {failures}')
146
+ # try searching starting here
147
+ m: int = base.RandInt(min_m, max_m)
148
+ if m % 2:
149
+ m += 1 # make even
150
+ count: int = 0
151
+ pr = 0
152
+ while count < shard_len:
153
+ pr = q * m + 1
154
+ if pr > max_p:
155
+ break
156
+ # first do a quick sieve test
157
+ if _PassesSieve(m):
158
+ if modmath.IsPrime(pr): # passed sieve, do full test
159
+ return (pr, m) # found a suitable prime set!
160
+ count += 1
161
+ m += 2 # next even number
162
+ return (None, None)
106
163
 
107
164
 
108
165
  @dataclasses.dataclass(kw_only=True, slots=True, frozen=True, repr=False)
@@ -203,7 +260,7 @@ class DSASharedPublicKey(base.CryptoKey):
203
260
  g: int = 0
204
261
  while g < 2:
205
262
  h: int = base.RandBits(p_bits - 1)
206
- g = modmath.ModExp(h, m, p)
263
+ g = int(gmpy2.powmod(h, m, p)) # type:ignore # pylint:disable=no-member
207
264
  return cls(prime_modulus=p, prime_seed=q, group_base=g)
208
265
 
209
266
 
@@ -278,10 +335,10 @@ class DSAPublicKey(DSASharedPublicKey, base.Verifier):
278
335
  raise base.InputError(f'invalid signature: {signature=}')
279
336
  # verify
280
337
  inv: int = modmath.ModInv(signature[1], self.prime_seed)
281
- a: int = modmath.ModExp(
282
- self.group_base, (message * inv) % self.prime_seed, self.prime_modulus)
283
- b: int = modmath.ModExp(
284
- self.individual_base, (signature[0] * inv) % self.prime_seed, self.prime_modulus)
338
+ a: int = int(gmpy2.powmod( # type:ignore # pylint:disable=no-member
339
+ self.group_base, (message * inv) % self.prime_seed, self.prime_modulus))
340
+ b: int = int(gmpy2.powmod( # type:ignore # pylint:disable=no-member
341
+ self.individual_base, (signature[0] * inv) % self.prime_seed, self.prime_modulus))
285
342
  return ((a * b) % self.prime_modulus) % self.prime_seed == signature[0]
286
343
 
287
344
  def Verify(
@@ -353,8 +410,7 @@ class DSAPrivateKey(DSAPublicKey, base.Signer): # pylint: disable=too-many-ance
353
410
  if (not 2 < self.decrypt_exp < self.prime_seed or
354
411
  self.decrypt_exp in (self.group_base, self.individual_base)):
355
412
  raise base.InputError(f'invalid decrypt_exp: {self}')
356
- if modmath.ModExp(
357
- self.group_base, self.decrypt_exp, self.prime_modulus) != self.individual_base:
413
+ if gmpy2.powmod(self.group_base, self.decrypt_exp, self.prime_modulus) != self.individual_base: # type:ignore # pylint:disable=no-member
358
414
  raise base.CryptoError(f'inconsistent g**d % p == i: {self}')
359
415
 
360
416
  def __str__(self) -> str:
@@ -387,10 +443,11 @@ class DSAPrivateKey(DSAPublicKey, base.Signer): # pylint: disable=too-many-ance
387
443
  if not 0 < message < self.prime_seed:
388
444
  raise base.InputError(f'invalid message: {message=}')
389
445
  # sign
390
- a, b = 0, 0
446
+ a: int = 0
447
+ b: int = 0
391
448
  while a < 2 or b < 2:
392
449
  ephemeral_key, ephemeral_inv = self._MakeEphemeralKey()
393
- a = modmath.ModExp(self.group_base, ephemeral_key, self.prime_modulus) % self.prime_seed
450
+ a = int(gmpy2.powmod(self.group_base, ephemeral_key, self.prime_modulus) % self.prime_seed) # type:ignore # pylint:disable=no-member
394
451
  b = (ephemeral_inv * ((message + a * self.decrypt_exp) % self.prime_seed)) % self.prime_seed
395
452
  return (a, b)
396
453
 
@@ -460,8 +517,8 @@ class DSAPrivateKey(DSAPublicKey, base.Signer): # pylint: disable=too-many-ance
460
517
  prime_modulus=shared_key.prime_modulus,
461
518
  prime_seed=shared_key.prime_seed,
462
519
  group_base=shared_key.group_base,
463
- individual_base=modmath.ModExp(
464
- shared_key.group_base, decrypt_exp, shared_key.prime_modulus),
520
+ individual_base=int(gmpy2.powmod( # type:ignore # pylint:disable=no-member
521
+ shared_key.group_base, decrypt_exp, shared_key.prime_modulus)),
465
522
  decrypt_exp=decrypt_exp)
466
523
  except base.InputError as err:
467
524
  failures += 1
transcrypto/elgamal.py CHANGED
@@ -23,6 +23,8 @@ import logging
23
23
  # import pdb
24
24
  from typing import Self
25
25
 
26
+ import gmpy2 # type:ignore
27
+
26
28
  from . import base, modmath, aes
27
29
 
28
30
  __author__ = 'balparda@github.com'
@@ -122,7 +124,7 @@ class ElGamalSharedPublicKey(base.CryptoKey):
122
124
  if bit_length < 11:
123
125
  raise base.InputError(f'invalid bit length: {bit_length=}')
124
126
  # generate random prime and number, create object (should never fail)
125
- p: int = modmath.NBitRandomPrime(bit_length)
127
+ p: int = modmath.NBitRandomPrimes(bit_length).pop()
126
128
  g: int = 0
127
129
  while not 2 < g < p - 1:
128
130
  g = base.RandBits(bit_length)
@@ -203,8 +205,8 @@ class ElGamalPublicKey(ElGamalSharedPublicKey, base.Encryptor, base.Verifier):
203
205
  b: int = 0
204
206
  while a < 2 or b < 2:
205
207
  ephemeral_key: int = self._MakeEphemeralKey()[0]
206
- a = modmath.ModExp(self.group_base, ephemeral_key, self.prime_modulus)
207
- s: int = modmath.ModExp(self.individual_base, ephemeral_key, self.prime_modulus)
208
+ a = int(gmpy2.powmod(self.group_base, ephemeral_key, self.prime_modulus)) # type:ignore # pylint:disable=no-member
209
+ s: int = int(gmpy2.powmod(self.individual_base, ephemeral_key, self.prime_modulus)) # type:ignore # pylint:disable=no-member
208
210
  b = (message * s) % self.prime_modulus
209
211
  return (a, b)
210
212
 
@@ -278,9 +280,9 @@ class ElGamalPublicKey(ElGamalSharedPublicKey, base.Encryptor, base.Verifier):
278
280
  not 2 <= signature[1] < self.prime_modulus - 1):
279
281
  raise base.InputError(f'invalid signature: {signature=}')
280
282
  # verify
281
- a: int = modmath.ModExp(self.group_base, message, self.prime_modulus)
282
- b: int = modmath.ModExp(signature[0], signature[1], self.prime_modulus)
283
- c: int = modmath.ModExp(self.individual_base, signature[0], self.prime_modulus)
283
+ a: int = int(gmpy2.powmod(self.group_base, message, self.prime_modulus)) # type:ignore # pylint:disable=no-member
284
+ b: int = int(gmpy2.powmod(signature[0], signature[1], self.prime_modulus)) # type:ignore # pylint:disable=no-member
285
+ c: int = int(gmpy2.powmod(self.individual_base, signature[0], self.prime_modulus)) # type:ignore # pylint:disable=no-member
284
286
  return a == (b * c) % self.prime_modulus
285
287
 
286
288
  def Verify(
@@ -351,8 +353,7 @@ class ElGamalPrivateKey(ElGamalPublicKey, base.Decryptor, base.Signer): # pylin
351
353
  if (not 2 < self.decrypt_exp < self.prime_modulus - 1 or
352
354
  self.decrypt_exp in (self.group_base, self.individual_base)):
353
355
  raise base.InputError(f'invalid decrypt_exp: {self}')
354
- if modmath.ModExp(
355
- self.group_base, self.decrypt_exp, self.prime_modulus) != self.individual_base:
356
+ if gmpy2.powmod(self.group_base, self.decrypt_exp, self.prime_modulus) != self.individual_base: # type:ignore # pylint:disable=no-member
356
357
  raise base.CryptoError(f'inconsistent g**e % p == i: {self}')
357
358
 
358
359
  def __str__(self) -> str:
@@ -385,8 +386,8 @@ class ElGamalPrivateKey(ElGamalPublicKey, base.Decryptor, base.Signer): # pylin
385
386
  not 2 <= ciphertext[1] < self.prime_modulus):
386
387
  raise base.InputError(f'invalid message: {ciphertext=}')
387
388
  # decrypt
388
- csi: int = modmath.ModExp(
389
- ciphertext[0], self.prime_modulus - 1 - self.decrypt_exp, self.prime_modulus)
389
+ csi: int = int(
390
+ gmpy2.powmod(ciphertext[0], self.prime_modulus - 1 - self.decrypt_exp, self.prime_modulus)) # type:ignore # pylint:disable=no-member
390
391
  return (ciphertext[1] * csi) % self.prime_modulus
391
392
 
392
393
  def Decrypt(self, ciphertext: bytes, /, *, associated_data: bytes | None = None) -> bytes:
@@ -450,7 +451,7 @@ class ElGamalPrivateKey(ElGamalPublicKey, base.Decryptor, base.Signer): # pylin
450
451
  p_1: int = self.prime_modulus - 1
451
452
  while a < 2 or b < 2:
452
453
  ephemeral_key, ephemeral_inv = self._MakeEphemeralKey()
453
- a = modmath.ModExp(self.group_base, ephemeral_key, self.prime_modulus)
454
+ a = int(gmpy2.powmod(self.group_base, ephemeral_key, self.prime_modulus)) # type:ignore # pylint:disable=no-member
454
455
  b = (ephemeral_inv * ((message - a * self.decrypt_exp) % p_1)) % p_1
455
456
  return (a, b)
456
457
 
@@ -519,8 +520,8 @@ class ElGamalPrivateKey(ElGamalPublicKey, base.Decryptor, base.Signer): # pylin
519
520
  return cls(
520
521
  prime_modulus=shared_key.prime_modulus,
521
522
  group_base=shared_key.group_base,
522
- individual_base=modmath.ModExp(
523
- shared_key.group_base, decrypt_exp, shared_key.prime_modulus),
523
+ individual_base=int(gmpy2.powmod( # type:ignore # pylint:disable=no-member
524
+ shared_key.group_base, decrypt_exp, shared_key.prime_modulus)),
524
525
  decrypt_exp=decrypt_exp)
525
526
  except base.InputError as err:
526
527
  failures += 1
transcrypto/modmath.py CHANGED
@@ -6,39 +6,22 @@
6
6
 
7
7
  from __future__ import annotations
8
8
 
9
+ import concurrent.futures
9
10
  import math
11
+ import multiprocessing
12
+ import os
10
13
  # import pdb
11
14
  from typing import Generator, Reversible
12
15
 
13
- from . import base
16
+ import gmpy2 # type:ignore
17
+
18
+ from . import base, constants
14
19
 
15
20
  __author__ = 'balparda@github.com'
16
21
  __version__: str = base.__version__ # version comes from base!
17
22
  __version_tuple__: tuple[int, ...] = base.__version_tuple__
18
23
 
19
24
 
20
- _FIRST_60_PRIMES: set[int] = {
21
- 2, 3, 5, 7, 11, 13, 17, 19, 23, 29,
22
- 31, 37, 41, 43, 47, 53, 59, 61, 67, 71,
23
- 73, 79, 83, 89, 97, 101, 103, 107, 109, 113,
24
- 127, 131, 137, 139, 149, 151, 157, 163, 167, 173,
25
- 179, 181, 191, 193, 197, 199, 211, 223, 227, 229,
26
- 233, 239, 241, 251, 257, 263, 269, 271, 277, 281,
27
- }
28
- _FIRST_60_PRIMES_SORTED: list[int] = sorted(_FIRST_60_PRIMES)
29
- _COMPOSITE_60: int = math.prod(_FIRST_60_PRIMES_SORTED)
30
- _PRIME_60: int = _FIRST_60_PRIMES_SORTED[-1]
31
- assert len(_FIRST_60_PRIMES) == 60 and _PRIME_60 == 281, f'should never happen: {_PRIME_60=}'
32
- _FIRST_49_MERSENNE: set[int] = { # <https://oeis.org/A000043>
33
- 2, 3, 5, 7, 13, 17, 19, 31, 61, 89,
34
- 107, 127, 521, 607, 1279, 2203, 2281, 3217, 4253, 4423,
35
- 9689, 9941, 11213, 19937, 21701, 23209, 44497, 86243, 110503, 132049,
36
- 216091, 756839, 859433, 1257787, 1398269, 2976221, 3021377, 6972593, 13466917, 20996011,
37
- 24036583, 25964951, 30402457, 32582657, 37156667, 42643801, 43112609, 57885161, 74207281,
38
- }
39
- _FIRST_49_MERSENNE_SORTED: list[int] = sorted(_FIRST_49_MERSENNE)
40
- assert len(_FIRST_49_MERSENNE) == 49 and _FIRST_49_MERSENNE_SORTED[-1] == 74207281, f'should never happen: {_FIRST_49_MERSENNE_SORTED[-1]}'
41
-
42
25
  _MAX_PRIMALITY_SAFETY = 100 # this is an absurd number, just to have a max
43
26
 
44
27
 
@@ -176,6 +159,7 @@ def ModExp(x: int, y: int, m: int, /) -> int:
176
159
  return x
177
160
  # now both x > 1 and y > 1
178
161
  z: int = 1
162
+ odd: int
179
163
  while y:
180
164
  y, odd = divmod(y, 2)
181
165
  if odd:
@@ -316,7 +300,7 @@ def FermatIsPrime(n: int, /, *, safety: int = 10, witnesses: set[int] | None = N
316
300
  for w in sorted(witnesses):
317
301
  if not 2 <= w <= (n - 2):
318
302
  raise base.InputError(f'out of bounds witness: 2 ≤ {w=} ≤ {n - 2}')
319
- if ModExp(w, n - 1, n) != 1:
303
+ if gmpy2.powmod(w, n - 1, n) != 1: # type:ignore # pylint:disable=no-member
320
304
  # number is proved to be composite
321
305
  return False
322
306
  # we declare the number PROBABLY a prime to the limits of this test
@@ -353,19 +337,19 @@ def _MillerRabinWitnesses(n: int, /) -> set[int]: # pylint: disable=too-many-re
353
337
  if n < 4759123141:
354
338
  return {2, 7, 61} # "safety" 3, but 100% coverage
355
339
  if n < 2152302898747:
356
- return set(_FIRST_60_PRIMES_SORTED[:5]) # "safety" 5, but 100% coverage
340
+ return set(constants.FIRST_5K_PRIMES_SORTED[:5]) # "safety" 5, but 100% coverage
357
341
  if n < 341550071728321:
358
- return set(_FIRST_60_PRIMES_SORTED[:7]) # "safety" 7, but 100% coverage
359
- if n < 18446744073709551616: # 2 ** 64
360
- return set(_FIRST_60_PRIMES_SORTED[:12]) # "safety" 12, but 100% coverage
361
- if n < 3317044064679887385961981: # > 2 ** 81
362
- return set(_FIRST_60_PRIMES_SORTED[:13]) # "safety" 13, but 100% coverage
342
+ return set(constants.FIRST_5K_PRIMES_SORTED[:7]) # "safety" 7, but 100% coverage
343
+ if n < 18446744073709551616: # 2 ** 64
344
+ return set(constants.FIRST_5K_PRIMES_SORTED[:12]) # "safety" 12, but 100% coverage
345
+ if n < 3317044064679887385961981: # > 2 ** 81
346
+ return set(constants.FIRST_5K_PRIMES_SORTED[:13]) # "safety" 13, but 100% coverage
363
347
  # here n should be greater than 2 ** 81, so safety should be 34 or less
364
348
  n_bits: int = n.bit_length()
365
349
  assert n_bits >= 82, f'should never happen: {n=} -> {n_bits=}'
366
350
  safety: int = int(math.ceil(0.375 + 1.59 / (0.000590 * n_bits))) if n_bits <= 1700 else 2
367
351
  assert 1 < safety <= 34, f'should never happen: {n=} -> {n_bits=} ; {safety=}'
368
- return set(_FIRST_60_PRIMES_SORTED[:safety])
352
+ return set(constants.FIRST_5K_PRIMES_SORTED[:safety])
369
353
 
370
354
 
371
355
  def _MillerRabinSR(n: int, /) -> tuple[int, int]:
@@ -427,7 +411,7 @@ def MillerRabinIsPrime(n: int, /, *, witnesses: set[int] | None = None) -> bool:
427
411
  for w in sorted(witnesses if witnesses else _MillerRabinWitnesses(n)):
428
412
  if not 2 <= w <= (n - 2):
429
413
  raise base.InputError(f'out of bounds witness: 2 ≤ {w=} ≤ {n - 2}')
430
- x: int = ModExp(w, r, n)
414
+ x: int = int(gmpy2.powmod(w, r, n)) # type:ignore # pylint:disable=no-member
431
415
  if x not in n_limits:
432
416
  for _ in range(s): # s >= 1 so will execute at least once
433
417
  y = (x * x) % n
@@ -452,10 +436,13 @@ def IsPrime(n: int, /) -> bool:
452
436
  Raises:
453
437
  InputError: invalid inputs
454
438
  """
455
- # is number divisible by (one of the) first 60 primes? test should eliminate 80%+ of candidates
456
- if n > _PRIME_60 and base.GCD(n, _COMPOSITE_60) != 1:
457
- return False
458
- # do the (more expensive) Miller-Rabin primality test
439
+ # is number divisible by (one of the) first 20000 primes? test should eliminate 90%+ of candidates
440
+ if n in constants.FIRST_20K_PRIMES:
441
+ return True
442
+ for r in constants.FIRST_20K_PRIMES_SORTED:
443
+ if not n % r:
444
+ return False # we already checked: it is not one of the 20k first primes, so not prime
445
+ # do the (much much more expensive) Miller-Rabin primality test
459
446
  return MillerRabinIsPrime(n)
460
447
 
461
448
 
@@ -486,18 +473,38 @@ def PrimeGenerator(start: int, /) -> Generator[int, None, None]:
486
473
  yield n # found a prime
487
474
 
488
475
 
489
- def NBitRandomPrime(n_bits: int, /) -> int:
476
+ def NBitRandomPrimes(n_bits: int, /, *, serial: bool = True, n_primes: int = 1) -> set[int]:
490
477
  """Generates a random prime with (guaranteed) `n_bits` size (i.e., first bit == 1).
491
478
 
492
479
  The fact that the first bit will be 1 means the entropy is ~ (n_bits-1) and
493
480
  because of this we only allow for a byte or more prime bits generated. This drawback
494
481
  is negligible for the large primes a crypto library will work with, in practice.
495
482
 
483
+ Method will decide if executes on one thread or many.
484
+
485
+ $ poetry run profiler -s -n 100 -b 1000,11000,1000 -c 98 primes # single-thread, Mac M2 Max, 2025
486
+ 1000 → 84.233 ms ± 18.853 ms [65.380 ms … 103.085 ms]98%CI@100
487
+ 2000 → 406.900 ms ± 91.575 ms [315.325 ms … 498.475 ms]98%CI@100
488
+ 3000 → 1.20 s ± 291.105 ms [907.331 ms … 1.49 s]98%CI@100
489
+ 4000 → 2.42 s ± 490.241 ms [1.93 s … 2.91 s]98%CI@100
490
+ 5000 → 4.78 s ± 1.02 s [3.76 s … 5.80 s]98%CI@100
491
+ 6000 → 7.63 s ± 1.57 s [6.06 s … 9.20 s]98%CI@100
492
+ 7000 → 13.66 s ± 3.00 s [10.66 s … 16.66 s]98%CI@100
493
+ 8000 → 20.71 s ± 5.05 s [15.67 s … 25.76 s]98%CI@100
494
+ 9000 → 33.12 s ± 7.61 s [25.51 s … 40.73 s]98%CI@100
495
+ 10000 → 52.91 s ± 11.73 s [41.18 s … 1.08 min]98%CI@100
496
+
497
+ Rule of thumb: double the bits requires ~10x execution time
498
+
496
499
  Args:
497
500
  n_bits (int): Number of guaranteed bits in prime representation, n ≥ 8
501
+ serial (bool, optional): True (default) will force one thread; False will allow parallelism;
502
+ we have temporarily disabled parallelism with a default of True because it is not making
503
+ things faster...
504
+ n_primes (int, optional): Number of required primes in the return set[int], default is 1
498
505
 
499
506
  Returns:
500
- random prime with `n_bits` bits
507
+ set[int]: `n_primes` random primes with `n_bits` bits
501
508
 
502
509
  Raises:
503
510
  InputError: invalid inputs
@@ -505,11 +512,69 @@ def NBitRandomPrime(n_bits: int, /) -> int:
505
512
  # test inputs
506
513
  if n_bits < 8:
507
514
  raise base.InputError(f'invalid n: {n_bits=}')
508
- # get a random number with guaranteed bit size
509
- prime: int = 0
510
- while prime.bit_length() != n_bits:
511
- prime = next(PrimeGenerator(base.RandBits(n_bits)))
512
- return prime
515
+ n_primes = 1 if n_primes < 1 else n_primes
516
+ # get number of CPUs and decide if we do parallel or not
517
+ n_workers: int = min(4, os.cpu_count() or 1)
518
+ pr_set: set[int] = set()
519
+ pr: int | None = None
520
+ if serial or n_workers <= 1 or n_bits < 200:
521
+ # do one worker
522
+ while len(pr_set) < n_primes:
523
+ while pr is None or pr.bit_length() != n_bits:
524
+ pr = _PrimeSearchShard(n_bits)
525
+ pr_set.add(pr)
526
+ pr = None
527
+ return pr_set
528
+ # parallel: keep a small pool of bounded shards; stop on first hit
529
+ multiprocessing.set_start_method('fork', force=True)
530
+ with concurrent.futures.ProcessPoolExecutor(max_workers=n_workers) as pool:
531
+ workers: set[concurrent.futures.Future[int | None]] = {
532
+ pool.submit(_PrimeSearchShard, n_bits) for _ in range(n_workers)}
533
+ while workers:
534
+ done: set[concurrent.futures.Future[int | None]] = concurrent.futures.wait(
535
+ workers, return_when=concurrent.futures.FIRST_COMPLETED)[0]
536
+ for worker in done:
537
+ workers.remove(worker)
538
+ pr = worker.result()
539
+ if pr is not None and pr.bit_length() == n_bits:
540
+ pr_set.add(pr)
541
+ pr = None
542
+ if len(pr_set) >= n_primes:
543
+ return pr_set
544
+ # no hit in that shard: keep the pool full with a fresh shard
545
+ workers.add(pool.submit(_PrimeSearchShard, n_bits))
546
+ # can never reach this point, but leave this here; remove line from coverage
547
+ raise base.Error(f'could not find prime with {n_bits=} bits') # pragma: no cover
548
+
549
+
550
+ def _PrimeSearchShard(n_bits: int) -> int | None:
551
+ """Search for a `n_bits` random prime, starting from a random point, for ~6× expected prime gap.
552
+
553
+ Args:
554
+ n_bits (int): Number of guaranteed bits in prime representation
555
+
556
+ Returns:
557
+ int | None: either the prime int or None if no prime found in this shard
558
+ """
559
+ shard_len: int = max(2000, 6 * int(0.693 * n_bits)) # ~6× expected prime gap ~2^k (≈ 0.693*k)
560
+ pr: int = base.RandBits(n_bits) | 1 # random position; make ODD
561
+ count: int = 0
562
+ while count < shard_len and pr.bit_length() == n_bits:
563
+ if IsPrime(pr):
564
+ return pr
565
+ count += 1
566
+ pr += 2
567
+ return None
568
+
569
+
570
+ def FirstNPrimesSorted(n: int) -> list[int]:
571
+ """Returns list of `n` first primes in a sorted list."""
572
+ primes: list[int] = []
573
+ for i, pr in enumerate(PrimeGenerator(0)):
574
+ if i >= n:
575
+ break
576
+ primes.append(pr)
577
+ return primes
513
578
 
514
579
 
515
580
  def MersennePrimesGenerator(start: int, /) -> Generator[tuple[int, int, int], None, None]: