omlish 0.0.0.dev122__py3-none-any.whl → 0.0.0.dev124__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
@@ -0,0 +1,376 @@
1
+ # ruff: noqa: UP006 UP007
2
+ import abc
3
+ import http.client
4
+ import http.server
5
+ import io
6
+ import typing as ta
7
+
8
+ from .versions import HttpProtocolVersion
9
+ from .versions import HttpProtocolVersions
10
+
11
+
12
+ T = ta.TypeVar('T')
13
+
14
+
15
+ HttpHeaders = http.client.HTTPMessage # ta.TypeAlias
16
+
17
+
18
+ ##
19
+
20
+
21
+ class ParseHttpRequestResult(abc.ABC): # noqa
22
+ __slots__ = (
23
+ 'server_version',
24
+ 'request_line',
25
+ 'request_version',
26
+ 'version',
27
+ 'headers',
28
+ 'close_connection',
29
+ )
30
+
31
+ def __init__(
32
+ self,
33
+ *,
34
+ server_version: HttpProtocolVersion,
35
+ request_line: str,
36
+ request_version: HttpProtocolVersion,
37
+ version: HttpProtocolVersion,
38
+ headers: ta.Optional[HttpHeaders],
39
+ close_connection: bool,
40
+ ) -> None:
41
+ super().__init__()
42
+
43
+ self.server_version = server_version
44
+ self.request_line = request_line
45
+ self.request_version = request_version
46
+ self.version = version
47
+ self.headers = headers
48
+ self.close_connection = close_connection
49
+
50
+ def __repr__(self) -> str:
51
+ return f'{self.__class__.__name__}({", ".join(f"{a}={getattr(self, a)!r}" for a in self.__slots__)})'
52
+
53
+
54
+ class EmptyParsedHttpResult(ParseHttpRequestResult):
55
+ pass
56
+
57
+
58
+ class ParseHttpRequestError(ParseHttpRequestResult):
59
+ __slots__ = (
60
+ 'code',
61
+ 'message',
62
+ *ParseHttpRequestResult.__slots__,
63
+ )
64
+
65
+ def __init__(
66
+ self,
67
+ *,
68
+ code: http.HTTPStatus,
69
+ message: ta.Union[str, ta.Tuple[str, str]],
70
+
71
+ **kwargs: ta.Any,
72
+ ) -> None:
73
+ super().__init__(**kwargs)
74
+
75
+ self.code = code
76
+ self.message = message
77
+
78
+
79
+ class ParsedHttpRequest(ParseHttpRequestResult):
80
+ __slots__ = (
81
+ 'method',
82
+ 'path',
83
+ 'headers',
84
+ 'expects_continue',
85
+ *[a for a in ParseHttpRequestResult.__slots__ if a != 'headers'],
86
+ )
87
+
88
+ def __init__(
89
+ self,
90
+ *,
91
+ method: str,
92
+ path: str,
93
+ headers: HttpHeaders,
94
+ expects_continue: bool,
95
+
96
+ **kwargs: ta.Any,
97
+ ) -> None:
98
+ super().__init__(
99
+ headers=headers,
100
+ **kwargs,
101
+ )
102
+
103
+ self.method = method
104
+ self.path = path
105
+ self.expects_continue = expects_continue
106
+
107
+ headers: HttpHeaders
108
+
109
+
110
+ #
111
+
112
+
113
+ class HttpRequestParser:
114
+ DEFAULT_SERVER_VERSION = HttpProtocolVersions.HTTP_1_0
115
+
116
+ # The default request version. This only affects responses up until the point where the request line is parsed, so
117
+ # it mainly decides what the client gets back when sending a malformed request line.
118
+ # Most web servers default to HTTP 0.9, i.e. don't send a status line.
119
+ DEFAULT_REQUEST_VERSION = HttpProtocolVersions.HTTP_0_9
120
+
121
+ #
122
+
123
+ DEFAULT_MAX_LINE: int = 0x10000
124
+ DEFAULT_MAX_HEADERS: int = 100
125
+
126
+ #
127
+
128
+ def __init__(
129
+ self,
130
+ *,
131
+ server_version: HttpProtocolVersion = DEFAULT_SERVER_VERSION,
132
+
133
+ max_line: int = DEFAULT_MAX_LINE,
134
+ max_headers: int = DEFAULT_MAX_HEADERS,
135
+ ) -> None:
136
+ super().__init__()
137
+
138
+ if server_version >= HttpProtocolVersions.HTTP_2_0:
139
+ raise ValueError(f'Unsupported protocol version: {server_version}')
140
+ self._server_version = server_version
141
+
142
+ self._max_line = max_line
143
+ self._max_headers = max_headers
144
+
145
+ #
146
+
147
+ @property
148
+ def server_version(self) -> HttpProtocolVersion:
149
+ return self._server_version
150
+
151
+ #
152
+
153
+ def _run_read_line_coro(
154
+ self,
155
+ gen: ta.Generator[int, bytes, T],
156
+ read_line: ta.Callable[[int], bytes],
157
+ ) -> T:
158
+ sz = next(gen)
159
+ while True:
160
+ try:
161
+ sz = gen.send(read_line(sz))
162
+ except StopIteration as e:
163
+ return e.value
164
+
165
+ #
166
+
167
+ def parse_request_version(self, version_str: str) -> HttpProtocolVersion:
168
+ if not version_str.startswith('HTTP/'):
169
+ raise ValueError(version_str) # noqa
170
+
171
+ base_version_number = version_str.split('/', 1)[1]
172
+ version_number_parts = base_version_number.split('.')
173
+
174
+ # RFC 2145 section 3.1 says there can be only one "." and
175
+ # - major and minor numbers MUST be treated as separate integers;
176
+ # - HTTP/2.4 is a lower version than HTTP/2.13, which in turn is lower than HTTP/12.3;
177
+ # - Leading zeros MUST be ignored by recipients.
178
+ if len(version_number_parts) != 2:
179
+ raise ValueError(version_number_parts) # noqa
180
+ if any(not component.isdigit() for component in version_number_parts):
181
+ raise ValueError('non digit in http version') # noqa
182
+ if any(len(component) > 10 for component in version_number_parts):
183
+ raise ValueError('unreasonable length http version') # noqa
184
+
185
+ return HttpProtocolVersion(
186
+ int(version_number_parts[0]),
187
+ int(version_number_parts[1]),
188
+ )
189
+
190
+ #
191
+
192
+ def coro_read_raw_headers(self) -> ta.Generator[int, bytes, ta.List[bytes]]:
193
+ raw_headers: ta.List[bytes] = []
194
+ while True:
195
+ line = yield self._max_line + 1
196
+ if len(line) > self._max_line:
197
+ raise http.client.LineTooLong('header line')
198
+ raw_headers.append(line)
199
+ if len(raw_headers) > self._max_headers:
200
+ raise http.client.HTTPException(f'got more than {self._max_headers} headers')
201
+ if line in (b'\r\n', b'\n', b''):
202
+ break
203
+ return raw_headers
204
+
205
+ def read_raw_headers(self, read_line: ta.Callable[[int], bytes]) -> ta.List[bytes]:
206
+ return self._run_read_line_coro(self.coro_read_raw_headers(), read_line)
207
+
208
+ def parse_raw_headers(self, raw_headers: ta.Sequence[bytes]) -> HttpHeaders:
209
+ return http.client.parse_headers(io.BytesIO(b''.join(raw_headers)))
210
+
211
+ #
212
+
213
+ def coro_parse(self) -> ta.Generator[int, bytes, ParseHttpRequestResult]:
214
+ raw_request_line = yield self._max_line + 1
215
+
216
+ # Common result kwargs
217
+
218
+ request_line = '-'
219
+ request_version = self.DEFAULT_REQUEST_VERSION
220
+
221
+ # Set to min(server, request) when it gets that far, but if it fails before that the server authoritatively
222
+ # responds with its own version.
223
+ version = self._server_version
224
+
225
+ headers: HttpHeaders | None = None
226
+
227
+ close_connection = True
228
+
229
+ def result_kwargs():
230
+ return dict(
231
+ server_version=self._server_version,
232
+ request_line=request_line,
233
+ request_version=request_version,
234
+ version=version,
235
+ headers=headers,
236
+ close_connection=close_connection,
237
+ )
238
+
239
+ # Decode line
240
+
241
+ if len(raw_request_line) > self._max_line:
242
+ return ParseHttpRequestError(
243
+ code=http.HTTPStatus.REQUEST_URI_TOO_LONG,
244
+ message='Request line too long',
245
+ **result_kwargs(),
246
+ )
247
+
248
+ if not raw_request_line:
249
+ return EmptyParsedHttpResult(**result_kwargs())
250
+
251
+ request_line = raw_request_line.decode('iso-8859-1').rstrip('\r\n')
252
+
253
+ # Split words
254
+
255
+ words = request_line.split()
256
+ if len(words) == 0:
257
+ return EmptyParsedHttpResult(**result_kwargs())
258
+
259
+ # Parse and set version
260
+
261
+ if len(words) >= 3: # Enough to determine protocol version
262
+ version_str = words[-1]
263
+ try:
264
+ request_version = self.parse_request_version(version_str)
265
+
266
+ except (ValueError, IndexError):
267
+ return ParseHttpRequestError(
268
+ code=http.HTTPStatus.BAD_REQUEST,
269
+ message=f'Bad request version ({version_str!r})',
270
+ **result_kwargs(),
271
+ )
272
+
273
+ if (
274
+ request_version < HttpProtocolVersions.HTTP_0_9 or
275
+ request_version >= HttpProtocolVersions.HTTP_2_0
276
+ ):
277
+ return ParseHttpRequestError(
278
+ code=http.HTTPStatus.HTTP_VERSION_NOT_SUPPORTED,
279
+ message=f'Invalid HTTP version ({version_str})',
280
+ **result_kwargs(),
281
+ )
282
+
283
+ version = min([self._server_version, request_version])
284
+
285
+ if version >= HttpProtocolVersions.HTTP_1_1:
286
+ close_connection = False
287
+
288
+ # Verify word count
289
+
290
+ if not 2 <= len(words) <= 3:
291
+ return ParseHttpRequestError(
292
+ code=http.HTTPStatus.BAD_REQUEST,
293
+ message=f'Bad request syntax ({request_line!r})',
294
+ **result_kwargs(),
295
+ )
296
+
297
+ # Parse method and path
298
+
299
+ method, path = words[:2]
300
+ if len(words) == 2:
301
+ close_connection = True
302
+ if method != 'GET':
303
+ return ParseHttpRequestError(
304
+ code=http.HTTPStatus.BAD_REQUEST,
305
+ message=f'Bad HTTP/0.9 request type ({method!r})',
306
+ **result_kwargs(),
307
+ )
308
+
309
+ # gh-87389: The purpose of replacing '//' with '/' is to protect against open redirect attacks possibly
310
+ # triggered if the path starts with '//' because http clients treat //path as an absolute URI without scheme
311
+ # (similar to http://path) rather than a path.
312
+ if path.startswith('//'):
313
+ path = '/' + path.lstrip('/') # Reduce to a single /
314
+
315
+ # Parse headers
316
+
317
+ try:
318
+ raw_gen = self.coro_read_raw_headers()
319
+ raw_sz = next(raw_gen)
320
+ while True:
321
+ buf = yield raw_sz
322
+ try:
323
+ raw_sz = raw_gen.send(buf)
324
+ except StopIteration as e:
325
+ raw_headers = e.value
326
+ break
327
+
328
+ headers = self.parse_raw_headers(raw_headers)
329
+
330
+ except http.client.LineTooLong as err:
331
+ return ParseHttpRequestError(
332
+ code=http.HTTPStatus.REQUEST_HEADER_FIELDS_TOO_LARGE,
333
+ message=('Line too long', str(err)),
334
+ **result_kwargs(),
335
+ )
336
+
337
+ except http.client.HTTPException as err:
338
+ return ParseHttpRequestError(
339
+ code=http.HTTPStatus.REQUEST_HEADER_FIELDS_TOO_LARGE,
340
+ message=('Too many headers', str(err)),
341
+ **result_kwargs(),
342
+ )
343
+
344
+ # Check for connection directive
345
+
346
+ conn_type = headers.get('Connection', '')
347
+ if conn_type.lower() == 'close':
348
+ close_connection = True
349
+ elif (
350
+ conn_type.lower() == 'keep-alive' and
351
+ version >= HttpProtocolVersions.HTTP_1_1
352
+ ):
353
+ close_connection = False
354
+
355
+ # Check for expect directive
356
+
357
+ expect = headers.get('Expect', '')
358
+ if (
359
+ expect.lower() == '100-continue' and
360
+ version >= HttpProtocolVersions.HTTP_1_1
361
+ ):
362
+ expects_continue = True
363
+ else:
364
+ expects_continue = False
365
+
366
+ # Return
367
+
368
+ return ParsedHttpRequest(
369
+ method=method,
370
+ path=path,
371
+ expects_continue=expects_continue,
372
+ **result_kwargs(),
373
+ )
374
+
375
+ def parse(self, read_line: ta.Callable[[int], bytes]) -> ParseHttpRequestResult:
376
+ return self._run_read_line_coro(self.coro_parse(), read_line)
@@ -0,0 +1,17 @@
1
+ # ruff: noqa: UP006 UP007
2
+ import typing as ta
3
+
4
+
5
+ class HttpProtocolVersion(ta.NamedTuple):
6
+ major: int
7
+ minor: int
8
+
9
+ def __str__(self) -> str:
10
+ return f'HTTP/{self.major}.{self.minor}'
11
+
12
+
13
+ class HttpProtocolVersions:
14
+ HTTP_0_9 = HttpProtocolVersion(0, 9)
15
+ HTTP_1_0 = HttpProtocolVersion(1, 0)
16
+ HTTP_1_1 = HttpProtocolVersion(1, 1)
17
+ HTTP_2_0 = HttpProtocolVersion(2, 0)
omlish/lite/inject.py CHANGED
@@ -30,12 +30,22 @@ InjectorBindingOrBindings = ta.Union['InjectorBinding', 'InjectorBindings']
30
30
 
31
31
 
32
32
  @dc.dataclass(frozen=True)
33
- class InjectorKey:
33
+ class InjectorKey(ta.Generic[T]):
34
34
  cls: InjectorKeyCls
35
35
  tag: ta.Any = None
36
36
  array: bool = False
37
37
 
38
38
 
39
+ def is_valid_injector_key_cls(cls: ta.Any) -> bool:
40
+ return isinstance(cls, type) or is_new_type(cls)
41
+
42
+
43
+ def check_valid_injector_key_cls(cls: T) -> T:
44
+ if not is_valid_injector_key_cls(cls):
45
+ raise TypeError(cls)
46
+ return cls
47
+
48
+
39
49
  ##
40
50
 
41
51
 
@@ -79,6 +89,12 @@ class Injector(abc.ABC):
79
89
  def inject(self, obj: ta.Any) -> ta.Any:
80
90
  raise NotImplementedError
81
91
 
92
+ def __getitem__(
93
+ self,
94
+ target: ta.Union[InjectorKey[T], ta.Type[T]],
95
+ ) -> T:
96
+ return self.provide(target)
97
+
82
98
 
83
99
  ###
84
100
  # exceptions
@@ -111,7 +127,7 @@ def as_injector_key(o: ta.Any) -> InjectorKey:
111
127
  raise TypeError(o)
112
128
  if isinstance(o, InjectorKey):
113
129
  return o
114
- if isinstance(o, type) or is_new_type(o):
130
+ if is_valid_injector_key_cls(o):
115
131
  return InjectorKey(o)
116
132
  raise TypeError(o)
117
133
 
@@ -443,8 +459,8 @@ class InjectorBinder:
443
459
  to_fn = obj
444
460
  if key is None:
445
461
  sig = _injection_signature(obj)
446
- ty = check_isinstance(sig.return_annotation, type)
447
- key = InjectorKey(ty)
462
+ key_cls = check_valid_injector_key_cls(sig.return_annotation)
463
+ key = InjectorKey(key_cls)
448
464
  else:
449
465
  if to_const is not None:
450
466
  raise TypeError('Cannot bind instance with to_const')
@@ -498,7 +514,7 @@ class InjectorBinder:
498
514
  # injector
499
515
 
500
516
 
501
- _INJECTOR_INJECTOR_KEY = InjectorKey(Injector)
517
+ _INJECTOR_INJECTOR_KEY: InjectorKey[Injector] = InjectorKey(Injector)
502
518
 
503
519
 
504
520
  class _Injector(Injector):
omlish/lite/journald.py CHANGED
@@ -29,7 +29,6 @@ sd_iovec._fields_ = [
29
29
  def sd_libsystemd() -> ta.Any:
30
30
  lib = ct.CDLL('libsystemd.so.0')
31
31
 
32
- lib.sd_journal_sendv = lib['sd_journal_sendv'] # type: ignore
33
32
  lib.sd_journal_sendv.restype = ct.c_int
34
33
  lib.sd_journal_sendv.argtypes = [ct.POINTER(sd_iovec), ct.c_int]
35
34
 
omlish/lite/runtime.py CHANGED
@@ -14,5 +14,4 @@ REQUIRED_PYTHON_VERSION = (3, 8)
14
14
 
15
15
  def check_runtime_version() -> None:
16
16
  if sys.version_info < REQUIRED_PYTHON_VERSION:
17
- raise OSError(
18
- f'Requires python {REQUIRED_PYTHON_VERSION}, got {sys.version_info} from {sys.executable}') # noqa
17
+ raise OSError(f'Requires python {REQUIRED_PYTHON_VERSION}, got {sys.version_info} from {sys.executable}') # noqa
omlish/lite/socket.py ADDED
@@ -0,0 +1,77 @@
1
+ # ruff: noqa: UP006 UP007
2
+ """
3
+ TODO:
4
+ - SocketClientAddress family / tuple pairs
5
+ + codification of https://docs.python.org/3/library/socket.html#socket-families
6
+ """
7
+ import abc
8
+ import dataclasses as dc
9
+ import socket
10
+ import typing as ta
11
+
12
+
13
+ SocketAddress = ta.Any
14
+
15
+
16
+ SocketHandlerFactory = ta.Callable[[SocketAddress, ta.BinaryIO, ta.BinaryIO], 'SocketHandler']
17
+
18
+
19
+ ##
20
+
21
+
22
+ @dc.dataclass(frozen=True)
23
+ class SocketAddressInfoArgs:
24
+ host: ta.Optional[str]
25
+ port: ta.Union[str, int, None]
26
+ family: socket.AddressFamily = socket.AddressFamily.AF_UNSPEC
27
+ type: int = 0
28
+ proto: int = 0
29
+ flags: socket.AddressInfo = socket.AddressInfo(0)
30
+
31
+
32
+ @dc.dataclass(frozen=True)
33
+ class SocketAddressInfo:
34
+ family: socket.AddressFamily
35
+ type: int
36
+ proto: int
37
+ canonname: ta.Optional[str]
38
+ sockaddr: SocketAddress
39
+
40
+
41
+ def get_best_socket_family(
42
+ host: ta.Optional[str],
43
+ port: ta.Union[str, int, None],
44
+ family: ta.Union[int, socket.AddressFamily] = socket.AddressFamily.AF_UNSPEC,
45
+ ) -> ta.Tuple[socket.AddressFamily, SocketAddress]:
46
+ """https://github.com/python/cpython/commit/f289084c83190cc72db4a70c58f007ec62e75247"""
47
+
48
+ infos = socket.getaddrinfo(
49
+ host,
50
+ port,
51
+ family,
52
+ type=socket.SOCK_STREAM,
53
+ flags=socket.AI_PASSIVE,
54
+ )
55
+ ai = SocketAddressInfo(*next(iter(infos)))
56
+ return ai.family, ai.sockaddr
57
+
58
+
59
+ ##
60
+
61
+
62
+ class SocketHandler(abc.ABC):
63
+ def __init__(
64
+ self,
65
+ client_address: SocketAddress,
66
+ rfile: ta.BinaryIO,
67
+ wfile: ta.BinaryIO,
68
+ ) -> None:
69
+ super().__init__()
70
+
71
+ self._client_address = client_address
72
+ self._rfile = rfile
73
+ self._wfile = wfile
74
+
75
+ @abc.abstractmethod
76
+ def handle(self) -> None:
77
+ raise NotImplementedError
@@ -0,0 +1,66 @@
1
+ # ruff: noqa: UP006 UP007
2
+ import socket
3
+ import socketserver
4
+ import typing as ta
5
+
6
+ from omlish.lite.check import check_not_none
7
+
8
+ from .socket import SocketAddress
9
+ from .socket import SocketHandlerFactory
10
+
11
+
12
+ ##
13
+
14
+
15
+ class SocketServerBaseRequestHandler_: # noqa
16
+ request: socket.socket
17
+ client_address: SocketAddress
18
+ server: socketserver.TCPServer
19
+
20
+
21
+ class SocketServerStreamRequestHandler_(SocketServerBaseRequestHandler_): # noqa
22
+ rbufsize: int
23
+ wbufsize: int
24
+
25
+ timeout: ta.Optional[float]
26
+
27
+ disable_nagle_algorithm: bool
28
+
29
+ connection: socket.socket
30
+ rfile: ta.BinaryIO
31
+ wfile: ta.BinaryIO
32
+
33
+
34
+ ##
35
+
36
+
37
+ class SocketHandlerSocketServerStreamRequestHandler( # type: ignore[misc]
38
+ socketserver.StreamRequestHandler,
39
+ SocketServerStreamRequestHandler_,
40
+ ):
41
+ socket_handler_factory: ta.Optional[SocketHandlerFactory] = None
42
+
43
+ def __init__(
44
+ self,
45
+ request: socket.socket,
46
+ client_address: SocketAddress,
47
+ server: socketserver.TCPServer,
48
+ *,
49
+ socket_handler_factory: ta.Optional[SocketHandlerFactory] = None,
50
+ ) -> None:
51
+ if socket_handler_factory is not None:
52
+ self.socket_handler_factory = socket_handler_factory
53
+
54
+ super().__init__(
55
+ request,
56
+ client_address,
57
+ server,
58
+ )
59
+
60
+ def handle(self) -> None:
61
+ target = check_not_none(self.socket_handler_factory)(
62
+ self.client_address,
63
+ self.rfile, # type: ignore[arg-type]
64
+ self.wfile, # type: ignore[arg-type]
65
+ )
66
+ target.handle()