omlish 0.0.0.dev123__py3-none-any.whl → 0.0.0.dev124__py3-none-any.whl
Sign up to get free protection for your applications and to get access to all the features.
- omlish/__about__.py +3 -2
- omlish/__init__.py +3 -0
- omlish/antlr/_runtime/LICENSE.txt +28 -0
- omlish/c3.py +35 -37
- omlish/dataclasses/__init__.py +4 -0
- omlish/dataclasses/utils.py +22 -4
- omlish/http/consts.py +2 -0
- omlish/lite/check.py +5 -0
- omlish/lite/http/__init__.py +0 -0
- omlish/lite/http/coroserver.py +585 -0
- omlish/lite/http/handlers.py +36 -0
- omlish/lite/http/parsing.py +376 -0
- omlish/lite/http/versions.py +17 -0
- omlish/lite/inject.py +21 -5
- omlish/lite/journald.py +0 -1
- omlish/lite/runtime.py +1 -2
- omlish/lite/socket.py +77 -0
- omlish/lite/socketserver.py +66 -0
- omlish-0.0.0.dev124.dist-info/METADATA +100 -0
- {omlish-0.0.0.dev123.dist-info → omlish-0.0.0.dev124.dist-info}/RECORD +24 -16
- {omlish-0.0.0.dev123.dist-info → omlish-0.0.0.dev124.dist-info}/WHEEL +1 -1
- omlish-0.0.0.dev123.dist-info/METADATA +0 -94
- {omlish-0.0.0.dev123.dist-info → omlish-0.0.0.dev124.dist-info}/LICENSE +0 -0
- {omlish-0.0.0.dev123.dist-info → omlish-0.0.0.dev124.dist-info}/entry_points.txt +0 -0
- {omlish-0.0.0.dev123.dist-info → omlish-0.0.0.dev124.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,376 @@
|
|
1
|
+
# ruff: noqa: UP006 UP007
|
2
|
+
import abc
|
3
|
+
import http.client
|
4
|
+
import http.server
|
5
|
+
import io
|
6
|
+
import typing as ta
|
7
|
+
|
8
|
+
from .versions import HttpProtocolVersion
|
9
|
+
from .versions import HttpProtocolVersions
|
10
|
+
|
11
|
+
|
12
|
+
T = ta.TypeVar('T')
|
13
|
+
|
14
|
+
|
15
|
+
HttpHeaders = http.client.HTTPMessage # ta.TypeAlias
|
16
|
+
|
17
|
+
|
18
|
+
##
|
19
|
+
|
20
|
+
|
21
|
+
class ParseHttpRequestResult(abc.ABC): # noqa
|
22
|
+
__slots__ = (
|
23
|
+
'server_version',
|
24
|
+
'request_line',
|
25
|
+
'request_version',
|
26
|
+
'version',
|
27
|
+
'headers',
|
28
|
+
'close_connection',
|
29
|
+
)
|
30
|
+
|
31
|
+
def __init__(
|
32
|
+
self,
|
33
|
+
*,
|
34
|
+
server_version: HttpProtocolVersion,
|
35
|
+
request_line: str,
|
36
|
+
request_version: HttpProtocolVersion,
|
37
|
+
version: HttpProtocolVersion,
|
38
|
+
headers: ta.Optional[HttpHeaders],
|
39
|
+
close_connection: bool,
|
40
|
+
) -> None:
|
41
|
+
super().__init__()
|
42
|
+
|
43
|
+
self.server_version = server_version
|
44
|
+
self.request_line = request_line
|
45
|
+
self.request_version = request_version
|
46
|
+
self.version = version
|
47
|
+
self.headers = headers
|
48
|
+
self.close_connection = close_connection
|
49
|
+
|
50
|
+
def __repr__(self) -> str:
|
51
|
+
return f'{self.__class__.__name__}({", ".join(f"{a}={getattr(self, a)!r}" for a in self.__slots__)})'
|
52
|
+
|
53
|
+
|
54
|
+
class EmptyParsedHttpResult(ParseHttpRequestResult):
|
55
|
+
pass
|
56
|
+
|
57
|
+
|
58
|
+
class ParseHttpRequestError(ParseHttpRequestResult):
|
59
|
+
__slots__ = (
|
60
|
+
'code',
|
61
|
+
'message',
|
62
|
+
*ParseHttpRequestResult.__slots__,
|
63
|
+
)
|
64
|
+
|
65
|
+
def __init__(
|
66
|
+
self,
|
67
|
+
*,
|
68
|
+
code: http.HTTPStatus,
|
69
|
+
message: ta.Union[str, ta.Tuple[str, str]],
|
70
|
+
|
71
|
+
**kwargs: ta.Any,
|
72
|
+
) -> None:
|
73
|
+
super().__init__(**kwargs)
|
74
|
+
|
75
|
+
self.code = code
|
76
|
+
self.message = message
|
77
|
+
|
78
|
+
|
79
|
+
class ParsedHttpRequest(ParseHttpRequestResult):
|
80
|
+
__slots__ = (
|
81
|
+
'method',
|
82
|
+
'path',
|
83
|
+
'headers',
|
84
|
+
'expects_continue',
|
85
|
+
*[a for a in ParseHttpRequestResult.__slots__ if a != 'headers'],
|
86
|
+
)
|
87
|
+
|
88
|
+
def __init__(
|
89
|
+
self,
|
90
|
+
*,
|
91
|
+
method: str,
|
92
|
+
path: str,
|
93
|
+
headers: HttpHeaders,
|
94
|
+
expects_continue: bool,
|
95
|
+
|
96
|
+
**kwargs: ta.Any,
|
97
|
+
) -> None:
|
98
|
+
super().__init__(
|
99
|
+
headers=headers,
|
100
|
+
**kwargs,
|
101
|
+
)
|
102
|
+
|
103
|
+
self.method = method
|
104
|
+
self.path = path
|
105
|
+
self.expects_continue = expects_continue
|
106
|
+
|
107
|
+
headers: HttpHeaders
|
108
|
+
|
109
|
+
|
110
|
+
#
|
111
|
+
|
112
|
+
|
113
|
+
class HttpRequestParser:
|
114
|
+
DEFAULT_SERVER_VERSION = HttpProtocolVersions.HTTP_1_0
|
115
|
+
|
116
|
+
# The default request version. This only affects responses up until the point where the request line is parsed, so
|
117
|
+
# it mainly decides what the client gets back when sending a malformed request line.
|
118
|
+
# Most web servers default to HTTP 0.9, i.e. don't send a status line.
|
119
|
+
DEFAULT_REQUEST_VERSION = HttpProtocolVersions.HTTP_0_9
|
120
|
+
|
121
|
+
#
|
122
|
+
|
123
|
+
DEFAULT_MAX_LINE: int = 0x10000
|
124
|
+
DEFAULT_MAX_HEADERS: int = 100
|
125
|
+
|
126
|
+
#
|
127
|
+
|
128
|
+
def __init__(
|
129
|
+
self,
|
130
|
+
*,
|
131
|
+
server_version: HttpProtocolVersion = DEFAULT_SERVER_VERSION,
|
132
|
+
|
133
|
+
max_line: int = DEFAULT_MAX_LINE,
|
134
|
+
max_headers: int = DEFAULT_MAX_HEADERS,
|
135
|
+
) -> None:
|
136
|
+
super().__init__()
|
137
|
+
|
138
|
+
if server_version >= HttpProtocolVersions.HTTP_2_0:
|
139
|
+
raise ValueError(f'Unsupported protocol version: {server_version}')
|
140
|
+
self._server_version = server_version
|
141
|
+
|
142
|
+
self._max_line = max_line
|
143
|
+
self._max_headers = max_headers
|
144
|
+
|
145
|
+
#
|
146
|
+
|
147
|
+
@property
|
148
|
+
def server_version(self) -> HttpProtocolVersion:
|
149
|
+
return self._server_version
|
150
|
+
|
151
|
+
#
|
152
|
+
|
153
|
+
def _run_read_line_coro(
|
154
|
+
self,
|
155
|
+
gen: ta.Generator[int, bytes, T],
|
156
|
+
read_line: ta.Callable[[int], bytes],
|
157
|
+
) -> T:
|
158
|
+
sz = next(gen)
|
159
|
+
while True:
|
160
|
+
try:
|
161
|
+
sz = gen.send(read_line(sz))
|
162
|
+
except StopIteration as e:
|
163
|
+
return e.value
|
164
|
+
|
165
|
+
#
|
166
|
+
|
167
|
+
def parse_request_version(self, version_str: str) -> HttpProtocolVersion:
|
168
|
+
if not version_str.startswith('HTTP/'):
|
169
|
+
raise ValueError(version_str) # noqa
|
170
|
+
|
171
|
+
base_version_number = version_str.split('/', 1)[1]
|
172
|
+
version_number_parts = base_version_number.split('.')
|
173
|
+
|
174
|
+
# RFC 2145 section 3.1 says there can be only one "." and
|
175
|
+
# - major and minor numbers MUST be treated as separate integers;
|
176
|
+
# - HTTP/2.4 is a lower version than HTTP/2.13, which in turn is lower than HTTP/12.3;
|
177
|
+
# - Leading zeros MUST be ignored by recipients.
|
178
|
+
if len(version_number_parts) != 2:
|
179
|
+
raise ValueError(version_number_parts) # noqa
|
180
|
+
if any(not component.isdigit() for component in version_number_parts):
|
181
|
+
raise ValueError('non digit in http version') # noqa
|
182
|
+
if any(len(component) > 10 for component in version_number_parts):
|
183
|
+
raise ValueError('unreasonable length http version') # noqa
|
184
|
+
|
185
|
+
return HttpProtocolVersion(
|
186
|
+
int(version_number_parts[0]),
|
187
|
+
int(version_number_parts[1]),
|
188
|
+
)
|
189
|
+
|
190
|
+
#
|
191
|
+
|
192
|
+
def coro_read_raw_headers(self) -> ta.Generator[int, bytes, ta.List[bytes]]:
|
193
|
+
raw_headers: ta.List[bytes] = []
|
194
|
+
while True:
|
195
|
+
line = yield self._max_line + 1
|
196
|
+
if len(line) > self._max_line:
|
197
|
+
raise http.client.LineTooLong('header line')
|
198
|
+
raw_headers.append(line)
|
199
|
+
if len(raw_headers) > self._max_headers:
|
200
|
+
raise http.client.HTTPException(f'got more than {self._max_headers} headers')
|
201
|
+
if line in (b'\r\n', b'\n', b''):
|
202
|
+
break
|
203
|
+
return raw_headers
|
204
|
+
|
205
|
+
def read_raw_headers(self, read_line: ta.Callable[[int], bytes]) -> ta.List[bytes]:
|
206
|
+
return self._run_read_line_coro(self.coro_read_raw_headers(), read_line)
|
207
|
+
|
208
|
+
def parse_raw_headers(self, raw_headers: ta.Sequence[bytes]) -> HttpHeaders:
|
209
|
+
return http.client.parse_headers(io.BytesIO(b''.join(raw_headers)))
|
210
|
+
|
211
|
+
#
|
212
|
+
|
213
|
+
def coro_parse(self) -> ta.Generator[int, bytes, ParseHttpRequestResult]:
|
214
|
+
raw_request_line = yield self._max_line + 1
|
215
|
+
|
216
|
+
# Common result kwargs
|
217
|
+
|
218
|
+
request_line = '-'
|
219
|
+
request_version = self.DEFAULT_REQUEST_VERSION
|
220
|
+
|
221
|
+
# Set to min(server, request) when it gets that far, but if it fails before that the server authoritatively
|
222
|
+
# responds with its own version.
|
223
|
+
version = self._server_version
|
224
|
+
|
225
|
+
headers: HttpHeaders | None = None
|
226
|
+
|
227
|
+
close_connection = True
|
228
|
+
|
229
|
+
def result_kwargs():
|
230
|
+
return dict(
|
231
|
+
server_version=self._server_version,
|
232
|
+
request_line=request_line,
|
233
|
+
request_version=request_version,
|
234
|
+
version=version,
|
235
|
+
headers=headers,
|
236
|
+
close_connection=close_connection,
|
237
|
+
)
|
238
|
+
|
239
|
+
# Decode line
|
240
|
+
|
241
|
+
if len(raw_request_line) > self._max_line:
|
242
|
+
return ParseHttpRequestError(
|
243
|
+
code=http.HTTPStatus.REQUEST_URI_TOO_LONG,
|
244
|
+
message='Request line too long',
|
245
|
+
**result_kwargs(),
|
246
|
+
)
|
247
|
+
|
248
|
+
if not raw_request_line:
|
249
|
+
return EmptyParsedHttpResult(**result_kwargs())
|
250
|
+
|
251
|
+
request_line = raw_request_line.decode('iso-8859-1').rstrip('\r\n')
|
252
|
+
|
253
|
+
# Split words
|
254
|
+
|
255
|
+
words = request_line.split()
|
256
|
+
if len(words) == 0:
|
257
|
+
return EmptyParsedHttpResult(**result_kwargs())
|
258
|
+
|
259
|
+
# Parse and set version
|
260
|
+
|
261
|
+
if len(words) >= 3: # Enough to determine protocol version
|
262
|
+
version_str = words[-1]
|
263
|
+
try:
|
264
|
+
request_version = self.parse_request_version(version_str)
|
265
|
+
|
266
|
+
except (ValueError, IndexError):
|
267
|
+
return ParseHttpRequestError(
|
268
|
+
code=http.HTTPStatus.BAD_REQUEST,
|
269
|
+
message=f'Bad request version ({version_str!r})',
|
270
|
+
**result_kwargs(),
|
271
|
+
)
|
272
|
+
|
273
|
+
if (
|
274
|
+
request_version < HttpProtocolVersions.HTTP_0_9 or
|
275
|
+
request_version >= HttpProtocolVersions.HTTP_2_0
|
276
|
+
):
|
277
|
+
return ParseHttpRequestError(
|
278
|
+
code=http.HTTPStatus.HTTP_VERSION_NOT_SUPPORTED,
|
279
|
+
message=f'Invalid HTTP version ({version_str})',
|
280
|
+
**result_kwargs(),
|
281
|
+
)
|
282
|
+
|
283
|
+
version = min([self._server_version, request_version])
|
284
|
+
|
285
|
+
if version >= HttpProtocolVersions.HTTP_1_1:
|
286
|
+
close_connection = False
|
287
|
+
|
288
|
+
# Verify word count
|
289
|
+
|
290
|
+
if not 2 <= len(words) <= 3:
|
291
|
+
return ParseHttpRequestError(
|
292
|
+
code=http.HTTPStatus.BAD_REQUEST,
|
293
|
+
message=f'Bad request syntax ({request_line!r})',
|
294
|
+
**result_kwargs(),
|
295
|
+
)
|
296
|
+
|
297
|
+
# Parse method and path
|
298
|
+
|
299
|
+
method, path = words[:2]
|
300
|
+
if len(words) == 2:
|
301
|
+
close_connection = True
|
302
|
+
if method != 'GET':
|
303
|
+
return ParseHttpRequestError(
|
304
|
+
code=http.HTTPStatus.BAD_REQUEST,
|
305
|
+
message=f'Bad HTTP/0.9 request type ({method!r})',
|
306
|
+
**result_kwargs(),
|
307
|
+
)
|
308
|
+
|
309
|
+
# gh-87389: The purpose of replacing '//' with '/' is to protect against open redirect attacks possibly
|
310
|
+
# triggered if the path starts with '//' because http clients treat //path as an absolute URI without scheme
|
311
|
+
# (similar to http://path) rather than a path.
|
312
|
+
if path.startswith('//'):
|
313
|
+
path = '/' + path.lstrip('/') # Reduce to a single /
|
314
|
+
|
315
|
+
# Parse headers
|
316
|
+
|
317
|
+
try:
|
318
|
+
raw_gen = self.coro_read_raw_headers()
|
319
|
+
raw_sz = next(raw_gen)
|
320
|
+
while True:
|
321
|
+
buf = yield raw_sz
|
322
|
+
try:
|
323
|
+
raw_sz = raw_gen.send(buf)
|
324
|
+
except StopIteration as e:
|
325
|
+
raw_headers = e.value
|
326
|
+
break
|
327
|
+
|
328
|
+
headers = self.parse_raw_headers(raw_headers)
|
329
|
+
|
330
|
+
except http.client.LineTooLong as err:
|
331
|
+
return ParseHttpRequestError(
|
332
|
+
code=http.HTTPStatus.REQUEST_HEADER_FIELDS_TOO_LARGE,
|
333
|
+
message=('Line too long', str(err)),
|
334
|
+
**result_kwargs(),
|
335
|
+
)
|
336
|
+
|
337
|
+
except http.client.HTTPException as err:
|
338
|
+
return ParseHttpRequestError(
|
339
|
+
code=http.HTTPStatus.REQUEST_HEADER_FIELDS_TOO_LARGE,
|
340
|
+
message=('Too many headers', str(err)),
|
341
|
+
**result_kwargs(),
|
342
|
+
)
|
343
|
+
|
344
|
+
# Check for connection directive
|
345
|
+
|
346
|
+
conn_type = headers.get('Connection', '')
|
347
|
+
if conn_type.lower() == 'close':
|
348
|
+
close_connection = True
|
349
|
+
elif (
|
350
|
+
conn_type.lower() == 'keep-alive' and
|
351
|
+
version >= HttpProtocolVersions.HTTP_1_1
|
352
|
+
):
|
353
|
+
close_connection = False
|
354
|
+
|
355
|
+
# Check for expect directive
|
356
|
+
|
357
|
+
expect = headers.get('Expect', '')
|
358
|
+
if (
|
359
|
+
expect.lower() == '100-continue' and
|
360
|
+
version >= HttpProtocolVersions.HTTP_1_1
|
361
|
+
):
|
362
|
+
expects_continue = True
|
363
|
+
else:
|
364
|
+
expects_continue = False
|
365
|
+
|
366
|
+
# Return
|
367
|
+
|
368
|
+
return ParsedHttpRequest(
|
369
|
+
method=method,
|
370
|
+
path=path,
|
371
|
+
expects_continue=expects_continue,
|
372
|
+
**result_kwargs(),
|
373
|
+
)
|
374
|
+
|
375
|
+
def parse(self, read_line: ta.Callable[[int], bytes]) -> ParseHttpRequestResult:
|
376
|
+
return self._run_read_line_coro(self.coro_parse(), read_line)
|
@@ -0,0 +1,17 @@
|
|
1
|
+
# ruff: noqa: UP006 UP007
|
2
|
+
import typing as ta
|
3
|
+
|
4
|
+
|
5
|
+
class HttpProtocolVersion(ta.NamedTuple):
|
6
|
+
major: int
|
7
|
+
minor: int
|
8
|
+
|
9
|
+
def __str__(self) -> str:
|
10
|
+
return f'HTTP/{self.major}.{self.minor}'
|
11
|
+
|
12
|
+
|
13
|
+
class HttpProtocolVersions:
|
14
|
+
HTTP_0_9 = HttpProtocolVersion(0, 9)
|
15
|
+
HTTP_1_0 = HttpProtocolVersion(1, 0)
|
16
|
+
HTTP_1_1 = HttpProtocolVersion(1, 1)
|
17
|
+
HTTP_2_0 = HttpProtocolVersion(2, 0)
|
omlish/lite/inject.py
CHANGED
@@ -30,12 +30,22 @@ InjectorBindingOrBindings = ta.Union['InjectorBinding', 'InjectorBindings']
|
|
30
30
|
|
31
31
|
|
32
32
|
@dc.dataclass(frozen=True)
|
33
|
-
class InjectorKey:
|
33
|
+
class InjectorKey(ta.Generic[T]):
|
34
34
|
cls: InjectorKeyCls
|
35
35
|
tag: ta.Any = None
|
36
36
|
array: bool = False
|
37
37
|
|
38
38
|
|
39
|
+
def is_valid_injector_key_cls(cls: ta.Any) -> bool:
|
40
|
+
return isinstance(cls, type) or is_new_type(cls)
|
41
|
+
|
42
|
+
|
43
|
+
def check_valid_injector_key_cls(cls: T) -> T:
|
44
|
+
if not is_valid_injector_key_cls(cls):
|
45
|
+
raise TypeError(cls)
|
46
|
+
return cls
|
47
|
+
|
48
|
+
|
39
49
|
##
|
40
50
|
|
41
51
|
|
@@ -79,6 +89,12 @@ class Injector(abc.ABC):
|
|
79
89
|
def inject(self, obj: ta.Any) -> ta.Any:
|
80
90
|
raise NotImplementedError
|
81
91
|
|
92
|
+
def __getitem__(
|
93
|
+
self,
|
94
|
+
target: ta.Union[InjectorKey[T], ta.Type[T]],
|
95
|
+
) -> T:
|
96
|
+
return self.provide(target)
|
97
|
+
|
82
98
|
|
83
99
|
###
|
84
100
|
# exceptions
|
@@ -111,7 +127,7 @@ def as_injector_key(o: ta.Any) -> InjectorKey:
|
|
111
127
|
raise TypeError(o)
|
112
128
|
if isinstance(o, InjectorKey):
|
113
129
|
return o
|
114
|
-
if
|
130
|
+
if is_valid_injector_key_cls(o):
|
115
131
|
return InjectorKey(o)
|
116
132
|
raise TypeError(o)
|
117
133
|
|
@@ -443,8 +459,8 @@ class InjectorBinder:
|
|
443
459
|
to_fn = obj
|
444
460
|
if key is None:
|
445
461
|
sig = _injection_signature(obj)
|
446
|
-
|
447
|
-
key = InjectorKey(
|
462
|
+
key_cls = check_valid_injector_key_cls(sig.return_annotation)
|
463
|
+
key = InjectorKey(key_cls)
|
448
464
|
else:
|
449
465
|
if to_const is not None:
|
450
466
|
raise TypeError('Cannot bind instance with to_const')
|
@@ -498,7 +514,7 @@ class InjectorBinder:
|
|
498
514
|
# injector
|
499
515
|
|
500
516
|
|
501
|
-
_INJECTOR_INJECTOR_KEY = InjectorKey(Injector)
|
517
|
+
_INJECTOR_INJECTOR_KEY: InjectorKey[Injector] = InjectorKey(Injector)
|
502
518
|
|
503
519
|
|
504
520
|
class _Injector(Injector):
|
omlish/lite/journald.py
CHANGED
@@ -29,7 +29,6 @@ sd_iovec._fields_ = [
|
|
29
29
|
def sd_libsystemd() -> ta.Any:
|
30
30
|
lib = ct.CDLL('libsystemd.so.0')
|
31
31
|
|
32
|
-
lib.sd_journal_sendv = lib['sd_journal_sendv'] # type: ignore
|
33
32
|
lib.sd_journal_sendv.restype = ct.c_int
|
34
33
|
lib.sd_journal_sendv.argtypes = [ct.POINTER(sd_iovec), ct.c_int]
|
35
34
|
|
omlish/lite/runtime.py
CHANGED
@@ -14,5 +14,4 @@ REQUIRED_PYTHON_VERSION = (3, 8)
|
|
14
14
|
|
15
15
|
def check_runtime_version() -> None:
|
16
16
|
if sys.version_info < REQUIRED_PYTHON_VERSION:
|
17
|
-
raise OSError(
|
18
|
-
f'Requires python {REQUIRED_PYTHON_VERSION}, got {sys.version_info} from {sys.executable}') # noqa
|
17
|
+
raise OSError(f'Requires python {REQUIRED_PYTHON_VERSION}, got {sys.version_info} from {sys.executable}') # noqa
|
omlish/lite/socket.py
ADDED
@@ -0,0 +1,77 @@
|
|
1
|
+
# ruff: noqa: UP006 UP007
|
2
|
+
"""
|
3
|
+
TODO:
|
4
|
+
- SocketClientAddress family / tuple pairs
|
5
|
+
+ codification of https://docs.python.org/3/library/socket.html#socket-families
|
6
|
+
"""
|
7
|
+
import abc
|
8
|
+
import dataclasses as dc
|
9
|
+
import socket
|
10
|
+
import typing as ta
|
11
|
+
|
12
|
+
|
13
|
+
SocketAddress = ta.Any
|
14
|
+
|
15
|
+
|
16
|
+
SocketHandlerFactory = ta.Callable[[SocketAddress, ta.BinaryIO, ta.BinaryIO], 'SocketHandler']
|
17
|
+
|
18
|
+
|
19
|
+
##
|
20
|
+
|
21
|
+
|
22
|
+
@dc.dataclass(frozen=True)
|
23
|
+
class SocketAddressInfoArgs:
|
24
|
+
host: ta.Optional[str]
|
25
|
+
port: ta.Union[str, int, None]
|
26
|
+
family: socket.AddressFamily = socket.AddressFamily.AF_UNSPEC
|
27
|
+
type: int = 0
|
28
|
+
proto: int = 0
|
29
|
+
flags: socket.AddressInfo = socket.AddressInfo(0)
|
30
|
+
|
31
|
+
|
32
|
+
@dc.dataclass(frozen=True)
|
33
|
+
class SocketAddressInfo:
|
34
|
+
family: socket.AddressFamily
|
35
|
+
type: int
|
36
|
+
proto: int
|
37
|
+
canonname: ta.Optional[str]
|
38
|
+
sockaddr: SocketAddress
|
39
|
+
|
40
|
+
|
41
|
+
def get_best_socket_family(
|
42
|
+
host: ta.Optional[str],
|
43
|
+
port: ta.Union[str, int, None],
|
44
|
+
family: ta.Union[int, socket.AddressFamily] = socket.AddressFamily.AF_UNSPEC,
|
45
|
+
) -> ta.Tuple[socket.AddressFamily, SocketAddress]:
|
46
|
+
"""https://github.com/python/cpython/commit/f289084c83190cc72db4a70c58f007ec62e75247"""
|
47
|
+
|
48
|
+
infos = socket.getaddrinfo(
|
49
|
+
host,
|
50
|
+
port,
|
51
|
+
family,
|
52
|
+
type=socket.SOCK_STREAM,
|
53
|
+
flags=socket.AI_PASSIVE,
|
54
|
+
)
|
55
|
+
ai = SocketAddressInfo(*next(iter(infos)))
|
56
|
+
return ai.family, ai.sockaddr
|
57
|
+
|
58
|
+
|
59
|
+
##
|
60
|
+
|
61
|
+
|
62
|
+
class SocketHandler(abc.ABC):
|
63
|
+
def __init__(
|
64
|
+
self,
|
65
|
+
client_address: SocketAddress,
|
66
|
+
rfile: ta.BinaryIO,
|
67
|
+
wfile: ta.BinaryIO,
|
68
|
+
) -> None:
|
69
|
+
super().__init__()
|
70
|
+
|
71
|
+
self._client_address = client_address
|
72
|
+
self._rfile = rfile
|
73
|
+
self._wfile = wfile
|
74
|
+
|
75
|
+
@abc.abstractmethod
|
76
|
+
def handle(self) -> None:
|
77
|
+
raise NotImplementedError
|
@@ -0,0 +1,66 @@
|
|
1
|
+
# ruff: noqa: UP006 UP007
|
2
|
+
import socket
|
3
|
+
import socketserver
|
4
|
+
import typing as ta
|
5
|
+
|
6
|
+
from omlish.lite.check import check_not_none
|
7
|
+
|
8
|
+
from .socket import SocketAddress
|
9
|
+
from .socket import SocketHandlerFactory
|
10
|
+
|
11
|
+
|
12
|
+
##
|
13
|
+
|
14
|
+
|
15
|
+
class SocketServerBaseRequestHandler_: # noqa
|
16
|
+
request: socket.socket
|
17
|
+
client_address: SocketAddress
|
18
|
+
server: socketserver.TCPServer
|
19
|
+
|
20
|
+
|
21
|
+
class SocketServerStreamRequestHandler_(SocketServerBaseRequestHandler_): # noqa
|
22
|
+
rbufsize: int
|
23
|
+
wbufsize: int
|
24
|
+
|
25
|
+
timeout: ta.Optional[float]
|
26
|
+
|
27
|
+
disable_nagle_algorithm: bool
|
28
|
+
|
29
|
+
connection: socket.socket
|
30
|
+
rfile: ta.BinaryIO
|
31
|
+
wfile: ta.BinaryIO
|
32
|
+
|
33
|
+
|
34
|
+
##
|
35
|
+
|
36
|
+
|
37
|
+
class SocketHandlerSocketServerStreamRequestHandler( # type: ignore[misc]
|
38
|
+
socketserver.StreamRequestHandler,
|
39
|
+
SocketServerStreamRequestHandler_,
|
40
|
+
):
|
41
|
+
socket_handler_factory: ta.Optional[SocketHandlerFactory] = None
|
42
|
+
|
43
|
+
def __init__(
|
44
|
+
self,
|
45
|
+
request: socket.socket,
|
46
|
+
client_address: SocketAddress,
|
47
|
+
server: socketserver.TCPServer,
|
48
|
+
*,
|
49
|
+
socket_handler_factory: ta.Optional[SocketHandlerFactory] = None,
|
50
|
+
) -> None:
|
51
|
+
if socket_handler_factory is not None:
|
52
|
+
self.socket_handler_factory = socket_handler_factory
|
53
|
+
|
54
|
+
super().__init__(
|
55
|
+
request,
|
56
|
+
client_address,
|
57
|
+
server,
|
58
|
+
)
|
59
|
+
|
60
|
+
def handle(self) -> None:
|
61
|
+
target = check_not_none(self.socket_handler_factory)(
|
62
|
+
self.client_address,
|
63
|
+
self.rfile, # type: ignore[arg-type]
|
64
|
+
self.wfile, # type: ignore[arg-type]
|
65
|
+
)
|
66
|
+
target.handle()
|