omserv 0.0.0.dev7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- omserv/__about__.py +28 -0
- omserv/__init__.py +0 -0
- omserv/apps/__init__.py +0 -0
- omserv/apps/base.py +23 -0
- omserv/apps/inject.py +89 -0
- omserv/apps/markers.py +41 -0
- omserv/apps/routes.py +139 -0
- omserv/apps/sessions.py +57 -0
- omserv/apps/templates.py +90 -0
- omserv/dbs.py +24 -0
- omserv/node/__init__.py +0 -0
- omserv/node/models.py +53 -0
- omserv/node/registry.py +124 -0
- omserv/node/sql.py +131 -0
- omserv/secrets.py +12 -0
- omserv/server/__init__.py +18 -0
- omserv/server/config.py +51 -0
- omserv/server/debug.py +14 -0
- omserv/server/events.py +83 -0
- omserv/server/headers.py +36 -0
- omserv/server/lifespans.py +132 -0
- omserv/server/multiprocess.py +157 -0
- omserv/server/protocols/__init__.py +1 -0
- omserv/server/protocols/h11.py +334 -0
- omserv/server/protocols/h2.py +407 -0
- omserv/server/protocols/protocols.py +91 -0
- omserv/server/protocols/types.py +18 -0
- omserv/server/resources/__init__.py +8 -0
- omserv/server/sockets.py +111 -0
- omserv/server/ssl.py +47 -0
- omserv/server/streams/__init__.py +0 -0
- omserv/server/streams/httpstream.py +237 -0
- omserv/server/streams/utils.py +53 -0
- omserv/server/streams/wsstream.py +447 -0
- omserv/server/taskspawner.py +111 -0
- omserv/server/tcpserver.py +173 -0
- omserv/server/types.py +94 -0
- omserv/server/workercontext.py +52 -0
- omserv/server/workers.py +193 -0
- omserv-0.0.0.dev7.dist-info/LICENSE +21 -0
- omserv-0.0.0.dev7.dist-info/METADATA +21 -0
- omserv-0.0.0.dev7.dist-info/RECORD +44 -0
- omserv-0.0.0.dev7.dist-info/WHEEL +5 -0
- omserv-0.0.0.dev7.dist-info/top_level.txt +1 -0
omserv/node/sql.py
ADDED
@@ -0,0 +1,131 @@
|
|
1
|
+
"""
|
2
|
+
TODO:
|
3
|
+
- move to omlish.sql, pg-specific trigger compiler - https://docs.sqlalchemy.org/en/20/core/compiler.html
|
4
|
+
"""
|
5
|
+
import textwrap
|
6
|
+
import typing as ta
|
7
|
+
|
8
|
+
import sqlalchemy as sa
|
9
|
+
import sqlalchemy.ext.compiler
|
10
|
+
|
11
|
+
|
12
|
+
##
|
13
|
+
|
14
|
+
|
15
|
+
class IdMixin:
|
16
|
+
_id = sa.Column(
|
17
|
+
sa.Integer,
|
18
|
+
nullable=False,
|
19
|
+
primary_key=True,
|
20
|
+
autoincrement=True,
|
21
|
+
)
|
22
|
+
|
23
|
+
|
24
|
+
##
|
25
|
+
|
26
|
+
|
27
|
+
class utcnow(sa.sql.expression.FunctionElement): # noqa
|
28
|
+
inherit_cache = True
|
29
|
+
type = sa.TIMESTAMP()
|
30
|
+
|
31
|
+
|
32
|
+
@sa.ext.compiler.compiles(utcnow)
|
33
|
+
def _compile_utcnow(
|
34
|
+
element: utcnow,
|
35
|
+
compiler: sa.sql.compiler.SQLCompiler,
|
36
|
+
**kw: ta.Any,
|
37
|
+
) -> str:
|
38
|
+
return "timezone('utc', now())"
|
39
|
+
|
40
|
+
|
41
|
+
##
|
42
|
+
|
43
|
+
|
44
|
+
class TimestampsMixin:
|
45
|
+
created_at = sa.Column(
|
46
|
+
sa.TIMESTAMP(timezone=True),
|
47
|
+
server_default=utcnow(),
|
48
|
+
nullable=False,
|
49
|
+
)
|
50
|
+
|
51
|
+
updated_at = sa.Column(
|
52
|
+
sa.TIMESTAMP(timezone=True),
|
53
|
+
server_default=utcnow(),
|
54
|
+
server_onupdate=sa.schema.FetchedValue(for_update=True),
|
55
|
+
nullable=False,
|
56
|
+
)
|
57
|
+
|
58
|
+
|
59
|
+
##
|
60
|
+
|
61
|
+
|
62
|
+
CREATE_UPDATED_AT_FUNCTION_STATEMENT = textwrap.dedent("""
|
63
|
+
create or replace function set_updated_at_timestamp()
|
64
|
+
returns trigger as $$
|
65
|
+
begin
|
66
|
+
new.updated_at = now() at time zone 'utc';
|
67
|
+
return new;
|
68
|
+
end;
|
69
|
+
$$ language 'plpgsql';
|
70
|
+
""")
|
71
|
+
|
72
|
+
|
73
|
+
##
|
74
|
+
|
75
|
+
|
76
|
+
def get_update_at_trigger_name(table_name: str) -> str:
|
77
|
+
return f'trigger__updated_at_{table_name}'
|
78
|
+
|
79
|
+
|
80
|
+
#
|
81
|
+
|
82
|
+
|
83
|
+
class CreateUpdateAtTrigger(sa.schema.DDLElement):
|
84
|
+
inherit_cache = False
|
85
|
+
|
86
|
+
def __init__(self, table_name: str) -> None:
|
87
|
+
super().__init__()
|
88
|
+
self.table_name = table_name
|
89
|
+
|
90
|
+
|
91
|
+
@sa.ext.compiler.compiles(CreateUpdateAtTrigger)
|
92
|
+
def _compile_create_update_at_trigger(
|
93
|
+
element: CreateUpdateAtTrigger,
|
94
|
+
compiler: sa.sql.compiler.SQLCompiler,
|
95
|
+
**kw: ta.Any,
|
96
|
+
):
|
97
|
+
return textwrap.dedent(f"""
|
98
|
+
create or replace trigger {get_update_at_trigger_name(element.table_name)}
|
99
|
+
before update
|
100
|
+
on {element.table_name}
|
101
|
+
for each row
|
102
|
+
execute procedure set_updated_at_timestamp()
|
103
|
+
""")
|
104
|
+
|
105
|
+
|
106
|
+
#
|
107
|
+
|
108
|
+
|
109
|
+
class DropUpdateAtTrigger(sa.schema.DDLElement):
|
110
|
+
inherit_cache = False
|
111
|
+
|
112
|
+
def __init__(self, table_name: str) -> None:
|
113
|
+
super().__init__()
|
114
|
+
self.table_name = table_name
|
115
|
+
|
116
|
+
|
117
|
+
@sa.ext.compiler.compiles(DropUpdateAtTrigger)
|
118
|
+
def _compile_drop_update_at_trigger(
|
119
|
+
element: DropUpdateAtTrigger,
|
120
|
+
compiler: sa.sql.compiler.SQLCompiler,
|
121
|
+
**kw: ta.Any,
|
122
|
+
) -> str:
|
123
|
+
return f'drop trigger if exists {get_update_at_trigger_name(element.table_name)} on {element.table_name}'
|
124
|
+
|
125
|
+
|
126
|
+
#
|
127
|
+
|
128
|
+
|
129
|
+
def install_updated_at_trigger(metadata: sa.MetaData, table_name: str) -> None:
|
130
|
+
sa.event.listen(metadata, 'after_create', CreateUpdateAtTrigger(table_name))
|
131
|
+
sa.event.listen(metadata, 'before_drop', DropUpdateAtTrigger(table_name))
|
omserv/secrets.py
ADDED
@@ -0,0 +1,12 @@
|
|
1
|
+
import os.path
|
2
|
+
import typing as ta
|
3
|
+
|
4
|
+
import yaml
|
5
|
+
|
6
|
+
|
7
|
+
SECRETS_PATH = os.getenv('SECRETS_PATH', os.path.expanduser('~/Dropbox/.dotfiles/secrets.yml'))
|
8
|
+
|
9
|
+
|
10
|
+
def load_secrets() -> dict[str, ta.Any]:
|
11
|
+
with open(SECRETS_PATH) as f:
|
12
|
+
return yaml.safe_load(f)
|
@@ -0,0 +1,18 @@
|
|
1
|
+
"""
|
2
|
+
Based on https://github.com/pgjones/hypercorn
|
3
|
+
|
4
|
+
TODO:
|
5
|
+
- !!! error handling jfc
|
6
|
+
- add ssl back lol
|
7
|
+
- events as dc's
|
8
|
+
- injectify
|
9
|
+
- lifecycle / otp-ify
|
10
|
+
- configify
|
11
|
+
|
12
|
+
Lookit:
|
13
|
+
- https://github.com/davidbrochart/anycorn
|
14
|
+
- https://github.com/encode/starlette
|
15
|
+
- https://github.com/tiangolo/fastapi
|
16
|
+
"""
|
17
|
+
from .config import Config # noqa
|
18
|
+
from .workers import serve # noqa
|
omserv/server/config.py
ADDED
@@ -0,0 +1,51 @@
|
|
1
|
+
import dataclasses as dc
|
2
|
+
import typing as ta
|
3
|
+
|
4
|
+
|
5
|
+
BYTES = 1
|
6
|
+
OCTETS = 1
|
7
|
+
SECONDS = 1.0
|
8
|
+
|
9
|
+
|
10
|
+
@dc.dataclass(frozen=True, kw_only=True)
|
11
|
+
class Config:
|
12
|
+
bind: ta.Sequence[str] = ('127.0.0.1:8000',)
|
13
|
+
|
14
|
+
umask: int | None = None
|
15
|
+
user: int | None = None
|
16
|
+
group: int | None = None
|
17
|
+
|
18
|
+
workers: int = 0
|
19
|
+
|
20
|
+
max_app_queue_size: int = 10
|
21
|
+
|
22
|
+
startup_timeout = 60 * SECONDS
|
23
|
+
shutdown_timeout = 60 * SECONDS
|
24
|
+
|
25
|
+
server_names: ta.Sequence[str] = ()
|
26
|
+
|
27
|
+
max_requests: int | None = None
|
28
|
+
max_requests_jitter: int = 0
|
29
|
+
|
30
|
+
backlog: int = 100
|
31
|
+
|
32
|
+
graceful_timeout: float = 3 * SECONDS
|
33
|
+
|
34
|
+
keep_alive_timeout: float = 5 * SECONDS
|
35
|
+
keep_alive_max_requests: int = 1000
|
36
|
+
|
37
|
+
read_timeout: int | None = None
|
38
|
+
|
39
|
+
h11_max_incomplete_size: int = 16 * 1024 * BYTES
|
40
|
+
h11_pass_raw_headers: bool = False
|
41
|
+
|
42
|
+
h2_max_concurrent_streams: int = 100
|
43
|
+
h2_max_header_list_size: int = 2 ** 16
|
44
|
+
h2_max_inbound_frame_size: int = 2 ** 14 * OCTETS
|
45
|
+
|
46
|
+
include_date_header: bool = True
|
47
|
+
include_server_header: bool = True
|
48
|
+
alt_svc_headers: list[str] = dc.field(default_factory=list)
|
49
|
+
|
50
|
+
websocket_max_message_size: int = 16 * 1024 * 1024 * BYTES
|
51
|
+
websocket_ping_interval: float | None = None
|
omserv/server/debug.py
ADDED
@@ -0,0 +1,14 @@
|
|
1
|
+
import logging
|
2
|
+
import sys
|
3
|
+
|
4
|
+
from omlish.diag import pydevd as pdu
|
5
|
+
|
6
|
+
|
7
|
+
log = logging.getLogger(__name__)
|
8
|
+
|
9
|
+
|
10
|
+
def handle_error_debug(e: BaseException) -> None:
|
11
|
+
exc_info = sys.exc_info()
|
12
|
+
log.warning('Launching debugger')
|
13
|
+
pdu.debug_unhandled_exception(exc_info)
|
14
|
+
log.warning('Done debugging')
|
omserv/server/events.py
ADDED
@@ -0,0 +1,83 @@
|
|
1
|
+
import dataclasses as dc
|
2
|
+
|
3
|
+
from omlish import lang
|
4
|
+
|
5
|
+
|
6
|
+
##
|
7
|
+
|
8
|
+
|
9
|
+
class ServerEvent(lang.Abstract):
|
10
|
+
pass
|
11
|
+
|
12
|
+
|
13
|
+
@dc.dataclass(frozen=True)
|
14
|
+
class RawData(ServerEvent):
|
15
|
+
data: bytes
|
16
|
+
address: tuple[str, int] | None = None
|
17
|
+
|
18
|
+
|
19
|
+
@dc.dataclass(frozen=True)
|
20
|
+
class Closed(ServerEvent):
|
21
|
+
pass
|
22
|
+
|
23
|
+
|
24
|
+
@dc.dataclass(frozen=True)
|
25
|
+
class Updated(ServerEvent):
|
26
|
+
idle: bool
|
27
|
+
|
28
|
+
|
29
|
+
##
|
30
|
+
|
31
|
+
|
32
|
+
@dc.dataclass(frozen=True)
|
33
|
+
class ProtocolEvent:
|
34
|
+
stream_id: int
|
35
|
+
|
36
|
+
|
37
|
+
@dc.dataclass(frozen=True)
|
38
|
+
class Request(ProtocolEvent):
|
39
|
+
headers: list[tuple[bytes, bytes]]
|
40
|
+
http_version: str
|
41
|
+
method: str
|
42
|
+
raw_path: bytes
|
43
|
+
|
44
|
+
|
45
|
+
@dc.dataclass(frozen=True)
|
46
|
+
class Body(ProtocolEvent):
|
47
|
+
data: bytes
|
48
|
+
|
49
|
+
|
50
|
+
@dc.dataclass(frozen=True)
|
51
|
+
class EndBody(ProtocolEvent):
|
52
|
+
pass
|
53
|
+
|
54
|
+
|
55
|
+
@dc.dataclass(frozen=True)
|
56
|
+
class Data(ProtocolEvent):
|
57
|
+
data: bytes
|
58
|
+
|
59
|
+
|
60
|
+
@dc.dataclass(frozen=True)
|
61
|
+
class EndData(ProtocolEvent):
|
62
|
+
pass
|
63
|
+
|
64
|
+
|
65
|
+
@dc.dataclass(frozen=True)
|
66
|
+
class Response(ProtocolEvent):
|
67
|
+
headers: list[tuple[bytes, bytes]]
|
68
|
+
status_code: int
|
69
|
+
|
70
|
+
|
71
|
+
@dc.dataclass(frozen=True)
|
72
|
+
class InformationalResponse(ProtocolEvent):
|
73
|
+
headers: list[tuple[bytes, bytes]]
|
74
|
+
status_code: int
|
75
|
+
|
76
|
+
def __post_init__(self) -> None:
|
77
|
+
if self.status_code >= 200 or self.status_code < 100:
|
78
|
+
raise ValueError(f'Status code must be 1XX not {self.status_code}')
|
79
|
+
|
80
|
+
|
81
|
+
@dc.dataclass(frozen=True)
|
82
|
+
class StreamClosed(ProtocolEvent):
|
83
|
+
pass
|
omserv/server/headers.py
ADDED
@@ -0,0 +1,36 @@
|
|
1
|
+
import time
|
2
|
+
import wsgiref.handlers
|
3
|
+
|
4
|
+
from .config import Config
|
5
|
+
|
6
|
+
|
7
|
+
def _now() -> float:
|
8
|
+
return time.time()
|
9
|
+
|
10
|
+
|
11
|
+
def response_headers(config: Config, protocol: str) -> list[tuple[bytes, bytes]]:
|
12
|
+
headers = []
|
13
|
+
if config.include_date_header:
|
14
|
+
headers.append((b'date', wsgiref.handlers.format_date_time(_now()).encode('ascii')))
|
15
|
+
if config.include_server_header:
|
16
|
+
headers.append((b'server', f'omlicorn-{protocol}'.encode('ascii')))
|
17
|
+
|
18
|
+
for alt_svc_header in config.alt_svc_headers:
|
19
|
+
headers.append((b'alt-svc', alt_svc_header.encode()))
|
20
|
+
|
21
|
+
return headers
|
22
|
+
|
23
|
+
|
24
|
+
def filter_pseudo_headers(headers: list[tuple[bytes, bytes]]) -> list[tuple[bytes, bytes]]:
|
25
|
+
filtered_headers: list[tuple[bytes, bytes]] = [(b'host', b'')] # Placeholder
|
26
|
+
authority = None
|
27
|
+
host = b''
|
28
|
+
for name, value in headers:
|
29
|
+
if name == b':authority': # h2 & h3 libraries validate this is present
|
30
|
+
authority = value
|
31
|
+
elif name == b'host':
|
32
|
+
host = value
|
33
|
+
elif name[0] != b':'[0]:
|
34
|
+
filtered_headers.append((name, value))
|
35
|
+
filtered_headers[0] = (b'host', authority if authority is not None else host)
|
36
|
+
return filtered_headers
|
@@ -0,0 +1,132 @@
|
|
1
|
+
import logging
|
2
|
+
import typing as ta
|
3
|
+
|
4
|
+
import anyio
|
5
|
+
import anyio.abc
|
6
|
+
import anyio.from_thread
|
7
|
+
import anyio.to_thread
|
8
|
+
|
9
|
+
from .config import Config
|
10
|
+
from .debug import handle_error_debug
|
11
|
+
from .types import AppWrapper
|
12
|
+
from .types import AsgiReceiveEvent
|
13
|
+
from .types import AsgiSendEvent
|
14
|
+
from .types import LifespanScope
|
15
|
+
from .types import UnexpectedMessageError
|
16
|
+
|
17
|
+
|
18
|
+
log = logging.getLogger(__name__)
|
19
|
+
|
20
|
+
|
21
|
+
class LifespanTimeoutError(Exception):
|
22
|
+
def __init__(self, stage: str) -> None:
|
23
|
+
super().__init__(
|
24
|
+
f'Timeout whilst awaiting {stage}. Your application may not support the Asgi Lifespan '
|
25
|
+
f'protocol correctly, alternatively the {stage}_timeout configuration is incorrect.',
|
26
|
+
)
|
27
|
+
|
28
|
+
|
29
|
+
class LifespanFailureError(Exception):
|
30
|
+
def __init__(self, stage: str, message: str) -> None:
|
31
|
+
super().__init__(f"Lifespan failure in {stage}. '{message}'")
|
32
|
+
|
33
|
+
|
34
|
+
class Lifespan:
|
35
|
+
def __init__(self, app: AppWrapper, config: Config) -> None:
|
36
|
+
super().__init__()
|
37
|
+
self.app = app
|
38
|
+
self.config = config
|
39
|
+
self.startup = anyio.Event()
|
40
|
+
self.shutdown = anyio.Event()
|
41
|
+
self.app_send_channel, self.app_receive_channel = anyio.create_memory_object_stream[ta.Any](
|
42
|
+
config.max_app_queue_size,
|
43
|
+
)
|
44
|
+
self.supported = True
|
45
|
+
|
46
|
+
async def handle_lifespan(
|
47
|
+
self,
|
48
|
+
*,
|
49
|
+
task_status: anyio.abc.TaskStatus[ta.Any] = anyio.TASK_STATUS_IGNORED,
|
50
|
+
) -> None:
|
51
|
+
task_status.started()
|
52
|
+
scope: LifespanScope = {
|
53
|
+
'type': 'lifespan',
|
54
|
+
'asgi': {'spec_version': '2.0', 'version': '3.0'},
|
55
|
+
}
|
56
|
+
|
57
|
+
try:
|
58
|
+
await self.app(
|
59
|
+
scope,
|
60
|
+
self.asgi_receive,
|
61
|
+
self.asgi_send,
|
62
|
+
anyio.to_thread.run_sync,
|
63
|
+
anyio.from_thread.run,
|
64
|
+
)
|
65
|
+
except (LifespanFailureError, anyio.get_cancelled_exc_class()):
|
66
|
+
raise
|
67
|
+
except (BaseExceptionGroup, Exception) as error:
|
68
|
+
handle_error_debug(error)
|
69
|
+
|
70
|
+
if isinstance(error, BaseExceptionGroup):
|
71
|
+
failure_error = error.subgroup(LifespanFailureError)
|
72
|
+
if failure_error is not None:
|
73
|
+
# Lifespan failures should crash the server
|
74
|
+
raise failure_error # noqa
|
75
|
+
reraise_error = error.subgroup((LifespanFailureError, anyio.get_cancelled_exc_class()))
|
76
|
+
if reraise_error is not None:
|
77
|
+
raise reraise_error # noqa
|
78
|
+
|
79
|
+
self.supported = False
|
80
|
+
if not self.startup.is_set():
|
81
|
+
log.warning('Asgi Framework Lifespan error, continuing without Lifespan support')
|
82
|
+
elif not self.shutdown.is_set():
|
83
|
+
log.exception('Asgi Framework Lifespan error, shutdown without Lifespan support')
|
84
|
+
else:
|
85
|
+
log.exception('Asgi Framework Lifespan errored after shutdown.')
|
86
|
+
|
87
|
+
finally:
|
88
|
+
self.startup.set()
|
89
|
+
self.shutdown.set()
|
90
|
+
await self.app_send_channel.aclose()
|
91
|
+
await self.app_receive_channel.aclose()
|
92
|
+
|
93
|
+
async def wait_for_startup(self) -> None:
|
94
|
+
if not self.supported:
|
95
|
+
return
|
96
|
+
|
97
|
+
await self.app_send_channel.send({'type': 'lifespan.startup'})
|
98
|
+
try:
|
99
|
+
with anyio.fail_after(self.config.startup_timeout):
|
100
|
+
await self.startup.wait()
|
101
|
+
except TimeoutError as error:
|
102
|
+
raise LifespanTimeoutError('startup') from error
|
103
|
+
|
104
|
+
async def wait_for_shutdown(self) -> None:
|
105
|
+
if not self.supported:
|
106
|
+
return
|
107
|
+
|
108
|
+
await self.app_send_channel.send({'type': 'lifespan.shutdown'})
|
109
|
+
try:
|
110
|
+
with anyio.fail_after(self.config.shutdown_timeout):
|
111
|
+
await self.shutdown.wait()
|
112
|
+
except TimeoutError as error:
|
113
|
+
raise LifespanTimeoutError('startup') from error
|
114
|
+
|
115
|
+
async def asgi_receive(self) -> AsgiReceiveEvent:
|
116
|
+
return await self.app_receive_channel.receive()
|
117
|
+
|
118
|
+
async def asgi_send(self, message: AsgiSendEvent) -> None:
|
119
|
+
if message['type'] == 'lifespan.startup.complete':
|
120
|
+
self.startup.set()
|
121
|
+
|
122
|
+
elif message['type'] == 'lifespan.shutdown.complete':
|
123
|
+
self.shutdown.set()
|
124
|
+
|
125
|
+
elif message['type'] == 'lifespan.startup.failed':
|
126
|
+
raise LifespanFailureError('startup', message.get('message', ''))
|
127
|
+
|
128
|
+
elif message['type'] == 'lifespan.shutdown.failed':
|
129
|
+
raise LifespanFailureError('shutdown', message.get('message', ''))
|
130
|
+
|
131
|
+
else:
|
132
|
+
raise UnexpectedMessageError(message['type'])
|
@@ -0,0 +1,157 @@
|
|
1
|
+
import functools
|
2
|
+
import multiprocessing
|
3
|
+
import multiprocessing.connection
|
4
|
+
import multiprocessing.context
|
5
|
+
import multiprocessing.synchronize
|
6
|
+
import pickle
|
7
|
+
import platform
|
8
|
+
import signal
|
9
|
+
import time
|
10
|
+
import typing as ta
|
11
|
+
|
12
|
+
import anyio
|
13
|
+
|
14
|
+
from .config import Config
|
15
|
+
from .sockets import Sockets
|
16
|
+
from .sockets import create_sockets
|
17
|
+
from .types import AsgiFramework
|
18
|
+
from .types import wrap_app
|
19
|
+
from .workers import serve
|
20
|
+
|
21
|
+
|
22
|
+
async def check_multiprocess_shutdown_event(
|
23
|
+
shutdown_event: multiprocessing.synchronize.Event,
|
24
|
+
sleep: ta.Callable[[float], ta.Awaitable[ta.Any]],
|
25
|
+
) -> None:
|
26
|
+
while True:
|
27
|
+
if shutdown_event.is_set():
|
28
|
+
return
|
29
|
+
await sleep(0.1)
|
30
|
+
|
31
|
+
|
32
|
+
def _multiprocess_serve(
|
33
|
+
app: AsgiFramework,
|
34
|
+
config: Config,
|
35
|
+
sockets: Sockets | None = None,
|
36
|
+
shutdown_event: multiprocessing.synchronize.Event | None = None,
|
37
|
+
) -> None:
|
38
|
+
if sockets is not None:
|
39
|
+
for sock in sockets.insecure_sockets:
|
40
|
+
sock.listen(config.backlog)
|
41
|
+
|
42
|
+
shutdown_trigger = None
|
43
|
+
if shutdown_event is not None:
|
44
|
+
shutdown_trigger = functools.partial(check_multiprocess_shutdown_event, shutdown_event, anyio.sleep)
|
45
|
+
|
46
|
+
anyio.run(
|
47
|
+
functools.partial(
|
48
|
+
serve,
|
49
|
+
wrap_app(app),
|
50
|
+
config,
|
51
|
+
sockets=sockets,
|
52
|
+
shutdown_trigger=shutdown_trigger,
|
53
|
+
),
|
54
|
+
# backend='trio',
|
55
|
+
)
|
56
|
+
|
57
|
+
|
58
|
+
def serve_multiprocess(
|
59
|
+
app: AsgiFramework,
|
60
|
+
config: Config,
|
61
|
+
) -> int:
|
62
|
+
sockets = create_sockets(config)
|
63
|
+
|
64
|
+
exitcode = 0
|
65
|
+
ctx = multiprocessing.get_context('spawn')
|
66
|
+
|
67
|
+
active = True
|
68
|
+
shutdown_event = ctx.Event()
|
69
|
+
|
70
|
+
def shutdown(*args: ta.Any) -> None:
|
71
|
+
nonlocal active, shutdown_event
|
72
|
+
shutdown_event.set()
|
73
|
+
active = False
|
74
|
+
|
75
|
+
processes: list[multiprocessing.Process] = []
|
76
|
+
while active:
|
77
|
+
# Ignore SIGINT before creating the processes, so that they inherit the signal handling. This means that the
|
78
|
+
# shutdown function controls the shutdown.
|
79
|
+
signal.signal(signal.SIGINT, signal.SIG_IGN)
|
80
|
+
|
81
|
+
_populate(
|
82
|
+
processes,
|
83
|
+
app,
|
84
|
+
config,
|
85
|
+
_multiprocess_serve,
|
86
|
+
sockets,
|
87
|
+
shutdown_event,
|
88
|
+
ctx,
|
89
|
+
)
|
90
|
+
|
91
|
+
for signal_name in ('SIGINT', 'SIGTERM', 'SIGBREAK'):
|
92
|
+
if hasattr(signal, signal_name):
|
93
|
+
signal.signal(getattr(signal, signal_name), shutdown)
|
94
|
+
|
95
|
+
multiprocessing.connection.wait(process.sentinel for process in processes)
|
96
|
+
|
97
|
+
exitcode = _join_exited(processes)
|
98
|
+
if exitcode != 0:
|
99
|
+
shutdown_event.set()
|
100
|
+
active = False
|
101
|
+
|
102
|
+
for process in processes:
|
103
|
+
process.terminate()
|
104
|
+
|
105
|
+
exitcode = _join_exited(processes) if exitcode != 0 else exitcode
|
106
|
+
|
107
|
+
for sock in sockets.insecure_sockets:
|
108
|
+
sock.close()
|
109
|
+
|
110
|
+
return exitcode
|
111
|
+
|
112
|
+
|
113
|
+
def _populate(
|
114
|
+
processes: list[multiprocessing.Process],
|
115
|
+
app: AsgiFramework,
|
116
|
+
config: Config,
|
117
|
+
worker_func: ta.Callable,
|
118
|
+
sockets: Sockets,
|
119
|
+
shutdown_event: multiprocessing.synchronize.Event,
|
120
|
+
ctx: multiprocessing.context.BaseContext,
|
121
|
+
) -> None:
|
122
|
+
num_workers = config.workers or 1
|
123
|
+
if num_workers < 0:
|
124
|
+
num_workers = multiprocessing.cpu_count()
|
125
|
+
for _ in range(num_workers - len(processes)):
|
126
|
+
process = ctx.Process( # type: ignore
|
127
|
+
target=worker_func,
|
128
|
+
kwargs={
|
129
|
+
'app': app,
|
130
|
+
'config': config,
|
131
|
+
'shutdown_event': shutdown_event,
|
132
|
+
'sockets': sockets,
|
133
|
+
},
|
134
|
+
)
|
135
|
+
process.daemon = True
|
136
|
+
try:
|
137
|
+
process.start()
|
138
|
+
except pickle.PicklingError as error:
|
139
|
+
raise RuntimeError(
|
140
|
+
'Cannot pickle the config, see https://docs.python.org/3/library/pickle.html#pickle-picklable',
|
141
|
+
# noqa: E501
|
142
|
+
) from error
|
143
|
+
processes.append(process)
|
144
|
+
if platform.system() == 'Windows':
|
145
|
+
time.sleep(0.1)
|
146
|
+
|
147
|
+
|
148
|
+
def _join_exited(processes: list[multiprocessing.Process]) -> int:
|
149
|
+
exitcode = 0
|
150
|
+
for index in reversed(range(len(processes))):
|
151
|
+
worker = processes[index]
|
152
|
+
if worker.exitcode is not None:
|
153
|
+
worker.join()
|
154
|
+
exitcode = worker.exitcode if exitcode == 0 else exitcode
|
155
|
+
del processes[index]
|
156
|
+
|
157
|
+
return exitcode
|
@@ -0,0 +1 @@
|
|
1
|
+
from .protocols import ProtocolWrapper # noqa
|