tigrcorn-security 0.3.16.dev5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tigrcorn_security/__init__.py +1 -0
- tigrcorn_security/alpn.py +29 -0
- tigrcorn_security/certs.py +10 -0
- tigrcorn_security/policies.py +68 -0
- tigrcorn_security/py.typed +1 -0
- tigrcorn_security/tls.py +583 -0
- tigrcorn_security/tls13/__init__.py +95 -0
- tigrcorn_security/tls13/extensions.py +759 -0
- tigrcorn_security/tls13/handshake.py +1411 -0
- tigrcorn_security/tls13/key_schedule.py +108 -0
- tigrcorn_security/tls13/messages.py +428 -0
- tigrcorn_security/tls13/transcript.py +51 -0
- tigrcorn_security/tls_cipher_policy.py +43 -0
- tigrcorn_security/x509/__init__.py +31 -0
- tigrcorn_security/x509/path.py +1284 -0
- tigrcorn_security-0.3.16.dev5.dist-info/METADATA +239 -0
- tigrcorn_security-0.3.16.dev5.dist-info/RECORD +20 -0
- tigrcorn_security-0.3.16.dev5.dist-info/WHEEL +5 -0
- tigrcorn_security-0.3.16.dev5.dist-info/licenses/LICENSE +163 -0
- tigrcorn_security-0.3.16.dev5.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,759 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from dataclasses import dataclass, field
|
|
4
|
+
from enum import IntEnum
|
|
5
|
+
from typing import Iterable, Sequence
|
|
6
|
+
|
|
7
|
+
from tigrcorn_core.errors import ProtocolError
|
|
8
|
+
from tigrcorn_core.utils.bytes import decode_quic_varint, encode_quic_varint
|
|
9
|
+
|
|
10
|
+
TLS_VERSION_1_3 = 0x0304
|
|
11
|
+
TLS_LEGACY_VERSION = 0x0303
|
|
12
|
+
|
|
13
|
+
CIPHER_TLS_AES_128_GCM_SHA256 = 0x1301
|
|
14
|
+
CIPHER_TLS_AES_256_GCM_SHA384 = 0x1302
|
|
15
|
+
|
|
16
|
+
GROUP_SECP256R1 = 0x0017
|
|
17
|
+
GROUP_X25519 = 0x001D
|
|
18
|
+
|
|
19
|
+
SIG_RSA_PKCS1_SHA256 = 0x0401
|
|
20
|
+
SIG_ECDSA_SECP256R1_SHA256 = 0x0403
|
|
21
|
+
SIG_RSA_PSS_RSAE_SHA256 = 0x0804
|
|
22
|
+
SIG_ED25519 = 0x0807
|
|
23
|
+
SIG_RSA_PSS_PSS_SHA256 = 0x0809
|
|
24
|
+
|
|
25
|
+
PSK_MODE_KE = 0
|
|
26
|
+
PSK_MODE_DHE_KE = 1
|
|
27
|
+
|
|
28
|
+
QUIC_EARLY_DATA_SENTINEL = 0xFFFFFFFF
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
class ExtensionType(IntEnum):
|
|
32
|
+
SERVER_NAME = 0
|
|
33
|
+
SUPPORTED_GROUPS = 10
|
|
34
|
+
SIGNATURE_ALGORITHMS = 13
|
|
35
|
+
ALPN = 16
|
|
36
|
+
SIGNATURE_ALGORITHMS_CERT = 50
|
|
37
|
+
PRE_SHARED_KEY = 41
|
|
38
|
+
EARLY_DATA = 42
|
|
39
|
+
SUPPORTED_VERSIONS = 43
|
|
40
|
+
COOKIE = 44
|
|
41
|
+
PSK_KEY_EXCHANGE_MODES = 45
|
|
42
|
+
KEY_SHARE = 51
|
|
43
|
+
QUIC_TRANSPORT_PARAMETERS = 57
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
@dataclass(slots=True)
|
|
47
|
+
class TlsExtension:
|
|
48
|
+
extension_type: int
|
|
49
|
+
value: object
|
|
50
|
+
raw_data: bytes | None = None
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
@dataclass(slots=True)
|
|
54
|
+
class PskIdentity:
|
|
55
|
+
identity: bytes
|
|
56
|
+
obfuscated_ticket_age: int
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
@dataclass(slots=True)
|
|
60
|
+
class OfferedPsks:
|
|
61
|
+
identities: tuple[PskIdentity, ...]
|
|
62
|
+
binders: tuple[bytes, ...]
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
@dataclass(frozen=True, slots=True)
|
|
66
|
+
class CipherSuiteParameters:
|
|
67
|
+
hash_name: str
|
|
68
|
+
key_length: int
|
|
69
|
+
hp_length: int
|
|
70
|
+
iv_length: int = 12
|
|
71
|
+
|
|
72
|
+
|
|
73
|
+
_TP_ORIGINAL_DESTINATION_CONNECTION_ID = 0x00
|
|
74
|
+
_TP_MAX_IDLE_TIMEOUT = 0x01
|
|
75
|
+
_TP_STATELESS_RESET_TOKEN = 0x02
|
|
76
|
+
_TP_MAX_UDP_PAYLOAD_SIZE = 0x03
|
|
77
|
+
_TP_INITIAL_MAX_DATA = 0x04
|
|
78
|
+
_TP_INITIAL_MAX_STREAM_DATA_BIDI_LOCAL = 0x05
|
|
79
|
+
_TP_INITIAL_MAX_STREAM_DATA_BIDI_REMOTE = 0x06
|
|
80
|
+
_TP_INITIAL_MAX_STREAM_DATA_UNI = 0x07
|
|
81
|
+
_TP_INITIAL_MAX_STREAMS_BIDI = 0x08
|
|
82
|
+
_TP_INITIAL_MAX_STREAMS_UNI = 0x09
|
|
83
|
+
_TP_ACK_DELAY_EXPONENT = 0x0A
|
|
84
|
+
_TP_MAX_ACK_DELAY = 0x0B
|
|
85
|
+
_TP_DISABLE_ACTIVE_MIGRATION = 0x0C
|
|
86
|
+
_TP_PREFERRED_ADDRESS = 0x0D
|
|
87
|
+
_TP_ACTIVE_CONNECTION_ID_LIMIT = 0x0E
|
|
88
|
+
_TP_INITIAL_SOURCE_CONNECTION_ID = 0x0F
|
|
89
|
+
_TP_RETRY_SOURCE_CONNECTION_ID = 0x10
|
|
90
|
+
|
|
91
|
+
|
|
92
|
+
@dataclass(slots=True)
|
|
93
|
+
class TransportParameters:
|
|
94
|
+
max_data: int = 65536
|
|
95
|
+
max_stream_data_bidi_local: int = 65536
|
|
96
|
+
max_stream_data_bidi_remote: int = 65536
|
|
97
|
+
max_stream_data_uni: int = 65536
|
|
98
|
+
max_streams_bidi: int = 128
|
|
99
|
+
max_streams_uni: int = 128
|
|
100
|
+
idle_timeout: int = 30000
|
|
101
|
+
active_connection_id_limit: int = 4
|
|
102
|
+
max_udp_payload_size: int = 1200
|
|
103
|
+
ack_delay_exponent: int = 3
|
|
104
|
+
max_ack_delay: int = 25
|
|
105
|
+
disable_active_migration: bool = False
|
|
106
|
+
original_destination_connection_id: bytes | None = None
|
|
107
|
+
stateless_reset_token: bytes | None = None
|
|
108
|
+
preferred_address: bytes | None = None
|
|
109
|
+
initial_source_connection_id: bytes | None = None
|
|
110
|
+
retry_source_connection_id: bytes | None = None
|
|
111
|
+
unknown_parameters: dict[int, bytes] = field(default_factory=dict)
|
|
112
|
+
|
|
113
|
+
def __post_init__(self) -> None:
|
|
114
|
+
if self.active_connection_id_limit < 2:
|
|
115
|
+
raise ValueError('active_connection_id_limit must be at least 2')
|
|
116
|
+
if self.ack_delay_exponent < 0:
|
|
117
|
+
raise ValueError('ack_delay_exponent must be non-negative')
|
|
118
|
+
if self.max_ack_delay < 0:
|
|
119
|
+
raise ValueError('max_ack_delay must be non-negative')
|
|
120
|
+
if self.max_udp_payload_size < 1200:
|
|
121
|
+
raise ValueError('max_udp_payload_size must be at least 1200')
|
|
122
|
+
if self.stateless_reset_token is not None and len(self.stateless_reset_token) != 16:
|
|
123
|
+
raise ValueError('stateless_reset_token must be exactly 16 bytes')
|
|
124
|
+
|
|
125
|
+
def to_bytes(self) -> bytes:
|
|
126
|
+
payload = bytearray()
|
|
127
|
+
|
|
128
|
+
def add_int(parameter_id: int, value: int | None) -> None:
|
|
129
|
+
if value is None:
|
|
130
|
+
return
|
|
131
|
+
encoded = encode_quic_varint(value)
|
|
132
|
+
payload.extend(encode_quic_varint(parameter_id))
|
|
133
|
+
payload.extend(encode_quic_varint(len(encoded)))
|
|
134
|
+
payload.extend(encoded)
|
|
135
|
+
|
|
136
|
+
def add_bytes(parameter_id: int, value: bytes | None) -> None:
|
|
137
|
+
if value is None:
|
|
138
|
+
return
|
|
139
|
+
payload.extend(encode_quic_varint(parameter_id))
|
|
140
|
+
payload.extend(encode_quic_varint(len(value)))
|
|
141
|
+
payload.extend(value)
|
|
142
|
+
|
|
143
|
+
add_bytes(_TP_ORIGINAL_DESTINATION_CONNECTION_ID, self.original_destination_connection_id)
|
|
144
|
+
add_int(_TP_MAX_IDLE_TIMEOUT, self.idle_timeout)
|
|
145
|
+
add_bytes(_TP_STATELESS_RESET_TOKEN, self.stateless_reset_token)
|
|
146
|
+
add_int(_TP_MAX_UDP_PAYLOAD_SIZE, self.max_udp_payload_size)
|
|
147
|
+
add_int(_TP_INITIAL_MAX_DATA, self.max_data)
|
|
148
|
+
add_int(_TP_INITIAL_MAX_STREAM_DATA_BIDI_LOCAL, self.max_stream_data_bidi_local)
|
|
149
|
+
add_int(_TP_INITIAL_MAX_STREAM_DATA_BIDI_REMOTE, self.max_stream_data_bidi_remote)
|
|
150
|
+
add_int(_TP_INITIAL_MAX_STREAM_DATA_UNI, self.max_stream_data_uni)
|
|
151
|
+
add_int(_TP_INITIAL_MAX_STREAMS_BIDI, self.max_streams_bidi)
|
|
152
|
+
add_int(_TP_INITIAL_MAX_STREAMS_UNI, self.max_streams_uni)
|
|
153
|
+
add_int(_TP_ACK_DELAY_EXPONENT, self.ack_delay_exponent)
|
|
154
|
+
add_int(_TP_MAX_ACK_DELAY, self.max_ack_delay)
|
|
155
|
+
if self.disable_active_migration:
|
|
156
|
+
payload.extend(encode_quic_varint(_TP_DISABLE_ACTIVE_MIGRATION))
|
|
157
|
+
payload.extend(encode_quic_varint(0))
|
|
158
|
+
add_bytes(_TP_PREFERRED_ADDRESS, self.preferred_address)
|
|
159
|
+
add_int(_TP_ACTIVE_CONNECTION_ID_LIMIT, self.active_connection_id_limit)
|
|
160
|
+
add_bytes(_TP_INITIAL_SOURCE_CONNECTION_ID, self.initial_source_connection_id)
|
|
161
|
+
add_bytes(_TP_RETRY_SOURCE_CONNECTION_ID, self.retry_source_connection_id)
|
|
162
|
+
for parameter_id, value in sorted(self.unknown_parameters.items()):
|
|
163
|
+
payload.extend(encode_quic_varint(parameter_id))
|
|
164
|
+
payload.extend(encode_quic_varint(len(value)))
|
|
165
|
+
payload.extend(value)
|
|
166
|
+
return bytes(payload)
|
|
167
|
+
|
|
168
|
+
@classmethod
|
|
169
|
+
def from_bytes(cls, data: bytes) -> 'TransportParameters':
|
|
170
|
+
values: dict[str, object] = {'unknown_parameters': {}}
|
|
171
|
+
seen: set[int] = set()
|
|
172
|
+
offset = 0
|
|
173
|
+
while offset < len(data):
|
|
174
|
+
parameter_id, offset = decode_quic_varint(data, offset)
|
|
175
|
+
if parameter_id in seen:
|
|
176
|
+
raise ProtocolError('duplicate QUIC transport parameter')
|
|
177
|
+
seen.add(parameter_id)
|
|
178
|
+
parameter_length, offset = decode_quic_varint(data, offset)
|
|
179
|
+
end = offset + parameter_length
|
|
180
|
+
if end > len(data):
|
|
181
|
+
raise ProtocolError('truncated QUIC transport parameter')
|
|
182
|
+
raw = data[offset:end]
|
|
183
|
+
offset = end
|
|
184
|
+
|
|
185
|
+
def decode_int(value: bytes) -> int:
|
|
186
|
+
decoded, inner_offset = decode_quic_varint(value, 0)
|
|
187
|
+
if inner_offset != len(value):
|
|
188
|
+
raise ProtocolError('invalid QUIC transport parameter encoding')
|
|
189
|
+
return decoded
|
|
190
|
+
|
|
191
|
+
if parameter_id == _TP_ORIGINAL_DESTINATION_CONNECTION_ID:
|
|
192
|
+
values['original_destination_connection_id'] = raw
|
|
193
|
+
elif parameter_id == _TP_MAX_IDLE_TIMEOUT:
|
|
194
|
+
values['idle_timeout'] = decode_int(raw)
|
|
195
|
+
elif parameter_id == _TP_STATELESS_RESET_TOKEN:
|
|
196
|
+
if len(raw) != 16:
|
|
197
|
+
raise ProtocolError('stateless_reset_token transport parameter must be 16 bytes')
|
|
198
|
+
values['stateless_reset_token'] = raw
|
|
199
|
+
elif parameter_id == _TP_MAX_UDP_PAYLOAD_SIZE:
|
|
200
|
+
values['max_udp_payload_size'] = decode_int(raw)
|
|
201
|
+
elif parameter_id == _TP_INITIAL_MAX_DATA:
|
|
202
|
+
values['max_data'] = decode_int(raw)
|
|
203
|
+
elif parameter_id == _TP_INITIAL_MAX_STREAM_DATA_BIDI_LOCAL:
|
|
204
|
+
values['max_stream_data_bidi_local'] = decode_int(raw)
|
|
205
|
+
elif parameter_id == _TP_INITIAL_MAX_STREAM_DATA_BIDI_REMOTE:
|
|
206
|
+
values['max_stream_data_bidi_remote'] = decode_int(raw)
|
|
207
|
+
elif parameter_id == _TP_INITIAL_MAX_STREAM_DATA_UNI:
|
|
208
|
+
values['max_stream_data_uni'] = decode_int(raw)
|
|
209
|
+
elif parameter_id == _TP_INITIAL_MAX_STREAMS_BIDI:
|
|
210
|
+
values['max_streams_bidi'] = decode_int(raw)
|
|
211
|
+
elif parameter_id == _TP_INITIAL_MAX_STREAMS_UNI:
|
|
212
|
+
values['max_streams_uni'] = decode_int(raw)
|
|
213
|
+
elif parameter_id == _TP_ACK_DELAY_EXPONENT:
|
|
214
|
+
values['ack_delay_exponent'] = decode_int(raw)
|
|
215
|
+
elif parameter_id == _TP_MAX_ACK_DELAY:
|
|
216
|
+
values['max_ack_delay'] = decode_int(raw)
|
|
217
|
+
elif parameter_id == _TP_DISABLE_ACTIVE_MIGRATION:
|
|
218
|
+
if raw:
|
|
219
|
+
raise ProtocolError('disable_active_migration transport parameter must be empty')
|
|
220
|
+
values['disable_active_migration'] = True
|
|
221
|
+
elif parameter_id == _TP_PREFERRED_ADDRESS:
|
|
222
|
+
values['preferred_address'] = raw
|
|
223
|
+
elif parameter_id == _TP_ACTIVE_CONNECTION_ID_LIMIT:
|
|
224
|
+
values['active_connection_id_limit'] = decode_int(raw)
|
|
225
|
+
elif parameter_id == _TP_INITIAL_SOURCE_CONNECTION_ID:
|
|
226
|
+
values['initial_source_connection_id'] = raw
|
|
227
|
+
elif parameter_id == _TP_RETRY_SOURCE_CONNECTION_ID:
|
|
228
|
+
values['retry_source_connection_id'] = raw
|
|
229
|
+
else:
|
|
230
|
+
values['unknown_parameters'][parameter_id] = raw
|
|
231
|
+
return cls(**values)
|
|
232
|
+
|
|
233
|
+
def is_0rtt_compatible_with(self, current: 'TransportParameters') -> bool:
|
|
234
|
+
return (
|
|
235
|
+
current.max_data >= self.max_data
|
|
236
|
+
and current.max_stream_data_bidi_local >= self.max_stream_data_bidi_local
|
|
237
|
+
and current.max_stream_data_bidi_remote >= self.max_stream_data_bidi_remote
|
|
238
|
+
and current.max_stream_data_uni >= self.max_stream_data_uni
|
|
239
|
+
and current.max_streams_bidi >= self.max_streams_bidi
|
|
240
|
+
and current.max_streams_uni >= self.max_streams_uni
|
|
241
|
+
and current.max_udp_payload_size >= self.max_udp_payload_size
|
|
242
|
+
and current.active_connection_id_limit >= self.active_connection_id_limit
|
|
243
|
+
and current.ack_delay_exponent == self.ack_delay_exponent
|
|
244
|
+
and current.max_ack_delay == self.max_ack_delay
|
|
245
|
+
and current.disable_active_migration == self.disable_active_migration
|
|
246
|
+
)
|
|
247
|
+
|
|
248
|
+
|
|
249
|
+
SUPPORTED_SIGNATURE_SCHEMES = (
|
|
250
|
+
SIG_ED25519,
|
|
251
|
+
SIG_RSA_PSS_RSAE_SHA256,
|
|
252
|
+
SIG_RSA_PSS_PSS_SHA256,
|
|
253
|
+
SIG_ECDSA_SECP256R1_SHA256,
|
|
254
|
+
)
|
|
255
|
+
SUPPORTED_CERTIFICATE_SIGNATURE_SCHEMES = (
|
|
256
|
+
SIG_ED25519,
|
|
257
|
+
SIG_RSA_PSS_RSAE_SHA256,
|
|
258
|
+
SIG_RSA_PSS_PSS_SHA256,
|
|
259
|
+
SIG_ECDSA_SECP256R1_SHA256,
|
|
260
|
+
SIG_RSA_PKCS1_SHA256,
|
|
261
|
+
)
|
|
262
|
+
SUPPORTED_GROUPS = (
|
|
263
|
+
GROUP_X25519,
|
|
264
|
+
GROUP_SECP256R1,
|
|
265
|
+
)
|
|
266
|
+
|
|
267
|
+
_CIPHER_SUITE_PARAMETERS = {
|
|
268
|
+
CIPHER_TLS_AES_256_GCM_SHA384: CipherSuiteParameters(hash_name='sha384', key_length=32, hp_length=32),
|
|
269
|
+
CIPHER_TLS_AES_128_GCM_SHA256: CipherSuiteParameters(hash_name='sha256', key_length=16, hp_length=16),
|
|
270
|
+
}
|
|
271
|
+
|
|
272
|
+
SUPPORTED_CIPHER_SUITES = tuple(_CIPHER_SUITE_PARAMETERS)
|
|
273
|
+
_CIPHER_SUITE_NAMES = {
|
|
274
|
+
CIPHER_TLS_AES_128_GCM_SHA256: 'TLS_AES_128_GCM_SHA256',
|
|
275
|
+
CIPHER_TLS_AES_256_GCM_SHA384: 'TLS_AES_256_GCM_SHA384',
|
|
276
|
+
}
|
|
277
|
+
_CIPHER_SUITE_NAME_TO_ID = {value: key for key, value in _CIPHER_SUITE_NAMES.items()}
|
|
278
|
+
|
|
279
|
+
|
|
280
|
+
def cipher_suite_name(cipher_suite: int) -> str:
|
|
281
|
+
return _CIPHER_SUITE_NAMES.get(cipher_suite, f'0x{cipher_suite:04x}')
|
|
282
|
+
|
|
283
|
+
|
|
284
|
+
def parse_cipher_suite_allowlist(value: str | None) -> tuple[int, ...]:
|
|
285
|
+
if value is None:
|
|
286
|
+
return ()
|
|
287
|
+
tokens = [token.strip() for token in value.replace(',', ':').split(':') if token.strip()]
|
|
288
|
+
if not tokens:
|
|
289
|
+
raise ProtocolError('ssl_ciphers must contain at least one supported TLS 1.3 cipher suite')
|
|
290
|
+
resolved: list[int] = []
|
|
291
|
+
for token in tokens:
|
|
292
|
+
cipher_suite = _CIPHER_SUITE_NAME_TO_ID.get(token)
|
|
293
|
+
if cipher_suite is None:
|
|
294
|
+
raise ProtocolError(f'unsupported TLS cipher suite: {token!r}')
|
|
295
|
+
if cipher_suite not in resolved:
|
|
296
|
+
resolved.append(cipher_suite)
|
|
297
|
+
return tuple(resolved)
|
|
298
|
+
|
|
299
|
+
|
|
300
|
+
def format_cipher_suite_allowlist(cipher_suites: Sequence[int]) -> str:
|
|
301
|
+
return ':'.join(cipher_suite_name(cipher_suite) for cipher_suite in cipher_suites)
|
|
302
|
+
|
|
303
|
+
|
|
304
|
+
def cipher_suite_parameters(cipher_suite: int) -> CipherSuiteParameters:
|
|
305
|
+
try:
|
|
306
|
+
return _CIPHER_SUITE_PARAMETERS[cipher_suite]
|
|
307
|
+
except KeyError as exc:
|
|
308
|
+
raise ProtocolError(f'unsupported TLS cipher suite: {cipher_suite:#06x}') from exc
|
|
309
|
+
|
|
310
|
+
|
|
311
|
+
def _u8_vector(payload: bytes) -> bytes:
|
|
312
|
+
if len(payload) > 255:
|
|
313
|
+
raise ValueError('u8 vector too large')
|
|
314
|
+
return bytes([len(payload)]) + payload
|
|
315
|
+
|
|
316
|
+
|
|
317
|
+
|
|
318
|
+
def _u16_vector(payload: bytes) -> bytes:
|
|
319
|
+
if len(payload) > 0xFFFF:
|
|
320
|
+
raise ValueError('u16 vector too large')
|
|
321
|
+
return len(payload).to_bytes(2, 'big') + payload
|
|
322
|
+
|
|
323
|
+
|
|
324
|
+
|
|
325
|
+
def _u24_vector(payload: bytes) -> bytes:
|
|
326
|
+
if len(payload) > 0xFFFFFF:
|
|
327
|
+
raise ValueError('u24 vector too large')
|
|
328
|
+
return len(payload).to_bytes(3, 'big') + payload
|
|
329
|
+
|
|
330
|
+
|
|
331
|
+
|
|
332
|
+
def _read_exact(data: bytes, offset: int, length: int) -> tuple[bytes, int]:
|
|
333
|
+
end = offset + length
|
|
334
|
+
if end > len(data):
|
|
335
|
+
raise ProtocolError('truncated TLS extension payload')
|
|
336
|
+
return data[offset:end], end
|
|
337
|
+
|
|
338
|
+
|
|
339
|
+
|
|
340
|
+
def _read_u8(data: bytes, offset: int) -> tuple[int, int]:
|
|
341
|
+
raw, offset = _read_exact(data, offset, 1)
|
|
342
|
+
return raw[0], offset
|
|
343
|
+
|
|
344
|
+
|
|
345
|
+
|
|
346
|
+
def _read_u16(data: bytes, offset: int) -> tuple[int, int]:
|
|
347
|
+
raw, offset = _read_exact(data, offset, 2)
|
|
348
|
+
return int.from_bytes(raw, 'big'), offset
|
|
349
|
+
|
|
350
|
+
|
|
351
|
+
|
|
352
|
+
def _read_u24(data: bytes, offset: int) -> tuple[int, int]:
|
|
353
|
+
raw, offset = _read_exact(data, offset, 3)
|
|
354
|
+
return int.from_bytes(raw, 'big'), offset
|
|
355
|
+
|
|
356
|
+
|
|
357
|
+
|
|
358
|
+
def _read_u8_vector(data: bytes, offset: int) -> tuple[bytes, int]:
|
|
359
|
+
length, offset = _read_u8(data, offset)
|
|
360
|
+
return _read_exact(data, offset, length)
|
|
361
|
+
|
|
362
|
+
|
|
363
|
+
|
|
364
|
+
def _read_u16_vector(data: bytes, offset: int) -> tuple[bytes, int]:
|
|
365
|
+
length, offset = _read_u16(data, offset)
|
|
366
|
+
return _read_exact(data, offset, length)
|
|
367
|
+
|
|
368
|
+
|
|
369
|
+
|
|
370
|
+
def _read_u24_vector(data: bytes, offset: int) -> tuple[bytes, int]:
|
|
371
|
+
length, offset = _read_u24(data, offset)
|
|
372
|
+
return _read_exact(data, offset, length)
|
|
373
|
+
|
|
374
|
+
|
|
375
|
+
|
|
376
|
+
def encode_server_name(server_name: str) -> bytes:
|
|
377
|
+
encoded = server_name.encode('utf-8')
|
|
378
|
+
entry = b'\x00' + _u16_vector(encoded)
|
|
379
|
+
return _u16_vector(entry)
|
|
380
|
+
|
|
381
|
+
|
|
382
|
+
|
|
383
|
+
def decode_server_name(data: bytes) -> str:
|
|
384
|
+
names_raw, offset = _read_u16_vector(data, 0)
|
|
385
|
+
if offset != len(data):
|
|
386
|
+
raise ProtocolError('invalid server_name extension')
|
|
387
|
+
inner = 0
|
|
388
|
+
while inner < len(names_raw):
|
|
389
|
+
name_type, inner = _read_u8(names_raw, inner)
|
|
390
|
+
name, inner = _read_u16_vector(names_raw, inner)
|
|
391
|
+
if name_type == 0:
|
|
392
|
+
return name.decode('utf-8')
|
|
393
|
+
raise ProtocolError('server_name extension does not contain a host_name entry')
|
|
394
|
+
|
|
395
|
+
|
|
396
|
+
|
|
397
|
+
def encode_supported_versions_client(versions: Sequence[int]) -> bytes:
|
|
398
|
+
payload = b''.join(version.to_bytes(2, 'big') for version in versions)
|
|
399
|
+
return _u8_vector(payload)
|
|
400
|
+
|
|
401
|
+
|
|
402
|
+
|
|
403
|
+
def decode_supported_versions_client(data: bytes) -> tuple[int, ...]:
|
|
404
|
+
payload, offset = _read_u8_vector(data, 0)
|
|
405
|
+
if offset != len(data) or len(payload) % 2:
|
|
406
|
+
raise ProtocolError('invalid supported_versions extension')
|
|
407
|
+
return tuple(int.from_bytes(payload[index:index + 2], 'big') for index in range(0, len(payload), 2))
|
|
408
|
+
|
|
409
|
+
|
|
410
|
+
|
|
411
|
+
def encode_supported_versions_server(version: int) -> bytes:
|
|
412
|
+
return version.to_bytes(2, 'big')
|
|
413
|
+
|
|
414
|
+
|
|
415
|
+
|
|
416
|
+
def decode_supported_versions_server(data: bytes) -> int:
|
|
417
|
+
if len(data) != 2:
|
|
418
|
+
raise ProtocolError('invalid selected supported_versions extension')
|
|
419
|
+
return int.from_bytes(data, 'big')
|
|
420
|
+
|
|
421
|
+
|
|
422
|
+
|
|
423
|
+
def encode_supported_groups(groups: Sequence[int]) -> bytes:
|
|
424
|
+
payload = b''.join(group.to_bytes(2, 'big') for group in groups)
|
|
425
|
+
return _u16_vector(payload)
|
|
426
|
+
|
|
427
|
+
|
|
428
|
+
|
|
429
|
+
def decode_supported_groups(data: bytes) -> tuple[int, ...]:
|
|
430
|
+
payload, offset = _read_u16_vector(data, 0)
|
|
431
|
+
if offset != len(data) or len(payload) % 2:
|
|
432
|
+
raise ProtocolError('invalid supported_groups extension')
|
|
433
|
+
return tuple(int.from_bytes(payload[index:index + 2], 'big') for index in range(0, len(payload), 2))
|
|
434
|
+
|
|
435
|
+
|
|
436
|
+
|
|
437
|
+
def encode_signature_algorithms(schemes: Sequence[int]) -> bytes:
|
|
438
|
+
payload = b''.join(scheme.to_bytes(2, 'big') for scheme in schemes)
|
|
439
|
+
return _u16_vector(payload)
|
|
440
|
+
|
|
441
|
+
|
|
442
|
+
|
|
443
|
+
def decode_signature_algorithms(data: bytes) -> tuple[int, ...]:
|
|
444
|
+
payload, offset = _read_u16_vector(data, 0)
|
|
445
|
+
if offset != len(data) or len(payload) % 2:
|
|
446
|
+
raise ProtocolError('invalid signature_algorithms extension')
|
|
447
|
+
return tuple(int.from_bytes(payload[index:index + 2], 'big') for index in range(0, len(payload), 2))
|
|
448
|
+
|
|
449
|
+
|
|
450
|
+
|
|
451
|
+
def encode_alpn(protocols: Sequence[str]) -> bytes:
|
|
452
|
+
payload = bytearray()
|
|
453
|
+
for protocol in protocols:
|
|
454
|
+
raw = protocol.encode('ascii')
|
|
455
|
+
payload.extend(_u8_vector(raw))
|
|
456
|
+
return _u16_vector(bytes(payload))
|
|
457
|
+
|
|
458
|
+
|
|
459
|
+
|
|
460
|
+
def decode_alpn(data: bytes) -> tuple[str, ...]:
|
|
461
|
+
payload, offset = _read_u16_vector(data, 0)
|
|
462
|
+
if offset != len(data):
|
|
463
|
+
raise ProtocolError('invalid ALPN extension')
|
|
464
|
+
inner = 0
|
|
465
|
+
protocols: list[str] = []
|
|
466
|
+
while inner < len(payload):
|
|
467
|
+
raw, inner = _read_u8_vector(payload, inner)
|
|
468
|
+
protocols.append(raw.decode('ascii'))
|
|
469
|
+
if not protocols:
|
|
470
|
+
raise ProtocolError('ALPN extension is empty')
|
|
471
|
+
return tuple(protocols)
|
|
472
|
+
|
|
473
|
+
|
|
474
|
+
|
|
475
|
+
def encode_psk_key_exchange_modes(modes: Sequence[int]) -> bytes:
|
|
476
|
+
return _u8_vector(bytes(modes))
|
|
477
|
+
|
|
478
|
+
|
|
479
|
+
|
|
480
|
+
def decode_psk_key_exchange_modes(data: bytes) -> tuple[int, ...]:
|
|
481
|
+
payload, offset = _read_u8_vector(data, 0)
|
|
482
|
+
if offset != len(data):
|
|
483
|
+
raise ProtocolError('invalid psk_key_exchange_modes extension')
|
|
484
|
+
return tuple(payload)
|
|
485
|
+
|
|
486
|
+
|
|
487
|
+
|
|
488
|
+
def encode_keyshare_client(shares: Sequence[tuple[int, bytes]]) -> bytes:
|
|
489
|
+
payload = bytearray()
|
|
490
|
+
for group, key_exchange in shares:
|
|
491
|
+
payload.extend(group.to_bytes(2, 'big'))
|
|
492
|
+
payload.extend(_u16_vector(key_exchange))
|
|
493
|
+
return _u16_vector(bytes(payload))
|
|
494
|
+
|
|
495
|
+
|
|
496
|
+
|
|
497
|
+
def decode_keyshare_client(data: bytes) -> dict[int, bytes]:
|
|
498
|
+
payload, offset = _read_u16_vector(data, 0)
|
|
499
|
+
if offset != len(data):
|
|
500
|
+
raise ProtocolError('invalid key_share extension')
|
|
501
|
+
inner = 0
|
|
502
|
+
shares: dict[int, bytes] = {}
|
|
503
|
+
while inner < len(payload):
|
|
504
|
+
group, inner = _read_u16(payload, inner)
|
|
505
|
+
key_exchange, inner = _read_u16_vector(payload, inner)
|
|
506
|
+
shares[group] = key_exchange
|
|
507
|
+
return shares
|
|
508
|
+
|
|
509
|
+
|
|
510
|
+
|
|
511
|
+
def encode_keyshare_server(group: int, key_exchange: bytes) -> bytes:
|
|
512
|
+
return group.to_bytes(2, 'big') + _u16_vector(key_exchange)
|
|
513
|
+
|
|
514
|
+
|
|
515
|
+
|
|
516
|
+
def decode_keyshare_server(data: bytes) -> tuple[int, bytes]:
|
|
517
|
+
group, offset = _read_u16(data, 0)
|
|
518
|
+
key_exchange, offset = _read_u16_vector(data, offset)
|
|
519
|
+
if offset != len(data):
|
|
520
|
+
raise ProtocolError('invalid server key_share extension')
|
|
521
|
+
return group, key_exchange
|
|
522
|
+
|
|
523
|
+
|
|
524
|
+
|
|
525
|
+
def encode_keyshare_hrr(selected_group: int) -> bytes:
|
|
526
|
+
return selected_group.to_bytes(2, 'big')
|
|
527
|
+
|
|
528
|
+
|
|
529
|
+
|
|
530
|
+
def decode_keyshare_hrr(data: bytes) -> int:
|
|
531
|
+
if len(data) != 2:
|
|
532
|
+
raise ProtocolError('invalid HelloRetryRequest key_share extension')
|
|
533
|
+
return int.from_bytes(data, 'big')
|
|
534
|
+
|
|
535
|
+
|
|
536
|
+
|
|
537
|
+
def encode_cookie(cookie: bytes) -> bytes:
|
|
538
|
+
return _u16_vector(cookie)
|
|
539
|
+
|
|
540
|
+
|
|
541
|
+
|
|
542
|
+
def decode_cookie(data: bytes) -> bytes:
|
|
543
|
+
cookie, offset = _read_u16_vector(data, 0)
|
|
544
|
+
if offset != len(data):
|
|
545
|
+
raise ProtocolError('invalid cookie extension')
|
|
546
|
+
return cookie
|
|
547
|
+
|
|
548
|
+
|
|
549
|
+
|
|
550
|
+
def encode_early_data(message_context: str, max_early_data_size: int = QUIC_EARLY_DATA_SENTINEL) -> bytes:
|
|
551
|
+
if message_context in {'client_hello', 'encrypted_extensions'}:
|
|
552
|
+
return b''
|
|
553
|
+
if message_context == 'new_session_ticket':
|
|
554
|
+
return max_early_data_size.to_bytes(4, 'big')
|
|
555
|
+
raise ValueError(f'unsupported early_data context: {message_context}')
|
|
556
|
+
|
|
557
|
+
|
|
558
|
+
|
|
559
|
+
def decode_early_data(data: bytes, message_context: str) -> object:
|
|
560
|
+
if message_context in {'client_hello', 'encrypted_extensions'}:
|
|
561
|
+
if data:
|
|
562
|
+
raise ProtocolError('early_data extension must be empty in this context')
|
|
563
|
+
return True
|
|
564
|
+
if message_context == 'new_session_ticket':
|
|
565
|
+
if len(data) != 4:
|
|
566
|
+
raise ProtocolError('invalid early_data NewSessionTicket extension')
|
|
567
|
+
return int.from_bytes(data, 'big')
|
|
568
|
+
return data
|
|
569
|
+
|
|
570
|
+
|
|
571
|
+
|
|
572
|
+
def encode_pre_shared_key_client(identities: Sequence[PskIdentity], binders: Sequence[bytes]) -> bytes:
|
|
573
|
+
if len(identities) != len(binders):
|
|
574
|
+
raise ValueError('PSK identities and binders must have matching counts')
|
|
575
|
+
identities_payload = bytearray()
|
|
576
|
+
binders_payload = bytearray()
|
|
577
|
+
for identity, binder in zip(identities, binders):
|
|
578
|
+
identities_payload.extend(_u16_vector(identity.identity))
|
|
579
|
+
identities_payload.extend(identity.obfuscated_ticket_age.to_bytes(4, 'big'))
|
|
580
|
+
binders_payload.extend(_u8_vector(binder))
|
|
581
|
+
return _u16_vector(bytes(identities_payload)) + _u16_vector(bytes(binders_payload))
|
|
582
|
+
|
|
583
|
+
|
|
584
|
+
|
|
585
|
+
def encode_pre_shared_key_client_without_binders(identities: Sequence[PskIdentity]) -> bytes:
|
|
586
|
+
identities_payload = bytearray()
|
|
587
|
+
for identity in identities:
|
|
588
|
+
identities_payload.extend(_u16_vector(identity.identity))
|
|
589
|
+
identities_payload.extend(identity.obfuscated_ticket_age.to_bytes(4, 'big'))
|
|
590
|
+
return _u16_vector(bytes(identities_payload))
|
|
591
|
+
|
|
592
|
+
|
|
593
|
+
|
|
594
|
+
def decode_pre_shared_key_client(data: bytes) -> OfferedPsks:
|
|
595
|
+
identities_raw, offset = _read_u16_vector(data, 0)
|
|
596
|
+
binders_raw, offset = _read_u16_vector(data, offset)
|
|
597
|
+
if offset != len(data):
|
|
598
|
+
raise ProtocolError('invalid pre_shared_key extension')
|
|
599
|
+
identities: list[PskIdentity] = []
|
|
600
|
+
inner = 0
|
|
601
|
+
while inner < len(identities_raw):
|
|
602
|
+
identity, inner = _read_u16_vector(identities_raw, inner)
|
|
603
|
+
obfuscated_ticket_age, inner = _read_u32(identities_raw, inner)
|
|
604
|
+
identities.append(PskIdentity(identity=identity, obfuscated_ticket_age=obfuscated_ticket_age))
|
|
605
|
+
binders: list[bytes] = []
|
|
606
|
+
inner = 0
|
|
607
|
+
while inner < len(binders_raw):
|
|
608
|
+
binder, inner = _read_u8_vector(binders_raw, inner)
|
|
609
|
+
binders.append(binder)
|
|
610
|
+
if len(identities) != len(binders):
|
|
611
|
+
raise ProtocolError('mismatched PSK identities and binders')
|
|
612
|
+
return OfferedPsks(identities=tuple(identities), binders=tuple(binders))
|
|
613
|
+
|
|
614
|
+
|
|
615
|
+
|
|
616
|
+
def encode_pre_shared_key_server(selected_identity: int) -> bytes:
|
|
617
|
+
return selected_identity.to_bytes(2, 'big')
|
|
618
|
+
|
|
619
|
+
|
|
620
|
+
|
|
621
|
+
def decode_pre_shared_key_server(data: bytes) -> int:
|
|
622
|
+
if len(data) != 2:
|
|
623
|
+
raise ProtocolError('invalid server pre_shared_key extension')
|
|
624
|
+
return int.from_bytes(data, 'big')
|
|
625
|
+
|
|
626
|
+
|
|
627
|
+
|
|
628
|
+
def _read_u32(data: bytes, offset: int) -> tuple[int, int]:
|
|
629
|
+
raw, offset = _read_exact(data, offset, 4)
|
|
630
|
+
return int.from_bytes(raw, 'big'), offset
|
|
631
|
+
|
|
632
|
+
|
|
633
|
+
|
|
634
|
+
def encode_quic_transport_parameters(parameters: TransportParameters) -> bytes:
|
|
635
|
+
return parameters.to_bytes()
|
|
636
|
+
|
|
637
|
+
|
|
638
|
+
|
|
639
|
+
def decode_quic_transport_parameters(data: bytes) -> TransportParameters:
|
|
640
|
+
return TransportParameters.from_bytes(data)
|
|
641
|
+
|
|
642
|
+
|
|
643
|
+
|
|
644
|
+
def encode_extensions(extensions: Sequence[TlsExtension], *, message_context: str) -> bytes:
|
|
645
|
+
payload = bytearray()
|
|
646
|
+
for extension in extensions:
|
|
647
|
+
raw = extension.raw_data
|
|
648
|
+
if raw is None:
|
|
649
|
+
raw = encode_extension_value(extension.extension_type, extension.value, message_context=message_context)
|
|
650
|
+
payload.extend(int(extension.extension_type).to_bytes(2, 'big'))
|
|
651
|
+
payload.extend(len(raw).to_bytes(2, 'big'))
|
|
652
|
+
payload.extend(raw)
|
|
653
|
+
return _u16_vector(bytes(payload))
|
|
654
|
+
|
|
655
|
+
|
|
656
|
+
|
|
657
|
+
def decode_extensions(data: bytes, *, message_context: str) -> tuple[TlsExtension, ...]:
|
|
658
|
+
payload, offset = _read_u16_vector(data, 0)
|
|
659
|
+
if offset != len(data):
|
|
660
|
+
raise ProtocolError('invalid TLS extensions vector')
|
|
661
|
+
inner = 0
|
|
662
|
+
items: list[TlsExtension] = []
|
|
663
|
+
while inner < len(payload):
|
|
664
|
+
extension_type, inner = _read_u16(payload, inner)
|
|
665
|
+
extension_data, inner = _read_u16_vector(payload, inner)
|
|
666
|
+
value = decode_extension_value(extension_type, extension_data, message_context=message_context)
|
|
667
|
+
items.append(TlsExtension(extension_type=extension_type, value=value, raw_data=extension_data))
|
|
668
|
+
return tuple(items)
|
|
669
|
+
|
|
670
|
+
|
|
671
|
+
|
|
672
|
+
def encode_extension_value(extension_type: int, value: object, *, message_context: str) -> bytes:
|
|
673
|
+
ext = ExtensionType(extension_type) if extension_type in set(item.value for item in ExtensionType) else None
|
|
674
|
+
if ext == ExtensionType.SERVER_NAME:
|
|
675
|
+
assert isinstance(value, str)
|
|
676
|
+
return encode_server_name(value)
|
|
677
|
+
if ext == ExtensionType.SUPPORTED_VERSIONS:
|
|
678
|
+
if message_context == 'client_hello':
|
|
679
|
+
return encode_supported_versions_client(tuple(int(item) for item in value))
|
|
680
|
+
return encode_supported_versions_server(int(value))
|
|
681
|
+
if ext == ExtensionType.SUPPORTED_GROUPS:
|
|
682
|
+
return encode_supported_groups(tuple(int(item) for item in value))
|
|
683
|
+
if ext in {ExtensionType.SIGNATURE_ALGORITHMS, ExtensionType.SIGNATURE_ALGORITHMS_CERT}:
|
|
684
|
+
return encode_signature_algorithms(tuple(int(item) for item in value))
|
|
685
|
+
if ext == ExtensionType.ALPN:
|
|
686
|
+
if isinstance(value, str):
|
|
687
|
+
return encode_alpn((value,))
|
|
688
|
+
return encode_alpn(tuple(str(item) for item in value))
|
|
689
|
+
if ext == ExtensionType.PSK_KEY_EXCHANGE_MODES:
|
|
690
|
+
return encode_psk_key_exchange_modes(tuple(int(item) for item in value))
|
|
691
|
+
if ext == ExtensionType.KEY_SHARE:
|
|
692
|
+
if message_context == 'client_hello':
|
|
693
|
+
return encode_keyshare_client(tuple((int(group), bytes(key_exchange)) for group, key_exchange in value))
|
|
694
|
+
if message_context == 'hello_retry_request':
|
|
695
|
+
return encode_keyshare_hrr(int(value))
|
|
696
|
+
group, key_exchange = value
|
|
697
|
+
return encode_keyshare_server(int(group), bytes(key_exchange))
|
|
698
|
+
if ext == ExtensionType.COOKIE:
|
|
699
|
+
return encode_cookie(bytes(value))
|
|
700
|
+
if ext == ExtensionType.EARLY_DATA:
|
|
701
|
+
size = QUIC_EARLY_DATA_SENTINEL if value is True else int(value)
|
|
702
|
+
return encode_early_data(message_context, size)
|
|
703
|
+
if ext == ExtensionType.PRE_SHARED_KEY:
|
|
704
|
+
if message_context == 'client_hello':
|
|
705
|
+
offered = value
|
|
706
|
+
assert isinstance(offered, OfferedPsks)
|
|
707
|
+
return encode_pre_shared_key_client(offered.identities, offered.binders)
|
|
708
|
+
return encode_pre_shared_key_server(int(value))
|
|
709
|
+
if ext == ExtensionType.QUIC_TRANSPORT_PARAMETERS:
|
|
710
|
+
assert isinstance(value, TransportParameters)
|
|
711
|
+
return encode_quic_transport_parameters(value)
|
|
712
|
+
if isinstance(value, bytes):
|
|
713
|
+
return value
|
|
714
|
+
raise ProtocolError(f'unsupported TLS extension type {extension_type}')
|
|
715
|
+
|
|
716
|
+
|
|
717
|
+
|
|
718
|
+
def decode_extension_value(extension_type: int, data: bytes, *, message_context: str) -> object:
|
|
719
|
+
try:
|
|
720
|
+
ext = ExtensionType(extension_type)
|
|
721
|
+
except ValueError:
|
|
722
|
+
return data
|
|
723
|
+
if ext == ExtensionType.SERVER_NAME:
|
|
724
|
+
return decode_server_name(data)
|
|
725
|
+
if ext == ExtensionType.SUPPORTED_VERSIONS:
|
|
726
|
+
if message_context == 'client_hello':
|
|
727
|
+
return decode_supported_versions_client(data)
|
|
728
|
+
return decode_supported_versions_server(data)
|
|
729
|
+
if ext == ExtensionType.SUPPORTED_GROUPS:
|
|
730
|
+
return decode_supported_groups(data)
|
|
731
|
+
if ext in {ExtensionType.SIGNATURE_ALGORITHMS, ExtensionType.SIGNATURE_ALGORITHMS_CERT}:
|
|
732
|
+
return decode_signature_algorithms(data)
|
|
733
|
+
if ext == ExtensionType.ALPN:
|
|
734
|
+
protocols = decode_alpn(data)
|
|
735
|
+
return protocols if message_context == 'client_hello' else protocols[0]
|
|
736
|
+
if ext == ExtensionType.PSK_KEY_EXCHANGE_MODES:
|
|
737
|
+
return decode_psk_key_exchange_modes(data)
|
|
738
|
+
if ext == ExtensionType.KEY_SHARE:
|
|
739
|
+
if message_context == 'client_hello':
|
|
740
|
+
return decode_keyshare_client(data)
|
|
741
|
+
if message_context == 'hello_retry_request':
|
|
742
|
+
return decode_keyshare_hrr(data)
|
|
743
|
+
return decode_keyshare_server(data)
|
|
744
|
+
if ext == ExtensionType.COOKIE:
|
|
745
|
+
return decode_cookie(data)
|
|
746
|
+
if ext == ExtensionType.EARLY_DATA:
|
|
747
|
+
return decode_early_data(data, message_context)
|
|
748
|
+
if ext == ExtensionType.PRE_SHARED_KEY:
|
|
749
|
+
if message_context == 'client_hello':
|
|
750
|
+
return decode_pre_shared_key_client(data)
|
|
751
|
+
return decode_pre_shared_key_server(data)
|
|
752
|
+
if ext == ExtensionType.QUIC_TRANSPORT_PARAMETERS:
|
|
753
|
+
return decode_quic_transport_parameters(data)
|
|
754
|
+
return data
|
|
755
|
+
|
|
756
|
+
|
|
757
|
+
|
|
758
|
+
def extension_dict(extensions: Iterable[TlsExtension]) -> dict[int, object]:
|
|
759
|
+
return {int(extension.extension_type): extension.value for extension in extensions}
|