omserv 0.0.0.dev7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- omserv/__about__.py +28 -0
- omserv/__init__.py +0 -0
- omserv/apps/__init__.py +0 -0
- omserv/apps/base.py +23 -0
- omserv/apps/inject.py +89 -0
- omserv/apps/markers.py +41 -0
- omserv/apps/routes.py +139 -0
- omserv/apps/sessions.py +57 -0
- omserv/apps/templates.py +90 -0
- omserv/dbs.py +24 -0
- omserv/node/__init__.py +0 -0
- omserv/node/models.py +53 -0
- omserv/node/registry.py +124 -0
- omserv/node/sql.py +131 -0
- omserv/secrets.py +12 -0
- omserv/server/__init__.py +18 -0
- omserv/server/config.py +51 -0
- omserv/server/debug.py +14 -0
- omserv/server/events.py +83 -0
- omserv/server/headers.py +36 -0
- omserv/server/lifespans.py +132 -0
- omserv/server/multiprocess.py +157 -0
- omserv/server/protocols/__init__.py +1 -0
- omserv/server/protocols/h11.py +334 -0
- omserv/server/protocols/h2.py +407 -0
- omserv/server/protocols/protocols.py +91 -0
- omserv/server/protocols/types.py +18 -0
- omserv/server/resources/__init__.py +8 -0
- omserv/server/sockets.py +111 -0
- omserv/server/ssl.py +47 -0
- omserv/server/streams/__init__.py +0 -0
- omserv/server/streams/httpstream.py +237 -0
- omserv/server/streams/utils.py +53 -0
- omserv/server/streams/wsstream.py +447 -0
- omserv/server/taskspawner.py +111 -0
- omserv/server/tcpserver.py +173 -0
- omserv/server/types.py +94 -0
- omserv/server/workercontext.py +52 -0
- omserv/server/workers.py +193 -0
- omserv-0.0.0.dev7.dist-info/LICENSE +21 -0
- omserv-0.0.0.dev7.dist-info/METADATA +21 -0
- omserv-0.0.0.dev7.dist-info/RECORD +44 -0
- omserv-0.0.0.dev7.dist-info/WHEEL +5 -0
- omserv-0.0.0.dev7.dist-info/top_level.txt +1 -0
@@ -0,0 +1,447 @@
|
|
1
|
+
import encodings.idna # prevents `LookupError: unknown encoding: idna` # noqa
|
2
|
+
import enum
|
3
|
+
import io
|
4
|
+
import logging
|
5
|
+
import time
|
6
|
+
import typing as ta
|
7
|
+
import urllib.parse
|
8
|
+
|
9
|
+
from omlish import check
|
10
|
+
import wsproto as wsp
|
11
|
+
import wsproto.events as wse
|
12
|
+
import wsproto.extensions
|
13
|
+
import wsproto.frame_protocol
|
14
|
+
import wsproto.utilities
|
15
|
+
|
16
|
+
from ..config import Config
|
17
|
+
from ..events import Body
|
18
|
+
from ..events import Data
|
19
|
+
from ..events import EndBody
|
20
|
+
from ..events import EndData
|
21
|
+
from ..events import ProtocolEvent
|
22
|
+
from ..events import Request
|
23
|
+
from ..events import Response
|
24
|
+
from ..events import StreamClosed
|
25
|
+
from ..taskspawner import TaskSpawner
|
26
|
+
from ..types import AsgiSendEvent
|
27
|
+
from ..types import AppWrapper
|
28
|
+
from ..types import WebsocketAcceptEvent
|
29
|
+
from ..types import WebsocketResponseBodyEvent
|
30
|
+
from ..types import WebsocketResponseStartEvent
|
31
|
+
from ..types import WebsocketScope
|
32
|
+
from ..workercontext import WorkerContext
|
33
|
+
from .httpstream import UnexpectedMessageError
|
34
|
+
from .utils import build_and_validate_headers
|
35
|
+
from .utils import log_access
|
36
|
+
from .utils import suppress_body
|
37
|
+
from .utils import valid_server_name
|
38
|
+
|
39
|
+
|
40
|
+
log = logging.getLogger(__name__)
|
41
|
+
|
42
|
+
|
43
|
+
class AsgiWebsocketState(enum.Enum):
|
44
|
+
# Hypercorn supports the Asgi websocket HTTP response extension, which allows HTTP responses rather than acceptance.
|
45
|
+
HANDSHAKE = enum.auto()
|
46
|
+
CONNECTED = enum.auto()
|
47
|
+
RESPONSE = enum.auto()
|
48
|
+
CLOSED = enum.auto()
|
49
|
+
HTTPCLOSED = enum.auto()
|
50
|
+
|
51
|
+
|
52
|
+
class FrameTooLargeError(Exception):
|
53
|
+
pass
|
54
|
+
|
55
|
+
|
56
|
+
class Handshake:
|
57
|
+
def __init__(
|
58
|
+
self,
|
59
|
+
headers: list[tuple[bytes, bytes]],
|
60
|
+
http_version: str,
|
61
|
+
) -> None:
|
62
|
+
super().__init__()
|
63
|
+
|
64
|
+
self.http_version = http_version
|
65
|
+
self.connection_tokens: list[str] | None = None
|
66
|
+
self.extensions: list[str] | None = None
|
67
|
+
self.key: bytes | None = None
|
68
|
+
self.subprotocols: list[str] | None = None
|
69
|
+
self.upgrade: bytes | None = None
|
70
|
+
self.version: bytes | None = None
|
71
|
+
|
72
|
+
for name, value in headers:
|
73
|
+
name = name.lower()
|
74
|
+
|
75
|
+
if name == b'connection':
|
76
|
+
self.connection_tokens = wsp.utilities.split_comma_header(value)
|
77
|
+
|
78
|
+
elif name == b'sec-websocket-extensions':
|
79
|
+
self.extensions = wsp.utilities.split_comma_header(value)
|
80
|
+
|
81
|
+
elif name == b'sec-websocket-key':
|
82
|
+
self.key = value
|
83
|
+
|
84
|
+
elif name == b'sec-websocket-protocol':
|
85
|
+
self.subprotocols = wsp.utilities.split_comma_header(value)
|
86
|
+
|
87
|
+
elif name == b'sec-websocket-version':
|
88
|
+
self.version = value
|
89
|
+
|
90
|
+
elif name == b'upgrade':
|
91
|
+
self.upgrade = value
|
92
|
+
|
93
|
+
def is_valid(self) -> bool:
|
94
|
+
if self.http_version < '1.1':
|
95
|
+
return False
|
96
|
+
|
97
|
+
elif self.http_version == '1.1':
|
98
|
+
if self.key is None:
|
99
|
+
return False
|
100
|
+
|
101
|
+
if self.connection_tokens is None or not any(
|
102
|
+
token.lower() == 'upgrade' for token in self.connection_tokens
|
103
|
+
):
|
104
|
+
return False
|
105
|
+
|
106
|
+
if (self.upgrade or b'').lower() != b'websocket':
|
107
|
+
return False
|
108
|
+
|
109
|
+
if self.version != wsp.handshake.WEBSOCKET_VERSION:
|
110
|
+
return False
|
111
|
+
|
112
|
+
return True
|
113
|
+
|
114
|
+
def accept(
|
115
|
+
self,
|
116
|
+
subprotocol: str | None,
|
117
|
+
additional_headers: ta.Iterable[tuple[bytes, bytes]],
|
118
|
+
) -> tuple[int, list[tuple[bytes, bytes]], wsp.Connection]:
|
119
|
+
headers = []
|
120
|
+
if subprotocol is not None:
|
121
|
+
if self.subprotocols is None or subprotocol not in self.subprotocols:
|
122
|
+
raise Exception('Invalid Subprotocol')
|
123
|
+
else:
|
124
|
+
headers.append((b'sec-websocket-protocol', subprotocol.encode()))
|
125
|
+
|
126
|
+
extensions: list[wsp.extensions.Extension] = [wsp.extensions.PerMessageDeflate()]
|
127
|
+
accepts = None
|
128
|
+
if self.extensions is not None:
|
129
|
+
accepts = wsp.handshake.server_extensions_handshake(self.extensions, extensions)
|
130
|
+
|
131
|
+
if accepts:
|
132
|
+
headers.append((b'sec-websocket-extensions', accepts))
|
133
|
+
|
134
|
+
if self.key is not None:
|
135
|
+
headers.append((b'sec-websocket-accept', wsp.utilities.generate_accept_token(self.key)))
|
136
|
+
|
137
|
+
status_code = 200
|
138
|
+
if self.http_version == '1.1':
|
139
|
+
headers.extend([(b'upgrade', b'WebSocket'), (b'connection', b'Upgrade')])
|
140
|
+
status_code = 101
|
141
|
+
|
142
|
+
for name, value in additional_headers:
|
143
|
+
if name == b'sec-websocket-protocol' or name.startswith(b':'):
|
144
|
+
raise Exception(f'Invalid additional header, {name.decode()}')
|
145
|
+
|
146
|
+
headers.append((name, value))
|
147
|
+
|
148
|
+
return status_code, headers, wsp.Connection(wsp.ConnectionType.SERVER, extensions)
|
149
|
+
|
150
|
+
|
151
|
+
class WebsocketBuffer:
|
152
|
+
def __init__(self, max_length: int) -> None:
|
153
|
+
super().__init__()
|
154
|
+
|
155
|
+
self.value: io.BytesIO | io.StringIO | None = None
|
156
|
+
self.length = 0
|
157
|
+
self.max_length = max_length
|
158
|
+
|
159
|
+
def extend(self, event: wse.Message) -> None:
|
160
|
+
if self.value is None:
|
161
|
+
if isinstance(event, wse.TextMessage):
|
162
|
+
self.value = io.StringIO()
|
163
|
+
else:
|
164
|
+
self.value = io.BytesIO()
|
165
|
+
|
166
|
+
self.length += self.value.write(event.data)
|
167
|
+
|
168
|
+
if self.length > self.max_length:
|
169
|
+
raise FrameTooLargeError
|
170
|
+
|
171
|
+
def clear(self) -> None:
|
172
|
+
self.value = None
|
173
|
+
self.length = 0
|
174
|
+
|
175
|
+
def to_message(self) -> dict:
|
176
|
+
return {
|
177
|
+
'type': 'websocket.receive',
|
178
|
+
'bytes': self.value.getvalue() if isinstance(self.value, io.BytesIO) else None,
|
179
|
+
'text': self.value.getvalue() if isinstance(self.value, io.StringIO) else None,
|
180
|
+
}
|
181
|
+
|
182
|
+
|
183
|
+
class WsStream:
|
184
|
+
def __init__(
|
185
|
+
self,
|
186
|
+
app: AppWrapper,
|
187
|
+
config: Config,
|
188
|
+
context: WorkerContext,
|
189
|
+
task_spawner: TaskSpawner,
|
190
|
+
client: tuple[str, int] | None,
|
191
|
+
server: tuple[str, int] | None,
|
192
|
+
send: ta.Callable[[ProtocolEvent], ta.Awaitable[None]],
|
193
|
+
stream_id: int,
|
194
|
+
) -> None:
|
195
|
+
super().__init__()
|
196
|
+
|
197
|
+
self.app = app
|
198
|
+
self.app_put: ta.Callable | None = None
|
199
|
+
self.buffer = WebsocketBuffer(config.websocket_max_message_size)
|
200
|
+
self.client = client
|
201
|
+
self.closed = False
|
202
|
+
self.config = config
|
203
|
+
self.context = context
|
204
|
+
self.task_spawner = task_spawner
|
205
|
+
self.response: WebsocketResponseStartEvent
|
206
|
+
self.scope: WebsocketScope
|
207
|
+
self.send = send
|
208
|
+
# RFC 8441 for HTTP/2 says use http or https, Asgi says ws or wss
|
209
|
+
self.scheme = 'ws' # #'wss' if ssl else 'ws'
|
210
|
+
self.server = server
|
211
|
+
self.start_time: float
|
212
|
+
self.state = AsgiWebsocketState.HANDSHAKE
|
213
|
+
self.stream_id = stream_id
|
214
|
+
|
215
|
+
self.connection: wsp.Connection
|
216
|
+
self.handshake: Handshake
|
217
|
+
|
218
|
+
@property
|
219
|
+
def idle(self) -> bool:
|
220
|
+
return self.state in {AsgiWebsocketState.CLOSED, AsgiWebsocketState.HTTPCLOSED}
|
221
|
+
|
222
|
+
async def handle(self, event: ProtocolEvent) -> None:
|
223
|
+
if self.closed:
|
224
|
+
return
|
225
|
+
elif isinstance(event, Request):
|
226
|
+
self.start_time = time.time()
|
227
|
+
self.handshake = Handshake(event.headers, event.http_version)
|
228
|
+
path, _, query_string = event.raw_path.partition(b'?')
|
229
|
+
self.scope = {
|
230
|
+
'type': 'websocket',
|
231
|
+
'asgi': {'spec_version': '2.3', 'version': '3.0'},
|
232
|
+
'scheme': self.scheme,
|
233
|
+
'http_version': event.http_version,
|
234
|
+
'path': urllib.parse.unquote(path.decode('ascii')),
|
235
|
+
'raw_path': path,
|
236
|
+
'query_string': query_string,
|
237
|
+
# 'root_path': self.config.root_path,
|
238
|
+
'client': self.client,
|
239
|
+
'server': self.server,
|
240
|
+
'subprotocols': self.handshake.subprotocols or [],
|
241
|
+
'extensions': {'websocket.http.response': {}},
|
242
|
+
}
|
243
|
+
|
244
|
+
if not valid_server_name(self.config, event):
|
245
|
+
await self._send_error_response(404)
|
246
|
+
self.closed = True
|
247
|
+
|
248
|
+
elif not self.handshake.is_valid():
|
249
|
+
await self._send_error_response(400)
|
250
|
+
self.closed = True
|
251
|
+
|
252
|
+
else:
|
253
|
+
self.app_put = await self.task_spawner.spawn_app(self.app, self.config, self.scope, self.app_send)
|
254
|
+
await self.app_put({'type': 'websocket.connect'})
|
255
|
+
|
256
|
+
elif isinstance(event, (Body, Data)):
|
257
|
+
self.connection.receive_data(event.data)
|
258
|
+
await self._handle_events()
|
259
|
+
|
260
|
+
elif isinstance(event, StreamClosed):
|
261
|
+
self.closed = True
|
262
|
+
|
263
|
+
if self.app_put is not None:
|
264
|
+
if self.state in {AsgiWebsocketState.HTTPCLOSED, AsgiWebsocketState.CLOSED}:
|
265
|
+
code = wsp.frame_protocol.CloseReason.NORMAL_CLOSURE.value
|
266
|
+
else:
|
267
|
+
code = wsp.frame_protocol.CloseReason.ABNORMAL_CLOSURE.value
|
268
|
+
|
269
|
+
await self.app_put({'type': 'websocket.disconnect', 'code': code})
|
270
|
+
|
271
|
+
async def app_send(self, message: AsgiSendEvent | None) -> None:
|
272
|
+
if self.closed:
|
273
|
+
# Allow app to finish after close
|
274
|
+
return
|
275
|
+
|
276
|
+
if message is None: # Asgi App has finished sending messages
|
277
|
+
# Cleanup if required
|
278
|
+
if self.state == AsgiWebsocketState.HANDSHAKE:
|
279
|
+
await self._send_error_response(500)
|
280
|
+
|
281
|
+
log_access(
|
282
|
+
self.config,
|
283
|
+
self.scope,
|
284
|
+
{'status': 500, 'headers': []},
|
285
|
+
time.time() - self.start_time,
|
286
|
+
)
|
287
|
+
|
288
|
+
elif self.state == AsgiWebsocketState.CONNECTED:
|
289
|
+
await self._send_wsproto_event(wse.CloseConnection(code=wsp.frame_protocol.CloseReason.INTERNAL_ERROR))
|
290
|
+
|
291
|
+
await self.send(StreamClosed(stream_id=self.stream_id))
|
292
|
+
|
293
|
+
elif message['type'] == 'websocket.accept' and self.state == AsgiWebsocketState.HANDSHAKE:
|
294
|
+
await self._accept(message)
|
295
|
+
|
296
|
+
elif (
|
297
|
+
message['type'] == 'websocket.http.response.start'
|
298
|
+
and self.state == AsgiWebsocketState.HANDSHAKE
|
299
|
+
):
|
300
|
+
self.response = message
|
301
|
+
|
302
|
+
elif message['type'] == 'websocket.http.response.body' and self.state in {
|
303
|
+
AsgiWebsocketState.HANDSHAKE,
|
304
|
+
AsgiWebsocketState.RESPONSE,
|
305
|
+
}:
|
306
|
+
await self._send_rejection(message)
|
307
|
+
|
308
|
+
elif message['type'] == 'websocket.send' and self.state == AsgiWebsocketState.CONNECTED:
|
309
|
+
event: wse.Event
|
310
|
+
if message.get('bytes') is not None:
|
311
|
+
event = wse.BytesMessage(data=bytes(message['bytes']))
|
312
|
+
|
313
|
+
elif not isinstance(message['text'], str):
|
314
|
+
raise TypeError(f'{message["text"]} should be a str')
|
315
|
+
|
316
|
+
else:
|
317
|
+
event = wse.TextMessage(data=message['text'])
|
318
|
+
|
319
|
+
await self._send_wsproto_event(event)
|
320
|
+
|
321
|
+
elif (
|
322
|
+
message['type'] == 'websocket.close' and self.state == AsgiWebsocketState.HANDSHAKE
|
323
|
+
):
|
324
|
+
self.state = AsgiWebsocketState.HTTPCLOSED
|
325
|
+
await self._send_error_response(403)
|
326
|
+
|
327
|
+
elif message['type'] == 'websocket.close':
|
328
|
+
self.state = AsgiWebsocketState.CLOSED
|
329
|
+
|
330
|
+
await self._send_wsproto_event(wse.CloseConnection(
|
331
|
+
code=int(message.get('code', wsp.frame_protocol.CloseReason.NORMAL_CLOSURE)),
|
332
|
+
reason=message.get('reason'),
|
333
|
+
))
|
334
|
+
|
335
|
+
await self.send(EndData(stream_id=self.stream_id))
|
336
|
+
|
337
|
+
else:
|
338
|
+
raise UnexpectedMessageError(self.state, message['type'])
|
339
|
+
|
340
|
+
async def _handle_events(self) -> None:
|
341
|
+
for event in self.connection.events():
|
342
|
+
if isinstance(event, wse.Message):
|
343
|
+
try:
|
344
|
+
self.buffer.extend(event)
|
345
|
+
except FrameTooLargeError:
|
346
|
+
await self._send_wsproto_event(wse.CloseConnection(
|
347
|
+
code=wsp.frame_protocol.CloseReason.MESSAGE_TOO_BIG,
|
348
|
+
))
|
349
|
+
break
|
350
|
+
|
351
|
+
if event.message_finished:
|
352
|
+
await check.not_none(self.app_put)(self.buffer.to_message())
|
353
|
+
self.buffer.clear()
|
354
|
+
|
355
|
+
elif isinstance(event, wse.Ping):
|
356
|
+
await self._send_wsproto_event(event.response())
|
357
|
+
|
358
|
+
elif isinstance(event, wse.CloseConnection):
|
359
|
+
if self.connection.state == wsp.ConnectionState.REMOTE_CLOSING:
|
360
|
+
await self._send_wsproto_event(event.response())
|
361
|
+
|
362
|
+
await self.send(StreamClosed(stream_id=self.stream_id))
|
363
|
+
|
364
|
+
async def _send_error_response(self, status_code: int) -> None:
|
365
|
+
await self.send(
|
366
|
+
Response(
|
367
|
+
stream_id=self.stream_id,
|
368
|
+
status_code=status_code,
|
369
|
+
headers=[(b'content-length', b'0'), (b'connection', b'close')],
|
370
|
+
),
|
371
|
+
)
|
372
|
+
await self.send(EndBody(stream_id=self.stream_id))
|
373
|
+
log_access(
|
374
|
+
self.config,
|
375
|
+
self.scope,
|
376
|
+
{
|
377
|
+
'status': status_code,
|
378
|
+
'headers': [],
|
379
|
+
},
|
380
|
+
time.time() - self.start_time,
|
381
|
+
)
|
382
|
+
|
383
|
+
async def _send_wsproto_event(self, event: wse.Event) -> None:
|
384
|
+
try:
|
385
|
+
data = self.connection.send(event)
|
386
|
+
except wsp.utilities.LocalProtocolError:
|
387
|
+
pass
|
388
|
+
else:
|
389
|
+
await self.send(Data(stream_id=self.stream_id, data=data))
|
390
|
+
|
391
|
+
async def _accept(self, message: WebsocketAcceptEvent) -> None:
|
392
|
+
self.state = AsgiWebsocketState.CONNECTED
|
393
|
+
|
394
|
+
status_code, headers, self.connection = self.handshake.accept(
|
395
|
+
message.get('subprotocol'), message.get('headers', []),
|
396
|
+
)
|
397
|
+
|
398
|
+
await self.send(Response(stream_id=self.stream_id, status_code=status_code, headers=headers))
|
399
|
+
|
400
|
+
log_access(
|
401
|
+
self.config,
|
402
|
+
self.scope,
|
403
|
+
{
|
404
|
+
'status': status_code,
|
405
|
+
'headers': [],
|
406
|
+
},
|
407
|
+
time.time() - self.start_time,
|
408
|
+
)
|
409
|
+
|
410
|
+
if self.config.websocket_ping_interval is not None:
|
411
|
+
self.task_spawner.spawn(self._send_pings)
|
412
|
+
|
413
|
+
async def _send_rejection(self, message: WebsocketResponseBodyEvent) -> None:
|
414
|
+
body_suppressed = suppress_body('GET', self.response['status'])
|
415
|
+
|
416
|
+
if self.state == AsgiWebsocketState.HANDSHAKE:
|
417
|
+
headers = build_and_validate_headers(self.response['headers'])
|
418
|
+
|
419
|
+
await self.send(
|
420
|
+
Response(
|
421
|
+
stream_id=self.stream_id,
|
422
|
+
status_code=int(self.response['status']),
|
423
|
+
headers=headers,
|
424
|
+
),
|
425
|
+
)
|
426
|
+
|
427
|
+
self.state = AsgiWebsocketState.RESPONSE
|
428
|
+
|
429
|
+
if not body_suppressed:
|
430
|
+
await self.send(Body(stream_id=self.stream_id, data=bytes(message.get('body', b''))))
|
431
|
+
|
432
|
+
if not message.get('more_body', False):
|
433
|
+
self.state = AsgiWebsocketState.HTTPCLOSED
|
434
|
+
|
435
|
+
await self.send(EndBody(stream_id=self.stream_id))
|
436
|
+
|
437
|
+
log_access(
|
438
|
+
self.config,
|
439
|
+
self.scope,
|
440
|
+
self.response, # type: ignore
|
441
|
+
time.time() - self.start_time,
|
442
|
+
)
|
443
|
+
|
444
|
+
async def _send_pings(self) -> None:
|
445
|
+
while not self.closed:
|
446
|
+
await self._send_wsproto_event(wse.Ping())
|
447
|
+
await self.context.sleep(check.not_none(self.config.websocket_ping_interval))
|
@@ -0,0 +1,111 @@
|
|
1
|
+
import logging
|
2
|
+
import types
|
3
|
+
import typing as ta
|
4
|
+
|
5
|
+
import anyio
|
6
|
+
import anyio.abc
|
7
|
+
import anyio.from_thread
|
8
|
+
import anyio.to_thread
|
9
|
+
|
10
|
+
from omlish import check
|
11
|
+
|
12
|
+
from .config import Config
|
13
|
+
from .debug import handle_error_debug
|
14
|
+
from .types import AppWrapper
|
15
|
+
from .types import AsgiReceiveCallable
|
16
|
+
from .types import AsgiReceiveEvent
|
17
|
+
from .types import AsgiSendEvent
|
18
|
+
from .types import Scope
|
19
|
+
|
20
|
+
|
21
|
+
log = logging.getLogger(__name__)
|
22
|
+
|
23
|
+
|
24
|
+
async def _handle(
|
25
|
+
app: AppWrapper,
|
26
|
+
config: Config,
|
27
|
+
scope: Scope,
|
28
|
+
receive: AsgiReceiveCallable,
|
29
|
+
send: ta.Callable[[AsgiSendEvent | None], ta.Awaitable[None]],
|
30
|
+
sync_spawn: ta.Callable,
|
31
|
+
call_soon: ta.Callable,
|
32
|
+
) -> None:
|
33
|
+
try:
|
34
|
+
await app(
|
35
|
+
scope,
|
36
|
+
receive,
|
37
|
+
send,
|
38
|
+
sync_spawn,
|
39
|
+
call_soon,
|
40
|
+
)
|
41
|
+
|
42
|
+
except anyio.get_cancelled_exc_class():
|
43
|
+
raise
|
44
|
+
|
45
|
+
except BaseExceptionGroup as error:
|
46
|
+
handle_error_debug(error)
|
47
|
+
|
48
|
+
_, other_errors = error.split(anyio.get_cancelled_exc_class())
|
49
|
+
if other_errors is not None:
|
50
|
+
log.exception('Error in Asgi Framework')
|
51
|
+
await send(None)
|
52
|
+
else:
|
53
|
+
raise
|
54
|
+
|
55
|
+
except Exception as error:
|
56
|
+
handle_error_debug(error)
|
57
|
+
|
58
|
+
log.exception('Error in Asgi Framework')
|
59
|
+
|
60
|
+
finally:
|
61
|
+
await send(None)
|
62
|
+
|
63
|
+
|
64
|
+
class TaskSpawner:
|
65
|
+
def __init__(self) -> None:
|
66
|
+
super().__init__()
|
67
|
+
self._task_group: anyio.abc.TaskGroup | None = None
|
68
|
+
|
69
|
+
async def start(
|
70
|
+
self,
|
71
|
+
func: ta.Callable[..., ta.Awaitable[ta.Any]],
|
72
|
+
*args: ta.Any,
|
73
|
+
) -> anyio.CancelScope:
|
74
|
+
return await check.not_none(self._task_group).start(func, *args)
|
75
|
+
|
76
|
+
async def spawn_app(
|
77
|
+
self,
|
78
|
+
app: AppWrapper,
|
79
|
+
config: Config,
|
80
|
+
scope: Scope,
|
81
|
+
send: ta.Callable[[AsgiSendEvent | None], ta.Awaitable[None]],
|
82
|
+
) -> ta.Callable[[AsgiReceiveEvent], ta.Awaitable[None]]:
|
83
|
+
app_send_channel, app_receive_channel = anyio.create_memory_object_stream[ta.Any](config.max_app_queue_size)
|
84
|
+
check.not_none(self._task_group).start_soon(
|
85
|
+
_handle,
|
86
|
+
app,
|
87
|
+
config,
|
88
|
+
scope,
|
89
|
+
app_receive_channel.receive,
|
90
|
+
send,
|
91
|
+
anyio.to_thread.run_sync,
|
92
|
+
anyio.from_thread.run,
|
93
|
+
)
|
94
|
+
return app_send_channel.send
|
95
|
+
|
96
|
+
def spawn(self, func: ta.Callable, *args: ta.Any) -> None:
|
97
|
+
check.not_none(self._task_group).start_soon(func, *args)
|
98
|
+
|
99
|
+
async def __aenter__(self) -> ta.Self:
|
100
|
+
self._task_group = anyio.create_task_group()
|
101
|
+
await self._task_group.__aenter__()
|
102
|
+
return self
|
103
|
+
|
104
|
+
async def __aexit__(
|
105
|
+
self,
|
106
|
+
exc_type: type[BaseException] | None,
|
107
|
+
exc_value: BaseException | None,
|
108
|
+
tb: types.TracebackType | None,
|
109
|
+
) -> None:
|
110
|
+
await check.not_none(self._task_group).__aexit__(exc_type, exc_value, tb)
|
111
|
+
self._task_group = None
|