omlish 0.0.0.dev123__py3-none-any.whl → 0.0.0.dev124__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- omlish/__about__.py +3 -2
- omlish/__init__.py +3 -0
- omlish/antlr/_runtime/LICENSE.txt +28 -0
- omlish/c3.py +35 -37
- omlish/dataclasses/__init__.py +4 -0
- omlish/dataclasses/utils.py +22 -4
- omlish/http/consts.py +2 -0
- omlish/lite/check.py +5 -0
- omlish/lite/http/__init__.py +0 -0
- omlish/lite/http/coroserver.py +585 -0
- omlish/lite/http/handlers.py +36 -0
- omlish/lite/http/parsing.py +376 -0
- omlish/lite/http/versions.py +17 -0
- omlish/lite/inject.py +21 -5
- omlish/lite/journald.py +0 -1
- omlish/lite/runtime.py +1 -2
- omlish/lite/socket.py +77 -0
- omlish/lite/socketserver.py +66 -0
- omlish-0.0.0.dev124.dist-info/METADATA +100 -0
- {omlish-0.0.0.dev123.dist-info → omlish-0.0.0.dev124.dist-info}/RECORD +24 -16
- {omlish-0.0.0.dev123.dist-info → omlish-0.0.0.dev124.dist-info}/WHEEL +1 -1
- omlish-0.0.0.dev123.dist-info/METADATA +0 -94
- {omlish-0.0.0.dev123.dist-info → omlish-0.0.0.dev124.dist-info}/LICENSE +0 -0
- {omlish-0.0.0.dev123.dist-info → omlish-0.0.0.dev124.dist-info}/entry_points.txt +0 -0
- {omlish-0.0.0.dev123.dist-info → omlish-0.0.0.dev124.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,376 @@
|
|
1
|
+
# ruff: noqa: UP006 UP007
|
2
|
+
import abc
|
3
|
+
import http.client
|
4
|
+
import http.server
|
5
|
+
import io
|
6
|
+
import typing as ta
|
7
|
+
|
8
|
+
from .versions import HttpProtocolVersion
|
9
|
+
from .versions import HttpProtocolVersions
|
10
|
+
|
11
|
+
|
12
|
+
T = ta.TypeVar('T')
|
13
|
+
|
14
|
+
|
15
|
+
HttpHeaders = http.client.HTTPMessage # ta.TypeAlias
|
16
|
+
|
17
|
+
|
18
|
+
##
|
19
|
+
|
20
|
+
|
21
|
+
class ParseHttpRequestResult(abc.ABC): # noqa
|
22
|
+
__slots__ = (
|
23
|
+
'server_version',
|
24
|
+
'request_line',
|
25
|
+
'request_version',
|
26
|
+
'version',
|
27
|
+
'headers',
|
28
|
+
'close_connection',
|
29
|
+
)
|
30
|
+
|
31
|
+
def __init__(
|
32
|
+
self,
|
33
|
+
*,
|
34
|
+
server_version: HttpProtocolVersion,
|
35
|
+
request_line: str,
|
36
|
+
request_version: HttpProtocolVersion,
|
37
|
+
version: HttpProtocolVersion,
|
38
|
+
headers: ta.Optional[HttpHeaders],
|
39
|
+
close_connection: bool,
|
40
|
+
) -> None:
|
41
|
+
super().__init__()
|
42
|
+
|
43
|
+
self.server_version = server_version
|
44
|
+
self.request_line = request_line
|
45
|
+
self.request_version = request_version
|
46
|
+
self.version = version
|
47
|
+
self.headers = headers
|
48
|
+
self.close_connection = close_connection
|
49
|
+
|
50
|
+
def __repr__(self) -> str:
|
51
|
+
return f'{self.__class__.__name__}({", ".join(f"{a}={getattr(self, a)!r}" for a in self.__slots__)})'
|
52
|
+
|
53
|
+
|
54
|
+
class EmptyParsedHttpResult(ParseHttpRequestResult):
|
55
|
+
pass
|
56
|
+
|
57
|
+
|
58
|
+
class ParseHttpRequestError(ParseHttpRequestResult):
|
59
|
+
__slots__ = (
|
60
|
+
'code',
|
61
|
+
'message',
|
62
|
+
*ParseHttpRequestResult.__slots__,
|
63
|
+
)
|
64
|
+
|
65
|
+
def __init__(
|
66
|
+
self,
|
67
|
+
*,
|
68
|
+
code: http.HTTPStatus,
|
69
|
+
message: ta.Union[str, ta.Tuple[str, str]],
|
70
|
+
|
71
|
+
**kwargs: ta.Any,
|
72
|
+
) -> None:
|
73
|
+
super().__init__(**kwargs)
|
74
|
+
|
75
|
+
self.code = code
|
76
|
+
self.message = message
|
77
|
+
|
78
|
+
|
79
|
+
class ParsedHttpRequest(ParseHttpRequestResult):
|
80
|
+
__slots__ = (
|
81
|
+
'method',
|
82
|
+
'path',
|
83
|
+
'headers',
|
84
|
+
'expects_continue',
|
85
|
+
*[a for a in ParseHttpRequestResult.__slots__ if a != 'headers'],
|
86
|
+
)
|
87
|
+
|
88
|
+
def __init__(
|
89
|
+
self,
|
90
|
+
*,
|
91
|
+
method: str,
|
92
|
+
path: str,
|
93
|
+
headers: HttpHeaders,
|
94
|
+
expects_continue: bool,
|
95
|
+
|
96
|
+
**kwargs: ta.Any,
|
97
|
+
) -> None:
|
98
|
+
super().__init__(
|
99
|
+
headers=headers,
|
100
|
+
**kwargs,
|
101
|
+
)
|
102
|
+
|
103
|
+
self.method = method
|
104
|
+
self.path = path
|
105
|
+
self.expects_continue = expects_continue
|
106
|
+
|
107
|
+
headers: HttpHeaders
|
108
|
+
|
109
|
+
|
110
|
+
#
|
111
|
+
|
112
|
+
|
113
|
+
class HttpRequestParser:
|
114
|
+
DEFAULT_SERVER_VERSION = HttpProtocolVersions.HTTP_1_0
|
115
|
+
|
116
|
+
# The default request version. This only affects responses up until the point where the request line is parsed, so
|
117
|
+
# it mainly decides what the client gets back when sending a malformed request line.
|
118
|
+
# Most web servers default to HTTP 0.9, i.e. don't send a status line.
|
119
|
+
DEFAULT_REQUEST_VERSION = HttpProtocolVersions.HTTP_0_9
|
120
|
+
|
121
|
+
#
|
122
|
+
|
123
|
+
DEFAULT_MAX_LINE: int = 0x10000
|
124
|
+
DEFAULT_MAX_HEADERS: int = 100
|
125
|
+
|
126
|
+
#
|
127
|
+
|
128
|
+
def __init__(
|
129
|
+
self,
|
130
|
+
*,
|
131
|
+
server_version: HttpProtocolVersion = DEFAULT_SERVER_VERSION,
|
132
|
+
|
133
|
+
max_line: int = DEFAULT_MAX_LINE,
|
134
|
+
max_headers: int = DEFAULT_MAX_HEADERS,
|
135
|
+
) -> None:
|
136
|
+
super().__init__()
|
137
|
+
|
138
|
+
if server_version >= HttpProtocolVersions.HTTP_2_0:
|
139
|
+
raise ValueError(f'Unsupported protocol version: {server_version}')
|
140
|
+
self._server_version = server_version
|
141
|
+
|
142
|
+
self._max_line = max_line
|
143
|
+
self._max_headers = max_headers
|
144
|
+
|
145
|
+
#
|
146
|
+
|
147
|
+
@property
|
148
|
+
def server_version(self) -> HttpProtocolVersion:
|
149
|
+
return self._server_version
|
150
|
+
|
151
|
+
#
|
152
|
+
|
153
|
+
def _run_read_line_coro(
|
154
|
+
self,
|
155
|
+
gen: ta.Generator[int, bytes, T],
|
156
|
+
read_line: ta.Callable[[int], bytes],
|
157
|
+
) -> T:
|
158
|
+
sz = next(gen)
|
159
|
+
while True:
|
160
|
+
try:
|
161
|
+
sz = gen.send(read_line(sz))
|
162
|
+
except StopIteration as e:
|
163
|
+
return e.value
|
164
|
+
|
165
|
+
#
|
166
|
+
|
167
|
+
def parse_request_version(self, version_str: str) -> HttpProtocolVersion:
|
168
|
+
if not version_str.startswith('HTTP/'):
|
169
|
+
raise ValueError(version_str) # noqa
|
170
|
+
|
171
|
+
base_version_number = version_str.split('/', 1)[1]
|
172
|
+
version_number_parts = base_version_number.split('.')
|
173
|
+
|
174
|
+
# RFC 2145 section 3.1 says there can be only one "." and
|
175
|
+
# - major and minor numbers MUST be treated as separate integers;
|
176
|
+
# - HTTP/2.4 is a lower version than HTTP/2.13, which in turn is lower than HTTP/12.3;
|
177
|
+
# - Leading zeros MUST be ignored by recipients.
|
178
|
+
if len(version_number_parts) != 2:
|
179
|
+
raise ValueError(version_number_parts) # noqa
|
180
|
+
if any(not component.isdigit() for component in version_number_parts):
|
181
|
+
raise ValueError('non digit in http version') # noqa
|
182
|
+
if any(len(component) > 10 for component in version_number_parts):
|
183
|
+
raise ValueError('unreasonable length http version') # noqa
|
184
|
+
|
185
|
+
return HttpProtocolVersion(
|
186
|
+
int(version_number_parts[0]),
|
187
|
+
int(version_number_parts[1]),
|
188
|
+
)
|
189
|
+
|
190
|
+
#
|
191
|
+
|
192
|
+
def coro_read_raw_headers(self) -> ta.Generator[int, bytes, ta.List[bytes]]:
|
193
|
+
raw_headers: ta.List[bytes] = []
|
194
|
+
while True:
|
195
|
+
line = yield self._max_line + 1
|
196
|
+
if len(line) > self._max_line:
|
197
|
+
raise http.client.LineTooLong('header line')
|
198
|
+
raw_headers.append(line)
|
199
|
+
if len(raw_headers) > self._max_headers:
|
200
|
+
raise http.client.HTTPException(f'got more than {self._max_headers} headers')
|
201
|
+
if line in (b'\r\n', b'\n', b''):
|
202
|
+
break
|
203
|
+
return raw_headers
|
204
|
+
|
205
|
+
def read_raw_headers(self, read_line: ta.Callable[[int], bytes]) -> ta.List[bytes]:
|
206
|
+
return self._run_read_line_coro(self.coro_read_raw_headers(), read_line)
|
207
|
+
|
208
|
+
def parse_raw_headers(self, raw_headers: ta.Sequence[bytes]) -> HttpHeaders:
|
209
|
+
return http.client.parse_headers(io.BytesIO(b''.join(raw_headers)))
|
210
|
+
|
211
|
+
#
|
212
|
+
|
213
|
+
def coro_parse(self) -> ta.Generator[int, bytes, ParseHttpRequestResult]:
|
214
|
+
raw_request_line = yield self._max_line + 1
|
215
|
+
|
216
|
+
# Common result kwargs
|
217
|
+
|
218
|
+
request_line = '-'
|
219
|
+
request_version = self.DEFAULT_REQUEST_VERSION
|
220
|
+
|
221
|
+
# Set to min(server, request) when it gets that far, but if it fails before that the server authoritatively
|
222
|
+
# responds with its own version.
|
223
|
+
version = self._server_version
|
224
|
+
|
225
|
+
headers: HttpHeaders | None = None
|
226
|
+
|
227
|
+
close_connection = True
|
228
|
+
|
229
|
+
def result_kwargs():
|
230
|
+
return dict(
|
231
|
+
server_version=self._server_version,
|
232
|
+
request_line=request_line,
|
233
|
+
request_version=request_version,
|
234
|
+
version=version,
|
235
|
+
headers=headers,
|
236
|
+
close_connection=close_connection,
|
237
|
+
)
|
238
|
+
|
239
|
+
# Decode line
|
240
|
+
|
241
|
+
if len(raw_request_line) > self._max_line:
|
242
|
+
return ParseHttpRequestError(
|
243
|
+
code=http.HTTPStatus.REQUEST_URI_TOO_LONG,
|
244
|
+
message='Request line too long',
|
245
|
+
**result_kwargs(),
|
246
|
+
)
|
247
|
+
|
248
|
+
if not raw_request_line:
|
249
|
+
return EmptyParsedHttpResult(**result_kwargs())
|
250
|
+
|
251
|
+
request_line = raw_request_line.decode('iso-8859-1').rstrip('\r\n')
|
252
|
+
|
253
|
+
# Split words
|
254
|
+
|
255
|
+
words = request_line.split()
|
256
|
+
if len(words) == 0:
|
257
|
+
return EmptyParsedHttpResult(**result_kwargs())
|
258
|
+
|
259
|
+
# Parse and set version
|
260
|
+
|
261
|
+
if len(words) >= 3: # Enough to determine protocol version
|
262
|
+
version_str = words[-1]
|
263
|
+
try:
|
264
|
+
request_version = self.parse_request_version(version_str)
|
265
|
+
|
266
|
+
except (ValueError, IndexError):
|
267
|
+
return ParseHttpRequestError(
|
268
|
+
code=http.HTTPStatus.BAD_REQUEST,
|
269
|
+
message=f'Bad request version ({version_str!r})',
|
270
|
+
**result_kwargs(),
|
271
|
+
)
|
272
|
+
|
273
|
+
if (
|
274
|
+
request_version < HttpProtocolVersions.HTTP_0_9 or
|
275
|
+
request_version >= HttpProtocolVersions.HTTP_2_0
|
276
|
+
):
|
277
|
+
return ParseHttpRequestError(
|
278
|
+
code=http.HTTPStatus.HTTP_VERSION_NOT_SUPPORTED,
|
279
|
+
message=f'Invalid HTTP version ({version_str})',
|
280
|
+
**result_kwargs(),
|
281
|
+
)
|
282
|
+
|
283
|
+
version = min([self._server_version, request_version])
|
284
|
+
|
285
|
+
if version >= HttpProtocolVersions.HTTP_1_1:
|
286
|
+
close_connection = False
|
287
|
+
|
288
|
+
# Verify word count
|
289
|
+
|
290
|
+
if not 2 <= len(words) <= 3:
|
291
|
+
return ParseHttpRequestError(
|
292
|
+
code=http.HTTPStatus.BAD_REQUEST,
|
293
|
+
message=f'Bad request syntax ({request_line!r})',
|
294
|
+
**result_kwargs(),
|
295
|
+
)
|
296
|
+
|
297
|
+
# Parse method and path
|
298
|
+
|
299
|
+
method, path = words[:2]
|
300
|
+
if len(words) == 2:
|
301
|
+
close_connection = True
|
302
|
+
if method != 'GET':
|
303
|
+
return ParseHttpRequestError(
|
304
|
+
code=http.HTTPStatus.BAD_REQUEST,
|
305
|
+
message=f'Bad HTTP/0.9 request type ({method!r})',
|
306
|
+
**result_kwargs(),
|
307
|
+
)
|
308
|
+
|
309
|
+
# gh-87389: The purpose of replacing '//' with '/' is to protect against open redirect attacks possibly
|
310
|
+
# triggered if the path starts with '//' because http clients treat //path as an absolute URI without scheme
|
311
|
+
# (similar to http://path) rather than a path.
|
312
|
+
if path.startswith('//'):
|
313
|
+
path = '/' + path.lstrip('/') # Reduce to a single /
|
314
|
+
|
315
|
+
# Parse headers
|
316
|
+
|
317
|
+
try:
|
318
|
+
raw_gen = self.coro_read_raw_headers()
|
319
|
+
raw_sz = next(raw_gen)
|
320
|
+
while True:
|
321
|
+
buf = yield raw_sz
|
322
|
+
try:
|
323
|
+
raw_sz = raw_gen.send(buf)
|
324
|
+
except StopIteration as e:
|
325
|
+
raw_headers = e.value
|
326
|
+
break
|
327
|
+
|
328
|
+
headers = self.parse_raw_headers(raw_headers)
|
329
|
+
|
330
|
+
except http.client.LineTooLong as err:
|
331
|
+
return ParseHttpRequestError(
|
332
|
+
code=http.HTTPStatus.REQUEST_HEADER_FIELDS_TOO_LARGE,
|
333
|
+
message=('Line too long', str(err)),
|
334
|
+
**result_kwargs(),
|
335
|
+
)
|
336
|
+
|
337
|
+
except http.client.HTTPException as err:
|
338
|
+
return ParseHttpRequestError(
|
339
|
+
code=http.HTTPStatus.REQUEST_HEADER_FIELDS_TOO_LARGE,
|
340
|
+
message=('Too many headers', str(err)),
|
341
|
+
**result_kwargs(),
|
342
|
+
)
|
343
|
+
|
344
|
+
# Check for connection directive
|
345
|
+
|
346
|
+
conn_type = headers.get('Connection', '')
|
347
|
+
if conn_type.lower() == 'close':
|
348
|
+
close_connection = True
|
349
|
+
elif (
|
350
|
+
conn_type.lower() == 'keep-alive' and
|
351
|
+
version >= HttpProtocolVersions.HTTP_1_1
|
352
|
+
):
|
353
|
+
close_connection = False
|
354
|
+
|
355
|
+
# Check for expect directive
|
356
|
+
|
357
|
+
expect = headers.get('Expect', '')
|
358
|
+
if (
|
359
|
+
expect.lower() == '100-continue' and
|
360
|
+
version >= HttpProtocolVersions.HTTP_1_1
|
361
|
+
):
|
362
|
+
expects_continue = True
|
363
|
+
else:
|
364
|
+
expects_continue = False
|
365
|
+
|
366
|
+
# Return
|
367
|
+
|
368
|
+
return ParsedHttpRequest(
|
369
|
+
method=method,
|
370
|
+
path=path,
|
371
|
+
expects_continue=expects_continue,
|
372
|
+
**result_kwargs(),
|
373
|
+
)
|
374
|
+
|
375
|
+
def parse(self, read_line: ta.Callable[[int], bytes]) -> ParseHttpRequestResult:
|
376
|
+
return self._run_read_line_coro(self.coro_parse(), read_line)
|
@@ -0,0 +1,17 @@
|
|
1
|
+
# ruff: noqa: UP006 UP007
|
2
|
+
import typing as ta
|
3
|
+
|
4
|
+
|
5
|
+
class HttpProtocolVersion(ta.NamedTuple):
|
6
|
+
major: int
|
7
|
+
minor: int
|
8
|
+
|
9
|
+
def __str__(self) -> str:
|
10
|
+
return f'HTTP/{self.major}.{self.minor}'
|
11
|
+
|
12
|
+
|
13
|
+
class HttpProtocolVersions:
|
14
|
+
HTTP_0_9 = HttpProtocolVersion(0, 9)
|
15
|
+
HTTP_1_0 = HttpProtocolVersion(1, 0)
|
16
|
+
HTTP_1_1 = HttpProtocolVersion(1, 1)
|
17
|
+
HTTP_2_0 = HttpProtocolVersion(2, 0)
|
omlish/lite/inject.py
CHANGED
@@ -30,12 +30,22 @@ InjectorBindingOrBindings = ta.Union['InjectorBinding', 'InjectorBindings']
|
|
30
30
|
|
31
31
|
|
32
32
|
@dc.dataclass(frozen=True)
|
33
|
-
class InjectorKey:
|
33
|
+
class InjectorKey(ta.Generic[T]):
|
34
34
|
cls: InjectorKeyCls
|
35
35
|
tag: ta.Any = None
|
36
36
|
array: bool = False
|
37
37
|
|
38
38
|
|
39
|
+
def is_valid_injector_key_cls(cls: ta.Any) -> bool:
|
40
|
+
return isinstance(cls, type) or is_new_type(cls)
|
41
|
+
|
42
|
+
|
43
|
+
def check_valid_injector_key_cls(cls: T) -> T:
|
44
|
+
if not is_valid_injector_key_cls(cls):
|
45
|
+
raise TypeError(cls)
|
46
|
+
return cls
|
47
|
+
|
48
|
+
|
39
49
|
##
|
40
50
|
|
41
51
|
|
@@ -79,6 +89,12 @@ class Injector(abc.ABC):
|
|
79
89
|
def inject(self, obj: ta.Any) -> ta.Any:
|
80
90
|
raise NotImplementedError
|
81
91
|
|
92
|
+
def __getitem__(
|
93
|
+
self,
|
94
|
+
target: ta.Union[InjectorKey[T], ta.Type[T]],
|
95
|
+
) -> T:
|
96
|
+
return self.provide(target)
|
97
|
+
|
82
98
|
|
83
99
|
###
|
84
100
|
# exceptions
|
@@ -111,7 +127,7 @@ def as_injector_key(o: ta.Any) -> InjectorKey:
|
|
111
127
|
raise TypeError(o)
|
112
128
|
if isinstance(o, InjectorKey):
|
113
129
|
return o
|
114
|
-
if
|
130
|
+
if is_valid_injector_key_cls(o):
|
115
131
|
return InjectorKey(o)
|
116
132
|
raise TypeError(o)
|
117
133
|
|
@@ -443,8 +459,8 @@ class InjectorBinder:
|
|
443
459
|
to_fn = obj
|
444
460
|
if key is None:
|
445
461
|
sig = _injection_signature(obj)
|
446
|
-
|
447
|
-
key = InjectorKey(
|
462
|
+
key_cls = check_valid_injector_key_cls(sig.return_annotation)
|
463
|
+
key = InjectorKey(key_cls)
|
448
464
|
else:
|
449
465
|
if to_const is not None:
|
450
466
|
raise TypeError('Cannot bind instance with to_const')
|
@@ -498,7 +514,7 @@ class InjectorBinder:
|
|
498
514
|
# injector
|
499
515
|
|
500
516
|
|
501
|
-
_INJECTOR_INJECTOR_KEY = InjectorKey(Injector)
|
517
|
+
_INJECTOR_INJECTOR_KEY: InjectorKey[Injector] = InjectorKey(Injector)
|
502
518
|
|
503
519
|
|
504
520
|
class _Injector(Injector):
|
omlish/lite/journald.py
CHANGED
@@ -29,7 +29,6 @@ sd_iovec._fields_ = [
|
|
29
29
|
def sd_libsystemd() -> ta.Any:
|
30
30
|
lib = ct.CDLL('libsystemd.so.0')
|
31
31
|
|
32
|
-
lib.sd_journal_sendv = lib['sd_journal_sendv'] # type: ignore
|
33
32
|
lib.sd_journal_sendv.restype = ct.c_int
|
34
33
|
lib.sd_journal_sendv.argtypes = [ct.POINTER(sd_iovec), ct.c_int]
|
35
34
|
|
omlish/lite/runtime.py
CHANGED
@@ -14,5 +14,4 @@ REQUIRED_PYTHON_VERSION = (3, 8)
|
|
14
14
|
|
15
15
|
def check_runtime_version() -> None:
|
16
16
|
if sys.version_info < REQUIRED_PYTHON_VERSION:
|
17
|
-
raise OSError(
|
18
|
-
f'Requires python {REQUIRED_PYTHON_VERSION}, got {sys.version_info} from {sys.executable}') # noqa
|
17
|
+
raise OSError(f'Requires python {REQUIRED_PYTHON_VERSION}, got {sys.version_info} from {sys.executable}') # noqa
|
omlish/lite/socket.py
ADDED
@@ -0,0 +1,77 @@
|
|
1
|
+
# ruff: noqa: UP006 UP007
|
2
|
+
"""
|
3
|
+
TODO:
|
4
|
+
- SocketClientAddress family / tuple pairs
|
5
|
+
+ codification of https://docs.python.org/3/library/socket.html#socket-families
|
6
|
+
"""
|
7
|
+
import abc
|
8
|
+
import dataclasses as dc
|
9
|
+
import socket
|
10
|
+
import typing as ta
|
11
|
+
|
12
|
+
|
13
|
+
SocketAddress = ta.Any
|
14
|
+
|
15
|
+
|
16
|
+
SocketHandlerFactory = ta.Callable[[SocketAddress, ta.BinaryIO, ta.BinaryIO], 'SocketHandler']
|
17
|
+
|
18
|
+
|
19
|
+
##
|
20
|
+
|
21
|
+
|
22
|
+
@dc.dataclass(frozen=True)
|
23
|
+
class SocketAddressInfoArgs:
|
24
|
+
host: ta.Optional[str]
|
25
|
+
port: ta.Union[str, int, None]
|
26
|
+
family: socket.AddressFamily = socket.AddressFamily.AF_UNSPEC
|
27
|
+
type: int = 0
|
28
|
+
proto: int = 0
|
29
|
+
flags: socket.AddressInfo = socket.AddressInfo(0)
|
30
|
+
|
31
|
+
|
32
|
+
@dc.dataclass(frozen=True)
|
33
|
+
class SocketAddressInfo:
|
34
|
+
family: socket.AddressFamily
|
35
|
+
type: int
|
36
|
+
proto: int
|
37
|
+
canonname: ta.Optional[str]
|
38
|
+
sockaddr: SocketAddress
|
39
|
+
|
40
|
+
|
41
|
+
def get_best_socket_family(
|
42
|
+
host: ta.Optional[str],
|
43
|
+
port: ta.Union[str, int, None],
|
44
|
+
family: ta.Union[int, socket.AddressFamily] = socket.AddressFamily.AF_UNSPEC,
|
45
|
+
) -> ta.Tuple[socket.AddressFamily, SocketAddress]:
|
46
|
+
"""https://github.com/python/cpython/commit/f289084c83190cc72db4a70c58f007ec62e75247"""
|
47
|
+
|
48
|
+
infos = socket.getaddrinfo(
|
49
|
+
host,
|
50
|
+
port,
|
51
|
+
family,
|
52
|
+
type=socket.SOCK_STREAM,
|
53
|
+
flags=socket.AI_PASSIVE,
|
54
|
+
)
|
55
|
+
ai = SocketAddressInfo(*next(iter(infos)))
|
56
|
+
return ai.family, ai.sockaddr
|
57
|
+
|
58
|
+
|
59
|
+
##
|
60
|
+
|
61
|
+
|
62
|
+
class SocketHandler(abc.ABC):
|
63
|
+
def __init__(
|
64
|
+
self,
|
65
|
+
client_address: SocketAddress,
|
66
|
+
rfile: ta.BinaryIO,
|
67
|
+
wfile: ta.BinaryIO,
|
68
|
+
) -> None:
|
69
|
+
super().__init__()
|
70
|
+
|
71
|
+
self._client_address = client_address
|
72
|
+
self._rfile = rfile
|
73
|
+
self._wfile = wfile
|
74
|
+
|
75
|
+
@abc.abstractmethod
|
76
|
+
def handle(self) -> None:
|
77
|
+
raise NotImplementedError
|
@@ -0,0 +1,66 @@
|
|
1
|
+
# ruff: noqa: UP006 UP007
|
2
|
+
import socket
|
3
|
+
import socketserver
|
4
|
+
import typing as ta
|
5
|
+
|
6
|
+
from omlish.lite.check import check_not_none
|
7
|
+
|
8
|
+
from .socket import SocketAddress
|
9
|
+
from .socket import SocketHandlerFactory
|
10
|
+
|
11
|
+
|
12
|
+
##
|
13
|
+
|
14
|
+
|
15
|
+
class SocketServerBaseRequestHandler_: # noqa
|
16
|
+
request: socket.socket
|
17
|
+
client_address: SocketAddress
|
18
|
+
server: socketserver.TCPServer
|
19
|
+
|
20
|
+
|
21
|
+
class SocketServerStreamRequestHandler_(SocketServerBaseRequestHandler_): # noqa
|
22
|
+
rbufsize: int
|
23
|
+
wbufsize: int
|
24
|
+
|
25
|
+
timeout: ta.Optional[float]
|
26
|
+
|
27
|
+
disable_nagle_algorithm: bool
|
28
|
+
|
29
|
+
connection: socket.socket
|
30
|
+
rfile: ta.BinaryIO
|
31
|
+
wfile: ta.BinaryIO
|
32
|
+
|
33
|
+
|
34
|
+
##
|
35
|
+
|
36
|
+
|
37
|
+
class SocketHandlerSocketServerStreamRequestHandler( # type: ignore[misc]
|
38
|
+
socketserver.StreamRequestHandler,
|
39
|
+
SocketServerStreamRequestHandler_,
|
40
|
+
):
|
41
|
+
socket_handler_factory: ta.Optional[SocketHandlerFactory] = None
|
42
|
+
|
43
|
+
def __init__(
|
44
|
+
self,
|
45
|
+
request: socket.socket,
|
46
|
+
client_address: SocketAddress,
|
47
|
+
server: socketserver.TCPServer,
|
48
|
+
*,
|
49
|
+
socket_handler_factory: ta.Optional[SocketHandlerFactory] = None,
|
50
|
+
) -> None:
|
51
|
+
if socket_handler_factory is not None:
|
52
|
+
self.socket_handler_factory = socket_handler_factory
|
53
|
+
|
54
|
+
super().__init__(
|
55
|
+
request,
|
56
|
+
client_address,
|
57
|
+
server,
|
58
|
+
)
|
59
|
+
|
60
|
+
def handle(self) -> None:
|
61
|
+
target = check_not_none(self.socket_handler_factory)(
|
62
|
+
self.client_address,
|
63
|
+
self.rfile, # type: ignore[arg-type]
|
64
|
+
self.wfile, # type: ignore[arg-type]
|
65
|
+
)
|
66
|
+
target.handle()
|