@ajksunkang-aios/kgraph-linux-x64 0.1.2 → 0.1.3

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (196) hide show
  1. package/bin/kgraph-launcher +15 -3
  2. package/lib/kgraph/scripts/build-bundle.sh +17 -4
  3. package/lib/site-packages/outcome/__init__.py +20 -0
  4. package/lib/site-packages/outcome/_impl.py +239 -0
  5. package/lib/site-packages/outcome/_util.py +33 -0
  6. package/lib/site-packages/outcome/_version.py +7 -0
  7. package/lib/site-packages/outcome/py.typed +0 -0
  8. package/lib/site-packages/outcome-1.3.0.post0.dist-info/INSTALLER +1 -0
  9. package/lib/site-packages/outcome-1.3.0.post0.dist-info/LICENSE +3 -0
  10. package/lib/site-packages/outcome-1.3.0.post0.dist-info/LICENSE.APACHE2 +202 -0
  11. package/lib/site-packages/outcome-1.3.0.post0.dist-info/LICENSE.MIT +20 -0
  12. package/lib/site-packages/outcome-1.3.0.post0.dist-info/METADATA +63 -0
  13. package/lib/site-packages/outcome-1.3.0.post0.dist-info/RECORD +13 -0
  14. package/lib/site-packages/outcome-1.3.0.post0.dist-info/WHEEL +6 -0
  15. package/lib/site-packages/outcome-1.3.0.post0.dist-info/top_level.txt +1 -0
  16. package/lib/site-packages/sniffio/__init__.py +17 -0
  17. package/lib/site-packages/sniffio/_impl.py +95 -0
  18. package/lib/site-packages/sniffio/_tests/__init__.py +0 -0
  19. package/lib/site-packages/sniffio/_tests/test_sniffio.py +84 -0
  20. package/lib/site-packages/sniffio/_version.py +3 -0
  21. package/lib/site-packages/sniffio/py.typed +0 -0
  22. package/lib/site-packages/sniffio-1.3.1.dist-info/INSTALLER +1 -0
  23. package/lib/site-packages/sniffio-1.3.1.dist-info/LICENSE +3 -0
  24. package/lib/site-packages/sniffio-1.3.1.dist-info/LICENSE.APACHE2 +202 -0
  25. package/lib/site-packages/sniffio-1.3.1.dist-info/LICENSE.MIT +20 -0
  26. package/lib/site-packages/sniffio-1.3.1.dist-info/METADATA +104 -0
  27. package/lib/site-packages/sniffio-1.3.1.dist-info/RECORD +14 -0
  28. package/lib/site-packages/sniffio-1.3.1.dist-info/WHEEL +5 -0
  29. package/lib/site-packages/sniffio-1.3.1.dist-info/top_level.txt +1 -0
  30. package/lib/site-packages/sortedcontainers/__init__.py +74 -0
  31. package/lib/site-packages/sortedcontainers/sorteddict.py +812 -0
  32. package/lib/site-packages/sortedcontainers/sortedlist.py +2646 -0
  33. package/lib/site-packages/sortedcontainers/sortedset.py +733 -0
  34. package/lib/site-packages/sortedcontainers-2.4.0.dist-info/INSTALLER +1 -0
  35. package/lib/site-packages/sortedcontainers-2.4.0.dist-info/LICENSE +13 -0
  36. package/lib/site-packages/sortedcontainers-2.4.0.dist-info/METADATA +264 -0
  37. package/lib/site-packages/sortedcontainers-2.4.0.dist-info/RECORD +10 -0
  38. package/lib/site-packages/sortedcontainers-2.4.0.dist-info/WHEEL +6 -0
  39. package/lib/site-packages/sortedcontainers-2.4.0.dist-info/top_level.txt +1 -0
  40. package/lib/site-packages/trio/__init__.py +133 -0
  41. package/lib/site-packages/trio/__main__.py +3 -0
  42. package/lib/site-packages/trio/_abc.py +714 -0
  43. package/lib/site-packages/trio/_channel.py +610 -0
  44. package/lib/site-packages/trio/_core/__init__.py +94 -0
  45. package/lib/site-packages/trio/_core/_asyncgens.py +243 -0
  46. package/lib/site-packages/trio/_core/_concat_tb.py +26 -0
  47. package/lib/site-packages/trio/_core/_entry_queue.py +223 -0
  48. package/lib/site-packages/trio/_core/_exceptions.py +169 -0
  49. package/lib/site-packages/trio/_core/_generated_instrumentation.py +50 -0
  50. package/lib/site-packages/trio/_core/_generated_io_epoll.py +98 -0
  51. package/lib/site-packages/trio/_core/_generated_io_kqueue.py +153 -0
  52. package/lib/site-packages/trio/_core/_generated_io_windows.py +204 -0
  53. package/lib/site-packages/trio/_core/_generated_run.py +269 -0
  54. package/lib/site-packages/trio/_core/_generated_windows_ffi.py +10 -0
  55. package/lib/site-packages/trio/_core/_instrumentation.py +117 -0
  56. package/lib/site-packages/trio/_core/_io_common.py +31 -0
  57. package/lib/site-packages/trio/_core/_io_epoll.py +385 -0
  58. package/lib/site-packages/trio/_core/_io_kqueue.py +292 -0
  59. package/lib/site-packages/trio/_core/_io_windows.py +1036 -0
  60. package/lib/site-packages/trio/_core/_ki.py +271 -0
  61. package/lib/site-packages/trio/_core/_local.py +104 -0
  62. package/lib/site-packages/trio/_core/_mock_clock.py +165 -0
  63. package/lib/site-packages/trio/_core/_parking_lot.py +317 -0
  64. package/lib/site-packages/trio/_core/_run.py +3148 -0
  65. package/lib/site-packages/trio/_core/_run_context.py +15 -0
  66. package/lib/site-packages/trio/_core/_tests/__init__.py +0 -0
  67. package/lib/site-packages/trio/_core/_tests/test_asyncgen.py +339 -0
  68. package/lib/site-packages/trio/_core/_tests/test_cancelled.py +222 -0
  69. package/lib/site-packages/trio/_core/_tests/test_exceptiongroup_gc.py +103 -0
  70. package/lib/site-packages/trio/_core/_tests/test_guest_mode.py +755 -0
  71. package/lib/site-packages/trio/_core/_tests/test_instrumentation.py +315 -0
  72. package/lib/site-packages/trio/_core/_tests/test_io.py +522 -0
  73. package/lib/site-packages/trio/_core/_tests/test_ki.py +703 -0
  74. package/lib/site-packages/trio/_core/_tests/test_local.py +118 -0
  75. package/lib/site-packages/trio/_core/_tests/test_mock_clock.py +193 -0
  76. package/lib/site-packages/trio/_core/_tests/test_parking_lot.py +389 -0
  77. package/lib/site-packages/trio/_core/_tests/test_run.py +3024 -0
  78. package/lib/site-packages/trio/_core/_tests/test_thread_cache.py +227 -0
  79. package/lib/site-packages/trio/_core/_tests/test_tutil.py +13 -0
  80. package/lib/site-packages/trio/_core/_tests/test_unbounded_queue.py +154 -0
  81. package/lib/site-packages/trio/_core/_tests/test_windows.py +305 -0
  82. package/lib/site-packages/trio/_core/_tests/tutil.py +117 -0
  83. package/lib/site-packages/trio/_core/_tests/type_tests/nursery_start.py +79 -0
  84. package/lib/site-packages/trio/_core/_tests/type_tests/run.py +51 -0
  85. package/lib/site-packages/trio/_core/_thread_cache.py +317 -0
  86. package/lib/site-packages/trio/_core/_traps.py +318 -0
  87. package/lib/site-packages/trio/_core/_unbounded_queue.py +163 -0
  88. package/lib/site-packages/trio/_core/_wakeup_socketpair.py +75 -0
  89. package/lib/site-packages/trio/_core/_windows_cffi.py +313 -0
  90. package/lib/site-packages/trio/_deprecate.py +171 -0
  91. package/lib/site-packages/trio/_dtls.py +1380 -0
  92. package/lib/site-packages/trio/_file_io.py +513 -0
  93. package/lib/site-packages/trio/_highlevel_generic.py +125 -0
  94. package/lib/site-packages/trio/_highlevel_open_tcp_listeners.py +251 -0
  95. package/lib/site-packages/trio/_highlevel_open_tcp_stream.py +397 -0
  96. package/lib/site-packages/trio/_highlevel_open_unix_stream.py +65 -0
  97. package/lib/site-packages/trio/_highlevel_serve_listeners.py +148 -0
  98. package/lib/site-packages/trio/_highlevel_socket.py +423 -0
  99. package/lib/site-packages/trio/_highlevel_ssl_helpers.py +180 -0
  100. package/lib/site-packages/trio/_path.py +289 -0
  101. package/lib/site-packages/trio/_repl.py +159 -0
  102. package/lib/site-packages/trio/_signals.py +185 -0
  103. package/lib/site-packages/trio/_socket.py +1326 -0
  104. package/lib/site-packages/trio/_ssl.py +964 -0
  105. package/lib/site-packages/trio/_subprocess.py +1178 -0
  106. package/lib/site-packages/trio/_subprocess_platform/__init__.py +123 -0
  107. package/lib/site-packages/trio/_subprocess_platform/kqueue.py +48 -0
  108. package/lib/site-packages/trio/_subprocess_platform/waitid.py +113 -0
  109. package/lib/site-packages/trio/_subprocess_platform/windows.py +11 -0
  110. package/lib/site-packages/trio/_sync.py +908 -0
  111. package/lib/site-packages/trio/_tests/__init__.py +0 -0
  112. package/lib/site-packages/trio/_tests/astrill-codesigning-cert.cer +0 -0
  113. package/lib/site-packages/trio/_tests/check_type_completeness.py +247 -0
  114. package/lib/site-packages/trio/_tests/module_with_deprecations.py +22 -0
  115. package/lib/site-packages/trio/_tests/pytest_plugin.py +54 -0
  116. package/lib/site-packages/trio/_tests/test_abc.py +72 -0
  117. package/lib/site-packages/trio/_tests/test_channel.py +750 -0
  118. package/lib/site-packages/trio/_tests/test_contextvars.py +56 -0
  119. package/lib/site-packages/trio/_tests/test_deprecate.py +277 -0
  120. package/lib/site-packages/trio/_tests/test_deprecate_strict_exception_groups_false.py +64 -0
  121. package/lib/site-packages/trio/_tests/test_dtls.py +950 -0
  122. package/lib/site-packages/trio/_tests/test_exports.py +626 -0
  123. package/lib/site-packages/trio/_tests/test_fakenet.py +317 -0
  124. package/lib/site-packages/trio/_tests/test_file_io.py +269 -0
  125. package/lib/site-packages/trio/_tests/test_highlevel_generic.py +98 -0
  126. package/lib/site-packages/trio/_tests/test_highlevel_open_tcp_listeners.py +419 -0
  127. package/lib/site-packages/trio/_tests/test_highlevel_open_tcp_stream.py +693 -0
  128. package/lib/site-packages/trio/_tests/test_highlevel_open_unix_stream.py +86 -0
  129. package/lib/site-packages/trio/_tests/test_highlevel_serve_listeners.py +186 -0
  130. package/lib/site-packages/trio/_tests/test_highlevel_socket.py +336 -0
  131. package/lib/site-packages/trio/_tests/test_highlevel_ssl_helpers.py +169 -0
  132. package/lib/site-packages/trio/_tests/test_path.py +279 -0
  133. package/lib/site-packages/trio/_tests/test_repl.py +428 -0
  134. package/lib/site-packages/trio/_tests/test_scheduler_determinism.py +47 -0
  135. package/lib/site-packages/trio/_tests/test_signals.py +186 -0
  136. package/lib/site-packages/trio/_tests/test_socket.py +1253 -0
  137. package/lib/site-packages/trio/_tests/test_ssl.py +1371 -0
  138. package/lib/site-packages/trio/_tests/test_subprocess.py +767 -0
  139. package/lib/site-packages/trio/_tests/test_sync.py +735 -0
  140. package/lib/site-packages/trio/_tests/test_testing.py +682 -0
  141. package/lib/site-packages/trio/_tests/test_testing_raisesgroup.py +1128 -0
  142. package/lib/site-packages/trio/_tests/test_threads.py +1173 -0
  143. package/lib/site-packages/trio/_tests/test_timeouts.py +281 -0
  144. package/lib/site-packages/trio/_tests/test_tracing.py +88 -0
  145. package/lib/site-packages/trio/_tests/test_trio.py +8 -0
  146. package/lib/site-packages/trio/_tests/test_unix_pipes.py +288 -0
  147. package/lib/site-packages/trio/_tests/test_util.py +349 -0
  148. package/lib/site-packages/trio/_tests/test_wait_for_object.py +225 -0
  149. package/lib/site-packages/trio/_tests/test_windows_pipes.py +112 -0
  150. package/lib/site-packages/trio/_tests/tools/__init__.py +0 -0
  151. package/lib/site-packages/trio/_tests/tools/test_gen_exports.py +179 -0
  152. package/lib/site-packages/trio/_tests/tools/test_mypy_annotate.py +140 -0
  153. package/lib/site-packages/trio/_tests/tools/test_sync_requirements.py +80 -0
  154. package/lib/site-packages/trio/_tests/type_tests/check_wraps.py +9 -0
  155. package/lib/site-packages/trio/_tests/type_tests/open_memory_channel.py +4 -0
  156. package/lib/site-packages/trio/_tests/type_tests/path.py +140 -0
  157. package/lib/site-packages/trio/_tests/type_tests/subprocesses.py +23 -0
  158. package/lib/site-packages/trio/_tests/type_tests/task_status.py +29 -0
  159. package/lib/site-packages/trio/_threads.py +610 -0
  160. package/lib/site-packages/trio/_timeouts.py +197 -0
  161. package/lib/site-packages/trio/_tools/__init__.py +0 -0
  162. package/lib/site-packages/trio/_tools/gen_exports.py +401 -0
  163. package/lib/site-packages/trio/_tools/mypy_annotate.py +126 -0
  164. package/lib/site-packages/trio/_tools/sync_requirements.py +98 -0
  165. package/lib/site-packages/trio/_tools/windows_ffi_build.py +220 -0
  166. package/lib/site-packages/trio/_unix_pipes.py +197 -0
  167. package/lib/site-packages/trio/_util.py +385 -0
  168. package/lib/site-packages/trio/_version.py +3 -0
  169. package/lib/site-packages/trio/_wait_for_object.py +67 -0
  170. package/lib/site-packages/trio/_windows_pipes.py +144 -0
  171. package/lib/site-packages/trio/abc.py +23 -0
  172. package/lib/site-packages/trio/from_thread.py +13 -0
  173. package/lib/site-packages/trio/lowlevel.py +95 -0
  174. package/lib/site-packages/trio/py.typed +0 -0
  175. package/lib/site-packages/trio/socket.py +602 -0
  176. package/lib/site-packages/trio/testing/__init__.py +58 -0
  177. package/lib/site-packages/trio/testing/_check_streams.py +570 -0
  178. package/lib/site-packages/trio/testing/_checkpoints.py +69 -0
  179. package/lib/site-packages/trio/testing/_fake_net.py +584 -0
  180. package/lib/site-packages/trio/testing/_memory_streams.py +633 -0
  181. package/lib/site-packages/trio/testing/_network.py +36 -0
  182. package/lib/site-packages/trio/testing/_raises_group.py +1015 -0
  183. package/lib/site-packages/trio/testing/_sequencer.py +87 -0
  184. package/lib/site-packages/trio/testing/_trio_test.py +50 -0
  185. package/lib/site-packages/trio/to_thread.py +4 -0
  186. package/lib/site-packages/trio-0.33.0.dist-info/INSTALLER +1 -0
  187. package/lib/site-packages/trio-0.33.0.dist-info/METADATA +186 -0
  188. package/lib/site-packages/trio-0.33.0.dist-info/RECORD +156 -0
  189. package/lib/site-packages/trio-0.33.0.dist-info/REQUESTED +0 -0
  190. package/lib/site-packages/trio-0.33.0.dist-info/WHEEL +5 -0
  191. package/lib/site-packages/trio-0.33.0.dist-info/entry_points.txt +2 -0
  192. package/lib/site-packages/trio-0.33.0.dist-info/licenses/LICENSE +3 -0
  193. package/lib/site-packages/trio-0.33.0.dist-info/licenses/LICENSE.APACHE2 +202 -0
  194. package/lib/site-packages/trio-0.33.0.dist-info/licenses/LICENSE.MIT +22 -0
  195. package/lib/site-packages/trio-0.33.0.dist-info/top_level.txt +1 -0
  196. package/package.json +1 -1
@@ -0,0 +1,1380 @@
1
+ # Implementation of DTLS 1.2, using pyopenssl
2
+ # https://datatracker.ietf.org/doc/html/rfc6347
3
+ #
4
+ # OpenSSL's APIs for DTLS are extremely awkward and limited, which forces us to jump
5
+ # through a *lot* of hoops and implement important chunks of the protocol ourselves.
6
+ # Hopefully they fix this before implementing DTLS 1.3, because it's a very different
7
+ # protocol, and it's probably impossible to pull tricks like we do here.
8
+
9
+ from __future__ import annotations
10
+
11
+ import contextlib
12
+ import enum
13
+ import errno
14
+ import hmac
15
+ import os
16
+ import struct
17
+ import warnings
18
+ import weakref
19
+ from itertools import count
20
+ from typing import TYPE_CHECKING, Generic, TypeAlias, TypeVar
21
+ from weakref import ReferenceType, WeakValueDictionary
22
+
23
+ import attrs
24
+
25
+ import trio
26
+
27
+ from ._util import NoPublicConstructor, final
28
+
29
+ if TYPE_CHECKING:
30
+ from collections.abc import Awaitable, Callable, Iterable, Iterator
31
+ from types import TracebackType
32
+
33
+ # See DTLSEndpoint.__init__ for why this is imported here
34
+ from OpenSSL import SSL # noqa: TC004
35
+ from typing_extensions import Self, TypeVarTuple, Unpack
36
+
37
+ from trio._socket import AddressFormat
38
+ from trio.socket import SocketType
39
+
40
+ PosArgsT = TypeVarTuple("PosArgsT")
41
+
42
+ MAX_UDP_PACKET_SIZE = 65527
43
+
44
+
45
+ def packet_header_overhead(sock: SocketType) -> int:
46
+ if sock.family == trio.socket.AF_INET:
47
+ return 28
48
+ else:
49
+ return 48
50
+
51
+
52
+ def worst_case_mtu(sock: SocketType) -> int:
53
+ if sock.family == trio.socket.AF_INET:
54
+ return 576 - packet_header_overhead(sock)
55
+ else:
56
+ return 1280 - packet_header_overhead(sock) # TODO: test this line
57
+
58
+
59
+ def best_guess_mtu(sock: SocketType) -> int:
60
+ return 1500 - packet_header_overhead(sock)
61
+
62
+
63
+ # There are a bunch of different RFCs that define these codes, so for a
64
+ # comprehensive collection look here:
65
+ # https://www.iana.org/assignments/tls-parameters/tls-parameters.xhtml
66
+ class ContentType(enum.IntEnum):
67
+ change_cipher_spec = 20
68
+ alert = 21
69
+ handshake = 22
70
+ application_data = 23
71
+ heartbeat = 24
72
+
73
+
74
+ class HandshakeType(enum.IntEnum):
75
+ hello_request = 0
76
+ client_hello = 1
77
+ server_hello = 2
78
+ hello_verify_request = 3
79
+ new_session_ticket = 4
80
+ end_of_early_data = 4
81
+ encrypted_extensions = 8
82
+ certificate = 11
83
+ server_key_exchange = 12
84
+ certificate_request = 13
85
+ server_hello_done = 14
86
+ certificate_verify = 15
87
+ client_key_exchange = 16
88
+ finished = 20
89
+ certificate_url = 21
90
+ certificate_status = 22
91
+ supplemental_data = 23
92
+ key_update = 24
93
+ compressed_certificate = 25
94
+ ekt_key = 26
95
+ message_hash = 254
96
+
97
+
98
+ class ProtocolVersion:
99
+ DTLS10 = bytes([254, 255])
100
+ DTLS12 = bytes([254, 253])
101
+
102
+
103
+ EPOCH_MASK = 0xFFFF << (6 * 8)
104
+
105
+
106
+ # Conventions:
107
+ # - All functions that handle network data end in _untrusted.
108
+ # - All functions end in _untrusted MUST make sure that bad data from the
109
+ # network cannot *only* cause BadPacket to be raised. No IndexError or
110
+ # struct.error or whatever.
111
+ class BadPacket(Exception):
112
+ pass
113
+
114
+
115
+ # This checks that the DTLS 'epoch' field is 0, which is true iff we're in the
116
+ # initial handshake. It doesn't check the ContentType, because not all
117
+ # handshake messages have ContentType==handshake -- for example,
118
+ # ChangeCipherSpec is used during the handshake but has its own ContentType.
119
+ #
120
+ # Cannot fail.
121
+ def part_of_handshake_untrusted(packet: bytes) -> bool:
122
+ # If the packet is too short, then slicing will successfully return a
123
+ # short string, which will necessarily fail to match.
124
+ return packet[3:5] == b"\x00\x00"
125
+
126
+
127
+ # Cannot fail
128
+ def is_client_hello_untrusted(packet: bytes) -> bool:
129
+ try:
130
+ return (
131
+ packet[0] == ContentType.handshake
132
+ and packet[13] == HandshakeType.client_hello
133
+ )
134
+ except IndexError:
135
+ # Invalid DTLS record
136
+ return False
137
+
138
+
139
+ # DTLS records are:
140
+ # - 1 byte content type
141
+ # - 2 bytes version
142
+ # - 8 bytes epoch+seqno
143
+ # Technically this is 2 bytes epoch then 6 bytes seqno, but we treat it as
144
+ # a single 8-byte integer, where epoch changes are represented as jumping
145
+ # forward by 2**(6*8).
146
+ # - 2 bytes payload length (unsigned big-endian)
147
+ # - payload
148
+ RECORD_HEADER = struct.Struct("!B2sQH")
149
+
150
+
151
+ def to_hex(data: bytes) -> str: # pragma: no cover
152
+ return data.hex()
153
+
154
+
155
+ @attrs.frozen
156
+ class Record:
157
+ content_type: int
158
+ version: bytes = attrs.field(repr=to_hex)
159
+ epoch_seqno: int
160
+ payload: bytes = attrs.field(repr=to_hex)
161
+
162
+
163
+ def records_untrusted(packet: bytes) -> Iterator[Record]:
164
+ i = 0
165
+ while i < len(packet):
166
+ try:
167
+ ct, version, epoch_seqno, payload_len = RECORD_HEADER.unpack_from(packet, i)
168
+ # Marked as no-cover because at time of writing, this code is unreachable
169
+ # (records_untrusted only gets called on packets that are either trusted or that
170
+ # have passed is_client_hello_untrusted, which filters out short packets)
171
+ except struct.error as exc: # pragma: no cover
172
+ raise BadPacket("invalid record header") from exc
173
+ i += RECORD_HEADER.size
174
+ payload = packet[i : i + payload_len]
175
+ if len(payload) != payload_len:
176
+ raise BadPacket("short record")
177
+ i += payload_len
178
+ yield Record(ct, version, epoch_seqno, payload)
179
+
180
+
181
+ def encode_record(record: Record) -> bytes:
182
+ header = RECORD_HEADER.pack(
183
+ record.content_type,
184
+ record.version,
185
+ record.epoch_seqno,
186
+ len(record.payload),
187
+ )
188
+ return header + record.payload
189
+
190
+
191
+ # Handshake messages are:
192
+ # - 1 byte message type
193
+ # - 3 bytes total message length
194
+ # - 2 bytes message sequence number
195
+ # - 3 bytes fragment offset
196
+ # - 3 bytes fragment length
197
+ HANDSHAKE_MESSAGE_HEADER = struct.Struct("!B3sH3s3s")
198
+
199
+
200
+ @attrs.frozen
201
+ class HandshakeFragment:
202
+ msg_type: int
203
+ msg_len: int
204
+ msg_seq: int
205
+ frag_offset: int
206
+ frag_len: int
207
+ frag: bytes = attrs.field(repr=to_hex)
208
+
209
+
210
+ def decode_handshake_fragment_untrusted(payload: bytes) -> HandshakeFragment:
211
+ # Raises BadPacket if decoding fails
212
+ try:
213
+ (
214
+ msg_type,
215
+ msg_len_bytes,
216
+ msg_seq,
217
+ frag_offset_bytes,
218
+ frag_len_bytes,
219
+ ) = HANDSHAKE_MESSAGE_HEADER.unpack_from(payload)
220
+ except struct.error as exc: # TODO: test this line
221
+ raise BadPacket("bad handshake message header") from exc
222
+ # 'struct' doesn't have built-in support for 24-bit integers, so we
223
+ # have to do it by hand. These can't fail.
224
+ msg_len = int.from_bytes(msg_len_bytes, "big")
225
+ frag_offset = int.from_bytes(frag_offset_bytes, "big")
226
+ frag_len = int.from_bytes(frag_len_bytes, "big")
227
+ frag = payload[HANDSHAKE_MESSAGE_HEADER.size :]
228
+ if len(frag) != frag_len:
229
+ raise BadPacket("handshake fragment length doesn't match record length")
230
+ return HandshakeFragment(
231
+ msg_type,
232
+ msg_len,
233
+ msg_seq,
234
+ frag_offset,
235
+ frag_len,
236
+ frag,
237
+ )
238
+
239
+
240
+ def encode_handshake_fragment(hsf: HandshakeFragment) -> bytes:
241
+ hs_header = HANDSHAKE_MESSAGE_HEADER.pack(
242
+ hsf.msg_type,
243
+ hsf.msg_len.to_bytes(3, "big"),
244
+ hsf.msg_seq,
245
+ hsf.frag_offset.to_bytes(3, "big"),
246
+ hsf.frag_len.to_bytes(3, "big"),
247
+ )
248
+ return hs_header + hsf.frag
249
+
250
+
251
+ def decode_client_hello_untrusted(packet: bytes) -> tuple[int, bytes, bytes]:
252
+ # Raises BadPacket if parsing fails
253
+ # Returns (record epoch_seqno, cookie from the packet, data that should be
254
+ # hashed into cookie)
255
+ try:
256
+ # ClientHello has to be the first record in the packet
257
+ record = next(records_untrusted(packet))
258
+ # no-cover because at time of writing, this is unreachable:
259
+ # decode_client_hello_untrusted is only called on packets that have passed
260
+ # is_client_hello_untrusted, which confirms the content type.
261
+ if record.content_type != ContentType.handshake: # pragma: no cover
262
+ raise BadPacket("not a handshake record")
263
+ fragment = decode_handshake_fragment_untrusted(record.payload)
264
+ if fragment.msg_type != HandshakeType.client_hello:
265
+ raise BadPacket("not a ClientHello")
266
+ # ClientHello can't be fragmented, because reassembly requires holding
267
+ # per-connection state, and we refuse to allocate per-connection state
268
+ # until after we get a valid ClientHello.
269
+ if fragment.frag_offset != 0:
270
+ raise BadPacket("fragmented ClientHello")
271
+ if fragment.frag_len != fragment.msg_len:
272
+ raise BadPacket("fragmented ClientHello")
273
+
274
+ # As per RFC 6347:
275
+ #
276
+ # When responding to a HelloVerifyRequest, the client MUST use the
277
+ # same parameter values (version, random, session_id, cipher_suites,
278
+ # compression_method) as it did in the original ClientHello. The
279
+ # server SHOULD use those values to generate its cookie and verify that
280
+ # they are correct upon cookie receipt.
281
+ #
282
+ # However, the record-layer framing can and will change (e.g. the
283
+ # second ClientHello will have a new record-layer sequence number). So
284
+ # we need to pull out the handshake message alone, discarding the
285
+ # record-layer stuff, and then we're going to hash all of it *except*
286
+ # the cookie.
287
+
288
+ body = fragment.frag
289
+ # ClientHello is:
290
+ #
291
+ # - 2 bytes client_version
292
+ # - 32 bytes random
293
+ # - 1 byte session_id length
294
+ # - session_id
295
+ # - 1 byte cookie length
296
+ # - cookie
297
+ # - everything else
298
+ #
299
+ # So to find the cookie, so we need to figure out how long the
300
+ # session_id is and skip past it.
301
+ session_id_len = body[2 + 32]
302
+ cookie_len_offset = 2 + 32 + 1 + session_id_len
303
+ cookie_len = body[cookie_len_offset]
304
+
305
+ cookie_start = cookie_len_offset + 1
306
+ cookie_end = cookie_start + cookie_len
307
+
308
+ before_cookie = body[:cookie_len_offset]
309
+ cookie = body[cookie_start:cookie_end]
310
+ after_cookie = body[cookie_end:]
311
+
312
+ if len(cookie) != cookie_len:
313
+ raise BadPacket("short cookie")
314
+ return (record.epoch_seqno, cookie, before_cookie + after_cookie)
315
+
316
+ except (struct.error, IndexError) as exc:
317
+ raise BadPacket("bad ClientHello") from exc
318
+
319
+
320
+ @attrs.frozen
321
+ class HandshakeMessage:
322
+ record_version: bytes = attrs.field(repr=to_hex)
323
+ msg_type: HandshakeType
324
+ msg_seq: int
325
+ body: bytearray = attrs.field(repr=to_hex)
326
+
327
+
328
+ # ChangeCipherSpec is part of the handshake, but it's not a "handshake
329
+ # message" and can't be fragmented the same way. Sigh.
330
+ @attrs.frozen
331
+ class PseudoHandshakeMessage:
332
+ record_version: bytes = attrs.field(repr=to_hex)
333
+ content_type: int
334
+ payload: bytes = attrs.field(repr=to_hex)
335
+
336
+
337
+ # The final record in a handshake is Finished, which is encrypted, can't be fragmented
338
+ # (at least by us), and keeps its record number (because it's in a new epoch). So we
339
+ # just pass it through unchanged. (Fortunately, the payload is only a single hash value,
340
+ # so the largest it will ever be is 64 bytes for a 512-bit hash. Which is small enough
341
+ # that it never requires fragmenting to fit into a UDP packet.
342
+ @attrs.frozen
343
+ class OpaqueHandshakeMessage:
344
+ record: Record
345
+
346
+
347
+ _AnyHandshakeMessage: TypeAlias = (
348
+ HandshakeMessage | PseudoHandshakeMessage | OpaqueHandshakeMessage
349
+ )
350
+
351
+
352
+ # This takes a raw outgoing handshake volley that openssl generated, and
353
+ # reconstructs the handshake messages inside it, so that we can repack them
354
+ # into records while retransmitting. So the data ought to be well-behaved --
355
+ # it's not coming from the network.
356
+ def decode_volley_trusted(
357
+ volley: bytes,
358
+ ) -> list[_AnyHandshakeMessage]:
359
+ messages: list[_AnyHandshakeMessage] = []
360
+ messages_by_seq = {}
361
+ for record in records_untrusted(volley):
362
+ # ChangeCipherSpec isn't a handshake message, so it can't be fragmented.
363
+ # Handshake messages with epoch > 0 are encrypted, so we can't fragment them
364
+ # either. Fortunately, ChangeCipherSpec has a 1 byte payload, and the only
365
+ # encrypted handshake message is Finished, whose payload is a single hash value
366
+ # -- so 32 bytes for SHA-256, 64 for SHA-512, etc. Neither is going to be so
367
+ # large that it has to be fragmented to fit into a single packet.
368
+ if record.epoch_seqno & EPOCH_MASK:
369
+ messages.append(OpaqueHandshakeMessage(record))
370
+ elif record.content_type in (ContentType.change_cipher_spec, ContentType.alert):
371
+ messages.append(
372
+ PseudoHandshakeMessage(
373
+ record.version,
374
+ record.content_type,
375
+ record.payload,
376
+ ),
377
+ )
378
+ else:
379
+ assert record.content_type == ContentType.handshake
380
+ fragment = decode_handshake_fragment_untrusted(record.payload)
381
+ msg_type = HandshakeType(fragment.msg_type)
382
+ if fragment.msg_seq not in messages_by_seq:
383
+ msg = HandshakeMessage(
384
+ record.version,
385
+ msg_type,
386
+ fragment.msg_seq,
387
+ bytearray(fragment.msg_len),
388
+ )
389
+ messages.append(msg)
390
+ messages_by_seq[fragment.msg_seq] = msg
391
+ else:
392
+ msg = messages_by_seq[fragment.msg_seq]
393
+ assert msg.msg_type == fragment.msg_type
394
+ assert msg.msg_seq == fragment.msg_seq
395
+ assert len(msg.body) == fragment.msg_len
396
+
397
+ msg.body[
398
+ fragment.frag_offset : fragment.frag_offset + fragment.frag_len
399
+ ] = fragment.frag
400
+
401
+ return messages
402
+
403
+
404
+ class RecordEncoder:
405
+ def __init__(self) -> None:
406
+ self._record_seq = count()
407
+
408
+ def set_first_record_number(self, n: int) -> None:
409
+ self._record_seq = count(n)
410
+
411
+ def encode_volley(
412
+ self,
413
+ messages: Iterable[_AnyHandshakeMessage],
414
+ mtu: int,
415
+ ) -> list[bytearray]:
416
+ packets = []
417
+ packet = bytearray()
418
+ for message in messages:
419
+ if isinstance(message, OpaqueHandshakeMessage):
420
+ encoded = encode_record(message.record)
421
+ if mtu - len(packet) - len(encoded) <= 0: # TODO: test this line
422
+ packets.append(packet)
423
+ packet = bytearray()
424
+ packet += encoded
425
+ assert len(packet) <= mtu
426
+ elif isinstance(message, PseudoHandshakeMessage):
427
+ space = mtu - len(packet) - RECORD_HEADER.size - len(message.payload)
428
+ if space <= 0: # TODO: test this line
429
+ packets.append(packet)
430
+ packet = bytearray()
431
+ packet += RECORD_HEADER.pack(
432
+ message.content_type,
433
+ message.record_version,
434
+ next(self._record_seq),
435
+ len(message.payload),
436
+ )
437
+ packet += message.payload
438
+ assert len(packet) <= mtu
439
+ else:
440
+ msg_len_bytes = len(message.body).to_bytes(3, "big")
441
+ frag_offset = 0
442
+ frags_encoded = 0
443
+ # If message.body is empty, then we still want to encode it in one
444
+ # fragment, not zero.
445
+ while frag_offset < len(message.body) or not frags_encoded:
446
+ space = (
447
+ mtu
448
+ - len(packet)
449
+ - RECORD_HEADER.size
450
+ - HANDSHAKE_MESSAGE_HEADER.size
451
+ )
452
+ if space <= 0:
453
+ packets.append(packet)
454
+ packet = bytearray()
455
+ continue
456
+ frag = message.body[frag_offset : frag_offset + space]
457
+ frag_offset_bytes = frag_offset.to_bytes(3, "big")
458
+ frag_len_bytes = len(frag).to_bytes(3, "big")
459
+ frag_offset += len(frag)
460
+
461
+ packet += RECORD_HEADER.pack(
462
+ ContentType.handshake,
463
+ message.record_version,
464
+ next(self._record_seq),
465
+ HANDSHAKE_MESSAGE_HEADER.size + len(frag),
466
+ )
467
+
468
+ packet += HANDSHAKE_MESSAGE_HEADER.pack(
469
+ message.msg_type,
470
+ msg_len_bytes,
471
+ message.msg_seq,
472
+ frag_offset_bytes,
473
+ frag_len_bytes,
474
+ )
475
+
476
+ packet += frag
477
+
478
+ frags_encoded += 1
479
+ assert len(packet) <= mtu
480
+
481
+ if packet:
482
+ packets.append(packet)
483
+
484
+ return packets
485
+
486
+
487
+ # This bit requires implementing a bona fide cryptographic protocol, so even though it's
488
+ # a simple one let's take a moment to discuss the design.
489
+ #
490
+ # Our goal is to force new incoming handshakes that claim to be coming from a
491
+ # given ip:port to prove that they can also receive packets sent to that
492
+ # ip:port. (There's nothing in UDP to stop someone from forging the return
493
+ # address, and it's often used for stuff like DoS reflection attacks, where
494
+ # an attacker tries to trick us into sending data at some innocent victim.)
495
+ # For more details, see:
496
+ #
497
+ # https://datatracker.ietf.org/doc/html/rfc6347#section-4.2.1
498
+ #
499
+ # To do this, when we receive an initial ClientHello, we calculate a magic
500
+ # cookie, and send it back as a HelloVerifyRequest. Then the client sends us a
501
+ # second ClientHello, this time with the magic cookie included, and after we
502
+ # check that this cookie is valid we go ahead and start the handshake proper.
503
+ #
504
+ # So the magic cookie needs the following properties:
505
+ # - No-one can forge it without knowing our secret key
506
+ # - It ensures that the ip, port, and ClientHello contents from the response
507
+ # match those in the challenge
508
+ # - It expires after a short-ish period (so that if an attacker manages to steal one, it
509
+ # won't be useful for long)
510
+ # - It doesn't require storing any peer-specific state on our side
511
+ #
512
+ # To do that, we take the ip/port/ClientHello data and compute an HMAC of them, using a
513
+ # secret key we generate on startup. We also include:
514
+ #
515
+ # - The current time (using Trio's clock), rounded to the nearest 30 seconds
516
+ # - A random salt
517
+ #
518
+ # Then the cookie is the salt and the HMAC digest concatenated together.
519
+ #
520
+ # When verifying a cookie, we use the salt + new ip/port/ClientHello data to recompute
521
+ # the HMAC digest, for both the current time and the current time minus 30 seconds, and
522
+ # if either of them match, we consider the cookie good.
523
+ #
524
+ # Including the rounded-off time like this means that each cookie is good for at least
525
+ # 30 seconds, and possibly as much as 60 seconds.
526
+ #
527
+ # The salt is probably not necessary -- I'm pretty sure that all it does is make it hard
528
+ # for an attacker to figure out when our clock ticks over a 30 second boundary. Which is
529
+ # probably pretty harmless? But it's easier to add the salt than to convince myself that
530
+ # it's *completely* harmless, so, salt it is.
531
+
532
+ COOKIE_REFRESH_INTERVAL = 30 # seconds
533
+ KEY_BYTES = 32
534
+ COOKIE_HASH = "sha256"
535
+ SALT_BYTES = 8
536
+ # 32 bytes was the maximum cookie length in DTLS 1.0. DTLS 1.2 raised it to 255. I doubt
537
+ # there are any DTLS 1.0 implementations still in the wild, but really 32 bytes is
538
+ # plenty, and it also gets rid of a confusing warning in Wireshark output.
539
+ #
540
+ # We truncate the cookie to 32 bytes, of which 8 bytes is salt, so that leaves 24 bytes
541
+ # of truncated HMAC = 192 bit security, which is still massive overkill. (TCP uses 32
542
+ # *bits* for this.) HMAC truncation is explicitly noted as safe in RFC 2104:
543
+ # https://datatracker.ietf.org/doc/html/rfc2104#section-5
544
+ COOKIE_LENGTH = 32
545
+
546
+
547
+ def _current_cookie_tick() -> int:
548
+ return int(trio.current_time() / COOKIE_REFRESH_INTERVAL)
549
+
550
+
551
+ # Simple deterministic and invertible serializer -- i.e., a useful tool for converting
552
+ # structured data into something we can cryptographically sign.
553
+ def _signable(*fields: bytes) -> bytes:
554
+ out: list[bytes] = []
555
+ for field in fields:
556
+ out.extend((struct.pack("!Q", len(field)), field))
557
+ return b"".join(out)
558
+
559
+
560
+ def _make_cookie(
561
+ key: bytes,
562
+ salt: bytes,
563
+ tick: int,
564
+ address: AddressFormat,
565
+ client_hello_bits: bytes,
566
+ ) -> bytes:
567
+ assert len(salt) == SALT_BYTES
568
+ assert len(key) == KEY_BYTES
569
+
570
+ signable_data = _signable(
571
+ salt,
572
+ struct.pack("!Q", tick),
573
+ # address is a mix of strings and ints, and variable length, so pack
574
+ # it into a single nested field
575
+ _signable(*(str(part).encode() for part in address)),
576
+ client_hello_bits,
577
+ )
578
+
579
+ return (salt + hmac.digest(key, signable_data, COOKIE_HASH))[:COOKIE_LENGTH]
580
+
581
+
582
+ def valid_cookie(
583
+ key: bytes,
584
+ cookie: bytes,
585
+ address: AddressFormat,
586
+ client_hello_bits: bytes,
587
+ ) -> bool:
588
+ if len(cookie) > SALT_BYTES:
589
+ salt = cookie[:SALT_BYTES]
590
+
591
+ tick = _current_cookie_tick()
592
+
593
+ cur_cookie = _make_cookie(key, salt, tick, address, client_hello_bits)
594
+ old_cookie = _make_cookie(
595
+ key,
596
+ salt,
597
+ max(tick - 1, 0),
598
+ address,
599
+ client_hello_bits,
600
+ )
601
+
602
+ # I doubt using a short-circuiting 'or' here would leak any meaningful
603
+ # information, but why risk it when '|' is just as easy.
604
+ return hmac.compare_digest(cookie, cur_cookie) | hmac.compare_digest(
605
+ cookie,
606
+ old_cookie,
607
+ )
608
+ else:
609
+ return False
610
+
611
+
612
+ def challenge_for(
613
+ key: bytes,
614
+ address: AddressFormat,
615
+ epoch_seqno: int,
616
+ client_hello_bits: bytes,
617
+ ) -> bytes:
618
+ salt = os.urandom(SALT_BYTES)
619
+ tick = _current_cookie_tick()
620
+ cookie = _make_cookie(key, salt, tick, address, client_hello_bits)
621
+
622
+ # HelloVerifyRequest body is:
623
+ # - 2 bytes version
624
+ # - length-prefixed cookie
625
+ #
626
+ # The DTLS 1.2 spec says that for this message specifically we should use
627
+ # the DTLS 1.0 version.
628
+ #
629
+ # (It also says the opposite of that, but that part is a mistake:
630
+ # https://www.rfc-editor.org/errata/eid4103
631
+ # ).
632
+ #
633
+ # And I guess we use this for both the message-level and record-level
634
+ # ProtocolVersions, since we haven't negotiated anything else yet?
635
+ body = ProtocolVersion.DTLS10 + bytes([len(cookie)]) + cookie
636
+
637
+ # RFC says have to copy the client's record number
638
+ # Errata says it should be handshake message number
639
+ # Openssl copies back record sequence number, and always sets message seq
640
+ # number 0. So I guess we'll follow openssl.
641
+ hs = HandshakeFragment(
642
+ msg_type=HandshakeType.hello_verify_request,
643
+ msg_len=len(body),
644
+ msg_seq=0,
645
+ frag_offset=0,
646
+ frag_len=len(body),
647
+ frag=body,
648
+ )
649
+ payload = encode_handshake_fragment(hs)
650
+
651
+ packet = encode_record(
652
+ Record(ContentType.handshake, ProtocolVersion.DTLS10, epoch_seqno, payload),
653
+ )
654
+ return packet
655
+
656
+
657
+ _T = TypeVar("_T")
658
+
659
+
660
+ class _Queue(Generic[_T]):
661
+ def __init__(self, incoming_packets_buffer: int | float) -> None: # noqa: PYI041
662
+ self.s, self.r = trio.open_memory_channel[_T](incoming_packets_buffer)
663
+
664
+
665
+ def _read_loop(read_fn: Callable[[int], bytes]) -> bytes:
666
+ chunks = []
667
+ while True:
668
+ try:
669
+ chunk = read_fn(2**14) # max TLS record size
670
+ except SSL.WantReadError:
671
+ break
672
+ chunks.append(chunk)
673
+ return b"".join(chunks)
674
+
675
+
676
+ async def handle_client_hello_untrusted(
677
+ endpoint: DTLSEndpoint,
678
+ address: AddressFormat,
679
+ packet: bytes,
680
+ ) -> None:
681
+ # it's trivial to write a simple function that directly calls this to
682
+ # get code coverage, but it should maybe:
683
+ # 1. be removed
684
+ # 2. be asserted
685
+ # 3. Write a complicated test case where this happens "organically"
686
+ if endpoint._listening_context is None: # pragma: no cover
687
+ return
688
+
689
+ try:
690
+ epoch_seqno, cookie, bits = decode_client_hello_untrusted(packet)
691
+ except BadPacket:
692
+ return
693
+
694
+ if endpoint._listening_key is None:
695
+ endpoint._listening_key = os.urandom(KEY_BYTES)
696
+
697
+ if not valid_cookie(endpoint._listening_key, cookie, address, bits):
698
+ challenge_packet = challenge_for(
699
+ endpoint._listening_key,
700
+ address,
701
+ epoch_seqno,
702
+ bits,
703
+ )
704
+ try:
705
+ async with endpoint._send_lock:
706
+ await endpoint.socket.sendto(challenge_packet, address)
707
+ except (OSError, trio.ClosedResourceError):
708
+ pass
709
+ else:
710
+ # We got a real, valid ClientHello!
711
+ stream = DTLSChannel._create(endpoint, address, endpoint._listening_context)
712
+ # Our HelloRetryRequest had some sequence number. We need our future sequence
713
+ # numbers to be larger than it, so our peer knows that our future records aren't
714
+ # stale/duplicates. But, we don't know what this sequence number was. What we do
715
+ # know is:
716
+ # - the HelloRetryRequest seqno was copied it from the initial ClientHello
717
+ # - the new ClientHello has a higher seqno than the initial ClientHello
718
+ # So, if we copy the new ClientHello's seqno into our first real handshake
719
+ # record and increment from there, that should work.
720
+ stream._record_encoder.set_first_record_number(epoch_seqno)
721
+ # Process the ClientHello
722
+ try:
723
+ stream._ssl.bio_write(packet)
724
+ stream._ssl.DTLSv1_listen()
725
+ except SSL.Error: # pragma: no cover
726
+ # ...OpenSSL didn't like it, so I guess we didn't have a valid ClientHello
727
+ # after all.
728
+ return
729
+
730
+ # Check if we have an existing association
731
+ old_stream = endpoint._streams.get(address)
732
+ if old_stream is not None:
733
+ if old_stream._client_hello == (cookie, bits):
734
+ # ...This was just a duplicate of the last ClientHello, so never mind.
735
+ return
736
+ else:
737
+ # Ok, this *really is* a new handshake; the old stream should go away.
738
+ old_stream._set_replaced()
739
+ stream._client_hello = (cookie, bits)
740
+ endpoint._streams[address] = stream
741
+ endpoint._incoming_connections_q.s.send_nowait(stream)
742
+
743
+
744
+ async def dtls_receive_loop(
745
+ endpoint_ref: ReferenceType[DTLSEndpoint],
746
+ sock: SocketType,
747
+ ) -> None:
748
+ try:
749
+ while True:
750
+ try:
751
+ packet, address = await sock.recvfrom(MAX_UDP_PACKET_SIZE)
752
+ except OSError as exc:
753
+ if exc.errno == errno.ECONNRESET:
754
+ # Windows only: "On a UDP-datagram socket [ECONNRESET]
755
+ # indicates a previous send operation resulted in an ICMP Port
756
+ # Unreachable message" -- https://docs.microsoft.com/en-us/windows/win32/api/winsock/nf-winsock-recvfrom
757
+ #
758
+ # This is totally useless -- there's nothing we can do with this
759
+ # information. So we just ignore it and retry the recv.
760
+ continue
761
+ else:
762
+ raise
763
+ endpoint = endpoint_ref()
764
+ try:
765
+ if endpoint is None:
766
+ return
767
+ if is_client_hello_untrusted(packet):
768
+ await handle_client_hello_untrusted(endpoint, address, packet)
769
+ elif address in endpoint._streams:
770
+ stream = endpoint._streams[address]
771
+ if stream._did_handshake and part_of_handshake_untrusted(packet):
772
+ # The peer just sent us more handshake messages, that aren't a
773
+ # ClientHello, and we thought the handshake was done. Some of
774
+ # the packets that we sent to finish the handshake must have
775
+ # gotten lost. So re-send them. We do this directly here instead
776
+ # of just putting it into the queue and letting the receiver do
777
+ # it, because there's no guarantee that anyone is reading from
778
+ # the queue, because we think the handshake is done!
779
+ await stream._resend_final_volley()
780
+ else:
781
+ try:
782
+ stream._q.s.send_nowait(packet)
783
+ except trio.WouldBlock:
784
+ stream._packets_dropped_in_trio += 1
785
+ else:
786
+ # Drop packet
787
+ pass
788
+ finally:
789
+ del endpoint
790
+ except trio.ClosedResourceError:
791
+ # socket was closed
792
+ return
793
+ except OSError as exc:
794
+ if exc.errno in (errno.EBADF, errno.ENOTSOCK):
795
+ # socket was closed
796
+ return
797
+ else: # pragma: no cover
798
+ # ??? shouldn't happen
799
+ raise
800
+
801
+
802
+ @attrs.frozen
803
+ class DTLSChannelStatistics:
804
+ """Currently this has only one attribute:
805
+
806
+ - ``incoming_packets_dropped_in_trio`` (``int``): Gives a count of the number of
807
+ incoming packets from this peer that Trio successfully received from the
808
+ network, but then got dropped because the internal channel buffer was full. If
809
+ this is non-zero, then you might want to call ``receive`` more often, or use a
810
+ larger ``incoming_packets_buffer``, or just not worry about it because your
811
+ UDP-based protocol should be able to handle the occasional lost packet, right?
812
+
813
+ """
814
+
815
+ incoming_packets_dropped_in_trio: int
816
+
817
+
818
+ @final
819
+ class DTLSChannel(trio.abc.Channel[bytes], metaclass=NoPublicConstructor):
820
+ """A DTLS connection.
821
+
822
+ This class has no public constructor – you get instances by calling
823
+ `DTLSEndpoint.serve` or `~DTLSEndpoint.connect`.
824
+
825
+ .. attribute:: endpoint
826
+
827
+ The `DTLSEndpoint` that this connection is using.
828
+
829
+ .. attribute:: peer_address
830
+
831
+ The IP/port of the remote peer that this connection is associated with.
832
+
833
+ """
834
+
835
+ def __init__(
836
+ self,
837
+ endpoint: DTLSEndpoint,
838
+ peer_address: AddressFormat,
839
+ ctx: SSL.Context,
840
+ ) -> None:
841
+ self.endpoint = endpoint
842
+ self.peer_address = peer_address
843
+ self._packets_dropped_in_trio = 0
844
+ self._client_hello = None
845
+ self._did_handshake = False
846
+ self._ssl = SSL.Connection(ctx)
847
+ self._handshake_mtu = 0
848
+ # This calls self._ssl.set_ciphertext_mtu, which is important, because if you
849
+ # don't call it then openssl doesn't work.
850
+ self.set_ciphertext_mtu(best_guess_mtu(self.endpoint.socket))
851
+ self._replaced = False
852
+ self._closed = False
853
+ self._q = _Queue[bytes](endpoint.incoming_packets_buffer)
854
+ self._handshake_lock = trio.Lock()
855
+ self._record_encoder: RecordEncoder = RecordEncoder()
856
+
857
+ self._final_volley: list[_AnyHandshakeMessage] = []
858
+
859
+ def _set_replaced(self) -> None:
860
+ self._replaced = True
861
+ # Any packets we already received could maybe possibly still be processed, but
862
+ # there are no more coming. So we close this on the sender side.
863
+ self._q.s.close()
864
+
865
+ def _check_replaced(self) -> None:
866
+ if self._replaced:
867
+ raise trio.BrokenResourceError(
868
+ "peer tore down this connection to start a new one",
869
+ )
870
+
871
+ # XX on systems where we can (maybe just Linux?) take advantage of the kernel's PMTU
872
+ # estimate
873
+
874
+ # XX should we send close-notify when closing? It seems particularly pointless for
875
+ # DTLS where packets are all independent and can be lost anyway. We do at least need
876
+ # to handle receiving it properly though, which might be easier if we send it...
877
+
878
+ def close(self) -> None:
879
+ """Close this connection.
880
+
881
+ `DTLSChannel`\\s don't actually own any OS-level resources – the
882
+ socket is owned by the `DTLSEndpoint`, not the individual connections. So
883
+ you don't really *have* to call this. But it will interrupt any other tasks
884
+ calling `receive` with a `ClosedResourceError`, and cause future attempts to use
885
+ this connection to fail.
886
+
887
+ You can also use this object as a synchronous or asynchronous context manager.
888
+
889
+ """
890
+ if self._closed:
891
+ return
892
+ self._closed = True
893
+ if self.endpoint._streams.get(self.peer_address) is self:
894
+ del self.endpoint._streams[self.peer_address]
895
+ # Will wake any tasks waiting on self._q.get with a
896
+ # ClosedResourceError
897
+ self._q.r.close()
898
+
899
+ def __enter__(self) -> Self:
900
+ return self
901
+
902
+ def __exit__(
903
+ self,
904
+ exc_type: type[BaseException] | None,
905
+ exc_value: BaseException | None,
906
+ traceback: TracebackType | None,
907
+ ) -> None:
908
+ return self.close()
909
+
910
+ async def aclose(self) -> None:
911
+ """Close this connection, but asynchronously.
912
+
913
+ This is included to satisfy the `trio.abc.Channel` contract. It's
914
+ identical to `close`, but async.
915
+
916
+ """
917
+ self.close()
918
+ await trio.lowlevel.checkpoint()
919
+
920
+ async def _send_volley(self, volley_messages: list[_AnyHandshakeMessage]) -> None:
921
+ packets = self._record_encoder.encode_volley(
922
+ volley_messages,
923
+ self._handshake_mtu,
924
+ )
925
+ for packet in packets:
926
+ async with self.endpoint._send_lock:
927
+ await self.endpoint.socket.sendto(packet, self.peer_address)
928
+
929
+ async def _resend_final_volley(self) -> None:
930
+ await self._send_volley(self._final_volley)
931
+
932
+ async def do_handshake(self, *, initial_retransmit_timeout: float = 1.0) -> None:
933
+ """Perform the handshake.
934
+
935
+ Calling this is optional – if you don't, then it will be automatically called
936
+ the first time you call `send` or `receive`. But calling it explicitly can be
937
+ useful in case you want to control the retransmit timeout, use a cancel scope to
938
+ place an overall timeout on the handshake, or catch errors from the handshake
939
+ specifically.
940
+
941
+ It's safe to call this multiple times, or call it simultaneously from multiple
942
+ tasks – the first call will perform the handshake, and the rest will be no-ops.
943
+
944
+ Args:
945
+
946
+ initial_retransmit_timeout (float): Since UDP is an unreliable protocol, it's
947
+ possible that some of the packets we send during the handshake will get
948
+ lost. To handle this, DTLS uses a timer to automatically retransmit
949
+ handshake packets that don't receive a response. This lets you set the
950
+ timeout we use to detect packet loss. Ideally, it should be set to ~1.5
951
+ times the round-trip time to your peer, but 1 second is a reasonable
952
+ default. There's `some useful guidance here
953
+ <https://tlswg.org/dtls13-spec/draft-ietf-tls-dtls13.html#name-timer-values>`__.
954
+
955
+ This is the *initial* timeout, because if packets keep being lost then Trio
956
+ will automatically back off to longer values, to avoid overloading the
957
+ network.
958
+
959
+ """
960
+ async with self._handshake_lock:
961
+ if self._did_handshake:
962
+ return
963
+
964
+ timeout = initial_retransmit_timeout
965
+ volley_messages: list[_AnyHandshakeMessage] = []
966
+ volley_failed_sends = 0
967
+
968
+ def read_volley() -> list[_AnyHandshakeMessage]:
969
+ volley_bytes = _read_loop(self._ssl.bio_read)
970
+ new_volley_messages = decode_volley_trusted(volley_bytes)
971
+ if (
972
+ new_volley_messages
973
+ and volley_messages
974
+ and isinstance(new_volley_messages[0], HandshakeMessage)
975
+ and isinstance(volley_messages[0], HandshakeMessage)
976
+ and new_volley_messages[0].msg_seq == volley_messages[0].msg_seq
977
+ ):
978
+ # openssl decided to retransmit; discard because we handle
979
+ # retransmits ourselves
980
+ return []
981
+ else:
982
+ return new_volley_messages
983
+
984
+ # If we're a client, we send the initial volley. If we're a server, then
985
+ # the initial ClientHello has already been inserted into self._ssl's
986
+ # read BIO. So either way, we start by generating a new volley.
987
+ with contextlib.suppress(SSL.WantReadError):
988
+ self._ssl.do_handshake()
989
+ volley_messages = read_volley()
990
+ # If we don't have messages to send in our initial volley, then something
991
+ # has gone very wrong. (I'm not sure this can actually happen without an
992
+ # error from OpenSSL, but we check just in case.)
993
+ if not volley_messages: # pragma: no cover
994
+ raise SSL.Error("something wrong with peer's ClientHello")
995
+
996
+ while True:
997
+ # -- at this point, we need to either send or re-send a volley --
998
+ assert volley_messages
999
+ self._check_replaced()
1000
+ await self._send_volley(volley_messages)
1001
+ # -- then this is where we wait for a reply --
1002
+ self.endpoint._ensure_receive_loop()
1003
+ with trio.move_on_after(timeout) as cscope:
1004
+ async for packet in self._q.r:
1005
+ self._ssl.bio_write(packet)
1006
+ try:
1007
+ self._ssl.do_handshake()
1008
+ # We ignore generic SSL.Error here, because you can get those
1009
+ # from random invalid packets
1010
+ except (SSL.WantReadError, SSL.Error):
1011
+ pass
1012
+ else:
1013
+ # No exception -> the handshake is done, and we can
1014
+ # switch into data transfer mode.
1015
+ self._did_handshake = True
1016
+ # Might be empty, but that's ok -- we'll just send no
1017
+ # packets.
1018
+ self._final_volley = read_volley()
1019
+ await self._send_volley(self._final_volley)
1020
+ return
1021
+ maybe_volley = read_volley()
1022
+ if maybe_volley:
1023
+ if (
1024
+ isinstance(maybe_volley[0], PseudoHandshakeMessage)
1025
+ and maybe_volley[0].content_type == ContentType.alert
1026
+ ): # TODO: test this line
1027
+ # we're sending an alert (e.g. due to a corrupted
1028
+ # packet). We want to send it once, but don't save it to
1029
+ # retransmit -- keep the last volley as the current
1030
+ # volley.
1031
+ await self._send_volley(maybe_volley)
1032
+ else:
1033
+ # We managed to get all of the peer's volley and
1034
+ # generate a new one ourselves! break out of the 'for'
1035
+ # loop and restart the timer.
1036
+ volley_messages = maybe_volley
1037
+ # "Implementations SHOULD retain the current timer value
1038
+ # until a transmission without loss occurs, at which
1039
+ # time the value may be reset to the initial value."
1040
+ if volley_failed_sends == 0:
1041
+ timeout = initial_retransmit_timeout
1042
+ volley_failed_sends = 0
1043
+ break
1044
+ else:
1045
+ assert self._replaced
1046
+ self._check_replaced()
1047
+ if cscope.cancelled_caught:
1048
+ # Timeout expired. Double timeout for backoff, with a limit of 60
1049
+ # seconds (this matches what openssl does, and also the
1050
+ # recommendation in draft-ietf-tls-dtls13).
1051
+ timeout = min(2 * timeout, 60.0)
1052
+ volley_failed_sends += 1
1053
+ if volley_failed_sends == 2:
1054
+ # We tried sending this twice and they both failed. Maybe our
1055
+ # PMTU estimate is wrong? Let's try dropping it to the minimum
1056
+ # and hope that helps.
1057
+ self._handshake_mtu = min(
1058
+ self._handshake_mtu,
1059
+ worst_case_mtu(self.endpoint.socket),
1060
+ )
1061
+
1062
+ async def send(self, data: bytes) -> None:
1063
+ """Send a packet of data, securely."""
1064
+
1065
+ if self._closed:
1066
+ raise trio.ClosedResourceError
1067
+ if not data:
1068
+ raise ValueError("openssl doesn't support sending empty DTLS packets")
1069
+ if not self._did_handshake:
1070
+ await self.do_handshake()
1071
+ self._check_replaced()
1072
+ self._ssl.write(data)
1073
+ async with self.endpoint._send_lock:
1074
+ await self.endpoint.socket.sendto(
1075
+ _read_loop(self._ssl.bio_read),
1076
+ self.peer_address,
1077
+ )
1078
+
1079
+ async def receive(self) -> bytes:
1080
+ """Fetch the next packet of data from this connection's peer, waiting if
1081
+ necessary.
1082
+
1083
+ This is safe to call from multiple tasks simultaneously, in case you have some
1084
+ reason to do that. And more importantly, it's cancellation-safe, meaning that
1085
+ cancelling a call to `receive` will never cause a packet to be lost or corrupt
1086
+ the underlying connection.
1087
+
1088
+ """
1089
+ if not self._did_handshake:
1090
+ await self.do_handshake()
1091
+ # If the packet isn't really valid, then openssl can decode it to the empty
1092
+ # string (e.g. b/c it's a late-arriving handshake packet, or a duplicate copy of
1093
+ # a data packet). Skip over these instead of returning them.
1094
+ while True:
1095
+ try:
1096
+ packet = await self._q.r.receive()
1097
+ except trio.EndOfChannel:
1098
+ assert self._replaced
1099
+ self._check_replaced()
1100
+ self._ssl.bio_write(packet)
1101
+ cleartext = _read_loop(self._ssl.read)
1102
+ if cleartext:
1103
+ return cleartext
1104
+
1105
+ def set_ciphertext_mtu(self, new_mtu: int) -> None:
1106
+ """Tells Trio the `largest amount of data that can be sent in a single packet to
1107
+ this peer <https://en.wikipedia.org/wiki/Maximum_transmission_unit>`__.
1108
+
1109
+ Trio doesn't actually enforce this limit – if you pass a huge packet to `send`,
1110
+ then we'll dutifully encrypt it and attempt to send it. But calling this method
1111
+ does have two useful effects:
1112
+
1113
+ - If called before the handshake is performed, then Trio will automatically
1114
+ fragment handshake messages to fit within the given MTU. It also might
1115
+ fragment them even smaller, if it detects signs of packet loss, so setting
1116
+ this should never be necessary to make a successful connection. But, the
1117
+ packet loss detection only happens after multiple timeouts have expired, so if
1118
+ you have reason to believe that a smaller MTU is required, then you can set
1119
+ this to skip those timeouts and establish the connection more quickly.
1120
+
1121
+ - It changes the value returned from `get_cleartext_mtu`. So if you have some
1122
+ kind of estimate of the network-level MTU, then you can use this to figure out
1123
+ how much overhead DTLS will need for hashes/padding/etc., and how much space
1124
+ you have left for your application data.
1125
+
1126
+ The MTU here is measuring the largest UDP *payload* you think can be sent, the
1127
+ amount of encrypted data that can be handed to the operating system in a single
1128
+ call to `send`. It should *not* include IP/UDP headers. Note that OS estimates
1129
+ of the MTU often are link-layer MTUs, so you have to subtract off 28 bytes on
1130
+ IPv4 and 48 bytes on IPv6 to get the ciphertext MTU.
1131
+
1132
+ By default, Trio assumes an MTU of 1472 bytes on IPv4, and 1452 bytes on IPv6,
1133
+ which correspond to the common Ethernet MTU of 1500 bytes after accounting for
1134
+ IP/UDP overhead.
1135
+
1136
+ """
1137
+ self._handshake_mtu = new_mtu
1138
+ self._ssl.set_ciphertext_mtu(new_mtu)
1139
+
1140
+ def get_cleartext_mtu(self) -> int:
1141
+ """Returns the largest number of bytes that you can pass in a single call to
1142
+ `send` while still fitting within the network-level MTU.
1143
+
1144
+ See `set_ciphertext_mtu` for more details.
1145
+
1146
+ """
1147
+ if not self._did_handshake:
1148
+ raise trio.NeedHandshakeError
1149
+ return self._ssl.get_cleartext_mtu() # type: ignore[no-any-return]
1150
+
1151
+ def statistics(self) -> DTLSChannelStatistics:
1152
+ """Returns a `DTLSChannelStatistics` object with statistics about this connection."""
1153
+ return DTLSChannelStatistics(self._packets_dropped_in_trio)
1154
+
1155
+
1156
+ @final
1157
+ class DTLSEndpoint:
1158
+ """A DTLS endpoint.
1159
+
1160
+ A single UDP socket can handle arbitrarily many DTLS connections simultaneously,
1161
+ acting as a client or server as needed. A `DTLSEndpoint` object holds a UDP socket
1162
+ and manages these connections, which are represented as `DTLSChannel` objects.
1163
+
1164
+ Args:
1165
+ socket: (trio.socket.SocketType): A ``SOCK_DGRAM`` socket. If you want to accept
1166
+ incoming connections in server mode, then you should probably bind the socket to
1167
+ some known port.
1168
+ incoming_packets_buffer (int): Each `DTLSChannel` using this socket has its own
1169
+ buffer that holds incoming packets until you call `~DTLSChannel.receive` to read
1170
+ them. This lets you adjust the size of this buffer. `~DTLSChannel.statistics`
1171
+ lets you check if the buffer has overflowed.
1172
+
1173
+ .. attribute:: socket
1174
+ incoming_packets_buffer
1175
+
1176
+ Both constructor arguments are also exposed as attributes, in case you need to
1177
+ access them later.
1178
+
1179
+ """
1180
+
1181
+ def __init__(
1182
+ self,
1183
+ socket: SocketType,
1184
+ *,
1185
+ incoming_packets_buffer: int = 10,
1186
+ ) -> None:
1187
+ # We do this lazily on first construction, so only people who actually use DTLS
1188
+ # have to install PyOpenSSL.
1189
+ global SSL
1190
+ from OpenSSL import SSL
1191
+
1192
+ # for __del__, in case the next line raises
1193
+ self._initialized: bool = False
1194
+ if socket.type != trio.socket.SOCK_DGRAM:
1195
+ raise ValueError("DTLS requires a SOCK_DGRAM socket")
1196
+ self._initialized = True
1197
+ self.socket: SocketType = socket
1198
+
1199
+ self.incoming_packets_buffer = incoming_packets_buffer
1200
+ self._token = trio.lowlevel.current_trio_token()
1201
+ # We don't need to track handshaking vs non-handshake connections
1202
+ # separately. We only keep one connection per remote address; as soon
1203
+ # as a peer provides a valid cookie, we can immediately tear down the
1204
+ # old connection.
1205
+ # {remote address: DTLSChannel}
1206
+ self._streams: WeakValueDictionary[AddressFormat, DTLSChannel] = (
1207
+ WeakValueDictionary()
1208
+ )
1209
+ self._listening_context: SSL.Context | None = None
1210
+ self._listening_key: bytes | None = None
1211
+ self._incoming_connections_q = _Queue[DTLSChannel](float("inf"))
1212
+ self._send_lock = trio.Lock()
1213
+ self._closed = False
1214
+ self._receive_loop_spawned = False
1215
+
1216
+ def _ensure_receive_loop(self) -> None:
1217
+ # We have to spawn this lazily, because on Windows it will immediately error out
1218
+ # if the socket isn't already bound -- which for clients might not happen until
1219
+ # after we send our first packet.
1220
+ if not self._receive_loop_spawned:
1221
+ trio.lowlevel.spawn_system_task(
1222
+ dtls_receive_loop,
1223
+ weakref.ref(self),
1224
+ self.socket,
1225
+ )
1226
+ self._receive_loop_spawned = True
1227
+
1228
+ def __del__(self) -> None:
1229
+ # Do nothing if this object was never fully constructed
1230
+ if not self._initialized:
1231
+ return
1232
+ # Close the socket in Trio context (if our Trio context still exists), so that
1233
+ # the background task gets notified about the closure and can exit.
1234
+ if not self._closed:
1235
+ with contextlib.suppress(RuntimeError):
1236
+ self._token.run_sync_soon(self.close)
1237
+ # Do this last, because it might raise an exception
1238
+ warnings.warn(
1239
+ f"unclosed DTLS endpoint {self!r}",
1240
+ ResourceWarning,
1241
+ source=self,
1242
+ stacklevel=1,
1243
+ )
1244
+
1245
+ def close(self) -> None:
1246
+ """Close this socket, and all associated DTLS connections.
1247
+
1248
+ This object can also be used as a context manager.
1249
+
1250
+ """
1251
+ self._closed = True
1252
+ self.socket.close()
1253
+ for stream in list(self._streams.values()):
1254
+ stream.close()
1255
+ self._incoming_connections_q.s.close()
1256
+
1257
+ def __enter__(self) -> Self:
1258
+ return self
1259
+
1260
+ def __exit__(
1261
+ self,
1262
+ exc_type: type[BaseException] | None,
1263
+ exc_value: BaseException | None,
1264
+ traceback: TracebackType | None,
1265
+ ) -> None:
1266
+ return self.close()
1267
+
1268
+ def _check_closed(self) -> None:
1269
+ if self._closed:
1270
+ raise trio.ClosedResourceError
1271
+
1272
+ async def serve(
1273
+ self,
1274
+ ssl_context: SSL.Context,
1275
+ async_fn: Callable[[DTLSChannel, Unpack[PosArgsT]], Awaitable[object]],
1276
+ *args: Unpack[PosArgsT],
1277
+ task_status: trio.TaskStatus[None] = trio.TASK_STATUS_IGNORED,
1278
+ ) -> None:
1279
+ """Listen for incoming connections, and spawn a handler for each using an
1280
+ internal nursery.
1281
+
1282
+ Similar to `~trio.serve_tcp`, this function never returns until cancelled, or
1283
+ the `DTLSEndpoint` is closed and all handlers have exited.
1284
+
1285
+ Usage commonly looks like::
1286
+
1287
+ async def handler(dtls_channel):
1288
+ ...
1289
+
1290
+ async with trio.open_nursery() as nursery:
1291
+ await nursery.start(dtls_endpoint.serve, ssl_context, handler)
1292
+ # ... do other things here ...
1293
+
1294
+ The ``dtls_channel`` passed into the handler function has already performed the
1295
+ "cookie exchange" part of the DTLS handshake, so the peer address is
1296
+ trustworthy. But the actual cryptographic handshake doesn't happen until you
1297
+ start using it, giving you a chance for any last minute configuration, and the
1298
+ option to catch and handle handshake errors.
1299
+
1300
+ Args:
1301
+ ssl_context (OpenSSL.SSL.Context): The PyOpenSSL context object to use for
1302
+ incoming connections.
1303
+ async_fn: The handler function that will be invoked for each incoming
1304
+ connection.
1305
+ *args: Additional arguments to pass to the handler function.
1306
+
1307
+ """
1308
+ self._check_closed()
1309
+ if self._listening_context is not None:
1310
+ raise trio.BusyResourceError("another task is already listening")
1311
+ try:
1312
+ self.socket.getsockname()
1313
+ except OSError: # TODO: test this line
1314
+ raise RuntimeError(
1315
+ "DTLS socket must be bound before it can serve",
1316
+ ) from None
1317
+ self._ensure_receive_loop()
1318
+ # We do cookie verification ourselves, so tell OpenSSL not to worry about it.
1319
+ # (See also _inject_client_hello_untrusted.)
1320
+ ssl_context.set_cookie_verify_callback(lambda *_: True)
1321
+ set_ssl_context_options(ssl_context)
1322
+ try:
1323
+ self._listening_context = ssl_context
1324
+ task_status.started()
1325
+
1326
+ async def handler_wrapper(stream: DTLSChannel) -> None:
1327
+ with stream:
1328
+ await async_fn(stream, *args)
1329
+
1330
+ async with trio.open_nursery() as nursery:
1331
+ async for stream in self._incoming_connections_q.r: # pragma: no branch
1332
+ nursery.start_soon(handler_wrapper, stream)
1333
+ finally:
1334
+ self._listening_context = None
1335
+
1336
+ def connect(
1337
+ self,
1338
+ address: tuple[str, int],
1339
+ ssl_context: SSL.Context,
1340
+ ) -> DTLSChannel:
1341
+ """Initiate an outgoing DTLS connection.
1342
+
1343
+ Notice that this is a synchronous method. That's because it doesn't actually
1344
+ initiate any I/O – it just sets up a `DTLSChannel` object. The actual handshake
1345
+ doesn't occur until you start using the `DTLSChannel`. This gives you a chance
1346
+ to do further configuration first, like setting MTU etc.
1347
+
1348
+ Args:
1349
+ address: The address to connect to. Usually a (host, port) tuple, like
1350
+ ``("127.0.0.1", 12345)``.
1351
+ ssl_context (OpenSSL.SSL.Context): The PyOpenSSL context object to use for
1352
+ this connection.
1353
+
1354
+ Returns:
1355
+ DTLSChannel
1356
+
1357
+ """
1358
+ # it would be nice if we could detect when 'address' is our own endpoint (a
1359
+ # loopback connection), because that can't work
1360
+ # but I don't see how to do it reliably
1361
+ self._check_closed()
1362
+ set_ssl_context_options(ssl_context)
1363
+ channel = DTLSChannel._create(self, address, ssl_context)
1364
+ channel._ssl.set_connect_state()
1365
+ old_channel = self._streams.get(address)
1366
+ if old_channel is not None:
1367
+ old_channel._set_replaced()
1368
+ self._streams[address] = channel
1369
+ return channel
1370
+
1371
+
1372
+ def set_ssl_context_options(ctx: SSL.Context) -> None:
1373
+ # These are mandatory for all DTLS connections. OP_NO_QUERY_MTU is required to
1374
+ # stop openssl from trying to query the memory BIO's MTU and then breaking, and
1375
+ # OP_NO_RENEGOTIATION disables renegotiation, which is too complex for us to
1376
+ # support and isn't useful anyway -- especially for DTLS where it's equivalent
1377
+ # to just performing a new handshake.
1378
+ ctx.set_options(
1379
+ SSL.OP_NO_QUERY_MTU | SSL.OP_NO_RENEGOTIATION, # type: ignore[attr-defined]
1380
+ )