transcrypto 1.1.2__py3-none-any.whl → 1.3.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
transcrypto/rsa.py CHANGED
@@ -14,8 +14,9 @@ import logging
14
14
  # import pdb
15
15
  from typing import Self
16
16
 
17
- from . import base
18
- from . import modmath
17
+ import gmpy2 # type:ignore
18
+
19
+ from . import base, modmath, aes
19
20
 
20
21
  __author__ = 'balparda@github.com'
21
22
  __version__: str = base.__version__ # version comes from base!
@@ -27,13 +28,15 @@ _BIG_ENCRYPTION_EXPONENT = 2 ** 16 + 1 # 65537
27
28
 
28
29
  _MAX_KEY_GENERATION_FAILURES = 15
29
30
 
31
+ # fixed prefixes: do NOT ever change! will break all encryption and signature schemes
32
+ _RSA_ENCRYPTION_AAD_PREFIX = b'transcrypto.RSA.Encryption.1.0\x00'
33
+ _RSA_SIGNATURE_HASH_PREFIX = b'transcrypto.RSA.Signature.1.0\x00'
34
+
30
35
 
31
36
  @dataclasses.dataclass(kw_only=True, slots=True, frozen=True, repr=False)
32
- class RSAPublicKey(base.CryptoKey):
37
+ class RSAPublicKey(base.CryptoKey, base.Encryptor, base.Verifier):
33
38
  """RSA (Rivest-Shamir-Adleman) key, with the public part of the key.
34
39
 
35
- BEWARE: This is raw RSA, no OAEP or PSS padding or validation!
36
- These are pedagogical/raw primitives; do not use for new protocols.
37
40
  No measures are taken here to prevent timing attacks.
38
41
 
39
42
  By default and deliberate choice the encryption exponent will be either 7 or 65537,
@@ -70,12 +73,20 @@ class RSAPublicKey(base.CryptoKey):
70
73
  string representation of RSAPublicKey
71
74
  """
72
75
  return ('RSAPublicKey('
76
+ f'bits={self.public_modulus.bit_length()}, '
73
77
  f'public_modulus={base.IntToEncoded(self.public_modulus)}, '
74
78
  f'encrypt_exp={base.IntToEncoded(self.encrypt_exp)})')
75
79
 
76
- def Encrypt(self, message: int, /) -> int:
80
+ @property
81
+ def modulus_size(self) -> int:
82
+ """Modulus size in bytes. The number of bytes used in Encrypt/Decrypt/Sign/Verify."""
83
+ return (self.public_modulus.bit_length() + 7) // 8
84
+
85
+ def RawEncrypt(self, message: int, /) -> int:
77
86
  """Encrypt `message` with this public key.
78
87
 
88
+ BEWARE: This is raw RSA, no OAEP or PSS padding or validation!
89
+ These are pedagogical/raw primitives; do not use for new protocols.
79
90
  We explicitly disallow `message` to be zero.
80
91
 
81
92
  Args:
@@ -91,11 +102,54 @@ class RSAPublicKey(base.CryptoKey):
91
102
  if not 0 < message < self.public_modulus:
92
103
  raise base.InputError(f'invalid message: {message=}')
93
104
  # encrypt
94
- return modmath.ModExp(message, self.encrypt_exp, self.public_modulus)
105
+ return int(gmpy2.powmod(message, self.encrypt_exp, self.public_modulus)) # type:ignore # pylint:disable=no-member
106
+
107
+ def Encrypt(self, plaintext: bytes, /, *, associated_data: bytes | None = None) -> bytes:
108
+ """Encrypt `plaintext` and return `ciphertext`.
109
+
110
+ • Let k = ceil(log2(n))/8 be the modulus size in bytes.
111
+ • Pick random r ∈ [2, n-1]
112
+ • ct = r^e mod n
113
+ • return Padded(ct, k) + AES-256-GCM(key=SHA512(r)[32:], plaintext,
114
+ associated_data="prefix" + len(aad) + aad + Padded(ct, k))
115
+
116
+ We pick fresh random r, send ct = r^e mod n, and derive the DEM key from r,
117
+ then use AES-GCM for the payload. This is the classic RSA-KEM construction.
118
+ With AEAD as the DEM, we get strong confidentiality and ciphertext integrity
119
+ (CCA resistance in the ROM under standard assumptions). There are no
120
+ Bleichenbacher-style issue because we do not expose any padding semantics.
95
121
 
96
- def VerifySignature(self, message: int, signature: int, /) -> bool:
122
+ Args:
123
+ plaintext (bytes): Data to encrypt.
124
+ associated_data (bytes, optional): Optional AAD; must be provided again on decrypt
125
+
126
+ Returns:
127
+ bytes: Ciphertext; see above:
128
+ Padded(ct, k) + AES-256-GCM(key=SHA512(r)[32:], plaintext,
129
+ associated_data="prefix" + len(aad) + aad + Padded(ct, k))
130
+
131
+ Raises:
132
+ InputError: invalid inputs
133
+ CryptoError: internal crypto failures
134
+ """
135
+ # generate random r and encrypt it
136
+ r: int = 0
137
+ while not 1 < r < self.public_modulus or base.GCD(r, self.public_modulus) != 1:
138
+ r = base.RandBits(self.public_modulus.bit_length())
139
+ k: int = self.modulus_size
140
+ ct: bytes = base.IntToFixedBytes(self.RawEncrypt(r), k)
141
+ assert len(ct) == k, 'should never happen: c_kem should be exactly k bytes'
142
+ # encrypt plaintext with AES-256-GCM using SHA512(r)[32:] as key; return ct || Encrypt(...)
143
+ ss: bytes = base.Hash512(base.IntToFixedBytes(r, k))
144
+ aad: bytes = b'' if associated_data is None else associated_data
145
+ aad_prime: bytes = _RSA_ENCRYPTION_AAD_PREFIX + base.IntToFixedBytes(len(aad), 8) + aad + ct
146
+ return ct + aes.AESKey(key256=ss[32:]).Encrypt(plaintext, associated_data=aad_prime)
147
+
148
+ def RawVerify(self, message: int, signature: int, /) -> bool:
97
149
  """Verify a signature. True if OK; False if failed verification.
98
150
 
151
+ BEWARE: This is raw RSA, no OAEP or PSS padding or validation!
152
+ These are pedagogical/raw primitives; do not use for new protocols.
99
153
  We explicitly disallow `message` to be zero.
100
154
 
101
155
  Args:
@@ -109,7 +163,68 @@ class RSAPublicKey(base.CryptoKey):
109
163
  Raises:
110
164
  InputError: invalid inputs
111
165
  """
112
- return self.Encrypt(signature) == message
166
+ return self.RawEncrypt(signature) == message
167
+
168
+ def _DomainSeparatedHash(
169
+ self, message: bytes, associated_data: bytes | None, salt: bytes, /) -> int:
170
+ """Compute the domain-separated hash for signing and verifying.
171
+
172
+ Args:
173
+ message (bytes): message to sign/verify
174
+ associated_data (bytes | None): optional associated data
175
+ salt (bytes): salt to use in the hash
176
+
177
+ Returns:
178
+ int: integer representation of the hash output;
179
+ Hash512("prefix" || len(aad) || aad || message || salt)
180
+
181
+ Raises:
182
+ CryptoError: hash output is out of range
183
+ """
184
+ aad: bytes = b'' if associated_data is None else associated_data
185
+ la: bytes = base.IntToFixedBytes(len(aad), 8)
186
+ assert len(salt) == 64, 'should never happen: salt should be exactly 64 bytes'
187
+ y: int = base.BytesToInt(base.Hash512(_RSA_SIGNATURE_HASH_PREFIX + la + aad + message + salt))
188
+ if not 1 < y < self.public_modulus or base.GCD(y, self.public_modulus) != 1:
189
+ # will only reasonably happen if modulus is small
190
+ raise base.CryptoError(f'hash output {y} is out of range/invalid {self.public_modulus}')
191
+ return y
192
+
193
+ def Verify(
194
+ self, message: bytes, signature: bytes, /, *, associated_data: bytes | None = None) -> bool:
195
+ """Verify a `signature` for `message`. True if OK; False if failed verification.
196
+
197
+ • Let k = ceil(log2(n))/8 be the modulus size in bytes.
198
+ • Split signature in two parts: the first 64 bytes is salt, the rest is s
199
+ • y_check = s^e mod n
200
+ • return y_check == Hash512("prefix" || len(aad) || aad || message || salt)
201
+ • return False for any malformed signature
202
+
203
+ Args:
204
+ message (bytes): Data that was signed
205
+ signature (bytes): Signature data to verify
206
+ associated_data (bytes, optional): Optional AAD (must match what was used during signing)
207
+
208
+ Returns:
209
+ True if signature is valid, False otherwise
210
+
211
+ Raises:
212
+ InputError: invalid inputs
213
+ CryptoError: internal crypto failures, authentication failure, key mismatch, etc
214
+ """
215
+ k: int = self.modulus_size
216
+ if k <= 64:
217
+ raise base.InputError(f'modulus too small for signing operations: {k} bytes')
218
+ if len(signature) != (64 + k):
219
+ logging.info(f'invalid signature length: {len(signature)} ; expected {64 + k}')
220
+ return False
221
+ try:
222
+ return self.RawVerify(
223
+ self._DomainSeparatedHash(message, associated_data, signature[:64]),
224
+ base.BytesToInt(signature[64:]))
225
+ except base.InputError as err:
226
+ logging.info(err)
227
+ return False
113
228
 
114
229
  @classmethod
115
230
  def Copy(cls, other: RSAPublicKey, /) -> Self:
@@ -154,7 +269,8 @@ class RSAObfuscationPair(RSAPublicKey):
154
269
  Returns:
155
270
  string representation of RSAObfuscationPair without leaking secrets
156
271
  """
157
- return (f'RSAObfuscationPair({super(RSAObfuscationPair, self).__str__()}, ' # pylint: disable=super-with-arguments
272
+ return ('RSAObfuscationPair('
273
+ f'{super(RSAObfuscationPair, self).__str__()}, ' # pylint: disable=super-with-arguments
158
274
  f'random_key={base.ObfuscateSecret(self.random_key)}, '
159
275
  f'key_inverse={base.ObfuscateSecret(self.key_inverse)})')
160
276
 
@@ -176,8 +292,8 @@ class RSAObfuscationPair(RSAPublicKey):
176
292
  if not 0 < message < self.public_modulus:
177
293
  raise base.InputError(f'invalid message: {message=}')
178
294
  # encrypt
179
- return (message * modmath.ModExp(
180
- self.random_key, self.encrypt_exp, self.public_modulus)) % self.public_modulus
295
+ return (message * int(gmpy2.powmod( # type:ignore # pylint:disable=no-member
296
+ self.random_key, self.encrypt_exp, self.public_modulus))) % self.public_modulus
181
297
 
182
298
  def RevealOriginalSignature(self, message: int, signature: int, /) -> int:
183
299
  """Recover original signature for `message` from obfuscated `signature`.
@@ -198,11 +314,11 @@ class RSAObfuscationPair(RSAPublicKey):
198
314
  """
199
315
  # verify that obfuscated signature is valid
200
316
  obfuscated: int = self.ObfuscateMessage(message)
201
- if not self.VerifySignature(obfuscated, signature):
317
+ if not self.RawVerify(obfuscated, signature):
202
318
  raise base.CryptoError(f'obfuscated message was not signed: {message=} ; {signature=}')
203
319
  # compute signature for original message and check it
204
320
  original: int = (signature * self.key_inverse) % self.public_modulus
205
- if not self.VerifySignature(message, original):
321
+ if not self.RawVerify(message, original):
206
322
  raise base.CryptoError(f'failed signature recovery: {message=} ; {signature=}')
207
323
  return original
208
324
 
@@ -243,11 +359,9 @@ class RSAObfuscationPair(RSAPublicKey):
243
359
 
244
360
 
245
361
  @dataclasses.dataclass(kw_only=True, slots=True, frozen=True, repr=False)
246
- class RSAPrivateKey(RSAPublicKey):
362
+ class RSAPrivateKey(RSAPublicKey, base.Decryptor, base.Signer): # pylint: disable=too-many-ancestors
247
363
  """RSA (Rivest-Shamir-Adleman) private key.
248
364
 
249
- BEWARE: This is raw RSA, no OAEP or PSS padding or validation!
250
- These are pedagogical/raw primitives; do not use for new protocols.
251
365
  No measures are taken here to prevent timing attacks.
252
366
 
253
367
  The attributes modulus_p (p), modulus_q (q) and decrypt_exp (d) are "enough" for a working key,
@@ -280,7 +394,7 @@ class RSAPrivateKey(RSAPublicKey):
280
394
  """
281
395
  super(RSAPrivateKey, self).__post_init__() # pylint: disable=super-with-arguments # needed here b/c: dataclass
282
396
  phi: int = (self.modulus_p - 1) * (self.modulus_q - 1)
283
- min_prime_distance: int = 2 ** (self.public_modulus.bit_length() // 3 + 1)
397
+ min_prime_distance: int = 1 << (self.public_modulus.bit_length() // 4) # n**(1/4)
284
398
  if (self.modulus_p < 2 or not modmath.IsPrime(self.modulus_p) or # pylint: disable=too-many-boolean-expressions
285
399
  self.modulus_q < 3 or not modmath.IsPrime(self.modulus_q) or
286
400
  self.modulus_q <= self.modulus_p or
@@ -313,14 +427,17 @@ class RSAPrivateKey(RSAPublicKey):
313
427
  Returns:
314
428
  string representation of RSAPrivateKey without leaking secrets
315
429
  """
316
- return (f'RSAPrivateKey({super(RSAPrivateKey, self).__str__()}, ' # pylint: disable=super-with-arguments
430
+ return ('RSAPrivateKey('
431
+ f'{super(RSAPrivateKey, self).__str__()}, ' # pylint: disable=super-with-arguments
317
432
  f'modulus_p={base.ObfuscateSecret(self.modulus_p)}, '
318
433
  f'modulus_q={base.ObfuscateSecret(self.modulus_q)}, '
319
434
  f'decrypt_exp={base.ObfuscateSecret(self.decrypt_exp)})')
320
435
 
321
- def Decrypt(self, ciphertext: int, /) -> int:
436
+ def RawDecrypt(self, ciphertext: int, /) -> int:
322
437
  """Decrypt `ciphertext` with this private key.
323
438
 
439
+ BEWARE: This is raw RSA, no OAEP or PSS padding or validation!
440
+ These are pedagogical/raw primitives; do not use for new protocols.
324
441
  We explicitly allow `ciphertext` to be zero for completeness, but it shouldn't be in practice.
325
442
 
326
443
  Args:
@@ -336,15 +453,50 @@ class RSAPrivateKey(RSAPublicKey):
336
453
  if not 0 <= ciphertext < self.public_modulus:
337
454
  raise base.InputError(f'invalid message: {ciphertext=}')
338
455
  # decrypt using CRT (Chinese Remainder Theorem); 4x speedup; all the below is equivalent
339
- # of doing: return modmath.ModExp(ciphertext, self.decrypt_exp, self.public_modulus)
340
- m_p: int = modmath.ModExp(ciphertext % self.modulus_p, self.remainder_p, self.modulus_p)
341
- m_q: int = modmath.ModExp(ciphertext % self.modulus_q, self.remainder_q, self.modulus_q)
456
+ # of doing: return pow(ciphertext, self.decrypt_exp, self.public_modulus)
457
+ m_p: int = int(gmpy2.powmod(ciphertext % self.modulus_p, self.remainder_p, self.modulus_p)) # type:ignore # pylint:disable=no-member
458
+ m_q: int = int(gmpy2.powmod(ciphertext % self.modulus_q, self.remainder_q, self.modulus_q)) # type:ignore # pylint:disable=no-member
342
459
  h: int = (self.q_inverse_p * (m_p - m_q)) % self.modulus_p
343
460
  return (m_q + h * self.modulus_q) % self.public_modulus
344
461
 
345
- def Sign(self, message: int, /) -> int:
462
+ def Decrypt(self, ciphertext: bytes, /, *, associated_data: bytes | None = None) -> bytes:
463
+ """Decrypt `ciphertext` and return the original `plaintext`.
464
+
465
+ • Let k = ceil(log2(n))/8 be the modulus size in bytes.
466
+ • Split ciphertext in two parts: the first k bytes is ct, the rest is AES-256-GCM
467
+ • r = ct^d mod n
468
+ • return AES-256-GCM(key=SHA512(r)[32:], ciphertext,
469
+ associated_data="prefix" + len(aad) + aad + Padded(ct, k))
470
+
471
+ Args:
472
+ ciphertext (bytes): Data to decrypt; see Encrypt() above:
473
+ Padded(ct, k) + AES-256-GCM(key=SHA512(r)[32:], plaintext,
474
+ associated_data="prefix" + len(aad) + aad + Padded(ct, k))
475
+ associated_data (bytes, optional): Optional AAD (must match what was used during encrypt)
476
+
477
+ Returns:
478
+ bytes: Decrypted plaintext bytes
479
+
480
+ Raises:
481
+ InputError: invalid inputs
482
+ CryptoError: internal crypto failures, authentication failure, key mismatch, etc
483
+ """
484
+ k: int = self.modulus_size
485
+ if len(ciphertext) < (k + 32):
486
+ raise base.InputError(f'invalid ciphertext length: {len(ciphertext)} ; {k=}')
487
+ # split ciphertext in two parts: the first k bytes is ct, the rest is AES-256-GCM
488
+ rsa_ct, aes_ct = ciphertext[:k], ciphertext[k:]
489
+ r: int = self.RawDecrypt(base.BytesToInt(rsa_ct))
490
+ ss: bytes = base.Hash512(base.IntToFixedBytes(r, k))
491
+ aad: bytes = b'' if associated_data is None else associated_data
492
+ aad_prime: bytes = _RSA_ENCRYPTION_AAD_PREFIX + base.IntToFixedBytes(len(aad), 8) + aad + rsa_ct
493
+ return aes.AESKey(key256=ss[32:]).Decrypt(aes_ct, associated_data=aad_prime)
494
+
495
+ def RawSign(self, message: int, /) -> int:
346
496
  """Sign `message` with this private key.
347
497
 
498
+ BEWARE: This is raw RSA, no OAEP or PSS padding or validation!
499
+ These are pedagogical/raw primitives; do not use for new protocols.
348
500
  We explicitly disallow `message` to be zero.
349
501
 
350
502
  Args:
@@ -361,7 +513,41 @@ class RSAPrivateKey(RSAPublicKey):
361
513
  if not 0 < message < self.public_modulus:
362
514
  raise base.InputError(f'invalid message: {message=}')
363
515
  # call decryption
364
- return self.Decrypt(message)
516
+ return self.RawDecrypt(message)
517
+
518
+ def Sign(self, message: bytes, /, *, associated_data: bytes | None = None) -> bytes:
519
+ """Sign `message` and return the `signature`.
520
+
521
+ • Let k = ceil(log2(n))/8 be the modulus size in bytes.
522
+ • Pick random salt of 64 bytes
523
+ • s = (Hash512("prefix" || len(aad) || aad || message || salt))^d mod n
524
+ • return salt || Padded(s, k)
525
+
526
+ This is basically Full-Domain Hash RSA with a 512-bit hash and per-signature salt,
527
+ which is EUF-CMA secure in the ROM. Our domain-separation prefix and explicit AAD
528
+ length prefix are both correct and remove composition/ambiguity pitfalls.
529
+ There are no Bleichenbacher-style issue because we do not expose any padding semantics.
530
+
531
+ Args:
532
+ message (bytes): Data to sign.
533
+ associated_data (bytes, optional): Optional AAD for AEAD modes; must be
534
+ provided again on decrypt
535
+
536
+ Returns:
537
+ bytes: Signature; salt || Padded(s, k) - see above
538
+
539
+ Raises:
540
+ InputError: invalid inputs
541
+ CryptoError: internal crypto failures
542
+ """
543
+ k: int = self.modulus_size
544
+ if k <= 64:
545
+ raise base.InputError(f'modulus too small for signing operations: {k} bytes')
546
+ salt: bytes = base.RandBytes(64)
547
+ s_int: int = self.RawSign(self._DomainSeparatedHash(message, associated_data, salt))
548
+ s_bytes: bytes = base.IntToFixedBytes(s_int, k)
549
+ assert len(s_bytes) == k, 'should never happen: s_bytes should be exactly k bytes'
550
+ return salt + s_bytes
365
551
 
366
552
  @classmethod
367
553
  def New(cls, bit_length: int, /) -> Self:
@@ -384,21 +570,19 @@ class RSAPrivateKey(RSAPublicKey):
384
570
  failures: int = 0
385
571
  while True:
386
572
  try:
387
- primes: list[int] = [modmath.NBitRandomPrime(bit_length // 2),
388
- modmath.NBitRandomPrime(bit_length // 2)]
389
- modulus: int = primes[0] * primes[1]
390
- while modulus.bit_length() != bit_length or primes[0] == primes[1]:
391
- primes.remove(min(primes))
392
- primes.append(modmath.NBitRandomPrime(
393
- bit_length // 2 + (bit_length % 2 if modulus.bit_length() < bit_length else 0)))
394
- modulus = primes[0] * primes[1]
573
+ primes: set[int] = set()
574
+ modulus: int = 0
575
+ p: int = 0
576
+ q: int = 0
577
+ while modulus.bit_length() != bit_length:
578
+ primes = modmath.NBitRandomPrimes((bit_length + 1) // 2, n_primes=2)
579
+ p, q = min(primes), max(primes) # "p" is always the smaller, "q" the larger
580
+ modulus = p * q
395
581
  # build object
396
- phi: int = (primes[0] - 1) * (primes[1] - 1)
582
+ phi: int = (p - 1) * (q - 1)
397
583
  prime_exp: int = (_SMALL_ENCRYPTION_EXPONENT if phi <= _BIG_ENCRYPTION_EXPONENT else
398
584
  _BIG_ENCRYPTION_EXPONENT)
399
585
  decrypt_exp: int = modmath.ModInv(prime_exp, phi)
400
- p: int = min(primes) # "p" is always the smaller
401
- q: int = max(primes) # "q" is always the larger
402
586
  return cls(
403
587
  modulus_p=p,
404
588
  modulus_q=q,
transcrypto/sss.py CHANGED
@@ -14,25 +14,23 @@ import logging
14
14
  # import pdb
15
15
  from typing import Collection, Generator, Self
16
16
 
17
- from . import base
18
- from . import modmath
17
+ from . import base, modmath, aes
19
18
 
20
19
  __author__ = 'balparda@github.com'
21
20
  __version__: str = base.__version__ # version comes from base!
22
21
  __version_tuple__: tuple[int, ...] = base.__version_tuple__
23
22
 
24
23
 
24
+ # fixed prefixes: do NOT ever change! will break all encryption and signature schemes
25
+ _SSS_ENCRYPTION_AAD_PREFIX = b'transcrypto.SSS.Sharing.1.0\x00'
26
+
27
+
25
28
  @dataclasses.dataclass(kw_only=True, slots=True, frozen=True, repr=False)
26
29
  class ShamirSharedSecretPublic(base.CryptoKey):
27
30
  """Shamir Shared Secret (SSS) public part.
28
31
 
29
- BEWARE: This is raw SSS, no modern message wrapping, padding or validation!
30
- These are pedagogical/raw primitives; do not use for new protocols.
31
32
  No measures are taken here to prevent timing attacks.
32
-
33
- This is the information-theoretic SSS but with no authentication or binding between
34
- share and secret. Malicious share injection is possible! Add MAC or digital signature
35
- in hostile settings.
33
+ Malicious share injection is possible! Add MAC or digital signature in hostile settings.
36
34
 
37
35
  Attributes:
38
36
  minimum (int): minimum shares needed for recovery, ≥ 2
@@ -61,13 +59,24 @@ class ShamirSharedSecretPublic(base.CryptoKey):
61
59
  string representation of ShamirSharedSecretPublic
62
60
  """
63
61
  return ('ShamirSharedSecretPublic('
62
+ f'bits={self.modulus.bit_length()}, '
64
63
  f'minimum={self.minimum}, '
65
64
  f'modulus={base.IntToEncoded(self.modulus)})')
66
65
 
67
- def RecoverSecret(
66
+ @property
67
+ def modulus_size(self) -> int:
68
+ """Modulus size in bytes. The number of bytes used in MakeDataShares/RecoverData."""
69
+ return (self.modulus.bit_length() + 7) // 8
70
+
71
+ def RawRecoverSecret(
68
72
  self, shares: Collection[ShamirSharePrivate], /, *, force_recover: bool = False) -> int:
69
73
  """Recover the secret from ShamirSharePrivate objects.
70
74
 
75
+ BEWARE: This is raw SSS, no modern message wrapping, padding or validation!
76
+ These are pedagogical/raw primitives; do not use for new protocols.
77
+ This is the information-theoretic SSS but with no authentication or binding between
78
+ share and secret.
79
+
71
80
  Args:
72
81
  shares (Collection[ShamirSharePrivate]): shares to use to recover the secret
73
82
  force_recover (bool, optional): if True will try to recover (default: False)
@@ -114,9 +123,8 @@ class ShamirSharedSecretPublic(base.CryptoKey):
114
123
  class ShamirSharedSecretPrivate(ShamirSharedSecretPublic):
115
124
  """Shamir Shared Secret (SSS) private keys.
116
125
 
117
- BEWARE: This is raw SSS, no modern message wrapping, padding or validation!
118
- These are pedagogical/raw primitives; do not use for new protocols.
119
126
  No measures are taken here to prevent timing attacks.
127
+ Malicious share injection is possible! Add MAC or digital signature in hostile settings.
120
128
 
121
129
  We deliberately choose prime coefficients. This shrinks the key-space and leaks a bit of
122
130
  structure. It is "unusual", but with large enough modulus (bit length > ~ 500) it makes no
@@ -148,12 +156,18 @@ class ShamirSharedSecretPrivate(ShamirSharedSecretPublic):
148
156
  Returns:
149
157
  string representation of ShamirSharedSecretPrivate without leaking secrets
150
158
  """
151
- return (f'ShamirSharedSecretPrivate({super(ShamirSharedSecretPrivate, self).__str__()}, ' # pylint: disable=super-with-arguments
159
+ return ('ShamirSharedSecretPrivate('
160
+ f'{super(ShamirSharedSecretPrivate, self).__str__()}, ' # pylint: disable=super-with-arguments
152
161
  f'polynomial=[{", ".join(base.ObfuscateSecret(i) for i in self.polynomial)}])')
153
162
 
154
- def Share(self, secret: int, /, *, share_key: int = 0) -> ShamirSharePrivate:
163
+ def RawShare(self, secret: int, /, *, share_key: int = 0) -> ShamirSharePrivate:
155
164
  """Make a new ShamirSharePrivate for the `secret`.
156
165
 
166
+ BEWARE: This is raw SSS, no modern message wrapping, padding or validation!
167
+ These are pedagogical/raw primitives; do not use for new protocols.
168
+ This is the information-theoretic SSS but with no authentication or binding between
169
+ share and secret.
170
+
157
171
  Args:
158
172
  secret (int): secret message to encrypt and share, 0 ≤ s < modulus
159
173
  share_key (int, optional): if given, a random value to use, 1 ≤ r < modulus;
@@ -181,10 +195,15 @@ class ShamirSharedSecretPrivate(ShamirSharedSecretPublic):
181
195
  share_key=share_key,
182
196
  share_value=modmath.ModPolynomial(share_key, [secret] + self.polynomial, self.modulus))
183
197
 
184
- def Shares(
198
+ def RawShares(
185
199
  self, secret: int, /, *, max_shares: int = 0) -> Generator[ShamirSharePrivate, None, None]:
186
200
  """Make any number of ShamirSharePrivate for the `secret`.
187
201
 
202
+ BEWARE: This is raw SSS, no modern message wrapping, padding or validation!
203
+ These are pedagogical/raw primitives; do not use for new protocols.
204
+ This is the information-theoretic SSS but with no authentication or binding between
205
+ share and secret.
206
+
188
207
  Args:
189
208
  secret (int): secret message to encrypt and share, 0 ≤ s < modulus
190
209
  max_shares (int, optional): if given, number (≥ 2) of shares to generate; else infinite
@@ -206,16 +225,63 @@ class ShamirSharedSecretPrivate(ShamirSharedSecretPublic):
206
225
  while not share_key or share_key in self.polynomial or share_key in used_keys:
207
226
  share_key = base.RandBits(self.modulus.bit_length() - 1)
208
227
  try:
209
- yield self.Share(secret, share_key=share_key)
228
+ yield self.RawShare(secret, share_key=share_key)
210
229
  used_keys.add(share_key)
211
230
  count += 1
212
231
  except base.InputError as err:
213
232
  # it could happen, for example, that the share_key will generate a value of 0
214
233
  logging.warning(err)
215
234
 
216
- def VerifyShare(self, secret: int, share: ShamirSharePrivate, /) -> bool:
235
+ def MakeDataShares(self, secret: bytes, total_shares: int, /) -> list[ShamirShareData]:
236
+ """Make `total_shares` ShamirShareData objects with encrypted `secret`.
237
+
238
+ • Let k = ceil(log2(n))/8 be the modulus size in bytes
239
+ • r = random 32 bytes
240
+ • shares = SSS.Shares(r, total_shares)
241
+ • ct = AES-256-GCM(key=SHA512("prefix" + r)[32:], plaintext=secret,
242
+ associated_data="prefix" + minimum + modulus)
243
+ • return [share + ct for share in shares]
244
+
245
+ Args:
246
+ secret (bytes): Data to encrypt and distribute (encrypted) in each share.
247
+ total_shares (int): Number of shares to make, ≥ minimum
248
+
249
+ Returns:
250
+ list[ShamirShareData]: the list of shares with encrypted data
251
+
252
+ Raises:
253
+ InputError: invalid inputs
254
+ CryptoError: internal crypto failures
255
+ """
256
+ if total_shares < self.minimum:
257
+ raise base.InputError(f'invalid total_shares: {total_shares=} < {self.minimum=}')
258
+ k: int = self.modulus_size
259
+ if k <= 32:
260
+ raise base.InputError(f'modulus too small for key operations: {k} bytes')
261
+ key256: bytes = base.RandBytes(32)
262
+ shares: list[ShamirSharePrivate] = list(
263
+ self.RawShares(base.BytesToInt(key256), max_shares=total_shares))
264
+ aad: bytes = (
265
+ _SSS_ENCRYPTION_AAD_PREFIX +
266
+ base.IntToFixedBytes(self.minimum, 8) + base.IntToFixedBytes(self.modulus, k))
267
+ aead_key: bytes = base.Hash512(_SSS_ENCRYPTION_AAD_PREFIX + key256)
268
+ ct: bytes = aes.AESKey(key256=aead_key[32:]).Encrypt(secret, associated_data=aad)
269
+ return [ShamirShareData(
270
+ minimum=s.minimum,
271
+ modulus=s.modulus,
272
+ share_key=s.share_key,
273
+ share_value=s.share_value,
274
+ encrypted_data=ct,
275
+ ) for s in shares]
276
+
277
+ def RawVerifyShare(self, secret: int, share: ShamirSharePrivate, /) -> bool:
217
278
  """Verify a ShamirSharePrivate object for the `secret`.
218
279
 
280
+ BEWARE: This is raw SSS, no modern message wrapping, padding or validation!
281
+ These are pedagogical/raw primitives; do not use for new protocols.
282
+ This is the information-theoretic SSS but with no authentication or binding between
283
+ share and secret.
284
+
219
285
  Args:
220
286
  secret (int): secret message to encrypt and share, 0 ≤ s < modulus
221
287
  share (ShamirSharePrivate): share to verify
@@ -226,7 +292,7 @@ class ShamirSharedSecretPrivate(ShamirSharedSecretPublic):
226
292
  Raises:
227
293
  InputError: invalid inputs
228
294
  """
229
- return share == self.Share(secret, share_key=share.share_key)
295
+ return share == self.RawShare(secret, share_key=share.share_key)
230
296
 
231
297
  @classmethod
232
298
  def New(cls, minimum_shares: int, bit_length: int, /) -> Self:
@@ -248,9 +314,7 @@ class ShamirSharedSecretPrivate(ShamirSharedSecretPublic):
248
314
  if bit_length < 10:
249
315
  raise base.InputError(f'invalid bit length: {bit_length=}')
250
316
  # make the primes
251
- unique_primes: set[int] = set()
252
- while len(unique_primes) < minimum_shares:
253
- unique_primes.add(modmath.NBitRandomPrime(bit_length))
317
+ unique_primes: set[int] = modmath.NBitRandomPrimes(bit_length, n_primes=minimum_shares)
254
318
  # get the largest prime for the modulus
255
319
  ordered_primes: list[int] = list(unique_primes)
256
320
  modulus: int = max(ordered_primes)
@@ -265,9 +329,8 @@ class ShamirSharedSecretPrivate(ShamirSharedSecretPublic):
265
329
  class ShamirSharePrivate(ShamirSharedSecretPublic):
266
330
  """Shamir Shared Secret (SSS) one share.
267
331
 
268
- BEWARE: This is raw SSS, no modern message wrapping, padding or validation!
269
- These are pedagogical/raw primitives; do not use for new protocols.
270
332
  No measures are taken here to prevent timing attacks.
333
+ Malicious share injection is possible! Add MAC or digital signature in hostile settings.
271
334
 
272
335
  Attributes:
273
336
  share_key (int): share secret key; a randomly picked value, 1 ≤ k < modulus
@@ -294,6 +357,80 @@ class ShamirSharePrivate(ShamirSharedSecretPublic):
294
357
  Returns:
295
358
  string representation of ShamirSharePrivate without leaking secrets
296
359
  """
297
- return (f'ShamirSharePrivate({super(ShamirSharePrivate, self).__str__()}, ' # pylint: disable=super-with-arguments
360
+ return ('ShamirSharePrivate('
361
+ f'{super(ShamirSharePrivate, self).__str__()}, ' # pylint: disable=super-with-arguments
298
362
  f'share_key={base.ObfuscateSecret(self.share_key)}, '
299
363
  f'share_value={base.ObfuscateSecret(self.share_value)})')
364
+
365
+ @classmethod
366
+ def CopyShare(cls, other: ShamirSharePrivate, /) -> Self:
367
+ """Initialize a share taking the parts of another share."""
368
+ return cls(
369
+ minimum=other.minimum, modulus=other.modulus,
370
+ share_key=other.share_key, share_value=other.share_value)
371
+
372
+
373
+ @dataclasses.dataclass(kw_only=True, slots=True, frozen=True, repr=False)
374
+ class ShamirShareData(ShamirSharePrivate):
375
+ """Shamir Shared Secret (SSS) one share.
376
+
377
+ No measures are taken here to prevent timing attacks.
378
+ Malicious share injection is possible! Add MAC or digital signature in hostile settings.
379
+
380
+ Attributes:
381
+ share_key (int): share secret key; a randomly picked value, 1 ≤ k < modulus
382
+ share_value (int): share secret value, 1 ≤ v < modulus; (k, v) is a "point" of f(k)=v
383
+ """
384
+
385
+ encrypted_data: bytes
386
+
387
+ def __post_init__(self) -> None:
388
+ """Check data.
389
+
390
+ Raises:
391
+ InputError: invalid inputs
392
+ """
393
+ super(ShamirShareData, self).__post_init__() # pylint: disable=super-with-arguments # needed here b/c: dataclass
394
+ if len(self.encrypted_data) < 32:
395
+ raise base.InputError(f'AES256+GCM SSS should have ≥32 bytes IV/CT/tag: {self}')
396
+
397
+ def __str__(self) -> str:
398
+ """Safe (no secrets) string representation of the ShamirShareData.
399
+
400
+ Returns:
401
+ string representation of ShamirShareData without leaking secrets
402
+ """
403
+ return ('ShamirShareData('
404
+ f'{super(ShamirShareData, self).__str__()}, ' # pylint: disable=super-with-arguments
405
+ f'encrypted_data={base.ObfuscateSecret(self.encrypted_data)})')
406
+
407
+ def RecoverData(self, other_shares: list[ShamirSharePrivate]) -> bytes:
408
+ """Recover the encrypted data from ShamirSharePrivate objects.
409
+
410
+ * key256 = SSS.RecoverSecret([this] + other_shares)
411
+ * return AES-256-GCM(key=SHA512("prefix" + key256)[32:], ciphertext=encrypted_data,
412
+ associated_data="prefix" + minimum + modulus)
413
+
414
+ Args:
415
+ other_shares (list[ShamirSharePrivate]): Other shares to use to recover the secret
416
+
417
+ Returns:
418
+ bytes: Decrypted plaintext bytes
419
+
420
+ Raises:
421
+ InputError: invalid inputs
422
+ CryptoError: internal crypto failures, authentication failure, key mismatch, etc
423
+ """
424
+ k: int = self.modulus_size
425
+ if k <= 32:
426
+ raise base.InputError(f'modulus too small for key operations: {k} bytes')
427
+ # recover secret; raise if shares are invalid
428
+ secret: int = self.RawRecoverSecret([self] + other_shares)
429
+ if not 0 <= secret < (1 << 256):
430
+ raise base.CryptoError('recovered key out of range for 256-bit key')
431
+ key256: bytes = base.IntToFixedBytes(secret, 32)
432
+ aad: bytes = (
433
+ _SSS_ENCRYPTION_AAD_PREFIX +
434
+ base.IntToFixedBytes(self.minimum, 8) + base.IntToFixedBytes(self.modulus, k))
435
+ aead_key: bytes = base.Hash512(_SSS_ENCRYPTION_AAD_PREFIX + key256)
436
+ return aes.AESKey(key256=aead_key[32:]).Decrypt(self.encrypted_data, associated_data=aad)