tigrcorn-security 0.3.16.dev5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,759 @@
1
+ from __future__ import annotations
2
+
3
+ from dataclasses import dataclass, field
4
+ from enum import IntEnum
5
+ from typing import Iterable, Sequence
6
+
7
+ from tigrcorn_core.errors import ProtocolError
8
+ from tigrcorn_core.utils.bytes import decode_quic_varint, encode_quic_varint
9
+
10
+ TLS_VERSION_1_3 = 0x0304
11
+ TLS_LEGACY_VERSION = 0x0303
12
+
13
+ CIPHER_TLS_AES_128_GCM_SHA256 = 0x1301
14
+ CIPHER_TLS_AES_256_GCM_SHA384 = 0x1302
15
+
16
+ GROUP_SECP256R1 = 0x0017
17
+ GROUP_X25519 = 0x001D
18
+
19
+ SIG_RSA_PKCS1_SHA256 = 0x0401
20
+ SIG_ECDSA_SECP256R1_SHA256 = 0x0403
21
+ SIG_RSA_PSS_RSAE_SHA256 = 0x0804
22
+ SIG_ED25519 = 0x0807
23
+ SIG_RSA_PSS_PSS_SHA256 = 0x0809
24
+
25
+ PSK_MODE_KE = 0
26
+ PSK_MODE_DHE_KE = 1
27
+
28
+ QUIC_EARLY_DATA_SENTINEL = 0xFFFFFFFF
29
+
30
+
31
+ class ExtensionType(IntEnum):
32
+ SERVER_NAME = 0
33
+ SUPPORTED_GROUPS = 10
34
+ SIGNATURE_ALGORITHMS = 13
35
+ ALPN = 16
36
+ SIGNATURE_ALGORITHMS_CERT = 50
37
+ PRE_SHARED_KEY = 41
38
+ EARLY_DATA = 42
39
+ SUPPORTED_VERSIONS = 43
40
+ COOKIE = 44
41
+ PSK_KEY_EXCHANGE_MODES = 45
42
+ KEY_SHARE = 51
43
+ QUIC_TRANSPORT_PARAMETERS = 57
44
+
45
+
46
+ @dataclass(slots=True)
47
+ class TlsExtension:
48
+ extension_type: int
49
+ value: object
50
+ raw_data: bytes | None = None
51
+
52
+
53
+ @dataclass(slots=True)
54
+ class PskIdentity:
55
+ identity: bytes
56
+ obfuscated_ticket_age: int
57
+
58
+
59
+ @dataclass(slots=True)
60
+ class OfferedPsks:
61
+ identities: tuple[PskIdentity, ...]
62
+ binders: tuple[bytes, ...]
63
+
64
+
65
+ @dataclass(frozen=True, slots=True)
66
+ class CipherSuiteParameters:
67
+ hash_name: str
68
+ key_length: int
69
+ hp_length: int
70
+ iv_length: int = 12
71
+
72
+
73
+ _TP_ORIGINAL_DESTINATION_CONNECTION_ID = 0x00
74
+ _TP_MAX_IDLE_TIMEOUT = 0x01
75
+ _TP_STATELESS_RESET_TOKEN = 0x02
76
+ _TP_MAX_UDP_PAYLOAD_SIZE = 0x03
77
+ _TP_INITIAL_MAX_DATA = 0x04
78
+ _TP_INITIAL_MAX_STREAM_DATA_BIDI_LOCAL = 0x05
79
+ _TP_INITIAL_MAX_STREAM_DATA_BIDI_REMOTE = 0x06
80
+ _TP_INITIAL_MAX_STREAM_DATA_UNI = 0x07
81
+ _TP_INITIAL_MAX_STREAMS_BIDI = 0x08
82
+ _TP_INITIAL_MAX_STREAMS_UNI = 0x09
83
+ _TP_ACK_DELAY_EXPONENT = 0x0A
84
+ _TP_MAX_ACK_DELAY = 0x0B
85
+ _TP_DISABLE_ACTIVE_MIGRATION = 0x0C
86
+ _TP_PREFERRED_ADDRESS = 0x0D
87
+ _TP_ACTIVE_CONNECTION_ID_LIMIT = 0x0E
88
+ _TP_INITIAL_SOURCE_CONNECTION_ID = 0x0F
89
+ _TP_RETRY_SOURCE_CONNECTION_ID = 0x10
90
+
91
+
92
+ @dataclass(slots=True)
93
+ class TransportParameters:
94
+ max_data: int = 65536
95
+ max_stream_data_bidi_local: int = 65536
96
+ max_stream_data_bidi_remote: int = 65536
97
+ max_stream_data_uni: int = 65536
98
+ max_streams_bidi: int = 128
99
+ max_streams_uni: int = 128
100
+ idle_timeout: int = 30000
101
+ active_connection_id_limit: int = 4
102
+ max_udp_payload_size: int = 1200
103
+ ack_delay_exponent: int = 3
104
+ max_ack_delay: int = 25
105
+ disable_active_migration: bool = False
106
+ original_destination_connection_id: bytes | None = None
107
+ stateless_reset_token: bytes | None = None
108
+ preferred_address: bytes | None = None
109
+ initial_source_connection_id: bytes | None = None
110
+ retry_source_connection_id: bytes | None = None
111
+ unknown_parameters: dict[int, bytes] = field(default_factory=dict)
112
+
113
+ def __post_init__(self) -> None:
114
+ if self.active_connection_id_limit < 2:
115
+ raise ValueError('active_connection_id_limit must be at least 2')
116
+ if self.ack_delay_exponent < 0:
117
+ raise ValueError('ack_delay_exponent must be non-negative')
118
+ if self.max_ack_delay < 0:
119
+ raise ValueError('max_ack_delay must be non-negative')
120
+ if self.max_udp_payload_size < 1200:
121
+ raise ValueError('max_udp_payload_size must be at least 1200')
122
+ if self.stateless_reset_token is not None and len(self.stateless_reset_token) != 16:
123
+ raise ValueError('stateless_reset_token must be exactly 16 bytes')
124
+
125
+ def to_bytes(self) -> bytes:
126
+ payload = bytearray()
127
+
128
+ def add_int(parameter_id: int, value: int | None) -> None:
129
+ if value is None:
130
+ return
131
+ encoded = encode_quic_varint(value)
132
+ payload.extend(encode_quic_varint(parameter_id))
133
+ payload.extend(encode_quic_varint(len(encoded)))
134
+ payload.extend(encoded)
135
+
136
+ def add_bytes(parameter_id: int, value: bytes | None) -> None:
137
+ if value is None:
138
+ return
139
+ payload.extend(encode_quic_varint(parameter_id))
140
+ payload.extend(encode_quic_varint(len(value)))
141
+ payload.extend(value)
142
+
143
+ add_bytes(_TP_ORIGINAL_DESTINATION_CONNECTION_ID, self.original_destination_connection_id)
144
+ add_int(_TP_MAX_IDLE_TIMEOUT, self.idle_timeout)
145
+ add_bytes(_TP_STATELESS_RESET_TOKEN, self.stateless_reset_token)
146
+ add_int(_TP_MAX_UDP_PAYLOAD_SIZE, self.max_udp_payload_size)
147
+ add_int(_TP_INITIAL_MAX_DATA, self.max_data)
148
+ add_int(_TP_INITIAL_MAX_STREAM_DATA_BIDI_LOCAL, self.max_stream_data_bidi_local)
149
+ add_int(_TP_INITIAL_MAX_STREAM_DATA_BIDI_REMOTE, self.max_stream_data_bidi_remote)
150
+ add_int(_TP_INITIAL_MAX_STREAM_DATA_UNI, self.max_stream_data_uni)
151
+ add_int(_TP_INITIAL_MAX_STREAMS_BIDI, self.max_streams_bidi)
152
+ add_int(_TP_INITIAL_MAX_STREAMS_UNI, self.max_streams_uni)
153
+ add_int(_TP_ACK_DELAY_EXPONENT, self.ack_delay_exponent)
154
+ add_int(_TP_MAX_ACK_DELAY, self.max_ack_delay)
155
+ if self.disable_active_migration:
156
+ payload.extend(encode_quic_varint(_TP_DISABLE_ACTIVE_MIGRATION))
157
+ payload.extend(encode_quic_varint(0))
158
+ add_bytes(_TP_PREFERRED_ADDRESS, self.preferred_address)
159
+ add_int(_TP_ACTIVE_CONNECTION_ID_LIMIT, self.active_connection_id_limit)
160
+ add_bytes(_TP_INITIAL_SOURCE_CONNECTION_ID, self.initial_source_connection_id)
161
+ add_bytes(_TP_RETRY_SOURCE_CONNECTION_ID, self.retry_source_connection_id)
162
+ for parameter_id, value in sorted(self.unknown_parameters.items()):
163
+ payload.extend(encode_quic_varint(parameter_id))
164
+ payload.extend(encode_quic_varint(len(value)))
165
+ payload.extend(value)
166
+ return bytes(payload)
167
+
168
+ @classmethod
169
+ def from_bytes(cls, data: bytes) -> 'TransportParameters':
170
+ values: dict[str, object] = {'unknown_parameters': {}}
171
+ seen: set[int] = set()
172
+ offset = 0
173
+ while offset < len(data):
174
+ parameter_id, offset = decode_quic_varint(data, offset)
175
+ if parameter_id in seen:
176
+ raise ProtocolError('duplicate QUIC transport parameter')
177
+ seen.add(parameter_id)
178
+ parameter_length, offset = decode_quic_varint(data, offset)
179
+ end = offset + parameter_length
180
+ if end > len(data):
181
+ raise ProtocolError('truncated QUIC transport parameter')
182
+ raw = data[offset:end]
183
+ offset = end
184
+
185
+ def decode_int(value: bytes) -> int:
186
+ decoded, inner_offset = decode_quic_varint(value, 0)
187
+ if inner_offset != len(value):
188
+ raise ProtocolError('invalid QUIC transport parameter encoding')
189
+ return decoded
190
+
191
+ if parameter_id == _TP_ORIGINAL_DESTINATION_CONNECTION_ID:
192
+ values['original_destination_connection_id'] = raw
193
+ elif parameter_id == _TP_MAX_IDLE_TIMEOUT:
194
+ values['idle_timeout'] = decode_int(raw)
195
+ elif parameter_id == _TP_STATELESS_RESET_TOKEN:
196
+ if len(raw) != 16:
197
+ raise ProtocolError('stateless_reset_token transport parameter must be 16 bytes')
198
+ values['stateless_reset_token'] = raw
199
+ elif parameter_id == _TP_MAX_UDP_PAYLOAD_SIZE:
200
+ values['max_udp_payload_size'] = decode_int(raw)
201
+ elif parameter_id == _TP_INITIAL_MAX_DATA:
202
+ values['max_data'] = decode_int(raw)
203
+ elif parameter_id == _TP_INITIAL_MAX_STREAM_DATA_BIDI_LOCAL:
204
+ values['max_stream_data_bidi_local'] = decode_int(raw)
205
+ elif parameter_id == _TP_INITIAL_MAX_STREAM_DATA_BIDI_REMOTE:
206
+ values['max_stream_data_bidi_remote'] = decode_int(raw)
207
+ elif parameter_id == _TP_INITIAL_MAX_STREAM_DATA_UNI:
208
+ values['max_stream_data_uni'] = decode_int(raw)
209
+ elif parameter_id == _TP_INITIAL_MAX_STREAMS_BIDI:
210
+ values['max_streams_bidi'] = decode_int(raw)
211
+ elif parameter_id == _TP_INITIAL_MAX_STREAMS_UNI:
212
+ values['max_streams_uni'] = decode_int(raw)
213
+ elif parameter_id == _TP_ACK_DELAY_EXPONENT:
214
+ values['ack_delay_exponent'] = decode_int(raw)
215
+ elif parameter_id == _TP_MAX_ACK_DELAY:
216
+ values['max_ack_delay'] = decode_int(raw)
217
+ elif parameter_id == _TP_DISABLE_ACTIVE_MIGRATION:
218
+ if raw:
219
+ raise ProtocolError('disable_active_migration transport parameter must be empty')
220
+ values['disable_active_migration'] = True
221
+ elif parameter_id == _TP_PREFERRED_ADDRESS:
222
+ values['preferred_address'] = raw
223
+ elif parameter_id == _TP_ACTIVE_CONNECTION_ID_LIMIT:
224
+ values['active_connection_id_limit'] = decode_int(raw)
225
+ elif parameter_id == _TP_INITIAL_SOURCE_CONNECTION_ID:
226
+ values['initial_source_connection_id'] = raw
227
+ elif parameter_id == _TP_RETRY_SOURCE_CONNECTION_ID:
228
+ values['retry_source_connection_id'] = raw
229
+ else:
230
+ values['unknown_parameters'][parameter_id] = raw
231
+ return cls(**values)
232
+
233
+ def is_0rtt_compatible_with(self, current: 'TransportParameters') -> bool:
234
+ return (
235
+ current.max_data >= self.max_data
236
+ and current.max_stream_data_bidi_local >= self.max_stream_data_bidi_local
237
+ and current.max_stream_data_bidi_remote >= self.max_stream_data_bidi_remote
238
+ and current.max_stream_data_uni >= self.max_stream_data_uni
239
+ and current.max_streams_bidi >= self.max_streams_bidi
240
+ and current.max_streams_uni >= self.max_streams_uni
241
+ and current.max_udp_payload_size >= self.max_udp_payload_size
242
+ and current.active_connection_id_limit >= self.active_connection_id_limit
243
+ and current.ack_delay_exponent == self.ack_delay_exponent
244
+ and current.max_ack_delay == self.max_ack_delay
245
+ and current.disable_active_migration == self.disable_active_migration
246
+ )
247
+
248
+
249
+ SUPPORTED_SIGNATURE_SCHEMES = (
250
+ SIG_ED25519,
251
+ SIG_RSA_PSS_RSAE_SHA256,
252
+ SIG_RSA_PSS_PSS_SHA256,
253
+ SIG_ECDSA_SECP256R1_SHA256,
254
+ )
255
+ SUPPORTED_CERTIFICATE_SIGNATURE_SCHEMES = (
256
+ SIG_ED25519,
257
+ SIG_RSA_PSS_RSAE_SHA256,
258
+ SIG_RSA_PSS_PSS_SHA256,
259
+ SIG_ECDSA_SECP256R1_SHA256,
260
+ SIG_RSA_PKCS1_SHA256,
261
+ )
262
+ SUPPORTED_GROUPS = (
263
+ GROUP_X25519,
264
+ GROUP_SECP256R1,
265
+ )
266
+
267
+ _CIPHER_SUITE_PARAMETERS = {
268
+ CIPHER_TLS_AES_256_GCM_SHA384: CipherSuiteParameters(hash_name='sha384', key_length=32, hp_length=32),
269
+ CIPHER_TLS_AES_128_GCM_SHA256: CipherSuiteParameters(hash_name='sha256', key_length=16, hp_length=16),
270
+ }
271
+
272
+ SUPPORTED_CIPHER_SUITES = tuple(_CIPHER_SUITE_PARAMETERS)
273
+ _CIPHER_SUITE_NAMES = {
274
+ CIPHER_TLS_AES_128_GCM_SHA256: 'TLS_AES_128_GCM_SHA256',
275
+ CIPHER_TLS_AES_256_GCM_SHA384: 'TLS_AES_256_GCM_SHA384',
276
+ }
277
+ _CIPHER_SUITE_NAME_TO_ID = {value: key for key, value in _CIPHER_SUITE_NAMES.items()}
278
+
279
+
280
+ def cipher_suite_name(cipher_suite: int) -> str:
281
+ return _CIPHER_SUITE_NAMES.get(cipher_suite, f'0x{cipher_suite:04x}')
282
+
283
+
284
+ def parse_cipher_suite_allowlist(value: str | None) -> tuple[int, ...]:
285
+ if value is None:
286
+ return ()
287
+ tokens = [token.strip() for token in value.replace(',', ':').split(':') if token.strip()]
288
+ if not tokens:
289
+ raise ProtocolError('ssl_ciphers must contain at least one supported TLS 1.3 cipher suite')
290
+ resolved: list[int] = []
291
+ for token in tokens:
292
+ cipher_suite = _CIPHER_SUITE_NAME_TO_ID.get(token)
293
+ if cipher_suite is None:
294
+ raise ProtocolError(f'unsupported TLS cipher suite: {token!r}')
295
+ if cipher_suite not in resolved:
296
+ resolved.append(cipher_suite)
297
+ return tuple(resolved)
298
+
299
+
300
+ def format_cipher_suite_allowlist(cipher_suites: Sequence[int]) -> str:
301
+ return ':'.join(cipher_suite_name(cipher_suite) for cipher_suite in cipher_suites)
302
+
303
+
304
+ def cipher_suite_parameters(cipher_suite: int) -> CipherSuiteParameters:
305
+ try:
306
+ return _CIPHER_SUITE_PARAMETERS[cipher_suite]
307
+ except KeyError as exc:
308
+ raise ProtocolError(f'unsupported TLS cipher suite: {cipher_suite:#06x}') from exc
309
+
310
+
311
+ def _u8_vector(payload: bytes) -> bytes:
312
+ if len(payload) > 255:
313
+ raise ValueError('u8 vector too large')
314
+ return bytes([len(payload)]) + payload
315
+
316
+
317
+
318
+ def _u16_vector(payload: bytes) -> bytes:
319
+ if len(payload) > 0xFFFF:
320
+ raise ValueError('u16 vector too large')
321
+ return len(payload).to_bytes(2, 'big') + payload
322
+
323
+
324
+
325
+ def _u24_vector(payload: bytes) -> bytes:
326
+ if len(payload) > 0xFFFFFF:
327
+ raise ValueError('u24 vector too large')
328
+ return len(payload).to_bytes(3, 'big') + payload
329
+
330
+
331
+
332
+ def _read_exact(data: bytes, offset: int, length: int) -> tuple[bytes, int]:
333
+ end = offset + length
334
+ if end > len(data):
335
+ raise ProtocolError('truncated TLS extension payload')
336
+ return data[offset:end], end
337
+
338
+
339
+
340
+ def _read_u8(data: bytes, offset: int) -> tuple[int, int]:
341
+ raw, offset = _read_exact(data, offset, 1)
342
+ return raw[0], offset
343
+
344
+
345
+
346
+ def _read_u16(data: bytes, offset: int) -> tuple[int, int]:
347
+ raw, offset = _read_exact(data, offset, 2)
348
+ return int.from_bytes(raw, 'big'), offset
349
+
350
+
351
+
352
+ def _read_u24(data: bytes, offset: int) -> tuple[int, int]:
353
+ raw, offset = _read_exact(data, offset, 3)
354
+ return int.from_bytes(raw, 'big'), offset
355
+
356
+
357
+
358
+ def _read_u8_vector(data: bytes, offset: int) -> tuple[bytes, int]:
359
+ length, offset = _read_u8(data, offset)
360
+ return _read_exact(data, offset, length)
361
+
362
+
363
+
364
+ def _read_u16_vector(data: bytes, offset: int) -> tuple[bytes, int]:
365
+ length, offset = _read_u16(data, offset)
366
+ return _read_exact(data, offset, length)
367
+
368
+
369
+
370
+ def _read_u24_vector(data: bytes, offset: int) -> tuple[bytes, int]:
371
+ length, offset = _read_u24(data, offset)
372
+ return _read_exact(data, offset, length)
373
+
374
+
375
+
376
+ def encode_server_name(server_name: str) -> bytes:
377
+ encoded = server_name.encode('utf-8')
378
+ entry = b'\x00' + _u16_vector(encoded)
379
+ return _u16_vector(entry)
380
+
381
+
382
+
383
+ def decode_server_name(data: bytes) -> str:
384
+ names_raw, offset = _read_u16_vector(data, 0)
385
+ if offset != len(data):
386
+ raise ProtocolError('invalid server_name extension')
387
+ inner = 0
388
+ while inner < len(names_raw):
389
+ name_type, inner = _read_u8(names_raw, inner)
390
+ name, inner = _read_u16_vector(names_raw, inner)
391
+ if name_type == 0:
392
+ return name.decode('utf-8')
393
+ raise ProtocolError('server_name extension does not contain a host_name entry')
394
+
395
+
396
+
397
+ def encode_supported_versions_client(versions: Sequence[int]) -> bytes:
398
+ payload = b''.join(version.to_bytes(2, 'big') for version in versions)
399
+ return _u8_vector(payload)
400
+
401
+
402
+
403
+ def decode_supported_versions_client(data: bytes) -> tuple[int, ...]:
404
+ payload, offset = _read_u8_vector(data, 0)
405
+ if offset != len(data) or len(payload) % 2:
406
+ raise ProtocolError('invalid supported_versions extension')
407
+ return tuple(int.from_bytes(payload[index:index + 2], 'big') for index in range(0, len(payload), 2))
408
+
409
+
410
+
411
+ def encode_supported_versions_server(version: int) -> bytes:
412
+ return version.to_bytes(2, 'big')
413
+
414
+
415
+
416
+ def decode_supported_versions_server(data: bytes) -> int:
417
+ if len(data) != 2:
418
+ raise ProtocolError('invalid selected supported_versions extension')
419
+ return int.from_bytes(data, 'big')
420
+
421
+
422
+
423
+ def encode_supported_groups(groups: Sequence[int]) -> bytes:
424
+ payload = b''.join(group.to_bytes(2, 'big') for group in groups)
425
+ return _u16_vector(payload)
426
+
427
+
428
+
429
+ def decode_supported_groups(data: bytes) -> tuple[int, ...]:
430
+ payload, offset = _read_u16_vector(data, 0)
431
+ if offset != len(data) or len(payload) % 2:
432
+ raise ProtocolError('invalid supported_groups extension')
433
+ return tuple(int.from_bytes(payload[index:index + 2], 'big') for index in range(0, len(payload), 2))
434
+
435
+
436
+
437
+ def encode_signature_algorithms(schemes: Sequence[int]) -> bytes:
438
+ payload = b''.join(scheme.to_bytes(2, 'big') for scheme in schemes)
439
+ return _u16_vector(payload)
440
+
441
+
442
+
443
+ def decode_signature_algorithms(data: bytes) -> tuple[int, ...]:
444
+ payload, offset = _read_u16_vector(data, 0)
445
+ if offset != len(data) or len(payload) % 2:
446
+ raise ProtocolError('invalid signature_algorithms extension')
447
+ return tuple(int.from_bytes(payload[index:index + 2], 'big') for index in range(0, len(payload), 2))
448
+
449
+
450
+
451
+ def encode_alpn(protocols: Sequence[str]) -> bytes:
452
+ payload = bytearray()
453
+ for protocol in protocols:
454
+ raw = protocol.encode('ascii')
455
+ payload.extend(_u8_vector(raw))
456
+ return _u16_vector(bytes(payload))
457
+
458
+
459
+
460
+ def decode_alpn(data: bytes) -> tuple[str, ...]:
461
+ payload, offset = _read_u16_vector(data, 0)
462
+ if offset != len(data):
463
+ raise ProtocolError('invalid ALPN extension')
464
+ inner = 0
465
+ protocols: list[str] = []
466
+ while inner < len(payload):
467
+ raw, inner = _read_u8_vector(payload, inner)
468
+ protocols.append(raw.decode('ascii'))
469
+ if not protocols:
470
+ raise ProtocolError('ALPN extension is empty')
471
+ return tuple(protocols)
472
+
473
+
474
+
475
+ def encode_psk_key_exchange_modes(modes: Sequence[int]) -> bytes:
476
+ return _u8_vector(bytes(modes))
477
+
478
+
479
+
480
+ def decode_psk_key_exchange_modes(data: bytes) -> tuple[int, ...]:
481
+ payload, offset = _read_u8_vector(data, 0)
482
+ if offset != len(data):
483
+ raise ProtocolError('invalid psk_key_exchange_modes extension')
484
+ return tuple(payload)
485
+
486
+
487
+
488
+ def encode_keyshare_client(shares: Sequence[tuple[int, bytes]]) -> bytes:
489
+ payload = bytearray()
490
+ for group, key_exchange in shares:
491
+ payload.extend(group.to_bytes(2, 'big'))
492
+ payload.extend(_u16_vector(key_exchange))
493
+ return _u16_vector(bytes(payload))
494
+
495
+
496
+
497
+ def decode_keyshare_client(data: bytes) -> dict[int, bytes]:
498
+ payload, offset = _read_u16_vector(data, 0)
499
+ if offset != len(data):
500
+ raise ProtocolError('invalid key_share extension')
501
+ inner = 0
502
+ shares: dict[int, bytes] = {}
503
+ while inner < len(payload):
504
+ group, inner = _read_u16(payload, inner)
505
+ key_exchange, inner = _read_u16_vector(payload, inner)
506
+ shares[group] = key_exchange
507
+ return shares
508
+
509
+
510
+
511
+ def encode_keyshare_server(group: int, key_exchange: bytes) -> bytes:
512
+ return group.to_bytes(2, 'big') + _u16_vector(key_exchange)
513
+
514
+
515
+
516
+ def decode_keyshare_server(data: bytes) -> tuple[int, bytes]:
517
+ group, offset = _read_u16(data, 0)
518
+ key_exchange, offset = _read_u16_vector(data, offset)
519
+ if offset != len(data):
520
+ raise ProtocolError('invalid server key_share extension')
521
+ return group, key_exchange
522
+
523
+
524
+
525
+ def encode_keyshare_hrr(selected_group: int) -> bytes:
526
+ return selected_group.to_bytes(2, 'big')
527
+
528
+
529
+
530
+ def decode_keyshare_hrr(data: bytes) -> int:
531
+ if len(data) != 2:
532
+ raise ProtocolError('invalid HelloRetryRequest key_share extension')
533
+ return int.from_bytes(data, 'big')
534
+
535
+
536
+
537
+ def encode_cookie(cookie: bytes) -> bytes:
538
+ return _u16_vector(cookie)
539
+
540
+
541
+
542
+ def decode_cookie(data: bytes) -> bytes:
543
+ cookie, offset = _read_u16_vector(data, 0)
544
+ if offset != len(data):
545
+ raise ProtocolError('invalid cookie extension')
546
+ return cookie
547
+
548
+
549
+
550
+ def encode_early_data(message_context: str, max_early_data_size: int = QUIC_EARLY_DATA_SENTINEL) -> bytes:
551
+ if message_context in {'client_hello', 'encrypted_extensions'}:
552
+ return b''
553
+ if message_context == 'new_session_ticket':
554
+ return max_early_data_size.to_bytes(4, 'big')
555
+ raise ValueError(f'unsupported early_data context: {message_context}')
556
+
557
+
558
+
559
+ def decode_early_data(data: bytes, message_context: str) -> object:
560
+ if message_context in {'client_hello', 'encrypted_extensions'}:
561
+ if data:
562
+ raise ProtocolError('early_data extension must be empty in this context')
563
+ return True
564
+ if message_context == 'new_session_ticket':
565
+ if len(data) != 4:
566
+ raise ProtocolError('invalid early_data NewSessionTicket extension')
567
+ return int.from_bytes(data, 'big')
568
+ return data
569
+
570
+
571
+
572
+ def encode_pre_shared_key_client(identities: Sequence[PskIdentity], binders: Sequence[bytes]) -> bytes:
573
+ if len(identities) != len(binders):
574
+ raise ValueError('PSK identities and binders must have matching counts')
575
+ identities_payload = bytearray()
576
+ binders_payload = bytearray()
577
+ for identity, binder in zip(identities, binders):
578
+ identities_payload.extend(_u16_vector(identity.identity))
579
+ identities_payload.extend(identity.obfuscated_ticket_age.to_bytes(4, 'big'))
580
+ binders_payload.extend(_u8_vector(binder))
581
+ return _u16_vector(bytes(identities_payload)) + _u16_vector(bytes(binders_payload))
582
+
583
+
584
+
585
+ def encode_pre_shared_key_client_without_binders(identities: Sequence[PskIdentity]) -> bytes:
586
+ identities_payload = bytearray()
587
+ for identity in identities:
588
+ identities_payload.extend(_u16_vector(identity.identity))
589
+ identities_payload.extend(identity.obfuscated_ticket_age.to_bytes(4, 'big'))
590
+ return _u16_vector(bytes(identities_payload))
591
+
592
+
593
+
594
+ def decode_pre_shared_key_client(data: bytes) -> OfferedPsks:
595
+ identities_raw, offset = _read_u16_vector(data, 0)
596
+ binders_raw, offset = _read_u16_vector(data, offset)
597
+ if offset != len(data):
598
+ raise ProtocolError('invalid pre_shared_key extension')
599
+ identities: list[PskIdentity] = []
600
+ inner = 0
601
+ while inner < len(identities_raw):
602
+ identity, inner = _read_u16_vector(identities_raw, inner)
603
+ obfuscated_ticket_age, inner = _read_u32(identities_raw, inner)
604
+ identities.append(PskIdentity(identity=identity, obfuscated_ticket_age=obfuscated_ticket_age))
605
+ binders: list[bytes] = []
606
+ inner = 0
607
+ while inner < len(binders_raw):
608
+ binder, inner = _read_u8_vector(binders_raw, inner)
609
+ binders.append(binder)
610
+ if len(identities) != len(binders):
611
+ raise ProtocolError('mismatched PSK identities and binders')
612
+ return OfferedPsks(identities=tuple(identities), binders=tuple(binders))
613
+
614
+
615
+
616
+ def encode_pre_shared_key_server(selected_identity: int) -> bytes:
617
+ return selected_identity.to_bytes(2, 'big')
618
+
619
+
620
+
621
+ def decode_pre_shared_key_server(data: bytes) -> int:
622
+ if len(data) != 2:
623
+ raise ProtocolError('invalid server pre_shared_key extension')
624
+ return int.from_bytes(data, 'big')
625
+
626
+
627
+
628
+ def _read_u32(data: bytes, offset: int) -> tuple[int, int]:
629
+ raw, offset = _read_exact(data, offset, 4)
630
+ return int.from_bytes(raw, 'big'), offset
631
+
632
+
633
+
634
+ def encode_quic_transport_parameters(parameters: TransportParameters) -> bytes:
635
+ return parameters.to_bytes()
636
+
637
+
638
+
639
+ def decode_quic_transport_parameters(data: bytes) -> TransportParameters:
640
+ return TransportParameters.from_bytes(data)
641
+
642
+
643
+
644
+ def encode_extensions(extensions: Sequence[TlsExtension], *, message_context: str) -> bytes:
645
+ payload = bytearray()
646
+ for extension in extensions:
647
+ raw = extension.raw_data
648
+ if raw is None:
649
+ raw = encode_extension_value(extension.extension_type, extension.value, message_context=message_context)
650
+ payload.extend(int(extension.extension_type).to_bytes(2, 'big'))
651
+ payload.extend(len(raw).to_bytes(2, 'big'))
652
+ payload.extend(raw)
653
+ return _u16_vector(bytes(payload))
654
+
655
+
656
+
657
+ def decode_extensions(data: bytes, *, message_context: str) -> tuple[TlsExtension, ...]:
658
+ payload, offset = _read_u16_vector(data, 0)
659
+ if offset != len(data):
660
+ raise ProtocolError('invalid TLS extensions vector')
661
+ inner = 0
662
+ items: list[TlsExtension] = []
663
+ while inner < len(payload):
664
+ extension_type, inner = _read_u16(payload, inner)
665
+ extension_data, inner = _read_u16_vector(payload, inner)
666
+ value = decode_extension_value(extension_type, extension_data, message_context=message_context)
667
+ items.append(TlsExtension(extension_type=extension_type, value=value, raw_data=extension_data))
668
+ return tuple(items)
669
+
670
+
671
+
672
+ def encode_extension_value(extension_type: int, value: object, *, message_context: str) -> bytes:
673
+ ext = ExtensionType(extension_type) if extension_type in set(item.value for item in ExtensionType) else None
674
+ if ext == ExtensionType.SERVER_NAME:
675
+ assert isinstance(value, str)
676
+ return encode_server_name(value)
677
+ if ext == ExtensionType.SUPPORTED_VERSIONS:
678
+ if message_context == 'client_hello':
679
+ return encode_supported_versions_client(tuple(int(item) for item in value))
680
+ return encode_supported_versions_server(int(value))
681
+ if ext == ExtensionType.SUPPORTED_GROUPS:
682
+ return encode_supported_groups(tuple(int(item) for item in value))
683
+ if ext in {ExtensionType.SIGNATURE_ALGORITHMS, ExtensionType.SIGNATURE_ALGORITHMS_CERT}:
684
+ return encode_signature_algorithms(tuple(int(item) for item in value))
685
+ if ext == ExtensionType.ALPN:
686
+ if isinstance(value, str):
687
+ return encode_alpn((value,))
688
+ return encode_alpn(tuple(str(item) for item in value))
689
+ if ext == ExtensionType.PSK_KEY_EXCHANGE_MODES:
690
+ return encode_psk_key_exchange_modes(tuple(int(item) for item in value))
691
+ if ext == ExtensionType.KEY_SHARE:
692
+ if message_context == 'client_hello':
693
+ return encode_keyshare_client(tuple((int(group), bytes(key_exchange)) for group, key_exchange in value))
694
+ if message_context == 'hello_retry_request':
695
+ return encode_keyshare_hrr(int(value))
696
+ group, key_exchange = value
697
+ return encode_keyshare_server(int(group), bytes(key_exchange))
698
+ if ext == ExtensionType.COOKIE:
699
+ return encode_cookie(bytes(value))
700
+ if ext == ExtensionType.EARLY_DATA:
701
+ size = QUIC_EARLY_DATA_SENTINEL if value is True else int(value)
702
+ return encode_early_data(message_context, size)
703
+ if ext == ExtensionType.PRE_SHARED_KEY:
704
+ if message_context == 'client_hello':
705
+ offered = value
706
+ assert isinstance(offered, OfferedPsks)
707
+ return encode_pre_shared_key_client(offered.identities, offered.binders)
708
+ return encode_pre_shared_key_server(int(value))
709
+ if ext == ExtensionType.QUIC_TRANSPORT_PARAMETERS:
710
+ assert isinstance(value, TransportParameters)
711
+ return encode_quic_transport_parameters(value)
712
+ if isinstance(value, bytes):
713
+ return value
714
+ raise ProtocolError(f'unsupported TLS extension type {extension_type}')
715
+
716
+
717
+
718
+ def decode_extension_value(extension_type: int, data: bytes, *, message_context: str) -> object:
719
+ try:
720
+ ext = ExtensionType(extension_type)
721
+ except ValueError:
722
+ return data
723
+ if ext == ExtensionType.SERVER_NAME:
724
+ return decode_server_name(data)
725
+ if ext == ExtensionType.SUPPORTED_VERSIONS:
726
+ if message_context == 'client_hello':
727
+ return decode_supported_versions_client(data)
728
+ return decode_supported_versions_server(data)
729
+ if ext == ExtensionType.SUPPORTED_GROUPS:
730
+ return decode_supported_groups(data)
731
+ if ext in {ExtensionType.SIGNATURE_ALGORITHMS, ExtensionType.SIGNATURE_ALGORITHMS_CERT}:
732
+ return decode_signature_algorithms(data)
733
+ if ext == ExtensionType.ALPN:
734
+ protocols = decode_alpn(data)
735
+ return protocols if message_context == 'client_hello' else protocols[0]
736
+ if ext == ExtensionType.PSK_KEY_EXCHANGE_MODES:
737
+ return decode_psk_key_exchange_modes(data)
738
+ if ext == ExtensionType.KEY_SHARE:
739
+ if message_context == 'client_hello':
740
+ return decode_keyshare_client(data)
741
+ if message_context == 'hello_retry_request':
742
+ return decode_keyshare_hrr(data)
743
+ return decode_keyshare_server(data)
744
+ if ext == ExtensionType.COOKIE:
745
+ return decode_cookie(data)
746
+ if ext == ExtensionType.EARLY_DATA:
747
+ return decode_early_data(data, message_context)
748
+ if ext == ExtensionType.PRE_SHARED_KEY:
749
+ if message_context == 'client_hello':
750
+ return decode_pre_shared_key_client(data)
751
+ return decode_pre_shared_key_server(data)
752
+ if ext == ExtensionType.QUIC_TRANSPORT_PARAMETERS:
753
+ return decode_quic_transport_parameters(data)
754
+ return data
755
+
756
+
757
+
758
+ def extension_dict(extensions: Iterable[TlsExtension]) -> dict[int, object]:
759
+ return {int(extension.extension_type): extension.value for extension in extensions}