omserv 0.0.0.dev7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- omserv/__about__.py +28 -0
- omserv/__init__.py +0 -0
- omserv/apps/__init__.py +0 -0
- omserv/apps/base.py +23 -0
- omserv/apps/inject.py +89 -0
- omserv/apps/markers.py +41 -0
- omserv/apps/routes.py +139 -0
- omserv/apps/sessions.py +57 -0
- omserv/apps/templates.py +90 -0
- omserv/dbs.py +24 -0
- omserv/node/__init__.py +0 -0
- omserv/node/models.py +53 -0
- omserv/node/registry.py +124 -0
- omserv/node/sql.py +131 -0
- omserv/secrets.py +12 -0
- omserv/server/__init__.py +18 -0
- omserv/server/config.py +51 -0
- omserv/server/debug.py +14 -0
- omserv/server/events.py +83 -0
- omserv/server/headers.py +36 -0
- omserv/server/lifespans.py +132 -0
- omserv/server/multiprocess.py +157 -0
- omserv/server/protocols/__init__.py +1 -0
- omserv/server/protocols/h11.py +334 -0
- omserv/server/protocols/h2.py +407 -0
- omserv/server/protocols/protocols.py +91 -0
- omserv/server/protocols/types.py +18 -0
- omserv/server/resources/__init__.py +8 -0
- omserv/server/sockets.py +111 -0
- omserv/server/ssl.py +47 -0
- omserv/server/streams/__init__.py +0 -0
- omserv/server/streams/httpstream.py +237 -0
- omserv/server/streams/utils.py +53 -0
- omserv/server/streams/wsstream.py +447 -0
- omserv/server/taskspawner.py +111 -0
- omserv/server/tcpserver.py +173 -0
- omserv/server/types.py +94 -0
- omserv/server/workercontext.py +52 -0
- omserv/server/workers.py +193 -0
- omserv-0.0.0.dev7.dist-info/LICENSE +21 -0
- omserv-0.0.0.dev7.dist-info/METADATA +21 -0
- omserv-0.0.0.dev7.dist-info/RECORD +44 -0
- omserv-0.0.0.dev7.dist-info/WHEEL +5 -0
- omserv-0.0.0.dev7.dist-info/top_level.txt +1 -0
@@ -0,0 +1,407 @@
|
|
1
|
+
import typing as ta
|
2
|
+
|
3
|
+
import h2
|
4
|
+
import h2.config
|
5
|
+
import h2.connection
|
6
|
+
import h2.events
|
7
|
+
import h2.exceptions
|
8
|
+
import h2.settings
|
9
|
+
import priority
|
10
|
+
|
11
|
+
from ..config import Config
|
12
|
+
from ..events import Body
|
13
|
+
from ..events import Closed
|
14
|
+
from ..events import Data
|
15
|
+
from ..events import EndBody
|
16
|
+
from ..events import EndData
|
17
|
+
from ..events import InformationalResponse
|
18
|
+
from ..events import ProtocolEvent
|
19
|
+
from ..events import RawData
|
20
|
+
from ..events import Request
|
21
|
+
from ..events import Response
|
22
|
+
from ..events import ServerEvent
|
23
|
+
from ..events import StreamClosed
|
24
|
+
from ..events import Updated
|
25
|
+
from ..headers import filter_pseudo_headers
|
26
|
+
from ..headers import response_headers
|
27
|
+
from ..streams.httpstream import HttpStream
|
28
|
+
from ..streams.wsstream import WsStream
|
29
|
+
from ..taskspawner import TaskSpawner
|
30
|
+
from ..types import AppWrapper
|
31
|
+
from ..types import WaitableEvent
|
32
|
+
from ..workercontext import WorkerContext
|
33
|
+
from .types import Protocol
|
34
|
+
|
35
|
+
|
36
|
+
BUFFER_HIGH_WATER = 2 * 2**14 # Twice the default max frame size (two frames worth)
|
37
|
+
BUFFER_LOW_WATER = BUFFER_HIGH_WATER / 2
|
38
|
+
|
39
|
+
|
40
|
+
class BufferCompleteError(Exception):
|
41
|
+
pass
|
42
|
+
|
43
|
+
|
44
|
+
class StreamBuffer:
|
45
|
+
def __init__(self, event_class: type[WaitableEvent]) -> None:
|
46
|
+
super().__init__()
|
47
|
+
self.buffer = bytearray()
|
48
|
+
self._complete = False
|
49
|
+
self._is_empty = event_class()
|
50
|
+
self._paused = event_class()
|
51
|
+
|
52
|
+
async def drain(self) -> None:
|
53
|
+
await self._is_empty.wait()
|
54
|
+
|
55
|
+
def set_complete(self) -> None:
|
56
|
+
self._complete = True
|
57
|
+
|
58
|
+
async def close(self) -> None:
|
59
|
+
self._complete = True
|
60
|
+
self.buffer = bytearray()
|
61
|
+
await self._is_empty.set()
|
62
|
+
await self._paused.set()
|
63
|
+
|
64
|
+
@property
|
65
|
+
def complete(self) -> bool:
|
66
|
+
return self._complete and len(self.buffer) == 0
|
67
|
+
|
68
|
+
async def push(self, data: bytes) -> None:
|
69
|
+
if self._complete:
|
70
|
+
raise BufferCompleteError
|
71
|
+
self.buffer.extend(data)
|
72
|
+
await self._is_empty.clear()
|
73
|
+
if len(self.buffer) >= BUFFER_HIGH_WATER:
|
74
|
+
await self._paused.wait()
|
75
|
+
await self._paused.clear()
|
76
|
+
|
77
|
+
async def pop(self, max_length: int) -> bytes:
|
78
|
+
length = min(len(self.buffer), max_length)
|
79
|
+
data = bytes(self.buffer[:length])
|
80
|
+
del self.buffer[:length]
|
81
|
+
if len(data) < BUFFER_LOW_WATER:
|
82
|
+
await self._paused.set()
|
83
|
+
if len(self.buffer) == 0:
|
84
|
+
await self._is_empty.set()
|
85
|
+
return data
|
86
|
+
|
87
|
+
|
88
|
+
class H2Protocol(Protocol):
|
89
|
+
def __init__(
|
90
|
+
self,
|
91
|
+
app: AppWrapper,
|
92
|
+
config: Config,
|
93
|
+
context: WorkerContext,
|
94
|
+
task_spawner: TaskSpawner,
|
95
|
+
client: tuple[str, int] | None,
|
96
|
+
server: tuple[str, int] | None,
|
97
|
+
send: ta.Callable[[ServerEvent], ta.Awaitable[None]],
|
98
|
+
) -> None:
|
99
|
+
super().__init__()
|
100
|
+
|
101
|
+
self.app = app
|
102
|
+
self.client = client
|
103
|
+
self.closed = False
|
104
|
+
self.config = config
|
105
|
+
self.context = context
|
106
|
+
self.task_spawner = task_spawner
|
107
|
+
|
108
|
+
self.connection = h2.connection.H2Connection(
|
109
|
+
config=h2.config.H2Configuration(client_side=False, header_encoding=None),
|
110
|
+
)
|
111
|
+
self.connection.DEFAULT_MAX_INBOUND_FRAME_SIZE = config.h2_max_inbound_frame_size
|
112
|
+
self.connection.local_settings = h2.settings.Settings(
|
113
|
+
client=False,
|
114
|
+
initial_values={
|
115
|
+
h2.settings.SettingCodes.MAX_CONCURRENT_STREAMS: config.h2_max_concurrent_streams,
|
116
|
+
h2.settings.SettingCodes.MAX_HEADER_LIST_SIZE: config.h2_max_header_list_size,
|
117
|
+
h2.settings.SettingCodes.ENABLE_CONNECT_PROTOCOL: 1,
|
118
|
+
},
|
119
|
+
)
|
120
|
+
|
121
|
+
self.keep_alive_requests = 0
|
122
|
+
self.send = send
|
123
|
+
self.server = server
|
124
|
+
self.streams: dict[int, HttpStream | WsStream] = {}
|
125
|
+
# The below are used by the sending task
|
126
|
+
self.has_data = self.context.event_class()
|
127
|
+
self.priority = priority.PriorityTree()
|
128
|
+
self.stream_buffers: dict[int, StreamBuffer] = {}
|
129
|
+
|
130
|
+
@property
|
131
|
+
def idle(self) -> bool:
|
132
|
+
return len(self.streams) == 0 or all(stream.idle for stream in self.streams.values())
|
133
|
+
|
134
|
+
async def initiate(
|
135
|
+
self, headers: list[tuple[bytes, bytes]] | None = None, settings: str | None = None,
|
136
|
+
) -> None:
|
137
|
+
if settings is not None:
|
138
|
+
self.connection.initiate_upgrade_connection(settings)
|
139
|
+
else:
|
140
|
+
self.connection.initiate_connection()
|
141
|
+
await self._flush()
|
142
|
+
if headers is not None:
|
143
|
+
event = h2.events.RequestReceived()
|
144
|
+
event.stream_id = 1
|
145
|
+
event.headers = headers
|
146
|
+
await self._create_stream(event)
|
147
|
+
await self.streams[event.stream_id].handle(EndBody(stream_id=event.stream_id))
|
148
|
+
self.task_spawner.spawn(self.send_task)
|
149
|
+
|
150
|
+
async def send_task(self) -> None:
|
151
|
+
# This should be run in a seperate task to the rest of this class. This allows it seperately choose when to
|
152
|
+
# send, crucially in what order.
|
153
|
+
while not self.closed:
|
154
|
+
try:
|
155
|
+
stream_id = next(self.priority)
|
156
|
+
except priority.DeadlockError:
|
157
|
+
await self.has_data.wait()
|
158
|
+
await self.has_data.clear()
|
159
|
+
else:
|
160
|
+
await self._send_data(stream_id)
|
161
|
+
|
162
|
+
async def _send_data(self, stream_id: int) -> None:
|
163
|
+
try:
|
164
|
+
chunk_size = min(
|
165
|
+
self.connection.local_flow_control_window(stream_id),
|
166
|
+
self.connection.max_outbound_frame_size,
|
167
|
+
)
|
168
|
+
chunk_size = max(0, chunk_size)
|
169
|
+
data = await self.stream_buffers[stream_id].pop(chunk_size)
|
170
|
+
if data:
|
171
|
+
self.connection.send_data(stream_id, data)
|
172
|
+
await self._flush()
|
173
|
+
else:
|
174
|
+
self.priority.block(stream_id)
|
175
|
+
|
176
|
+
if self.stream_buffers[stream_id].complete:
|
177
|
+
self.connection.end_stream(stream_id)
|
178
|
+
await self._flush()
|
179
|
+
del self.stream_buffers[stream_id]
|
180
|
+
self.priority.remove_stream(stream_id)
|
181
|
+
except (h2.exceptions.StreamClosedError, KeyError, h2.exceptions.ProtocolError):
|
182
|
+
# Stream or connection has closed whilst waiting to send data, not a problem - just force close it.
|
183
|
+
await self.stream_buffers[stream_id].close()
|
184
|
+
del self.stream_buffers[stream_id]
|
185
|
+
self.priority.remove_stream(stream_id)
|
186
|
+
|
187
|
+
async def handle(self, event: ServerEvent) -> None:
|
188
|
+
if isinstance(event, RawData):
|
189
|
+
try:
|
190
|
+
events = self.connection.receive_data(event.data)
|
191
|
+
except h2.exceptions.ProtocolError:
|
192
|
+
await self._flush()
|
193
|
+
await self.send(Closed())
|
194
|
+
else:
|
195
|
+
await self._handle_events(events)
|
196
|
+
|
197
|
+
elif isinstance(event, Closed):
|
198
|
+
self.closed = True
|
199
|
+
stream_ids = list(self.streams.keys())
|
200
|
+
for stream_id in stream_ids:
|
201
|
+
await self._close_stream(stream_id)
|
202
|
+
await self.has_data.set()
|
203
|
+
|
204
|
+
async def stream_send(self, event: ProtocolEvent) -> None:
|
205
|
+
try:
|
206
|
+
if isinstance(event, (InformationalResponse, Response)):
|
207
|
+
self.connection.send_headers(
|
208
|
+
event.stream_id,
|
209
|
+
[
|
210
|
+
(b':status', b'%d' % event.status_code),
|
211
|
+
*event.headers,
|
212
|
+
*response_headers(self.config, 'h2'),
|
213
|
+
],
|
214
|
+
)
|
215
|
+
await self._flush()
|
216
|
+
|
217
|
+
elif isinstance(event, (Body, Data)):
|
218
|
+
self.priority.unblock(event.stream_id)
|
219
|
+
await self.has_data.set()
|
220
|
+
await self.stream_buffers[event.stream_id].push(event.data)
|
221
|
+
|
222
|
+
elif isinstance(event, (EndBody, EndData)):
|
223
|
+
self.stream_buffers[event.stream_id].set_complete()
|
224
|
+
self.priority.unblock(event.stream_id)
|
225
|
+
await self.has_data.set()
|
226
|
+
await self.stream_buffers[event.stream_id].drain()
|
227
|
+
|
228
|
+
elif isinstance(event, StreamClosed):
|
229
|
+
await self._close_stream(event.stream_id)
|
230
|
+
idle = len(self.streams) == 0 or all(
|
231
|
+
stream.idle for stream in self.streams.values()
|
232
|
+
)
|
233
|
+
if idle and self.context.terminated.is_set():
|
234
|
+
self.connection.close_connection()
|
235
|
+
await self._flush()
|
236
|
+
await self.send(Updated(idle=idle))
|
237
|
+
|
238
|
+
elif isinstance(event, Request):
|
239
|
+
await self._create_server_push(event.stream_id, event.raw_path, event.headers)
|
240
|
+
|
241
|
+
except (
|
242
|
+
BufferCompleteError,
|
243
|
+
KeyError,
|
244
|
+
priority.MissingStreamError,
|
245
|
+
h2.exceptions.ProtocolError,
|
246
|
+
):
|
247
|
+
# Connection has closed whilst blocked on flow control or connection has advanced ahead of the last emitted
|
248
|
+
# event.
|
249
|
+
return
|
250
|
+
|
251
|
+
async def _handle_events(self, events: list[h2.events.Event]) -> None:
|
252
|
+
for event in events:
|
253
|
+
if isinstance(event, h2.events.RequestReceived):
|
254
|
+
if self.context.terminated.is_set():
|
255
|
+
self.connection.reset_stream(event.stream_id)
|
256
|
+
self.connection.update_settings(
|
257
|
+
{h2.settings.SettingCodes.MAX_CONCURRENT_STREAMS: 0},
|
258
|
+
)
|
259
|
+
else:
|
260
|
+
await self._create_stream(event)
|
261
|
+
await self.send(Updated(idle=False))
|
262
|
+
|
263
|
+
if self.keep_alive_requests > self.config.keep_alive_max_requests:
|
264
|
+
self.connection.close_connection()
|
265
|
+
|
266
|
+
elif isinstance(event, h2.events.DataReceived):
|
267
|
+
await self.streams[event.stream_id].handle(Body(
|
268
|
+
stream_id=event.stream_id,
|
269
|
+
data=event.data,
|
270
|
+
))
|
271
|
+
self.connection.acknowledge_received_data(
|
272
|
+
event.flow_controlled_length, event.stream_id,
|
273
|
+
)
|
274
|
+
|
275
|
+
elif isinstance(event, h2.events.StreamEnded):
|
276
|
+
await self.streams[event.stream_id].handle(EndBody(stream_id=event.stream_id))
|
277
|
+
|
278
|
+
elif isinstance(event, h2.events.StreamReset):
|
279
|
+
await self._close_stream(event.stream_id)
|
280
|
+
await self._window_updated(event.stream_id)
|
281
|
+
|
282
|
+
elif isinstance(event, h2.events.WindowUpdated):
|
283
|
+
await self._window_updated(event.stream_id)
|
284
|
+
|
285
|
+
elif isinstance(event, h2.events.PriorityUpdated):
|
286
|
+
await self._priority_updated(event)
|
287
|
+
|
288
|
+
elif isinstance(event, h2.events.RemoteSettingsChanged):
|
289
|
+
if h2.settings.SettingCodes.INITIAL_WINDOW_SIZE in event.changed_settings:
|
290
|
+
await self._window_updated(None)
|
291
|
+
|
292
|
+
elif isinstance(event, h2.events.ConnectionTerminated):
|
293
|
+
await self.send(Closed())
|
294
|
+
|
295
|
+
await self._flush()
|
296
|
+
|
297
|
+
async def _flush(self) -> None:
|
298
|
+
data = self.connection.data_to_send()
|
299
|
+
if data != b'':
|
300
|
+
await self.send(RawData(data=data))
|
301
|
+
|
302
|
+
async def _window_updated(self, stream_id: int | None) -> None:
|
303
|
+
if stream_id is None or stream_id == 0:
|
304
|
+
# Unblock all streams
|
305
|
+
for buf_stream_id in list(self.stream_buffers.keys()):
|
306
|
+
self.priority.unblock(buf_stream_id)
|
307
|
+
elif stream_id is not None and stream_id in self.stream_buffers:
|
308
|
+
self.priority.unblock(stream_id)
|
309
|
+
await self.has_data.set()
|
310
|
+
|
311
|
+
async def _priority_updated(self, event: h2.events.PriorityUpdated) -> None:
|
312
|
+
try:
|
313
|
+
self.priority.reprioritize(
|
314
|
+
stream_id=event.stream_id,
|
315
|
+
depends_on=event.depends_on or None,
|
316
|
+
weight=event.weight,
|
317
|
+
exclusive=event.exclusive,
|
318
|
+
)
|
319
|
+
except priority.MissingStreamError:
|
320
|
+
# Received PRIORITY frame before HEADERS frame
|
321
|
+
self.priority.insert_stream(
|
322
|
+
stream_id=event.stream_id,
|
323
|
+
depends_on=event.depends_on or None,
|
324
|
+
weight=event.weight,
|
325
|
+
exclusive=event.exclusive,
|
326
|
+
)
|
327
|
+
self.priority.block(event.stream_id)
|
328
|
+
await self.has_data.set()
|
329
|
+
|
330
|
+
async def _create_stream(self, request: h2.events.RequestReceived) -> None:
|
331
|
+
for name, value in request.headers:
|
332
|
+
if name == b':method':
|
333
|
+
method = value.decode('ascii').upper()
|
334
|
+
elif name == b':path':
|
335
|
+
raw_path = value
|
336
|
+
|
337
|
+
if method == 'CONNECT':
|
338
|
+
self.streams[request.stream_id] = WsStream(
|
339
|
+
self.app,
|
340
|
+
self.config,
|
341
|
+
self.context,
|
342
|
+
self.task_spawner,
|
343
|
+
self.client,
|
344
|
+
self.server,
|
345
|
+
self.stream_send,
|
346
|
+
request.stream_id,
|
347
|
+
)
|
348
|
+
else:
|
349
|
+
self.streams[request.stream_id] = HttpStream(
|
350
|
+
self.app,
|
351
|
+
self.config,
|
352
|
+
self.context,
|
353
|
+
self.task_spawner,
|
354
|
+
self.client,
|
355
|
+
self.server,
|
356
|
+
self.stream_send,
|
357
|
+
request.stream_id,
|
358
|
+
)
|
359
|
+
self.stream_buffers[request.stream_id] = StreamBuffer(self.context.event_class)
|
360
|
+
try:
|
361
|
+
self.priority.insert_stream(request.stream_id)
|
362
|
+
except priority.DuplicateStreamError:
|
363
|
+
# Recieved PRIORITY frame before HEADERS frame
|
364
|
+
pass
|
365
|
+
else:
|
366
|
+
self.priority.block(request.stream_id)
|
367
|
+
|
368
|
+
await self.streams[request.stream_id].handle(Request(
|
369
|
+
stream_id=request.stream_id,
|
370
|
+
headers=filter_pseudo_headers(request.headers),
|
371
|
+
http_version='2',
|
372
|
+
method=method,
|
373
|
+
raw_path=raw_path,
|
374
|
+
))
|
375
|
+
self.keep_alive_requests += 1
|
376
|
+
await self.context.mark_request()
|
377
|
+
|
378
|
+
async def _create_server_push(
|
379
|
+
self, stream_id: int, path: bytes, headers: list[tuple[bytes, bytes]],
|
380
|
+
) -> None:
|
381
|
+
push_stream_id = self.connection.get_next_available_stream_id()
|
382
|
+
request_headers = [(b':method', b'GET'), (b':path', path)]
|
383
|
+
request_headers.extend(headers)
|
384
|
+
request_headers.extend(response_headers(self.config, 'h2'))
|
385
|
+
try:
|
386
|
+
self.connection.push_stream(
|
387
|
+
stream_id=stream_id,
|
388
|
+
promised_stream_id=push_stream_id,
|
389
|
+
request_headers=request_headers,
|
390
|
+
)
|
391
|
+
await self._flush()
|
392
|
+
except h2.exceptions.ProtocolError:
|
393
|
+
# Client does not accept push promises or we are trying to push on a push promises request.
|
394
|
+
pass
|
395
|
+
else:
|
396
|
+
event = h2.events.RequestReceived()
|
397
|
+
event.stream_id = push_stream_id
|
398
|
+
event.headers = request_headers
|
399
|
+
await self._create_stream(event)
|
400
|
+
await self.streams[event.stream_id].handle(EndBody(stream_id=event.stream_id))
|
401
|
+
self.keep_alive_requests += 1
|
402
|
+
|
403
|
+
async def _close_stream(self, stream_id: int) -> None:
|
404
|
+
if stream_id in self.streams:
|
405
|
+
stream = self.streams.pop(stream_id)
|
406
|
+
await stream.handle(StreamClosed(stream_id=stream_id))
|
407
|
+
await self.has_data.set()
|
@@ -0,0 +1,91 @@
|
|
1
|
+
import typing as ta
|
2
|
+
|
3
|
+
from ..config import Config
|
4
|
+
from ..events import RawData
|
5
|
+
from ..events import ServerEvent
|
6
|
+
from ..taskspawner import TaskSpawner
|
7
|
+
from ..types import AppWrapper
|
8
|
+
from ..workercontext import WorkerContext
|
9
|
+
from .h2 import H2Protocol
|
10
|
+
from .h11 import H2CProtocolRequiredError
|
11
|
+
from .h11 import H2ProtocolAssumedError
|
12
|
+
from .h11 import H11Protocol
|
13
|
+
from .types import Protocol
|
14
|
+
|
15
|
+
|
16
|
+
class ProtocolWrapper:
|
17
|
+
def __init__(
|
18
|
+
self,
|
19
|
+
app: AppWrapper,
|
20
|
+
config: Config,
|
21
|
+
context: WorkerContext,
|
22
|
+
task_spawner: TaskSpawner,
|
23
|
+
client: tuple[str, int] | None,
|
24
|
+
server: tuple[str, int] | None,
|
25
|
+
send: ta.Callable[[ServerEvent], ta.Awaitable[None]],
|
26
|
+
alpn_protocol: str | None = None,
|
27
|
+
) -> None:
|
28
|
+
super().__init__()
|
29
|
+
self.app = app
|
30
|
+
self.config = config
|
31
|
+
self.context = context
|
32
|
+
self.task_spawner = task_spawner
|
33
|
+
self.client = client
|
34
|
+
self.server = server
|
35
|
+
self.send = send
|
36
|
+
self.protocol: Protocol
|
37
|
+
if alpn_protocol == 'h2':
|
38
|
+
self.protocol = H2Protocol(
|
39
|
+
self.app,
|
40
|
+
self.config,
|
41
|
+
self.context,
|
42
|
+
self.task_spawner,
|
43
|
+
self.client,
|
44
|
+
self.server,
|
45
|
+
self.send,
|
46
|
+
)
|
47
|
+
else:
|
48
|
+
self.protocol = H11Protocol(
|
49
|
+
self.app,
|
50
|
+
self.config,
|
51
|
+
self.context,
|
52
|
+
self.task_spawner,
|
53
|
+
self.client,
|
54
|
+
self.server,
|
55
|
+
self.send,
|
56
|
+
)
|
57
|
+
|
58
|
+
async def initiate(self) -> None:
|
59
|
+
return await self.protocol.initiate()
|
60
|
+
|
61
|
+
async def handle(self, event: ServerEvent) -> None:
|
62
|
+
try:
|
63
|
+
return await self.protocol.handle(event)
|
64
|
+
|
65
|
+
except H2ProtocolAssumedError as error:
|
66
|
+
self.protocol = H2Protocol(
|
67
|
+
self.app,
|
68
|
+
self.config,
|
69
|
+
self.context,
|
70
|
+
self.task_spawner,
|
71
|
+
self.client,
|
72
|
+
self.server,
|
73
|
+
self.send,
|
74
|
+
)
|
75
|
+
await self.protocol.initiate()
|
76
|
+
if error.data != b'':
|
77
|
+
return await self.protocol.handle(RawData(data=error.data))
|
78
|
+
|
79
|
+
except H2CProtocolRequiredError as error:
|
80
|
+
self.protocol = H2Protocol(
|
81
|
+
self.app,
|
82
|
+
self.config,
|
83
|
+
self.context,
|
84
|
+
self.task_spawner,
|
85
|
+
self.client,
|
86
|
+
self.server,
|
87
|
+
self.send,
|
88
|
+
)
|
89
|
+
await self.protocol.initiate(error.headers, error.settings)
|
90
|
+
if error.data != b'':
|
91
|
+
return await self.protocol.handle(RawData(data=error.data))
|
@@ -0,0 +1,18 @@
|
|
1
|
+
import abc
|
2
|
+
|
3
|
+
from ..events import ProtocolEvent
|
4
|
+
from ..events import ServerEvent
|
5
|
+
|
6
|
+
|
7
|
+
class Protocol(abc.ABC):
|
8
|
+
@abc.abstractmethod
|
9
|
+
async def initiate(self) -> None:
|
10
|
+
raise NotImplementedError
|
11
|
+
|
12
|
+
@abc.abstractmethod
|
13
|
+
async def handle(self, event: ServerEvent) -> None:
|
14
|
+
raise NotImplementedError
|
15
|
+
|
16
|
+
@abc.abstractmethod
|
17
|
+
async def stream_send(self, event: ProtocolEvent) -> None:
|
18
|
+
raise NotImplementedError
|
omserv/server/sockets.py
ADDED
@@ -0,0 +1,111 @@
|
|
1
|
+
import contextlib
|
2
|
+
import dataclasses as dc
|
3
|
+
import os
|
4
|
+
import socket
|
5
|
+
import stat
|
6
|
+
import typing as ta
|
7
|
+
|
8
|
+
from .config import Config
|
9
|
+
|
10
|
+
|
11
|
+
@dc.dataclass()
|
12
|
+
class Sockets:
|
13
|
+
insecure_sockets: list[socket.socket]
|
14
|
+
|
15
|
+
|
16
|
+
SocketKind: ta.TypeAlias = int | socket.SocketKind
|
17
|
+
|
18
|
+
|
19
|
+
class SocketTypeError(Exception):
|
20
|
+
def __init__(self, expected: SocketKind, actual: SocketKind) -> None:
|
21
|
+
super().__init__(
|
22
|
+
f'Unexpected socket type, wanted "{socket.SocketKind(expected)}" got "{socket.SocketKind(actual)}"',
|
23
|
+
)
|
24
|
+
|
25
|
+
|
26
|
+
def _create_sockets(
|
27
|
+
config: Config,
|
28
|
+
binds: ta.Sequence[str],
|
29
|
+
type_: int = socket.SOCK_STREAM,
|
30
|
+
) -> list[socket.socket]:
|
31
|
+
sockets: list[socket.socket] = []
|
32
|
+
for bind in binds:
|
33
|
+
binding: ta.Any = None
|
34
|
+
|
35
|
+
if bind.startswith('unix:'):
|
36
|
+
sock = socket.socket(socket.AF_UNIX, type_)
|
37
|
+
binding = bind[5:]
|
38
|
+
try:
|
39
|
+
if stat.S_ISSOCK(os.stat(binding).st_mode):
|
40
|
+
os.remove(binding)
|
41
|
+
except FileNotFoundError:
|
42
|
+
pass
|
43
|
+
|
44
|
+
elif bind.startswith('fd://'):
|
45
|
+
sock = socket.socket(fileno=int(bind[5:]))
|
46
|
+
actual_type = sock.getsockopt(socket.SOL_SOCKET, socket.SO_TYPE)
|
47
|
+
if actual_type != type_:
|
48
|
+
raise SocketTypeError(type_, actual_type)
|
49
|
+
|
50
|
+
else:
|
51
|
+
bind = bind.replace('[', '').replace(']', '')
|
52
|
+
try:
|
53
|
+
value = bind.rsplit(':', 1)
|
54
|
+
host, port = value[0], int(value[1])
|
55
|
+
except (ValueError, IndexError):
|
56
|
+
host, port = bind, 8000
|
57
|
+
sock = socket.socket(socket.AF_INET6 if ':' in host else socket.AF_INET, type_)
|
58
|
+
sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
|
59
|
+
if config.workers:
|
60
|
+
with contextlib.suppress(AttributeError):
|
61
|
+
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT, 1)
|
62
|
+
binding = (host, port)
|
63
|
+
|
64
|
+
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
|
65
|
+
|
66
|
+
if bind.startswith('unix:'):
|
67
|
+
if config.umask is not None:
|
68
|
+
current_umask = os.umask(config.umask)
|
69
|
+
sock.bind(binding)
|
70
|
+
if config.user is not None and config.group is not None:
|
71
|
+
os.chown(binding, config.user, config.group)
|
72
|
+
if config.umask is not None:
|
73
|
+
os.umask(current_umask)
|
74
|
+
|
75
|
+
elif bind.startswith('fd://'):
|
76
|
+
pass
|
77
|
+
|
78
|
+
else:
|
79
|
+
sock.bind(binding)
|
80
|
+
|
81
|
+
sock.setblocking(False)
|
82
|
+
with contextlib.suppress(AttributeError):
|
83
|
+
sock.set_inheritable(True)
|
84
|
+
sockets.append(sock)
|
85
|
+
|
86
|
+
return sockets
|
87
|
+
|
88
|
+
|
89
|
+
def create_sockets(config: Config) -> Sockets:
|
90
|
+
insecure_sockets = _create_sockets(config, config.bind)
|
91
|
+
return Sockets(insecure_sockets)
|
92
|
+
|
93
|
+
|
94
|
+
def repr_socket_addr(family: int, address: tuple) -> str:
|
95
|
+
if family == socket.AF_INET:
|
96
|
+
return f'{address[0]}:{address[1]}'
|
97
|
+
elif family == socket.AF_INET6:
|
98
|
+
return f'[{address[0]}]:{address[1]}'
|
99
|
+
elif family == socket.AF_UNIX:
|
100
|
+
return f'unix:{address}'
|
101
|
+
else:
|
102
|
+
return f'{address}'
|
103
|
+
|
104
|
+
|
105
|
+
def parse_socket_addr(family: int, address: tuple) -> tuple[str, int] | None:
|
106
|
+
if family == socket.AF_INET:
|
107
|
+
return address
|
108
|
+
elif family == socket.AF_INET6:
|
109
|
+
return (address[0], address[1])
|
110
|
+
else:
|
111
|
+
return None
|
omserv/server/ssl.py
ADDED
@@ -0,0 +1,47 @@
|
|
1
|
+
import dataclasses as dc
|
2
|
+
import ssl
|
3
|
+
import typing as ta
|
4
|
+
|
5
|
+
from .config import SECONDS
|
6
|
+
|
7
|
+
|
8
|
+
@dc.dataclass(frozen=True, kw_only=True)
|
9
|
+
class SslConfig:
|
10
|
+
ca_certs: str | None = None
|
11
|
+
|
12
|
+
certfile: str | None = None
|
13
|
+
keyfile: str | None = None
|
14
|
+
keyfile_password: str | None = None
|
15
|
+
|
16
|
+
ciphers: str = 'ECDHE+AESGCM'
|
17
|
+
|
18
|
+
alpn_protocols: ta.Sequence[str] = ('h2', 'http/1.1')
|
19
|
+
|
20
|
+
verify_flags: ssl.VerifyFlags | None = None
|
21
|
+
verify_mode: ssl.VerifyMode | None = None
|
22
|
+
|
23
|
+
ssl_handshake_timeout: int | float = 60 * SECONDS
|
24
|
+
|
25
|
+
|
26
|
+
def create_ssl_context(ssl_cfg: SslConfig) -> ssl.SSLContext | None:
|
27
|
+
context = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH)
|
28
|
+
context.set_ciphers(ssl_cfg.ciphers)
|
29
|
+
context.minimum_version = ssl.TLSVersion.TLSv1_2 # RFC 7540 Section 9.2: MUST be TLS >=1.2
|
30
|
+
context.options = ssl.OP_NO_COMPRESSION # RFC 7540 Section 9.2.1: MUST disable compression
|
31
|
+
context.set_alpn_protocols(ssl_cfg.alpn_protocols)
|
32
|
+
|
33
|
+
if ssl_cfg.certfile is not None and ssl_cfg.keyfile is not None:
|
34
|
+
context.load_cert_chain(
|
35
|
+
certfile=ssl_cfg.certfile,
|
36
|
+
keyfile=ssl_cfg.keyfile,
|
37
|
+
password=ssl_cfg.keyfile_password,
|
38
|
+
)
|
39
|
+
|
40
|
+
if ssl_cfg.ca_certs is not None:
|
41
|
+
context.load_verify_locations(ssl_cfg.ca_certs)
|
42
|
+
if ssl_cfg.verify_mode is not None:
|
43
|
+
context.verify_mode = ssl_cfg.verify_mode
|
44
|
+
if ssl_cfg.verify_flags is not None:
|
45
|
+
context.verify_flags = ssl_cfg.verify_flags
|
46
|
+
|
47
|
+
return context
|
File without changes
|