fast-cipher 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,29 @@
1
+ from .cipher import FastCipher
2
+ from .params import calculate_recommended_params
3
+ from .types import (
4
+ FastError,
5
+ FastParams,
6
+ InvalidBranchDistError,
7
+ InvalidKeyError,
8
+ InvalidLengthError,
9
+ InvalidParametersError,
10
+ InvalidRadixError,
11
+ InvalidSBoxCountError,
12
+ InvalidValueError,
13
+ InvalidWordLengthError,
14
+ )
15
+
16
+ __all__ = [
17
+ "FastCipher",
18
+ "FastError",
19
+ "FastParams",
20
+ "InvalidBranchDistError",
21
+ "InvalidKeyError",
22
+ "InvalidLengthError",
23
+ "InvalidParametersError",
24
+ "InvalidRadixError",
25
+ "InvalidSBoxCountError",
26
+ "InvalidValueError",
27
+ "InvalidWordLengthError",
28
+ "calculate_recommended_params",
29
+ ]
fast_cipher/cipher.py ADDED
@@ -0,0 +1,131 @@
1
+ from __future__ import annotations
2
+
3
+ from .core import cdec, cenc
4
+ from .encoding import build_setup1_input, build_setup2_input
5
+ from .prf import derive_key
6
+ from .prng import generate_sequence
7
+ from .sbox import generate_sbox_pool
8
+ from .types import (
9
+ FastParams,
10
+ InvalidBranchDistError,
11
+ InvalidKeyError,
12
+ InvalidLengthError,
13
+ InvalidRadixError,
14
+ InvalidSBoxCountError,
15
+ InvalidValueError,
16
+ InvalidWordLengthError,
17
+ )
18
+
19
+ DERIVED_KEY_SIZE = 32
20
+
21
+
22
+ class FastCipher:
23
+ """FAST format-preserving encryption cipher."""
24
+
25
+ def __init__(self, params: FastParams, key: bytes) -> None:
26
+ _validate_params(params, key)
27
+ self.params = params
28
+ self._master_key = bytes(key)
29
+
30
+ pool_key_material = derive_key(
31
+ key, build_setup1_input(params), DERIVED_KEY_SIZE
32
+ )
33
+ self._sboxes = generate_sbox_pool(
34
+ params.radix, params.sbox_count, bytes(pool_key_material)
35
+ )
36
+
37
+ self._cached_tweak: bytes | None = None
38
+ self._cached_seq: list[int] | None = None
39
+ self._destroyed = False
40
+
41
+ def _ensure_sequence(self, tweak: bytes) -> list[int]:
42
+ if self._cached_seq is not None and self._cached_tweak == tweak:
43
+ return self._cached_seq
44
+
45
+ seq_key_material = derive_key(
46
+ self._master_key,
47
+ build_setup2_input(self.params, tweak),
48
+ DERIVED_KEY_SIZE,
49
+ )
50
+ seq = generate_sequence(
51
+ self.params.num_layers,
52
+ self.params.sbox_count,
53
+ bytes(seq_key_material),
54
+ )
55
+ self._cached_tweak = tweak
56
+ self._cached_seq = seq
57
+ return seq
58
+
59
+ def _assert_alive(self) -> None:
60
+ if self._destroyed:
61
+ raise RuntimeError("FastCipher has been destroyed")
62
+
63
+ def _validate_input(self, data: bytes | list[int]) -> list[int]:
64
+ values = list(data)
65
+ if len(values) != self.params.word_length:
66
+ raise InvalidLengthError(
67
+ f"Expected {self.params.word_length} elements, got {len(values)}"
68
+ )
69
+ for v in values:
70
+ if not (0 <= v < self.params.radix):
71
+ raise InvalidValueError(
72
+ f"Value {v} out of range [0, {self.params.radix})"
73
+ )
74
+ return values
75
+
76
+ def encrypt(self, plaintext: bytes | list[int], tweak: bytes = b"") -> list[int]:
77
+ self._assert_alive()
78
+ values = self._validate_input(plaintext)
79
+ seq = self._ensure_sequence(tweak)
80
+ return cenc(self.params, self._sboxes, seq, values)
81
+
82
+ def decrypt(self, ciphertext: bytes | list[int], tweak: bytes = b"") -> list[int]:
83
+ self._assert_alive()
84
+ values = self._validate_input(ciphertext)
85
+ seq = self._ensure_sequence(tweak)
86
+ return cdec(self.params, self._sboxes, seq, values)
87
+
88
+ def encrypt_bytes(self, plaintext: bytes, tweak: bytes = b"") -> bytes:
89
+ return bytes(self.encrypt(plaintext, tweak))
90
+
91
+ def decrypt_bytes(self, ciphertext: bytes, tweak: bytes = b"") -> bytes:
92
+ return bytes(self.decrypt(ciphertext, tweak))
93
+
94
+ def destroy(self) -> None:
95
+ self._destroyed = True
96
+ self._master_key = b"\x00" * len(self._master_key)
97
+ self._sboxes = []
98
+ self._cached_seq = None
99
+ self._cached_tweak = None
100
+
101
+
102
+ def _validate_params(params: FastParams, key: bytes) -> None:
103
+ if params.radix < 4 or params.radix > 256:
104
+ raise InvalidRadixError("Radix must be between 4 and 256")
105
+
106
+ if params.word_length < 1:
107
+ raise InvalidWordLengthError("Word length must be >= 1")
108
+
109
+ if params.num_layers < 1:
110
+ raise InvalidWordLengthError("num_layers must be >= 1")
111
+
112
+ if params.word_length > 1 and params.num_layers % params.word_length != 0:
113
+ raise InvalidWordLengthError("num_layers must be a multiple of word_length")
114
+
115
+ if params.sbox_count < 1:
116
+ raise InvalidSBoxCountError("S-box count must be >= 1")
117
+
118
+ if params.branch_dist1 < 0:
119
+ raise InvalidBranchDistError("branch_dist1 must be >= 0")
120
+
121
+ if params.branch_dist2 < 0:
122
+ raise InvalidBranchDistError("branch_dist2 must be >= 0")
123
+
124
+ if params.word_length > 1:
125
+ if params.branch_dist1 > params.word_length - 2:
126
+ raise InvalidBranchDistError("branch_dist1 must be <= word_length - 2")
127
+ if params.branch_dist2 == 0 or params.branch_dist2 > params.word_length - 1:
128
+ raise InvalidBranchDistError("branch_dist2 is out of valid range")
129
+
130
+ if len(key) not in (16, 24, 32):
131
+ raise InvalidKeyError("Key must be 16, 24, or 32 bytes")
fast_cipher/core.py ADDED
@@ -0,0 +1,39 @@
1
+ from __future__ import annotations
2
+
3
+ from .layers import ds_layer, es_layer
4
+ from .sbox import SBox
5
+ from .types import FastParams
6
+
7
+
8
+ def cenc(
9
+ params: FastParams,
10
+ sboxes: list[SBox],
11
+ seq: list[int],
12
+ plaintext: list[int],
13
+ ) -> list[int]:
14
+ """Component encryption: apply all ES layers in forward order."""
15
+ data = list(plaintext)
16
+ if params.word_length == 1:
17
+ for layer in range(params.num_layers):
18
+ data[0] = sboxes[seq[layer]].perm[data[0]]
19
+ return data
20
+ for layer in range(params.num_layers):
21
+ es_layer(params, sboxes[seq[layer]], data)
22
+ return data
23
+
24
+
25
+ def cdec(
26
+ params: FastParams,
27
+ sboxes: list[SBox],
28
+ seq: list[int],
29
+ ciphertext: list[int],
30
+ ) -> list[int]:
31
+ """Component decryption: apply all DS layers in reverse order."""
32
+ data = list(ciphertext)
33
+ if params.word_length == 1:
34
+ for layer in range(params.num_layers - 1, -1, -1):
35
+ data[0] = sboxes[seq[layer]].inv[data[0]]
36
+ return data
37
+ for layer in range(params.num_layers - 1, -1, -1):
38
+ ds_layer(params, sboxes[seq[layer]], data)
39
+ return data
@@ -0,0 +1,52 @@
1
+ from __future__ import annotations
2
+
3
+ import struct
4
+
5
+ from .types import FastParams
6
+
7
+ _LABEL_INSTANCE1 = b"instance1"
8
+ _LABEL_INSTANCE2 = b"instance2"
9
+ _LABEL_FPE_POOL = b"FPE Pool"
10
+ _LABEL_FPE_SEQ = b"FPE SEQ"
11
+ _LABEL_TWEAK = b"tweak"
12
+
13
+
14
+ def _u32be(value: int) -> bytes:
15
+ return struct.pack(">I", value)
16
+
17
+
18
+ def encode_parts(parts: list[bytes]) -> bytes:
19
+ buf = bytearray(_u32be(len(parts)))
20
+ for part in parts:
21
+ buf.extend(_u32be(len(part)))
22
+ buf.extend(part)
23
+ return bytes(buf)
24
+
25
+
26
+ def build_setup1_input(params: FastParams) -> bytes:
27
+ return encode_parts(
28
+ [
29
+ _LABEL_INSTANCE1,
30
+ _u32be(params.radix),
31
+ _u32be(params.sbox_count),
32
+ _LABEL_FPE_POOL,
33
+ ]
34
+ )
35
+
36
+
37
+ def build_setup2_input(params: FastParams, tweak: bytes) -> bytes:
38
+ return encode_parts(
39
+ [
40
+ _LABEL_INSTANCE1,
41
+ _u32be(params.radix),
42
+ _u32be(params.sbox_count),
43
+ _LABEL_INSTANCE2,
44
+ _u32be(params.word_length),
45
+ _u32be(params.num_layers),
46
+ _u32be(params.branch_dist1),
47
+ _u32be(params.branch_dist2),
48
+ _LABEL_FPE_SEQ,
49
+ _LABEL_TWEAK,
50
+ tweak,
51
+ ]
52
+ )
fast_cipher/layers.py ADDED
@@ -0,0 +1,55 @@
1
+ from __future__ import annotations
2
+
3
+ from .sbox import SBox
4
+ from .types import FastParams
5
+
6
+
7
+ def _mod_add(a: int, b: int, radix: int) -> int:
8
+ if radix == 256:
9
+ return (a + b) & 0xFF
10
+ return (a + b) % radix
11
+
12
+
13
+ def _mod_sub(a: int, b: int, radix: int) -> int:
14
+ if radix == 256:
15
+ return (a - b) & 0xFF
16
+ return (a - b) % radix
17
+
18
+
19
+ def es_layer(params: FastParams, sbox: SBox, data: list[int]) -> None:
20
+ """ES (Expansion-Substitution) forward layer."""
21
+ w = params.branch_dist1
22
+ wp = params.branch_dist2
23
+ ell = params.word_length
24
+ radix = params.radix
25
+ perm = sbox.perm
26
+
27
+ s = perm[_mod_add(data[0], data[ell - wp], radix)]
28
+ if w > 0:
29
+ nxt = perm[_mod_sub(s, data[w], radix)]
30
+ else:
31
+ nxt = perm[s]
32
+
33
+ # Shift left by 1
34
+ data[:-1] = data[1:]
35
+ data[ell - 1] = nxt
36
+
37
+
38
+ def ds_layer(params: FastParams, sbox: SBox, data: list[int]) -> None:
39
+ """DS (De-Substitution) backward layer."""
40
+ w = params.branch_dist1
41
+ wp = params.branch_dist2
42
+ ell = params.word_length
43
+ radix = params.radix
44
+ inv = sbox.inv
45
+
46
+ last = inv[data[ell - 1]]
47
+ if w > 0:
48
+ intermediate = inv[_mod_add(last, data[w - 1], radix)]
49
+ else:
50
+ intermediate = inv[last]
51
+ nxt = _mod_sub(intermediate, data[ell - wp - 1], radix)
52
+
53
+ # Shift right by 1
54
+ data[1:] = data[:-1]
55
+ data[0] = nxt
fast_cipher/params.py ADDED
@@ -0,0 +1,146 @@
1
+ from __future__ import annotations
2
+
3
+ import math
4
+
5
+ from .types import FastParams, InvalidParametersError
6
+
7
+ SBOX_POOL_SIZE = 256
8
+
9
+ ROUND_L_VALUES = [2, 3, 4, 5, 6, 7, 8, 9, 10, 12, 16, 32, 50, 64, 100]
10
+ ROUND_RADICES = [
11
+ 4,
12
+ 5,
13
+ 6,
14
+ 7,
15
+ 8,
16
+ 9,
17
+ 10,
18
+ 11,
19
+ 12,
20
+ 13,
21
+ 14,
22
+ 15,
23
+ 16,
24
+ 100,
25
+ 128,
26
+ 256,
27
+ 1000,
28
+ 1024,
29
+ 10000,
30
+ 65536,
31
+ ]
32
+
33
+ ROUND_TABLE = [
34
+ [165, 135, 117, 105, 96, 89, 83, 78, 74, 68, 59, 52, 52, 53, 57],
35
+ [131, 107, 93, 83, 76, 70, 66, 62, 59, 54, 48, 46, 47, 48, 53],
36
+ [113, 92, 80, 72, 65, 61, 57, 54, 51, 46, 44, 43, 44, 46, 52],
37
+ [102, 83, 72, 64, 59, 55, 51, 48, 46, 43, 41, 41, 43, 45, 50],
38
+ [94, 76, 66, 59, 54, 50, 47, 44, 42, 41, 39, 39, 42, 44, 50],
39
+ [88, 72, 62, 56, 51, 47, 44, 42, 40, 39, 38, 38, 41, 43, 49],
40
+ [83, 68, 59, 53, 48, 45, 42, 39, 39, 38, 37, 37, 40, 43, 49],
41
+ [79, 65, 56, 50, 46, 43, 40, 38, 38, 37, 36, 37, 40, 42, 48],
42
+ [76, 62, 54, 48, 44, 41, 38, 37, 37, 36, 35, 36, 39, 42, 48],
43
+ [73, 60, 52, 47, 43, 39, 37, 36, 36, 35, 34, 36, 39, 41, 48],
44
+ [71, 58, 50, 45, 41, 38, 36, 36, 35, 34, 34, 35, 39, 41, 47],
45
+ [69, 57, 49, 44, 40, 37, 36, 35, 34, 34, 33, 35, 38, 41, 47],
46
+ [67, 55, 48, 43, 39, 36, 35, 34, 34, 33, 33, 35, 38, 41, 47],
47
+ [40, 33, 28, 27, 26, 26, 25, 25, 25, 26, 26, 30, 34, 37, 44],
48
+ [38, 31, 27, 26, 25, 25, 25, 25, 25, 25, 26, 30, 34, 37, 44],
49
+ [33, 27, 25, 24, 23, 23, 23, 23, 23, 24, 25, 29, 33, 37, 44],
50
+ [32, 22, 21, 21, 21, 21, 21, 21, 21, 22, 23, 28, 32, 36, 43],
51
+ [32, 22, 21, 21, 21, 21, 21, 21, 21, 22, 23, 28, 32, 36, 43],
52
+ [32, 22, 18, 18, 18, 18, 19, 19, 19, 20, 21, 27, 32, 35, 42],
53
+ [32, 22, 17, 17, 17, 17, 17, 18, 18, 19, 21, 26, 31, 35, 42],
54
+ ]
55
+
56
+
57
+ def _interpolate(x: float, x0: float, x1: float, y0: float, y1: float) -> float:
58
+ if x1 == x0:
59
+ return y0
60
+ ratio = (x - x0) / (x1 - x0)
61
+ if ratio <= 0:
62
+ return y0
63
+ if ratio >= 1:
64
+ return y1
65
+ return y0 + ratio * (y1 - y0)
66
+
67
+
68
+ def _rounds_for_row(row_index: int, ell: int) -> float:
69
+ row = ROUND_TABLE[row_index]
70
+ last_index = len(ROUND_L_VALUES) - 1
71
+ max_word_length = ROUND_L_VALUES[last_index]
72
+
73
+ if ell <= ROUND_L_VALUES[0]:
74
+ return row[0]
75
+
76
+ if ell >= max_word_length:
77
+ base_rounds = row[last_index]
78
+ return max(base_rounds, base_rounds * math.sqrt(ell / max_word_length))
79
+
80
+ for i in range(1, last_index + 1):
81
+ if ell <= ROUND_L_VALUES[i]:
82
+ return _interpolate(
83
+ ell,
84
+ ROUND_L_VALUES[i - 1],
85
+ ROUND_L_VALUES[i],
86
+ row[i - 1],
87
+ row[i],
88
+ )
89
+
90
+ return row[last_index]
91
+
92
+
93
+ def _lookup_recommended_rounds(radix: int, ell: int) -> float:
94
+ last_index = len(ROUND_RADICES) - 1
95
+
96
+ if radix <= ROUND_RADICES[0]:
97
+ return _rounds_for_row(0, ell)
98
+
99
+ if radix >= ROUND_RADICES[last_index]:
100
+ return _rounds_for_row(last_index, ell)
101
+
102
+ log_radix = math.log(radix)
103
+ for i in range(1, last_index + 1):
104
+ if radix <= ROUND_RADICES[i]:
105
+ return _interpolate(
106
+ log_radix,
107
+ math.log(ROUND_RADICES[i - 1]),
108
+ math.log(ROUND_RADICES[i]),
109
+ _rounds_for_row(i - 1, ell),
110
+ _rounds_for_row(i, ell),
111
+ )
112
+
113
+ return _rounds_for_row(last_index, ell)
114
+
115
+
116
+ def calculate_recommended_params(
117
+ radix: int,
118
+ word_length: int,
119
+ security_level: int = 128,
120
+ ) -> FastParams:
121
+ if radix < 4 or radix > 256:
122
+ raise InvalidParametersError("radix must be between 4 and 256")
123
+ if word_length < 1:
124
+ raise InvalidParametersError("word_length must be >= 1")
125
+
126
+ sec_level = security_level if security_level != 0 else 128
127
+
128
+ w_candidate = math.ceil(math.sqrt(word_length))
129
+ branch_dist1 = max(min(w_candidate, word_length - 2), 0)
130
+ branch_dist2 = max(branch_dist1 - 1, 1)
131
+
132
+ rounds = _lookup_recommended_rounds(radix, word_length)
133
+ if rounds < 1.0:
134
+ rounds = 1.0
135
+ rounds_u = math.ceil(rounds)
136
+ num_layers = rounds_u * word_length
137
+
138
+ return FastParams(
139
+ radix=radix,
140
+ word_length=word_length,
141
+ sbox_count=SBOX_POOL_SIZE,
142
+ num_layers=num_layers,
143
+ branch_dist1=branch_dist1,
144
+ branch_dist2=branch_dist2,
145
+ security_level=sec_level,
146
+ )
fast_cipher/prf.py ADDED
@@ -0,0 +1,82 @@
1
+ from __future__ import annotations
2
+
3
+ import struct
4
+
5
+ from cryptography.hazmat.primitives.ciphers import Cipher
6
+ from cryptography.hazmat.primitives.ciphers.algorithms import AES
7
+ from cryptography.hazmat.primitives.ciphers.modes import ECB
8
+
9
+ AES_BLOCK_SIZE = 16
10
+ CMAC_RB = 0x87
11
+
12
+
13
+ def _left_shift_and_xor(data: bytes, xor_byte: int) -> bytes:
14
+ result = bytearray(AES_BLOCK_SIZE)
15
+ carry = 0
16
+ for i in range(AES_BLOCK_SIZE - 1, -1, -1):
17
+ result[i] = ((data[i] << 1) | carry) & 0xFF
18
+ carry = (data[i] >> 7) & 1
19
+ if (data[0] >> 7) & 1:
20
+ result[AES_BLOCK_SIZE - 1] ^= xor_byte
21
+ return bytes(result)
22
+
23
+
24
+ def aes_cmac(key: bytes, message: bytes) -> bytes:
25
+ enc = Cipher(AES(key), ECB()).encryptor()
26
+
27
+ L = enc.update(b"\x00" * AES_BLOCK_SIZE)
28
+ k1 = _left_shift_and_xor(L, CMAC_RB)
29
+ k2 = _left_shift_and_xor(k1, CMAC_RB)
30
+
31
+ msg_len = len(message)
32
+ block_count = max(1, (msg_len + AES_BLOCK_SIZE - 1) // AES_BLOCK_SIZE)
33
+ last_block_offset = (block_count - 1) * AES_BLOCK_SIZE
34
+ has_full_last_block = msg_len > 0 and msg_len % AES_BLOCK_SIZE == 0
35
+
36
+ last_block = bytearray(AES_BLOCK_SIZE)
37
+ if has_full_last_block:
38
+ for i in range(AES_BLOCK_SIZE):
39
+ last_block[i] = message[last_block_offset + i] ^ k1[i]
40
+ else:
41
+ remaining = msg_len - last_block_offset
42
+ last_block[:remaining] = message[last_block_offset:]
43
+ last_block[remaining] = 0x80
44
+ for i in range(AES_BLOCK_SIZE):
45
+ last_block[i] ^= k2[i]
46
+
47
+ state = bytearray(AES_BLOCK_SIZE)
48
+ for block_index in range(block_count - 1):
49
+ offset = block_index * AES_BLOCK_SIZE
50
+ for i in range(AES_BLOCK_SIZE):
51
+ state[i] ^= message[offset + i]
52
+ state[:] = enc.update(state)
53
+
54
+ for i in range(AES_BLOCK_SIZE):
55
+ state[i] ^= last_block[i]
56
+
57
+ return enc.update(state)
58
+
59
+
60
+ def derive_key(
61
+ master_key: bytes, input_data: bytes, output_length: int = 32
62
+ ) -> bytearray:
63
+ if len(master_key) not in (16, 24, 32):
64
+ raise ValueError("Master key must be 16, 24, or 32 bytes")
65
+ if output_length == 0:
66
+ raise ValueError("Output length must be > 0")
67
+
68
+ output = bytearray(output_length)
69
+ buffer = bytearray(4 + len(input_data))
70
+ buffer[4:] = input_data
71
+
72
+ bytes_generated = 0
73
+ counter = 0
74
+ while bytes_generated < output_length:
75
+ struct.pack_into(">I", buffer, 0, counter)
76
+ cmac_output = aes_cmac(master_key, buffer)
77
+ to_copy = min(output_length - bytes_generated, AES_BLOCK_SIZE)
78
+ output[bytes_generated : bytes_generated + to_copy] = cmac_output[:to_copy]
79
+ bytes_generated += to_copy
80
+ counter += 1
81
+
82
+ return output
fast_cipher/prng.py ADDED
@@ -0,0 +1,79 @@
1
+ from __future__ import annotations
2
+
3
+ import struct
4
+
5
+ from cryptography.hazmat.primitives.ciphers import Cipher
6
+ from cryptography.hazmat.primitives.ciphers.algorithms import AES
7
+ from cryptography.hazmat.primitives.ciphers.modes import ECB
8
+
9
+ AES_BLOCK_SIZE = 16
10
+
11
+
12
+ class PrngState:
13
+ """AES-128 ECB counter-mode PRNG matching C/Zig/JS reference implementations."""
14
+
15
+ def __init__(self, key: bytes, nonce: bytes) -> None:
16
+ self._encryptor = Cipher(AES(key), ECB()).encryptor()
17
+ self._counter = bytearray(nonce)
18
+ self._buffer = bytearray(AES_BLOCK_SIZE)
19
+ self._buffer_pos = AES_BLOCK_SIZE # force refill on first use
20
+
21
+ def _increment_counter(self) -> None:
22
+ for i in range(AES_BLOCK_SIZE - 1, -1, -1):
23
+ self._counter[i] = (self._counter[i] + 1) & 0xFF
24
+ if self._counter[i] != 0:
25
+ break
26
+
27
+ def _encrypt_block(self) -> None:
28
+ self._buffer[:] = self._encryptor.update(self._counter)
29
+
30
+ def get_bytes(self, n: int) -> bytes:
31
+ output = bytearray(n)
32
+ offset = 0
33
+ while offset < n:
34
+ if self._buffer_pos == AES_BLOCK_SIZE:
35
+ self._increment_counter()
36
+ self._encrypt_block()
37
+ self._buffer_pos = 0
38
+ chunk = min(n - offset, AES_BLOCK_SIZE - self._buffer_pos)
39
+ output[offset : offset + chunk] = self._buffer[
40
+ self._buffer_pos : self._buffer_pos + chunk
41
+ ]
42
+ self._buffer_pos += chunk
43
+ offset += chunk
44
+ return bytes(output)
45
+
46
+ def next_u32(self) -> int:
47
+ data = self.get_bytes(4)
48
+ return struct.unpack(">I", data)[0]
49
+
50
+ def uniform(self, bound: int) -> int:
51
+ """Unbiased uniform random in [0, bound) using Lemire's method."""
52
+ if bound <= 1:
53
+ return 0
54
+ threshold = (0x100000000 - bound) % bound
55
+ while True:
56
+ r = self.next_u32()
57
+ product = r * bound
58
+ low = product & 0xFFFFFFFF
59
+ if low >= threshold:
60
+ return product >> 32
61
+
62
+
63
+ def split_key_material(
64
+ key_material: bytes, zeroize_iv_suffix: bool
65
+ ) -> tuple[bytes, bytes]:
66
+ key = key_material[:16]
67
+ iv = bytearray(key_material[16:32])
68
+ if zeroize_iv_suffix:
69
+ iv[AES_BLOCK_SIZE - 1] = 0
70
+ iv[AES_BLOCK_SIZE - 2] = 0
71
+ return key, bytes(iv)
72
+
73
+
74
+ def generate_sequence(
75
+ num_layers: int, pool_size: int, key_material: bytes
76
+ ) -> list[int]:
77
+ key, iv = split_key_material(key_material, zeroize_iv_suffix=True)
78
+ prng = PrngState(key, iv)
79
+ return [prng.uniform(pool_size) for _ in range(num_layers)]
fast_cipher/py.typed ADDED
File without changes
fast_cipher/sbox.py ADDED
@@ -0,0 +1,28 @@
1
+ from __future__ import annotations
2
+
3
+ from dataclasses import dataclass
4
+
5
+ from .prng import PrngState, split_key_material
6
+
7
+
8
+ @dataclass
9
+ class SBox:
10
+ perm: list[int]
11
+ inv: list[int]
12
+
13
+
14
+ def generate_sbox(radix: int, prng: PrngState) -> SBox:
15
+ perm = list(range(radix))
16
+ for i in range(radix - 1, 0, -1):
17
+ j = prng.uniform(i + 1)
18
+ perm[i], perm[j] = perm[j], perm[i]
19
+ inv = [0] * radix
20
+ for i in range(radix):
21
+ inv[perm[i]] = i
22
+ return SBox(perm=perm, inv=inv)
23
+
24
+
25
+ def generate_sbox_pool(radix: int, count: int, key_material: bytes) -> list[SBox]:
26
+ key, iv = split_key_material(key_material, zeroize_iv_suffix=False)
27
+ prng = PrngState(key, iv)
28
+ return [generate_sbox(radix, prng) for _ in range(count)]