mrok 0.4.6__py3-none-any.whl → 0.6.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (50) hide show
  1. mrok/agent/devtools/inspector/app.py +2 -2
  2. mrok/agent/sidecar/app.py +61 -35
  3. mrok/agent/sidecar/main.py +35 -9
  4. mrok/agent/ziticorn.py +9 -3
  5. mrok/cli/commands/__init__.py +2 -2
  6. mrok/cli/commands/admin/bootstrap.py +3 -2
  7. mrok/cli/commands/admin/utils.py +2 -2
  8. mrok/cli/commands/agent/run/sidecar.py +59 -1
  9. mrok/cli/commands/{proxy → frontend}/__init__.py +1 -1
  10. mrok/cli/commands/frontend/run.py +91 -0
  11. mrok/constants.py +0 -2
  12. mrok/controller/openapi/examples.py +13 -0
  13. mrok/controller/schemas.py +2 -2
  14. mrok/frontend/__init__.py +3 -0
  15. mrok/frontend/app.py +75 -0
  16. mrok/{proxy → frontend}/main.py +12 -10
  17. mrok/proxy/__init__.py +0 -3
  18. mrok/proxy/app.py +158 -83
  19. mrok/proxy/asgi.py +96 -0
  20. mrok/proxy/backend.py +45 -0
  21. mrok/proxy/event_publisher.py +66 -0
  22. mrok/proxy/exceptions.py +22 -0
  23. mrok/{master.py → proxy/master.py} +36 -81
  24. mrok/{metrics.py → proxy/metrics.py} +38 -50
  25. mrok/{http/middlewares.py → proxy/middleware.py} +17 -26
  26. mrok/{datastructures.py → proxy/models.py} +43 -10
  27. mrok/proxy/stream.py +68 -0
  28. mrok/{http → proxy}/utils.py +1 -1
  29. mrok/proxy/worker.py +64 -0
  30. mrok/{http/config.py → proxy/ziticorn.py} +29 -6
  31. mrok/types/proxy.py +20 -0
  32. mrok/types/ziti.py +1 -0
  33. mrok/ziti/api.py +15 -18
  34. mrok/ziti/bootstrap.py +3 -2
  35. mrok/ziti/identities.py +5 -4
  36. mrok/ziti/services.py +3 -2
  37. {mrok-0.4.6.dist-info → mrok-0.6.0.dist-info}/METADATA +2 -5
  38. {mrok-0.4.6.dist-info → mrok-0.6.0.dist-info}/RECORD +43 -39
  39. mrok/cli/commands/proxy/run.py +0 -49
  40. mrok/http/forwarder.py +0 -354
  41. mrok/http/lifespan.py +0 -39
  42. mrok/http/pool.py +0 -239
  43. mrok/http/protocol.py +0 -11
  44. mrok/http/server.py +0 -14
  45. mrok/http/types.py +0 -18
  46. /mrok/{http → proxy}/constants.py +0 -0
  47. /mrok/{http → types}/__init__.py +0 -0
  48. {mrok-0.4.6.dist-info → mrok-0.6.0.dist-info}/WHEEL +0 -0
  49. {mrok-0.4.6.dist-info → mrok-0.6.0.dist-info}/entry_points.txt +0 -0
  50. {mrok-0.4.6.dist-info → mrok-0.6.0.dist-info}/licenses/LICENSE.txt +0 -0
mrok/proxy/app.py CHANGED
@@ -1,101 +1,176 @@
1
- import asyncio
1
+ import abc
2
2
  import logging
3
- from collections.abc import AsyncGenerator
4
- from contextlib import asynccontextmanager
5
- from pathlib import Path
6
3
 
7
- import openziti
8
- from openziti.context import ZitiContext
4
+ from httpcore import AsyncConnectionPool, Request
9
5
 
10
- from mrok.conf import get_settings
11
- from mrok.constants import RE_SUBDOMAIN
12
- from mrok.http.forwarder import BackendUnavailableError, ForwardAppBase, InvalidBackendError
13
- from mrok.http.pool import ConnectionPool, PoolManager
14
- from mrok.http.types import Scope, StreamPair
15
- from mrok.logging import setup_logging
6
+ from mrok.proxy.exceptions import ProxyError
7
+ from mrok.proxy.stream import ASGIRequestBodyStream
8
+ from mrok.types.proxy import ASGIReceive, ASGISend, Scope
16
9
 
17
10
  logger = logging.getLogger("mrok.proxy")
18
11
 
19
12
 
20
- class ProxyError(Exception):
21
- pass
13
+ HOP_BY_HOP_HEADERS = [
14
+ b"connection",
15
+ b"keep-alive",
16
+ b"proxy-authenticate",
17
+ b"proxy-authorization",
18
+ b"te",
19
+ b"trailers",
20
+ b"transfer-encoding",
21
+ b"upgrade",
22
+ ]
22
23
 
23
24
 
24
- class ProxyApp(ForwardAppBase):
25
+ class ProxyAppBase(abc.ABC):
25
26
  def __init__(
26
27
  self,
27
- identity_file: str | Path,
28
28
  *,
29
- read_chunk_size: int = 65536,
29
+ max_connections: int | None = 10,
30
+ max_keepalive_connections: int | None = None,
31
+ keepalive_expiry: float | None = None,
32
+ retries: int = 0,
30
33
  ) -> None:
31
- super().__init__(read_chunk_size=read_chunk_size)
32
- self._identity_file = identity_file
33
- settings = get_settings()
34
- self._proxy_wildcard_domain = (
35
- settings.proxy.domain
36
- if settings.proxy.domain[0] == "."
37
- else f".{settings.proxy.domain}"
38
- )
39
- self._ziti_ctx: ZitiContext | None = None
40
- self._pool_manager = PoolManager(self.build_connection_pool)
41
-
42
- def get_target_from_header(self, headers: dict[str, str], name: str) -> str | None:
43
- header_value = headers.get(name, "")
44
- if self._proxy_wildcard_domain in header_value:
45
- if ":" in header_value:
46
- header_value, _ = header_value.split(":", 1)
47
- return header_value[: -len(self._proxy_wildcard_domain)]
48
-
49
- def get_target_name(self, headers: dict[str, str]) -> str:
50
- target = self.get_target_from_header(headers, "x-forwarded-host")
51
- if not target:
52
- target = self.get_target_from_header(headers, "host")
53
- if not target:
54
- raise ProxyError("Neither Host nor X-Forwarded-Host contain a valid target name")
55
- return target
56
-
57
- def _get_ziti_ctx(self) -> ZitiContext:
58
- if self._ziti_ctx is None:
59
- ctx, err = openziti.load(str(self._identity_file), timeout=10_000)
60
- if err != 0:
61
- raise Exception(f"Cannot create a Ziti context from the identity file: {err}")
62
- self._ziti_ctx = ctx
63
- return self._ziti_ctx
64
-
65
- async def startup(self):
66
- setup_logging(get_settings())
67
- self._get_ziti_ctx()
68
-
69
- async def shutdown(self):
70
- await self._pool_manager.shutdown()
71
-
72
- async def build_connection_pool(self, key: str) -> ConnectionPool:
73
- async def connect():
74
- sock = self._get_ziti_ctx().connect(key)
75
- reader, writer = await asyncio.open_connection(sock=sock)
76
- return reader, writer
77
-
78
- return ConnectionPool(
79
- pool_name=key,
80
- factory=connect,
81
- initial_connections=5,
82
- max_size=100,
83
- idle_timeout=20.0,
84
- reaper_interval=5.0,
34
+ self._pool = self.setup_connection_pool(
35
+ max_connections=max_connections,
36
+ max_keepalive_connections=max_keepalive_connections,
37
+ keepalive_expiry=keepalive_expiry,
38
+ retries=retries,
85
39
  )
86
40
 
87
- @asynccontextmanager
88
- async def select_backend(
41
+ @abc.abstractmethod
42
+ def setup_connection_pool(
89
43
  self,
90
- scope: Scope,
91
- headers: dict[str, str],
92
- ) -> AsyncGenerator[StreamPair]:
93
- target_name = self.get_target_name(headers)
94
- if not target_name or not RE_SUBDOMAIN.fullmatch(target_name):
95
- raise InvalidBackendError()
96
- pool = await self._pool_manager.get_pool(target_name)
44
+ max_connections: int | None,
45
+ max_keepalive_connections: int | None,
46
+ keepalive_expiry: float | None,
47
+ retries: int,
48
+ ) -> AsyncConnectionPool:
49
+ raise NotImplementedError()
50
+
51
+ @abc.abstractmethod
52
+ def get_upstream_base_url(self, scope: Scope) -> str:
53
+ raise NotImplementedError()
54
+
55
+ async def __call__(self, scope: Scope, receive: ASGIReceive, send: ASGISend) -> None:
56
+ if scope.get("type") == "lifespan":
57
+ return
58
+
59
+ if scope.get("type") != "http":
60
+ await self._send_error(send, 500, "Unsupported")
61
+ return
62
+
97
63
  try:
98
- async with pool.acquire() as (reader, writer):
99
- yield reader, writer
64
+ base_url = self.get_upstream_base_url(scope)
65
+ if base_url.endswith("/"): # pragma: no cover
66
+ base_url = base_url[:-1]
67
+ full_path = self._format_path(scope)
68
+ url = f"{base_url}{full_path}"
69
+ method = scope.get("method", "GET").encode()
70
+ headers = self._prepare_headers(scope)
71
+
72
+ body_stream = ASGIRequestBodyStream(receive)
73
+
74
+ request = Request(
75
+ method=method,
76
+ url=url,
77
+ headers=headers,
78
+ content=body_stream,
79
+ )
80
+ response = await self._pool.handle_async_request(request)
81
+ logger.debug(f"connection pool status: {self._pool}")
82
+ response_headers = []
83
+ for k, v in response.headers:
84
+ if k.lower() not in HOP_BY_HOP_HEADERS:
85
+ response_headers.append((k, v))
86
+
87
+ await send(
88
+ {
89
+ "type": "http.response.start",
90
+ "status": response.status,
91
+ "headers": response_headers,
92
+ }
93
+ )
94
+
95
+ async for chunk in response.stream: # type: ignore[union-attr]
96
+ await send(
97
+ {
98
+ "type": "http.response.body",
99
+ "body": chunk,
100
+ "more_body": True,
101
+ }
102
+ )
103
+
104
+ await send({"type": "http.response.body", "body": b"", "more_body": False})
105
+ await response.aclose()
106
+
107
+ except ProxyError as pe:
108
+ await self._send_error(send, pe.http_status, pe.message)
109
+
100
110
  except Exception:
101
- raise BackendUnavailableError()
111
+ logger.exception("Unexpected error in forwarder")
112
+ await self._send_error(send, 502, "Bad Gateway")
113
+
114
+ async def _send_error(self, send: ASGISend, http_status: int, body: str):
115
+ try:
116
+ await send({"type": "http.response.start", "status": http_status, "headers": []})
117
+ await send({"type": "http.response.body", "body": body.encode()})
118
+ except Exception as e: # pragma: no cover
119
+ logger.error(f"Cannot send error response: {e}")
120
+
121
+ def _prepare_headers(self, scope: Scope) -> list[tuple[bytes, bytes]]:
122
+ headers: list[tuple[bytes, bytes]] = []
123
+ scope_headers = scope.get("headers", [])
124
+
125
+ for k, v in scope_headers:
126
+ if k.lower() not in HOP_BY_HOP_HEADERS:
127
+ headers.append((k, v))
128
+
129
+ self._merge_x_forwarded(headers, scope)
130
+
131
+ return headers
132
+
133
+ def _find_header(self, headers: list[tuple[bytes, bytes]], name: bytes) -> int | None:
134
+ """Return index of header `name` in `headers`, or None if missing."""
135
+ lname = name.lower()
136
+ for i, (k, _) in enumerate(headers):
137
+ if k.lower() == lname:
138
+ return i
139
+ return None
140
+
141
+ def _merge_x_forwarded(self, headers: list[tuple[bytes, bytes]], scope: Scope) -> None:
142
+ client = scope.get("client")
143
+ if client:
144
+ client_ip = client[0].encode()
145
+ idx = self._find_header(headers, b"x-forwarded-for")
146
+ if idx is None:
147
+ headers.append((b"x-forwarded-for", client_ip))
148
+ else:
149
+ k, v = headers[idx]
150
+ headers[idx] = (k, v + b", " + client_ip)
151
+
152
+ server = scope.get("server")
153
+ if server:
154
+ if self._find_header(headers, b"x-forwarded-host") is None:
155
+ headers.append((b"x-forwarded-host", server[0].encode()))
156
+ if server[1] and self._find_header(headers, b"x-forwarded-port") is None:
157
+ headers.append((b"x-forwarded-port", str(server[1]).encode()))
158
+
159
+ # Always set the protocol to https for upstream
160
+ idx_proto = self._find_header(headers, b"x-forwarded-proto")
161
+ if idx_proto is None:
162
+ headers.append((b"x-forwarded-proto", b"https"))
163
+ else:
164
+ k, _ = headers[idx_proto]
165
+ headers[idx_proto] = (k, b"https")
166
+
167
+ def _format_path(self, scope: Scope) -> str:
168
+ raw_path = scope.get("raw_path")
169
+ if raw_path:
170
+ return raw_path.decode()
171
+ q = scope.get("query_string", b"")
172
+ path = scope.get("path", "/")
173
+ path_qs = path
174
+ if q:
175
+ path_qs += "?" + q.decode()
176
+ return path_qs
mrok/proxy/asgi.py ADDED
@@ -0,0 +1,96 @@
1
+ from collections.abc import Iterator
2
+ from contextlib import AsyncExitStack, asynccontextmanager
3
+ from typing import Any, ParamSpec, Protocol
4
+
5
+ from mrok.types.proxy import ASGIApp, ASGIReceive, ASGISend, Lifespan, Scope
6
+
7
+ P = ParamSpec("P")
8
+
9
+
10
+ class ASGIMiddleware(Protocol[P]):
11
+ def __call__(
12
+ self, app: ASGIApp, /, *args: P.args, **kwargs: P.kwargs
13
+ ) -> ASGIApp: ... # pragma: no cover
14
+
15
+
16
+ class Middleware:
17
+ def __init__(self, cls: ASGIMiddleware[P], *args: P.args, **kwargs: P.kwargs) -> None:
18
+ self.cls = cls
19
+ self.args = args
20
+ self.kwargs = kwargs
21
+
22
+ def __iter__(self) -> Iterator[Any]:
23
+ as_tuple = (self.cls, self.args, self.kwargs)
24
+ return iter(as_tuple)
25
+
26
+
27
+ class ASGIAppWrapper:
28
+ def __init__(
29
+ self,
30
+ app: ASGIApp,
31
+ lifespan: Lifespan | None = None,
32
+ ) -> None:
33
+ self.app = app
34
+ self.lifespan = lifespan
35
+ self.middlware: list[Middleware] = []
36
+ self.middleare_stack: ASGIApp | None = None
37
+
38
+ def add_middleware(self, cls: ASGIMiddleware[P], *args: P.args, **kwargs: P.kwargs):
39
+ self.middlware.insert(0, Middleware(cls, *args, **kwargs))
40
+
41
+ def build_middleware_stack(self):
42
+ app = self.app
43
+ for cls, args, kwargs in reversed(self.middlware):
44
+ app = cls(app, *args, **kwargs)
45
+ return app
46
+
47
+ def get_starlette_lifespan(self):
48
+ router = getattr(self.app, "router", None)
49
+ if router is None:
50
+ return None
51
+ return getattr(router, "lifespan_context", None)
52
+
53
+ @asynccontextmanager
54
+ async def merge_lifespan(self, app: ASGIApp):
55
+ async with AsyncExitStack() as stack:
56
+ state: dict[Any, Any] = {}
57
+ if self.lifespan is not None:
58
+ outer_state = await stack.enter_async_context(self.lifespan(app))
59
+ state.update(outer_state or {})
60
+ starlette_lifespan = self.get_starlette_lifespan()
61
+ if starlette_lifespan is not None:
62
+ inner_state = await stack.enter_async_context(starlette_lifespan(app))
63
+ state.update(inner_state or {})
64
+ yield state
65
+
66
+ async def handle_lifespan(self, scope: Scope, receive: ASGIReceive, send: ASGISend) -> None:
67
+ started = False
68
+ app: Any = scope.get("app")
69
+ await receive()
70
+ try:
71
+ async with self.merge_lifespan(app) as state:
72
+ if state:
73
+ if "state" not in scope:
74
+ raise RuntimeError('"state" is unsupported by the current ASGI Server.')
75
+ scope["state"].update(state)
76
+ await send({"type": "lifespan.startup.complete"})
77
+ started = True
78
+ await receive()
79
+ except Exception as e: # pragma: no cover
80
+ if started:
81
+ await send({"type": "lifespan.shutdown.failed", "message": str(e)})
82
+ else:
83
+ await send({"type": "lifespan.startup.failed", "message": str(e)})
84
+ raise
85
+ else:
86
+ await send({"type": "lifespan.shutdown.complete"})
87
+
88
+ async def __call__(self, scope: Scope, receive: ASGIReceive, send: ASGISend) -> None:
89
+ if self.middleare_stack is None: # pragma: no branch
90
+ self.middleware_stack = self.build_middleware_stack()
91
+ if scope["type"] == "lifespan":
92
+ scope["app"] = self
93
+ await self.handle_lifespan(scope, receive, send)
94
+ return
95
+
96
+ await self.middleware_stack(scope, receive, send)
mrok/proxy/backend.py ADDED
@@ -0,0 +1,45 @@
1
+ import asyncio
2
+ from collections.abc import Iterable
3
+ from pathlib import Path
4
+
5
+ import openziti
6
+ from httpcore import SOCKET_OPTION, AsyncNetworkBackend, AsyncNetworkStream
7
+ from openziti.context import ZitiContext
8
+
9
+ from mrok.proxy.exceptions import InvalidTargetError, TargetUnavailableError
10
+ from mrok.proxy.stream import AIONetworkStream
11
+
12
+
13
+ class AIOZitiNetworkBackend(AsyncNetworkBackend):
14
+ def __init__(self, identity_file: str | Path) -> None:
15
+ self._identity_file = identity_file
16
+ self._ziti_ctx: ZitiContext | None = None
17
+
18
+ def _get_ziti_ctx(self) -> ZitiContext:
19
+ if self._ziti_ctx is None:
20
+ ctx, err = openziti.load(str(self._identity_file), timeout=10_000)
21
+ if err != 0:
22
+ raise Exception(f"Cannot create a Ziti context from the identity file: {err}")
23
+ self._ziti_ctx = ctx
24
+ return self._ziti_ctx
25
+
26
+ async def connect_tcp(
27
+ self,
28
+ host: str,
29
+ port: int,
30
+ timeout: float | None = None,
31
+ local_address: str | None = None,
32
+ socket_options: Iterable[SOCKET_OPTION] | None = None,
33
+ ) -> AsyncNetworkStream:
34
+ ctx = self._get_ziti_ctx()
35
+ try:
36
+ sock = ctx.connect(host)
37
+ reader, writer = await asyncio.open_connection(sock=sock)
38
+ return AIONetworkStream(reader, writer)
39
+ except Exception as e:
40
+ if e.args and e.args[0] == -24: # the service exists but is not available
41
+ raise TargetUnavailableError() from e
42
+ raise InvalidTargetError() from e
43
+
44
+ async def sleep(self, seconds: float) -> None:
45
+ await asyncio.sleep(seconds)
@@ -0,0 +1,66 @@
1
+ import asyncio
2
+ import contextlib
3
+ import logging
4
+
5
+ import zmq
6
+ import zmq.asyncio
7
+
8
+ from mrok.proxy.asgi import ASGIAppWrapper
9
+ from mrok.proxy.metrics import MetricsCollector
10
+ from mrok.proxy.middleware import CaptureMiddleware, MetricsMiddleware
11
+ from mrok.proxy.models import Event, HTTPResponse, ServiceMetadata, Status
12
+ from mrok.types.proxy import ASGIApp
13
+
14
+ logger = logging.getLogger("mrok.proxy")
15
+
16
+
17
+ class EventPublisher:
18
+ def __init__(
19
+ self,
20
+ worker_id: str,
21
+ meta: ServiceMetadata | None = None,
22
+ event_publisher_port: int = 50000,
23
+ metrics_interval: float = 5.0,
24
+ ):
25
+ self._worker_id = worker_id
26
+ self._meta = meta
27
+ self._metrics_interval = metrics_interval
28
+ self.publisher_port = event_publisher_port
29
+ self._zmq_ctx = zmq.asyncio.Context()
30
+ self._publisher = self._zmq_ctx.socket(zmq.PUB)
31
+ self._metrics_collector = MetricsCollector(self._worker_id)
32
+ self._publish_task = None
33
+
34
+ async def on_startup(self):
35
+ self._publisher.connect(f"tcp://localhost:{self.publisher_port}")
36
+ self._publish_task = asyncio.create_task(self.publish_metrics_event())
37
+ logger.info(f"Events publishing for worker {self._worker_id} started")
38
+
39
+ async def on_shutdown(self):
40
+ self._publish_task.cancel()
41
+ with contextlib.suppress(asyncio.CancelledError):
42
+ await self._publish_task
43
+ self._publisher.close()
44
+ self._zmq_ctx.term()
45
+ logger.info(f"Events publishing for worker {self._worker_id} stopped")
46
+
47
+ async def publish_metrics_event(self):
48
+ while True:
49
+ snap = await self._metrics_collector.snapshot()
50
+ event = Event(type="status", data=Status(meta=self._meta, metrics=snap))
51
+ await self._publisher.send_string(event.model_dump_json())
52
+ await asyncio.sleep(self._metrics_interval)
53
+
54
+ async def publish_response_event(self, response: HTTPResponse):
55
+ event = Event(type="response", data=response)
56
+ await self._publisher.send_string(event.model_dump_json()) # type: ignore[attr-defined]
57
+
58
+ def setup_middleware(self, app: ASGIAppWrapper):
59
+ app.add_middleware(CaptureMiddleware, self.publish_response_event)
60
+ app.add_middleware(MetricsMiddleware, self._metrics_collector) # type: ignore
61
+
62
+ @contextlib.asynccontextmanager
63
+ async def lifespan(self, app: ASGIApp):
64
+ await self.on_startup() # type: ignore
65
+ yield
66
+ await self.on_shutdown() # type: ignore
@@ -0,0 +1,22 @@
1
+ from http import HTTPStatus
2
+
3
+ from httpcore import ConnectError
4
+
5
+
6
+ class ProxyError(Exception):
7
+ def __init__(self, http_status: HTTPStatus, message: str) -> None:
8
+ self.http_status: HTTPStatus = http_status
9
+ self.message: str = message
10
+
11
+
12
+ class InvalidTargetError(ProxyError):
13
+ def __init__(self):
14
+ super().__init__(HTTPStatus.BAD_GATEWAY, "Bad Gateway: invalid target extension.")
15
+
16
+
17
+ class TargetUnavailableError(ProxyError, ConnectError):
18
+ def __init__(self):
19
+ super().__init__(
20
+ HTTPStatus.SERVICE_UNAVAILABLE,
21
+ "Service Unavailable: the target extension is unavailable.",
22
+ )
@@ -1,6 +1,3 @@
1
- import asyncio
2
- import contextlib
3
- import json
4
1
  import logging
5
2
  import os
6
3
  import signal
@@ -11,21 +8,14 @@ from pathlib import Path
11
8
  from uuid import uuid4
12
9
 
13
10
  import zmq
14
- import zmq.asyncio
15
- from uvicorn.importer import import_from_string
16
11
  from watchfiles import watch
17
12
  from watchfiles.filters import PythonFilter
18
13
  from watchfiles.run import CombinedProcess, start_process
19
14
 
20
15
  from mrok.conf import get_settings
21
- from mrok.datastructures import Event, HTTPResponse, Meta, Status
22
- from mrok.http.config import MrokBackendConfig
23
- from mrok.http.lifespan import LifespanWrapper
24
- from mrok.http.middlewares import CaptureMiddleware, MetricsMiddleware
25
- from mrok.http.server import MrokServer
26
- from mrok.http.types import ASGIApp
27
16
  from mrok.logging import setup_logging
28
- from mrok.metrics import WorkerMetricsCollector
17
+ from mrok.proxy.worker import Worker
18
+ from mrok.types.proxy import ASGIApp
29
19
 
30
20
  logger = logging.getLogger("mrok.agent")
31
21
 
@@ -52,7 +42,7 @@ def start_events_router(events_pub_port: int, events_sub_port: int):
52
42
  try:
53
43
  logger.info(f"Events router process started: {os.getpid()}")
54
44
  zmq.proxy(frontend, backend)
55
- except KeyboardInterrupt:
45
+ except KeyboardInterrupt: # pragma: no cover
56
46
  pass
57
47
  finally:
58
48
  frontend.close()
@@ -62,78 +52,43 @@ def start_events_router(events_pub_port: int, events_sub_port: int):
62
52
 
63
53
  def start_uvicorn_worker(
64
54
  worker_id: str,
65
- app: ASGIApp,
55
+ app: ASGIApp | str,
66
56
  identity_file: str,
57
+ events_enabled: bool,
67
58
  events_pub_port: int,
68
59
  metrics_interval: float = 5.0,
69
60
  ):
70
61
  import sys
71
62
 
72
63
  sys.path.insert(0, os.getcwd())
73
- if isinstance(app, str):
74
- app = import_from_string(app)
75
64
 
76
- setup_logging(get_settings())
77
- identity = json.load(open(identity_file))
78
- meta = Meta(**identity["mrok"])
79
- ctx = zmq.asyncio.Context()
80
- pub = ctx.socket(zmq.PUB)
81
- pub.connect(f"tcp://localhost:{events_pub_port}")
82
- metrics = WorkerMetricsCollector(worker_id)
83
-
84
- task = None
85
-
86
- async def status_sender():
87
- while True:
88
- snap = await metrics.snapshot()
89
- event = Event(type="status", data=Status(meta=meta, metrics=snap))
90
- await pub.send_string(event.model_dump_json())
91
- await asyncio.sleep(metrics_interval)
92
-
93
- async def on_startup(): # noqa
94
- nonlocal task
95
- task = asyncio.create_task(status_sender())
96
-
97
- async def on_shutdown(): # noqa
98
- if task:
99
- task.cancel()
100
-
101
- async def on_response_complete(response: HTTPResponse):
102
- event = Event(type="response", data=response)
103
- await pub.send_string(event.model_dump_json())
104
-
105
- config = MrokBackendConfig(
106
- LifespanWrapper(
107
- MetricsMiddleware(
108
- CaptureMiddleware(
109
- app,
110
- on_response_complete,
111
- ),
112
- metrics,
113
- ),
114
- on_startup=on_startup,
115
- on_shutdown=on_shutdown,
116
- ),
65
+ worker = Worker(
66
+ worker_id,
67
+ app,
117
68
  identity_file,
69
+ events_enabled=events_enabled,
70
+ event_publisher_port=events_pub_port,
71
+ metrics_interval=metrics_interval,
118
72
  )
119
- server = MrokServer(config)
120
- with contextlib.suppress(KeyboardInterrupt, asyncio.CancelledError):
121
- server.run()
73
+ worker.run()
122
74
 
123
75
 
124
76
  class MasterBase(ABC):
125
77
  def __init__(
126
78
  self,
127
79
  identity_file: str,
128
- workers: int,
129
- reload: bool,
130
- events_pub_port: int,
131
- events_sub_port: int,
80
+ *,
81
+ workers: int = 4,
82
+ reload: bool = False,
83
+ events_enabled: bool = True,
84
+ events_pub_port: int = 50000,
85
+ events_sub_port: int = 50001,
132
86
  metrics_interval: float = 5.0,
133
87
  ):
134
88
  self.identity_file = identity_file
135
89
  self.workers = workers
136
90
  self.reload = reload
91
+ self.events_enabled = events_enabled
137
92
  self.events_pub_port = events_pub_port
138
93
  self.events_sub_port = events_sub_port
139
94
  self.metrics_interval = metrics_interval
@@ -154,7 +109,7 @@ class MasterBase(ABC):
154
109
 
155
110
  @abstractmethod
156
111
  def get_asgi_app(self):
157
- pass
112
+ raise NotImplementedError()
158
113
 
159
114
  def setup_signals_handler(self):
160
115
  for sig in (signal.SIGINT, signal.SIGTERM):
@@ -173,6 +128,7 @@ class MasterBase(ABC):
173
128
  worker_id,
174
129
  self.get_asgi_app(),
175
130
  self.identity_file,
131
+ self.events_enabled,
176
132
  self.events_pub_port,
177
133
  self.metrics_interval,
178
134
  ),
@@ -181,18 +137,6 @@ class MasterBase(ABC):
181
137
  logger.info(f"Worker {worker_id} [{p.pid}] started")
182
138
  return p
183
139
 
184
- def start(self):
185
- self.start_events_router()
186
- self.start_workers()
187
- self.monitor_thread.start()
188
-
189
- def stop(self):
190
- if self.monitor_thread.is_alive():
191
- logger.debug("Wait for monitor worker to exit")
192
- self.monitor_thread.join(timeout=MONITOR_THREAD_JOIN_TIMEOUT)
193
- self.stop_workers()
194
- self.stop_events_router()
195
-
196
140
  def start_events_router(self):
197
141
  self.zmq_pubsub_router_process = start_process(
198
142
  start_events_router,
@@ -204,21 +148,32 @@ class MasterBase(ABC):
204
148
  None,
205
149
  )
206
150
 
207
- def stop_events_router(self):
208
- self.zmq_pubsub_router_process.stop(sigint_timeout=5, sigkill_timeout=1)
209
-
210
151
  def start_workers(self):
211
152
  for i in range(self.workers):
212
153
  worker_id = self.worker_identifiers[i]
213
154
  p = self.start_worker(worker_id)
214
155
  self.worker_processes[worker_id] = p
215
156
 
157
+ def start(self):
158
+ self.start_events_router()
159
+ self.start_workers()
160
+ self.monitor_thread.start()
161
+
162
+ def stop_events_router(self):
163
+ self.zmq_pubsub_router_process.stop(sigint_timeout=5, sigkill_timeout=1)
164
+
216
165
  def stop_workers(self):
217
166
  for process in self.worker_processes.values():
218
- if process.is_alive():
167
+ if process.is_alive(): # pragma: no branch
219
168
  process.stop(sigint_timeout=5, sigkill_timeout=1)
220
169
  self.worker_processes.clear()
221
170
 
171
+ def stop(self):
172
+ if self.monitor_thread.is_alive(): # pragma: no branch
173
+ self.monitor_thread.join(timeout=MONITOR_THREAD_JOIN_TIMEOUT)
174
+ self.stop_workers()
175
+ self.stop_events_router()
176
+
222
177
  def restart(self):
223
178
  self.pause_event.set()
224
179
  self.stop_workers()