omlish 0.0.0.dev123__py3-none-any.whl → 0.0.0.dev124__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,376 @@
1
+ # ruff: noqa: UP006 UP007
2
+ import abc
3
+ import http.client
4
+ import http.server
5
+ import io
6
+ import typing as ta
7
+
8
+ from .versions import HttpProtocolVersion
9
+ from .versions import HttpProtocolVersions
10
+
11
+
12
+ T = ta.TypeVar('T')
13
+
14
+
15
+ HttpHeaders = http.client.HTTPMessage # ta.TypeAlias
16
+
17
+
18
+ ##
19
+
20
+
21
+ class ParseHttpRequestResult(abc.ABC): # noqa
22
+ __slots__ = (
23
+ 'server_version',
24
+ 'request_line',
25
+ 'request_version',
26
+ 'version',
27
+ 'headers',
28
+ 'close_connection',
29
+ )
30
+
31
+ def __init__(
32
+ self,
33
+ *,
34
+ server_version: HttpProtocolVersion,
35
+ request_line: str,
36
+ request_version: HttpProtocolVersion,
37
+ version: HttpProtocolVersion,
38
+ headers: ta.Optional[HttpHeaders],
39
+ close_connection: bool,
40
+ ) -> None:
41
+ super().__init__()
42
+
43
+ self.server_version = server_version
44
+ self.request_line = request_line
45
+ self.request_version = request_version
46
+ self.version = version
47
+ self.headers = headers
48
+ self.close_connection = close_connection
49
+
50
+ def __repr__(self) -> str:
51
+ return f'{self.__class__.__name__}({", ".join(f"{a}={getattr(self, a)!r}" for a in self.__slots__)})'
52
+
53
+
54
+ class EmptyParsedHttpResult(ParseHttpRequestResult):
55
+ pass
56
+
57
+
58
+ class ParseHttpRequestError(ParseHttpRequestResult):
59
+ __slots__ = (
60
+ 'code',
61
+ 'message',
62
+ *ParseHttpRequestResult.__slots__,
63
+ )
64
+
65
+ def __init__(
66
+ self,
67
+ *,
68
+ code: http.HTTPStatus,
69
+ message: ta.Union[str, ta.Tuple[str, str]],
70
+
71
+ **kwargs: ta.Any,
72
+ ) -> None:
73
+ super().__init__(**kwargs)
74
+
75
+ self.code = code
76
+ self.message = message
77
+
78
+
79
+ class ParsedHttpRequest(ParseHttpRequestResult):
80
+ __slots__ = (
81
+ 'method',
82
+ 'path',
83
+ 'headers',
84
+ 'expects_continue',
85
+ *[a for a in ParseHttpRequestResult.__slots__ if a != 'headers'],
86
+ )
87
+
88
+ def __init__(
89
+ self,
90
+ *,
91
+ method: str,
92
+ path: str,
93
+ headers: HttpHeaders,
94
+ expects_continue: bool,
95
+
96
+ **kwargs: ta.Any,
97
+ ) -> None:
98
+ super().__init__(
99
+ headers=headers,
100
+ **kwargs,
101
+ )
102
+
103
+ self.method = method
104
+ self.path = path
105
+ self.expects_continue = expects_continue
106
+
107
+ headers: HttpHeaders
108
+
109
+
110
+ #
111
+
112
+
113
+ class HttpRequestParser:
114
+ DEFAULT_SERVER_VERSION = HttpProtocolVersions.HTTP_1_0
115
+
116
+ # The default request version. This only affects responses up until the point where the request line is parsed, so
117
+ # it mainly decides what the client gets back when sending a malformed request line.
118
+ # Most web servers default to HTTP 0.9, i.e. don't send a status line.
119
+ DEFAULT_REQUEST_VERSION = HttpProtocolVersions.HTTP_0_9
120
+
121
+ #
122
+
123
+ DEFAULT_MAX_LINE: int = 0x10000
124
+ DEFAULT_MAX_HEADERS: int = 100
125
+
126
+ #
127
+
128
+ def __init__(
129
+ self,
130
+ *,
131
+ server_version: HttpProtocolVersion = DEFAULT_SERVER_VERSION,
132
+
133
+ max_line: int = DEFAULT_MAX_LINE,
134
+ max_headers: int = DEFAULT_MAX_HEADERS,
135
+ ) -> None:
136
+ super().__init__()
137
+
138
+ if server_version >= HttpProtocolVersions.HTTP_2_0:
139
+ raise ValueError(f'Unsupported protocol version: {server_version}')
140
+ self._server_version = server_version
141
+
142
+ self._max_line = max_line
143
+ self._max_headers = max_headers
144
+
145
+ #
146
+
147
+ @property
148
+ def server_version(self) -> HttpProtocolVersion:
149
+ return self._server_version
150
+
151
+ #
152
+
153
+ def _run_read_line_coro(
154
+ self,
155
+ gen: ta.Generator[int, bytes, T],
156
+ read_line: ta.Callable[[int], bytes],
157
+ ) -> T:
158
+ sz = next(gen)
159
+ while True:
160
+ try:
161
+ sz = gen.send(read_line(sz))
162
+ except StopIteration as e:
163
+ return e.value
164
+
165
+ #
166
+
167
+ def parse_request_version(self, version_str: str) -> HttpProtocolVersion:
168
+ if not version_str.startswith('HTTP/'):
169
+ raise ValueError(version_str) # noqa
170
+
171
+ base_version_number = version_str.split('/', 1)[1]
172
+ version_number_parts = base_version_number.split('.')
173
+
174
+ # RFC 2145 section 3.1 says there can be only one "." and
175
+ # - major and minor numbers MUST be treated as separate integers;
176
+ # - HTTP/2.4 is a lower version than HTTP/2.13, which in turn is lower than HTTP/12.3;
177
+ # - Leading zeros MUST be ignored by recipients.
178
+ if len(version_number_parts) != 2:
179
+ raise ValueError(version_number_parts) # noqa
180
+ if any(not component.isdigit() for component in version_number_parts):
181
+ raise ValueError('non digit in http version') # noqa
182
+ if any(len(component) > 10 for component in version_number_parts):
183
+ raise ValueError('unreasonable length http version') # noqa
184
+
185
+ return HttpProtocolVersion(
186
+ int(version_number_parts[0]),
187
+ int(version_number_parts[1]),
188
+ )
189
+
190
+ #
191
+
192
+ def coro_read_raw_headers(self) -> ta.Generator[int, bytes, ta.List[bytes]]:
193
+ raw_headers: ta.List[bytes] = []
194
+ while True:
195
+ line = yield self._max_line + 1
196
+ if len(line) > self._max_line:
197
+ raise http.client.LineTooLong('header line')
198
+ raw_headers.append(line)
199
+ if len(raw_headers) > self._max_headers:
200
+ raise http.client.HTTPException(f'got more than {self._max_headers} headers')
201
+ if line in (b'\r\n', b'\n', b''):
202
+ break
203
+ return raw_headers
204
+
205
+ def read_raw_headers(self, read_line: ta.Callable[[int], bytes]) -> ta.List[bytes]:
206
+ return self._run_read_line_coro(self.coro_read_raw_headers(), read_line)
207
+
208
+ def parse_raw_headers(self, raw_headers: ta.Sequence[bytes]) -> HttpHeaders:
209
+ return http.client.parse_headers(io.BytesIO(b''.join(raw_headers)))
210
+
211
+ #
212
+
213
+ def coro_parse(self) -> ta.Generator[int, bytes, ParseHttpRequestResult]:
214
+ raw_request_line = yield self._max_line + 1
215
+
216
+ # Common result kwargs
217
+
218
+ request_line = '-'
219
+ request_version = self.DEFAULT_REQUEST_VERSION
220
+
221
+ # Set to min(server, request) when it gets that far, but if it fails before that the server authoritatively
222
+ # responds with its own version.
223
+ version = self._server_version
224
+
225
+ headers: HttpHeaders | None = None
226
+
227
+ close_connection = True
228
+
229
+ def result_kwargs():
230
+ return dict(
231
+ server_version=self._server_version,
232
+ request_line=request_line,
233
+ request_version=request_version,
234
+ version=version,
235
+ headers=headers,
236
+ close_connection=close_connection,
237
+ )
238
+
239
+ # Decode line
240
+
241
+ if len(raw_request_line) > self._max_line:
242
+ return ParseHttpRequestError(
243
+ code=http.HTTPStatus.REQUEST_URI_TOO_LONG,
244
+ message='Request line too long',
245
+ **result_kwargs(),
246
+ )
247
+
248
+ if not raw_request_line:
249
+ return EmptyParsedHttpResult(**result_kwargs())
250
+
251
+ request_line = raw_request_line.decode('iso-8859-1').rstrip('\r\n')
252
+
253
+ # Split words
254
+
255
+ words = request_line.split()
256
+ if len(words) == 0:
257
+ return EmptyParsedHttpResult(**result_kwargs())
258
+
259
+ # Parse and set version
260
+
261
+ if len(words) >= 3: # Enough to determine protocol version
262
+ version_str = words[-1]
263
+ try:
264
+ request_version = self.parse_request_version(version_str)
265
+
266
+ except (ValueError, IndexError):
267
+ return ParseHttpRequestError(
268
+ code=http.HTTPStatus.BAD_REQUEST,
269
+ message=f'Bad request version ({version_str!r})',
270
+ **result_kwargs(),
271
+ )
272
+
273
+ if (
274
+ request_version < HttpProtocolVersions.HTTP_0_9 or
275
+ request_version >= HttpProtocolVersions.HTTP_2_0
276
+ ):
277
+ return ParseHttpRequestError(
278
+ code=http.HTTPStatus.HTTP_VERSION_NOT_SUPPORTED,
279
+ message=f'Invalid HTTP version ({version_str})',
280
+ **result_kwargs(),
281
+ )
282
+
283
+ version = min([self._server_version, request_version])
284
+
285
+ if version >= HttpProtocolVersions.HTTP_1_1:
286
+ close_connection = False
287
+
288
+ # Verify word count
289
+
290
+ if not 2 <= len(words) <= 3:
291
+ return ParseHttpRequestError(
292
+ code=http.HTTPStatus.BAD_REQUEST,
293
+ message=f'Bad request syntax ({request_line!r})',
294
+ **result_kwargs(),
295
+ )
296
+
297
+ # Parse method and path
298
+
299
+ method, path = words[:2]
300
+ if len(words) == 2:
301
+ close_connection = True
302
+ if method != 'GET':
303
+ return ParseHttpRequestError(
304
+ code=http.HTTPStatus.BAD_REQUEST,
305
+ message=f'Bad HTTP/0.9 request type ({method!r})',
306
+ **result_kwargs(),
307
+ )
308
+
309
+ # gh-87389: The purpose of replacing '//' with '/' is to protect against open redirect attacks possibly
310
+ # triggered if the path starts with '//' because http clients treat //path as an absolute URI without scheme
311
+ # (similar to http://path) rather than a path.
312
+ if path.startswith('//'):
313
+ path = '/' + path.lstrip('/') # Reduce to a single /
314
+
315
+ # Parse headers
316
+
317
+ try:
318
+ raw_gen = self.coro_read_raw_headers()
319
+ raw_sz = next(raw_gen)
320
+ while True:
321
+ buf = yield raw_sz
322
+ try:
323
+ raw_sz = raw_gen.send(buf)
324
+ except StopIteration as e:
325
+ raw_headers = e.value
326
+ break
327
+
328
+ headers = self.parse_raw_headers(raw_headers)
329
+
330
+ except http.client.LineTooLong as err:
331
+ return ParseHttpRequestError(
332
+ code=http.HTTPStatus.REQUEST_HEADER_FIELDS_TOO_LARGE,
333
+ message=('Line too long', str(err)),
334
+ **result_kwargs(),
335
+ )
336
+
337
+ except http.client.HTTPException as err:
338
+ return ParseHttpRequestError(
339
+ code=http.HTTPStatus.REQUEST_HEADER_FIELDS_TOO_LARGE,
340
+ message=('Too many headers', str(err)),
341
+ **result_kwargs(),
342
+ )
343
+
344
+ # Check for connection directive
345
+
346
+ conn_type = headers.get('Connection', '')
347
+ if conn_type.lower() == 'close':
348
+ close_connection = True
349
+ elif (
350
+ conn_type.lower() == 'keep-alive' and
351
+ version >= HttpProtocolVersions.HTTP_1_1
352
+ ):
353
+ close_connection = False
354
+
355
+ # Check for expect directive
356
+
357
+ expect = headers.get('Expect', '')
358
+ if (
359
+ expect.lower() == '100-continue' and
360
+ version >= HttpProtocolVersions.HTTP_1_1
361
+ ):
362
+ expects_continue = True
363
+ else:
364
+ expects_continue = False
365
+
366
+ # Return
367
+
368
+ return ParsedHttpRequest(
369
+ method=method,
370
+ path=path,
371
+ expects_continue=expects_continue,
372
+ **result_kwargs(),
373
+ )
374
+
375
+ def parse(self, read_line: ta.Callable[[int], bytes]) -> ParseHttpRequestResult:
376
+ return self._run_read_line_coro(self.coro_parse(), read_line)
@@ -0,0 +1,17 @@
1
+ # ruff: noqa: UP006 UP007
2
+ import typing as ta
3
+
4
+
5
+ class HttpProtocolVersion(ta.NamedTuple):
6
+ major: int
7
+ minor: int
8
+
9
+ def __str__(self) -> str:
10
+ return f'HTTP/{self.major}.{self.minor}'
11
+
12
+
13
+ class HttpProtocolVersions:
14
+ HTTP_0_9 = HttpProtocolVersion(0, 9)
15
+ HTTP_1_0 = HttpProtocolVersion(1, 0)
16
+ HTTP_1_1 = HttpProtocolVersion(1, 1)
17
+ HTTP_2_0 = HttpProtocolVersion(2, 0)
omlish/lite/inject.py CHANGED
@@ -30,12 +30,22 @@ InjectorBindingOrBindings = ta.Union['InjectorBinding', 'InjectorBindings']
30
30
 
31
31
 
32
32
  @dc.dataclass(frozen=True)
33
- class InjectorKey:
33
+ class InjectorKey(ta.Generic[T]):
34
34
  cls: InjectorKeyCls
35
35
  tag: ta.Any = None
36
36
  array: bool = False
37
37
 
38
38
 
39
+ def is_valid_injector_key_cls(cls: ta.Any) -> bool:
40
+ return isinstance(cls, type) or is_new_type(cls)
41
+
42
+
43
+ def check_valid_injector_key_cls(cls: T) -> T:
44
+ if not is_valid_injector_key_cls(cls):
45
+ raise TypeError(cls)
46
+ return cls
47
+
48
+
39
49
  ##
40
50
 
41
51
 
@@ -79,6 +89,12 @@ class Injector(abc.ABC):
79
89
  def inject(self, obj: ta.Any) -> ta.Any:
80
90
  raise NotImplementedError
81
91
 
92
+ def __getitem__(
93
+ self,
94
+ target: ta.Union[InjectorKey[T], ta.Type[T]],
95
+ ) -> T:
96
+ return self.provide(target)
97
+
82
98
 
83
99
  ###
84
100
  # exceptions
@@ -111,7 +127,7 @@ def as_injector_key(o: ta.Any) -> InjectorKey:
111
127
  raise TypeError(o)
112
128
  if isinstance(o, InjectorKey):
113
129
  return o
114
- if isinstance(o, type) or is_new_type(o):
130
+ if is_valid_injector_key_cls(o):
115
131
  return InjectorKey(o)
116
132
  raise TypeError(o)
117
133
 
@@ -443,8 +459,8 @@ class InjectorBinder:
443
459
  to_fn = obj
444
460
  if key is None:
445
461
  sig = _injection_signature(obj)
446
- ty = check_isinstance(sig.return_annotation, type)
447
- key = InjectorKey(ty)
462
+ key_cls = check_valid_injector_key_cls(sig.return_annotation)
463
+ key = InjectorKey(key_cls)
448
464
  else:
449
465
  if to_const is not None:
450
466
  raise TypeError('Cannot bind instance with to_const')
@@ -498,7 +514,7 @@ class InjectorBinder:
498
514
  # injector
499
515
 
500
516
 
501
- _INJECTOR_INJECTOR_KEY = InjectorKey(Injector)
517
+ _INJECTOR_INJECTOR_KEY: InjectorKey[Injector] = InjectorKey(Injector)
502
518
 
503
519
 
504
520
  class _Injector(Injector):
omlish/lite/journald.py CHANGED
@@ -29,7 +29,6 @@ sd_iovec._fields_ = [
29
29
  def sd_libsystemd() -> ta.Any:
30
30
  lib = ct.CDLL('libsystemd.so.0')
31
31
 
32
- lib.sd_journal_sendv = lib['sd_journal_sendv'] # type: ignore
33
32
  lib.sd_journal_sendv.restype = ct.c_int
34
33
  lib.sd_journal_sendv.argtypes = [ct.POINTER(sd_iovec), ct.c_int]
35
34
 
omlish/lite/runtime.py CHANGED
@@ -14,5 +14,4 @@ REQUIRED_PYTHON_VERSION = (3, 8)
14
14
 
15
15
  def check_runtime_version() -> None:
16
16
  if sys.version_info < REQUIRED_PYTHON_VERSION:
17
- raise OSError(
18
- f'Requires python {REQUIRED_PYTHON_VERSION}, got {sys.version_info} from {sys.executable}') # noqa
17
+ raise OSError(f'Requires python {REQUIRED_PYTHON_VERSION}, got {sys.version_info} from {sys.executable}') # noqa
omlish/lite/socket.py ADDED
@@ -0,0 +1,77 @@
1
+ # ruff: noqa: UP006 UP007
2
+ """
3
+ TODO:
4
+ - SocketClientAddress family / tuple pairs
5
+ + codification of https://docs.python.org/3/library/socket.html#socket-families
6
+ """
7
+ import abc
8
+ import dataclasses as dc
9
+ import socket
10
+ import typing as ta
11
+
12
+
13
+ SocketAddress = ta.Any
14
+
15
+
16
+ SocketHandlerFactory = ta.Callable[[SocketAddress, ta.BinaryIO, ta.BinaryIO], 'SocketHandler']
17
+
18
+
19
+ ##
20
+
21
+
22
+ @dc.dataclass(frozen=True)
23
+ class SocketAddressInfoArgs:
24
+ host: ta.Optional[str]
25
+ port: ta.Union[str, int, None]
26
+ family: socket.AddressFamily = socket.AddressFamily.AF_UNSPEC
27
+ type: int = 0
28
+ proto: int = 0
29
+ flags: socket.AddressInfo = socket.AddressInfo(0)
30
+
31
+
32
+ @dc.dataclass(frozen=True)
33
+ class SocketAddressInfo:
34
+ family: socket.AddressFamily
35
+ type: int
36
+ proto: int
37
+ canonname: ta.Optional[str]
38
+ sockaddr: SocketAddress
39
+
40
+
41
+ def get_best_socket_family(
42
+ host: ta.Optional[str],
43
+ port: ta.Union[str, int, None],
44
+ family: ta.Union[int, socket.AddressFamily] = socket.AddressFamily.AF_UNSPEC,
45
+ ) -> ta.Tuple[socket.AddressFamily, SocketAddress]:
46
+ """https://github.com/python/cpython/commit/f289084c83190cc72db4a70c58f007ec62e75247"""
47
+
48
+ infos = socket.getaddrinfo(
49
+ host,
50
+ port,
51
+ family,
52
+ type=socket.SOCK_STREAM,
53
+ flags=socket.AI_PASSIVE,
54
+ )
55
+ ai = SocketAddressInfo(*next(iter(infos)))
56
+ return ai.family, ai.sockaddr
57
+
58
+
59
+ ##
60
+
61
+
62
+ class SocketHandler(abc.ABC):
63
+ def __init__(
64
+ self,
65
+ client_address: SocketAddress,
66
+ rfile: ta.BinaryIO,
67
+ wfile: ta.BinaryIO,
68
+ ) -> None:
69
+ super().__init__()
70
+
71
+ self._client_address = client_address
72
+ self._rfile = rfile
73
+ self._wfile = wfile
74
+
75
+ @abc.abstractmethod
76
+ def handle(self) -> None:
77
+ raise NotImplementedError
@@ -0,0 +1,66 @@
1
+ # ruff: noqa: UP006 UP007
2
+ import socket
3
+ import socketserver
4
+ import typing as ta
5
+
6
+ from omlish.lite.check import check_not_none
7
+
8
+ from .socket import SocketAddress
9
+ from .socket import SocketHandlerFactory
10
+
11
+
12
+ ##
13
+
14
+
15
+ class SocketServerBaseRequestHandler_: # noqa
16
+ request: socket.socket
17
+ client_address: SocketAddress
18
+ server: socketserver.TCPServer
19
+
20
+
21
+ class SocketServerStreamRequestHandler_(SocketServerBaseRequestHandler_): # noqa
22
+ rbufsize: int
23
+ wbufsize: int
24
+
25
+ timeout: ta.Optional[float]
26
+
27
+ disable_nagle_algorithm: bool
28
+
29
+ connection: socket.socket
30
+ rfile: ta.BinaryIO
31
+ wfile: ta.BinaryIO
32
+
33
+
34
+ ##
35
+
36
+
37
+ class SocketHandlerSocketServerStreamRequestHandler( # type: ignore[misc]
38
+ socketserver.StreamRequestHandler,
39
+ SocketServerStreamRequestHandler_,
40
+ ):
41
+ socket_handler_factory: ta.Optional[SocketHandlerFactory] = None
42
+
43
+ def __init__(
44
+ self,
45
+ request: socket.socket,
46
+ client_address: SocketAddress,
47
+ server: socketserver.TCPServer,
48
+ *,
49
+ socket_handler_factory: ta.Optional[SocketHandlerFactory] = None,
50
+ ) -> None:
51
+ if socket_handler_factory is not None:
52
+ self.socket_handler_factory = socket_handler_factory
53
+
54
+ super().__init__(
55
+ request,
56
+ client_address,
57
+ server,
58
+ )
59
+
60
+ def handle(self) -> None:
61
+ target = check_not_none(self.socket_handler_factory)(
62
+ self.client_address,
63
+ self.rfile, # type: ignore[arg-type]
64
+ self.wfile, # type: ignore[arg-type]
65
+ )
66
+ target.handle()