mrok 0.4.6__py3-none-any.whl → 0.6.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mrok/agent/devtools/inspector/app.py +2 -2
- mrok/agent/sidecar/app.py +61 -35
- mrok/agent/sidecar/main.py +35 -9
- mrok/agent/ziticorn.py +9 -3
- mrok/cli/commands/__init__.py +2 -2
- mrok/cli/commands/admin/bootstrap.py +3 -2
- mrok/cli/commands/admin/utils.py +2 -2
- mrok/cli/commands/agent/run/sidecar.py +59 -1
- mrok/cli/commands/{proxy → frontend}/__init__.py +1 -1
- mrok/cli/commands/frontend/run.py +91 -0
- mrok/constants.py +0 -2
- mrok/controller/openapi/examples.py +13 -0
- mrok/controller/schemas.py +2 -2
- mrok/frontend/__init__.py +3 -0
- mrok/frontend/app.py +75 -0
- mrok/{proxy → frontend}/main.py +12 -10
- mrok/proxy/__init__.py +0 -3
- mrok/proxy/app.py +158 -83
- mrok/proxy/asgi.py +96 -0
- mrok/proxy/backend.py +45 -0
- mrok/proxy/event_publisher.py +66 -0
- mrok/proxy/exceptions.py +22 -0
- mrok/{master.py → proxy/master.py} +36 -81
- mrok/{metrics.py → proxy/metrics.py} +38 -50
- mrok/{http/middlewares.py → proxy/middleware.py} +17 -26
- mrok/{datastructures.py → proxy/models.py} +43 -10
- mrok/proxy/stream.py +68 -0
- mrok/{http → proxy}/utils.py +1 -1
- mrok/proxy/worker.py +64 -0
- mrok/{http/config.py → proxy/ziticorn.py} +29 -6
- mrok/types/proxy.py +20 -0
- mrok/types/ziti.py +1 -0
- mrok/ziti/api.py +15 -18
- mrok/ziti/bootstrap.py +3 -2
- mrok/ziti/identities.py +5 -4
- mrok/ziti/services.py +3 -2
- {mrok-0.4.6.dist-info → mrok-0.6.0.dist-info}/METADATA +2 -5
- {mrok-0.4.6.dist-info → mrok-0.6.0.dist-info}/RECORD +43 -39
- mrok/cli/commands/proxy/run.py +0 -49
- mrok/http/forwarder.py +0 -354
- mrok/http/lifespan.py +0 -39
- mrok/http/pool.py +0 -239
- mrok/http/protocol.py +0 -11
- mrok/http/server.py +0 -14
- mrok/http/types.py +0 -18
- /mrok/{http → proxy}/constants.py +0 -0
- /mrok/{http → types}/__init__.py +0 -0
- {mrok-0.4.6.dist-info → mrok-0.6.0.dist-info}/WHEEL +0 -0
- {mrok-0.4.6.dist-info → mrok-0.6.0.dist-info}/entry_points.txt +0 -0
- {mrok-0.4.6.dist-info → mrok-0.6.0.dist-info}/licenses/LICENSE.txt +0 -0
mrok/proxy/app.py
CHANGED
|
@@ -1,101 +1,176 @@
|
|
|
1
|
-
import
|
|
1
|
+
import abc
|
|
2
2
|
import logging
|
|
3
|
-
from collections.abc import AsyncGenerator
|
|
4
|
-
from contextlib import asynccontextmanager
|
|
5
|
-
from pathlib import Path
|
|
6
3
|
|
|
7
|
-
import
|
|
8
|
-
from openziti.context import ZitiContext
|
|
4
|
+
from httpcore import AsyncConnectionPool, Request
|
|
9
5
|
|
|
10
|
-
from mrok.
|
|
11
|
-
from mrok.
|
|
12
|
-
from mrok.
|
|
13
|
-
from mrok.http.pool import ConnectionPool, PoolManager
|
|
14
|
-
from mrok.http.types import Scope, StreamPair
|
|
15
|
-
from mrok.logging import setup_logging
|
|
6
|
+
from mrok.proxy.exceptions import ProxyError
|
|
7
|
+
from mrok.proxy.stream import ASGIRequestBodyStream
|
|
8
|
+
from mrok.types.proxy import ASGIReceive, ASGISend, Scope
|
|
16
9
|
|
|
17
10
|
logger = logging.getLogger("mrok.proxy")
|
|
18
11
|
|
|
19
12
|
|
|
20
|
-
|
|
21
|
-
|
|
13
|
+
HOP_BY_HOP_HEADERS = [
|
|
14
|
+
b"connection",
|
|
15
|
+
b"keep-alive",
|
|
16
|
+
b"proxy-authenticate",
|
|
17
|
+
b"proxy-authorization",
|
|
18
|
+
b"te",
|
|
19
|
+
b"trailers",
|
|
20
|
+
b"transfer-encoding",
|
|
21
|
+
b"upgrade",
|
|
22
|
+
]
|
|
22
23
|
|
|
23
24
|
|
|
24
|
-
class
|
|
25
|
+
class ProxyAppBase(abc.ABC):
|
|
25
26
|
def __init__(
|
|
26
27
|
self,
|
|
27
|
-
identity_file: str | Path,
|
|
28
28
|
*,
|
|
29
|
-
|
|
29
|
+
max_connections: int | None = 10,
|
|
30
|
+
max_keepalive_connections: int | None = None,
|
|
31
|
+
keepalive_expiry: float | None = None,
|
|
32
|
+
retries: int = 0,
|
|
30
33
|
) -> None:
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
if settings.proxy.domain[0] == "."
|
|
37
|
-
else f".{settings.proxy.domain}"
|
|
38
|
-
)
|
|
39
|
-
self._ziti_ctx: ZitiContext | None = None
|
|
40
|
-
self._pool_manager = PoolManager(self.build_connection_pool)
|
|
41
|
-
|
|
42
|
-
def get_target_from_header(self, headers: dict[str, str], name: str) -> str | None:
|
|
43
|
-
header_value = headers.get(name, "")
|
|
44
|
-
if self._proxy_wildcard_domain in header_value:
|
|
45
|
-
if ":" in header_value:
|
|
46
|
-
header_value, _ = header_value.split(":", 1)
|
|
47
|
-
return header_value[: -len(self._proxy_wildcard_domain)]
|
|
48
|
-
|
|
49
|
-
def get_target_name(self, headers: dict[str, str]) -> str:
|
|
50
|
-
target = self.get_target_from_header(headers, "x-forwarded-host")
|
|
51
|
-
if not target:
|
|
52
|
-
target = self.get_target_from_header(headers, "host")
|
|
53
|
-
if not target:
|
|
54
|
-
raise ProxyError("Neither Host nor X-Forwarded-Host contain a valid target name")
|
|
55
|
-
return target
|
|
56
|
-
|
|
57
|
-
def _get_ziti_ctx(self) -> ZitiContext:
|
|
58
|
-
if self._ziti_ctx is None:
|
|
59
|
-
ctx, err = openziti.load(str(self._identity_file), timeout=10_000)
|
|
60
|
-
if err != 0:
|
|
61
|
-
raise Exception(f"Cannot create a Ziti context from the identity file: {err}")
|
|
62
|
-
self._ziti_ctx = ctx
|
|
63
|
-
return self._ziti_ctx
|
|
64
|
-
|
|
65
|
-
async def startup(self):
|
|
66
|
-
setup_logging(get_settings())
|
|
67
|
-
self._get_ziti_ctx()
|
|
68
|
-
|
|
69
|
-
async def shutdown(self):
|
|
70
|
-
await self._pool_manager.shutdown()
|
|
71
|
-
|
|
72
|
-
async def build_connection_pool(self, key: str) -> ConnectionPool:
|
|
73
|
-
async def connect():
|
|
74
|
-
sock = self._get_ziti_ctx().connect(key)
|
|
75
|
-
reader, writer = await asyncio.open_connection(sock=sock)
|
|
76
|
-
return reader, writer
|
|
77
|
-
|
|
78
|
-
return ConnectionPool(
|
|
79
|
-
pool_name=key,
|
|
80
|
-
factory=connect,
|
|
81
|
-
initial_connections=5,
|
|
82
|
-
max_size=100,
|
|
83
|
-
idle_timeout=20.0,
|
|
84
|
-
reaper_interval=5.0,
|
|
34
|
+
self._pool = self.setup_connection_pool(
|
|
35
|
+
max_connections=max_connections,
|
|
36
|
+
max_keepalive_connections=max_keepalive_connections,
|
|
37
|
+
keepalive_expiry=keepalive_expiry,
|
|
38
|
+
retries=retries,
|
|
85
39
|
)
|
|
86
40
|
|
|
87
|
-
@
|
|
88
|
-
|
|
41
|
+
@abc.abstractmethod
|
|
42
|
+
def setup_connection_pool(
|
|
89
43
|
self,
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
|
|
44
|
+
max_connections: int | None,
|
|
45
|
+
max_keepalive_connections: int | None,
|
|
46
|
+
keepalive_expiry: float | None,
|
|
47
|
+
retries: int,
|
|
48
|
+
) -> AsyncConnectionPool:
|
|
49
|
+
raise NotImplementedError()
|
|
50
|
+
|
|
51
|
+
@abc.abstractmethod
|
|
52
|
+
def get_upstream_base_url(self, scope: Scope) -> str:
|
|
53
|
+
raise NotImplementedError()
|
|
54
|
+
|
|
55
|
+
async def __call__(self, scope: Scope, receive: ASGIReceive, send: ASGISend) -> None:
|
|
56
|
+
if scope.get("type") == "lifespan":
|
|
57
|
+
return
|
|
58
|
+
|
|
59
|
+
if scope.get("type") != "http":
|
|
60
|
+
await self._send_error(send, 500, "Unsupported")
|
|
61
|
+
return
|
|
62
|
+
|
|
97
63
|
try:
|
|
98
|
-
|
|
99
|
-
|
|
64
|
+
base_url = self.get_upstream_base_url(scope)
|
|
65
|
+
if base_url.endswith("/"): # pragma: no cover
|
|
66
|
+
base_url = base_url[:-1]
|
|
67
|
+
full_path = self._format_path(scope)
|
|
68
|
+
url = f"{base_url}{full_path}"
|
|
69
|
+
method = scope.get("method", "GET").encode()
|
|
70
|
+
headers = self._prepare_headers(scope)
|
|
71
|
+
|
|
72
|
+
body_stream = ASGIRequestBodyStream(receive)
|
|
73
|
+
|
|
74
|
+
request = Request(
|
|
75
|
+
method=method,
|
|
76
|
+
url=url,
|
|
77
|
+
headers=headers,
|
|
78
|
+
content=body_stream,
|
|
79
|
+
)
|
|
80
|
+
response = await self._pool.handle_async_request(request)
|
|
81
|
+
logger.debug(f"connection pool status: {self._pool}")
|
|
82
|
+
response_headers = []
|
|
83
|
+
for k, v in response.headers:
|
|
84
|
+
if k.lower() not in HOP_BY_HOP_HEADERS:
|
|
85
|
+
response_headers.append((k, v))
|
|
86
|
+
|
|
87
|
+
await send(
|
|
88
|
+
{
|
|
89
|
+
"type": "http.response.start",
|
|
90
|
+
"status": response.status,
|
|
91
|
+
"headers": response_headers,
|
|
92
|
+
}
|
|
93
|
+
)
|
|
94
|
+
|
|
95
|
+
async for chunk in response.stream: # type: ignore[union-attr]
|
|
96
|
+
await send(
|
|
97
|
+
{
|
|
98
|
+
"type": "http.response.body",
|
|
99
|
+
"body": chunk,
|
|
100
|
+
"more_body": True,
|
|
101
|
+
}
|
|
102
|
+
)
|
|
103
|
+
|
|
104
|
+
await send({"type": "http.response.body", "body": b"", "more_body": False})
|
|
105
|
+
await response.aclose()
|
|
106
|
+
|
|
107
|
+
except ProxyError as pe:
|
|
108
|
+
await self._send_error(send, pe.http_status, pe.message)
|
|
109
|
+
|
|
100
110
|
except Exception:
|
|
101
|
-
|
|
111
|
+
logger.exception("Unexpected error in forwarder")
|
|
112
|
+
await self._send_error(send, 502, "Bad Gateway")
|
|
113
|
+
|
|
114
|
+
async def _send_error(self, send: ASGISend, http_status: int, body: str):
|
|
115
|
+
try:
|
|
116
|
+
await send({"type": "http.response.start", "status": http_status, "headers": []})
|
|
117
|
+
await send({"type": "http.response.body", "body": body.encode()})
|
|
118
|
+
except Exception as e: # pragma: no cover
|
|
119
|
+
logger.error(f"Cannot send error response: {e}")
|
|
120
|
+
|
|
121
|
+
def _prepare_headers(self, scope: Scope) -> list[tuple[bytes, bytes]]:
|
|
122
|
+
headers: list[tuple[bytes, bytes]] = []
|
|
123
|
+
scope_headers = scope.get("headers", [])
|
|
124
|
+
|
|
125
|
+
for k, v in scope_headers:
|
|
126
|
+
if k.lower() not in HOP_BY_HOP_HEADERS:
|
|
127
|
+
headers.append((k, v))
|
|
128
|
+
|
|
129
|
+
self._merge_x_forwarded(headers, scope)
|
|
130
|
+
|
|
131
|
+
return headers
|
|
132
|
+
|
|
133
|
+
def _find_header(self, headers: list[tuple[bytes, bytes]], name: bytes) -> int | None:
|
|
134
|
+
"""Return index of header `name` in `headers`, or None if missing."""
|
|
135
|
+
lname = name.lower()
|
|
136
|
+
for i, (k, _) in enumerate(headers):
|
|
137
|
+
if k.lower() == lname:
|
|
138
|
+
return i
|
|
139
|
+
return None
|
|
140
|
+
|
|
141
|
+
def _merge_x_forwarded(self, headers: list[tuple[bytes, bytes]], scope: Scope) -> None:
|
|
142
|
+
client = scope.get("client")
|
|
143
|
+
if client:
|
|
144
|
+
client_ip = client[0].encode()
|
|
145
|
+
idx = self._find_header(headers, b"x-forwarded-for")
|
|
146
|
+
if idx is None:
|
|
147
|
+
headers.append((b"x-forwarded-for", client_ip))
|
|
148
|
+
else:
|
|
149
|
+
k, v = headers[idx]
|
|
150
|
+
headers[idx] = (k, v + b", " + client_ip)
|
|
151
|
+
|
|
152
|
+
server = scope.get("server")
|
|
153
|
+
if server:
|
|
154
|
+
if self._find_header(headers, b"x-forwarded-host") is None:
|
|
155
|
+
headers.append((b"x-forwarded-host", server[0].encode()))
|
|
156
|
+
if server[1] and self._find_header(headers, b"x-forwarded-port") is None:
|
|
157
|
+
headers.append((b"x-forwarded-port", str(server[1]).encode()))
|
|
158
|
+
|
|
159
|
+
# Always set the protocol to https for upstream
|
|
160
|
+
idx_proto = self._find_header(headers, b"x-forwarded-proto")
|
|
161
|
+
if idx_proto is None:
|
|
162
|
+
headers.append((b"x-forwarded-proto", b"https"))
|
|
163
|
+
else:
|
|
164
|
+
k, _ = headers[idx_proto]
|
|
165
|
+
headers[idx_proto] = (k, b"https")
|
|
166
|
+
|
|
167
|
+
def _format_path(self, scope: Scope) -> str:
|
|
168
|
+
raw_path = scope.get("raw_path")
|
|
169
|
+
if raw_path:
|
|
170
|
+
return raw_path.decode()
|
|
171
|
+
q = scope.get("query_string", b"")
|
|
172
|
+
path = scope.get("path", "/")
|
|
173
|
+
path_qs = path
|
|
174
|
+
if q:
|
|
175
|
+
path_qs += "?" + q.decode()
|
|
176
|
+
return path_qs
|
mrok/proxy/asgi.py
ADDED
|
@@ -0,0 +1,96 @@
|
|
|
1
|
+
from collections.abc import Iterator
|
|
2
|
+
from contextlib import AsyncExitStack, asynccontextmanager
|
|
3
|
+
from typing import Any, ParamSpec, Protocol
|
|
4
|
+
|
|
5
|
+
from mrok.types.proxy import ASGIApp, ASGIReceive, ASGISend, Lifespan, Scope
|
|
6
|
+
|
|
7
|
+
P = ParamSpec("P")
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class ASGIMiddleware(Protocol[P]):
|
|
11
|
+
def __call__(
|
|
12
|
+
self, app: ASGIApp, /, *args: P.args, **kwargs: P.kwargs
|
|
13
|
+
) -> ASGIApp: ... # pragma: no cover
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class Middleware:
|
|
17
|
+
def __init__(self, cls: ASGIMiddleware[P], *args: P.args, **kwargs: P.kwargs) -> None:
|
|
18
|
+
self.cls = cls
|
|
19
|
+
self.args = args
|
|
20
|
+
self.kwargs = kwargs
|
|
21
|
+
|
|
22
|
+
def __iter__(self) -> Iterator[Any]:
|
|
23
|
+
as_tuple = (self.cls, self.args, self.kwargs)
|
|
24
|
+
return iter(as_tuple)
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
class ASGIAppWrapper:
|
|
28
|
+
def __init__(
|
|
29
|
+
self,
|
|
30
|
+
app: ASGIApp,
|
|
31
|
+
lifespan: Lifespan | None = None,
|
|
32
|
+
) -> None:
|
|
33
|
+
self.app = app
|
|
34
|
+
self.lifespan = lifespan
|
|
35
|
+
self.middlware: list[Middleware] = []
|
|
36
|
+
self.middleare_stack: ASGIApp | None = None
|
|
37
|
+
|
|
38
|
+
def add_middleware(self, cls: ASGIMiddleware[P], *args: P.args, **kwargs: P.kwargs):
|
|
39
|
+
self.middlware.insert(0, Middleware(cls, *args, **kwargs))
|
|
40
|
+
|
|
41
|
+
def build_middleware_stack(self):
|
|
42
|
+
app = self.app
|
|
43
|
+
for cls, args, kwargs in reversed(self.middlware):
|
|
44
|
+
app = cls(app, *args, **kwargs)
|
|
45
|
+
return app
|
|
46
|
+
|
|
47
|
+
def get_starlette_lifespan(self):
|
|
48
|
+
router = getattr(self.app, "router", None)
|
|
49
|
+
if router is None:
|
|
50
|
+
return None
|
|
51
|
+
return getattr(router, "lifespan_context", None)
|
|
52
|
+
|
|
53
|
+
@asynccontextmanager
|
|
54
|
+
async def merge_lifespan(self, app: ASGIApp):
|
|
55
|
+
async with AsyncExitStack() as stack:
|
|
56
|
+
state: dict[Any, Any] = {}
|
|
57
|
+
if self.lifespan is not None:
|
|
58
|
+
outer_state = await stack.enter_async_context(self.lifespan(app))
|
|
59
|
+
state.update(outer_state or {})
|
|
60
|
+
starlette_lifespan = self.get_starlette_lifespan()
|
|
61
|
+
if starlette_lifespan is not None:
|
|
62
|
+
inner_state = await stack.enter_async_context(starlette_lifespan(app))
|
|
63
|
+
state.update(inner_state or {})
|
|
64
|
+
yield state
|
|
65
|
+
|
|
66
|
+
async def handle_lifespan(self, scope: Scope, receive: ASGIReceive, send: ASGISend) -> None:
|
|
67
|
+
started = False
|
|
68
|
+
app: Any = scope.get("app")
|
|
69
|
+
await receive()
|
|
70
|
+
try:
|
|
71
|
+
async with self.merge_lifespan(app) as state:
|
|
72
|
+
if state:
|
|
73
|
+
if "state" not in scope:
|
|
74
|
+
raise RuntimeError('"state" is unsupported by the current ASGI Server.')
|
|
75
|
+
scope["state"].update(state)
|
|
76
|
+
await send({"type": "lifespan.startup.complete"})
|
|
77
|
+
started = True
|
|
78
|
+
await receive()
|
|
79
|
+
except Exception as e: # pragma: no cover
|
|
80
|
+
if started:
|
|
81
|
+
await send({"type": "lifespan.shutdown.failed", "message": str(e)})
|
|
82
|
+
else:
|
|
83
|
+
await send({"type": "lifespan.startup.failed", "message": str(e)})
|
|
84
|
+
raise
|
|
85
|
+
else:
|
|
86
|
+
await send({"type": "lifespan.shutdown.complete"})
|
|
87
|
+
|
|
88
|
+
async def __call__(self, scope: Scope, receive: ASGIReceive, send: ASGISend) -> None:
|
|
89
|
+
if self.middleare_stack is None: # pragma: no branch
|
|
90
|
+
self.middleware_stack = self.build_middleware_stack()
|
|
91
|
+
if scope["type"] == "lifespan":
|
|
92
|
+
scope["app"] = self
|
|
93
|
+
await self.handle_lifespan(scope, receive, send)
|
|
94
|
+
return
|
|
95
|
+
|
|
96
|
+
await self.middleware_stack(scope, receive, send)
|
mrok/proxy/backend.py
ADDED
|
@@ -0,0 +1,45 @@
|
|
|
1
|
+
import asyncio
|
|
2
|
+
from collections.abc import Iterable
|
|
3
|
+
from pathlib import Path
|
|
4
|
+
|
|
5
|
+
import openziti
|
|
6
|
+
from httpcore import SOCKET_OPTION, AsyncNetworkBackend, AsyncNetworkStream
|
|
7
|
+
from openziti.context import ZitiContext
|
|
8
|
+
|
|
9
|
+
from mrok.proxy.exceptions import InvalidTargetError, TargetUnavailableError
|
|
10
|
+
from mrok.proxy.stream import AIONetworkStream
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class AIOZitiNetworkBackend(AsyncNetworkBackend):
|
|
14
|
+
def __init__(self, identity_file: str | Path) -> None:
|
|
15
|
+
self._identity_file = identity_file
|
|
16
|
+
self._ziti_ctx: ZitiContext | None = None
|
|
17
|
+
|
|
18
|
+
def _get_ziti_ctx(self) -> ZitiContext:
|
|
19
|
+
if self._ziti_ctx is None:
|
|
20
|
+
ctx, err = openziti.load(str(self._identity_file), timeout=10_000)
|
|
21
|
+
if err != 0:
|
|
22
|
+
raise Exception(f"Cannot create a Ziti context from the identity file: {err}")
|
|
23
|
+
self._ziti_ctx = ctx
|
|
24
|
+
return self._ziti_ctx
|
|
25
|
+
|
|
26
|
+
async def connect_tcp(
|
|
27
|
+
self,
|
|
28
|
+
host: str,
|
|
29
|
+
port: int,
|
|
30
|
+
timeout: float | None = None,
|
|
31
|
+
local_address: str | None = None,
|
|
32
|
+
socket_options: Iterable[SOCKET_OPTION] | None = None,
|
|
33
|
+
) -> AsyncNetworkStream:
|
|
34
|
+
ctx = self._get_ziti_ctx()
|
|
35
|
+
try:
|
|
36
|
+
sock = ctx.connect(host)
|
|
37
|
+
reader, writer = await asyncio.open_connection(sock=sock)
|
|
38
|
+
return AIONetworkStream(reader, writer)
|
|
39
|
+
except Exception as e:
|
|
40
|
+
if e.args and e.args[0] == -24: # the service exists but is not available
|
|
41
|
+
raise TargetUnavailableError() from e
|
|
42
|
+
raise InvalidTargetError() from e
|
|
43
|
+
|
|
44
|
+
async def sleep(self, seconds: float) -> None:
|
|
45
|
+
await asyncio.sleep(seconds)
|
|
@@ -0,0 +1,66 @@
|
|
|
1
|
+
import asyncio
|
|
2
|
+
import contextlib
|
|
3
|
+
import logging
|
|
4
|
+
|
|
5
|
+
import zmq
|
|
6
|
+
import zmq.asyncio
|
|
7
|
+
|
|
8
|
+
from mrok.proxy.asgi import ASGIAppWrapper
|
|
9
|
+
from mrok.proxy.metrics import MetricsCollector
|
|
10
|
+
from mrok.proxy.middleware import CaptureMiddleware, MetricsMiddleware
|
|
11
|
+
from mrok.proxy.models import Event, HTTPResponse, ServiceMetadata, Status
|
|
12
|
+
from mrok.types.proxy import ASGIApp
|
|
13
|
+
|
|
14
|
+
logger = logging.getLogger("mrok.proxy")
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class EventPublisher:
|
|
18
|
+
def __init__(
|
|
19
|
+
self,
|
|
20
|
+
worker_id: str,
|
|
21
|
+
meta: ServiceMetadata | None = None,
|
|
22
|
+
event_publisher_port: int = 50000,
|
|
23
|
+
metrics_interval: float = 5.0,
|
|
24
|
+
):
|
|
25
|
+
self._worker_id = worker_id
|
|
26
|
+
self._meta = meta
|
|
27
|
+
self._metrics_interval = metrics_interval
|
|
28
|
+
self.publisher_port = event_publisher_port
|
|
29
|
+
self._zmq_ctx = zmq.asyncio.Context()
|
|
30
|
+
self._publisher = self._zmq_ctx.socket(zmq.PUB)
|
|
31
|
+
self._metrics_collector = MetricsCollector(self._worker_id)
|
|
32
|
+
self._publish_task = None
|
|
33
|
+
|
|
34
|
+
async def on_startup(self):
|
|
35
|
+
self._publisher.connect(f"tcp://localhost:{self.publisher_port}")
|
|
36
|
+
self._publish_task = asyncio.create_task(self.publish_metrics_event())
|
|
37
|
+
logger.info(f"Events publishing for worker {self._worker_id} started")
|
|
38
|
+
|
|
39
|
+
async def on_shutdown(self):
|
|
40
|
+
self._publish_task.cancel()
|
|
41
|
+
with contextlib.suppress(asyncio.CancelledError):
|
|
42
|
+
await self._publish_task
|
|
43
|
+
self._publisher.close()
|
|
44
|
+
self._zmq_ctx.term()
|
|
45
|
+
logger.info(f"Events publishing for worker {self._worker_id} stopped")
|
|
46
|
+
|
|
47
|
+
async def publish_metrics_event(self):
|
|
48
|
+
while True:
|
|
49
|
+
snap = await self._metrics_collector.snapshot()
|
|
50
|
+
event = Event(type="status", data=Status(meta=self._meta, metrics=snap))
|
|
51
|
+
await self._publisher.send_string(event.model_dump_json())
|
|
52
|
+
await asyncio.sleep(self._metrics_interval)
|
|
53
|
+
|
|
54
|
+
async def publish_response_event(self, response: HTTPResponse):
|
|
55
|
+
event = Event(type="response", data=response)
|
|
56
|
+
await self._publisher.send_string(event.model_dump_json()) # type: ignore[attr-defined]
|
|
57
|
+
|
|
58
|
+
def setup_middleware(self, app: ASGIAppWrapper):
|
|
59
|
+
app.add_middleware(CaptureMiddleware, self.publish_response_event)
|
|
60
|
+
app.add_middleware(MetricsMiddleware, self._metrics_collector) # type: ignore
|
|
61
|
+
|
|
62
|
+
@contextlib.asynccontextmanager
|
|
63
|
+
async def lifespan(self, app: ASGIApp):
|
|
64
|
+
await self.on_startup() # type: ignore
|
|
65
|
+
yield
|
|
66
|
+
await self.on_shutdown() # type: ignore
|
mrok/proxy/exceptions.py
ADDED
|
@@ -0,0 +1,22 @@
|
|
|
1
|
+
from http import HTTPStatus
|
|
2
|
+
|
|
3
|
+
from httpcore import ConnectError
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
class ProxyError(Exception):
|
|
7
|
+
def __init__(self, http_status: HTTPStatus, message: str) -> None:
|
|
8
|
+
self.http_status: HTTPStatus = http_status
|
|
9
|
+
self.message: str = message
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class InvalidTargetError(ProxyError):
|
|
13
|
+
def __init__(self):
|
|
14
|
+
super().__init__(HTTPStatus.BAD_GATEWAY, "Bad Gateway: invalid target extension.")
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class TargetUnavailableError(ProxyError, ConnectError):
|
|
18
|
+
def __init__(self):
|
|
19
|
+
super().__init__(
|
|
20
|
+
HTTPStatus.SERVICE_UNAVAILABLE,
|
|
21
|
+
"Service Unavailable: the target extension is unavailable.",
|
|
22
|
+
)
|
|
@@ -1,6 +1,3 @@
|
|
|
1
|
-
import asyncio
|
|
2
|
-
import contextlib
|
|
3
|
-
import json
|
|
4
1
|
import logging
|
|
5
2
|
import os
|
|
6
3
|
import signal
|
|
@@ -11,21 +8,14 @@ from pathlib import Path
|
|
|
11
8
|
from uuid import uuid4
|
|
12
9
|
|
|
13
10
|
import zmq
|
|
14
|
-
import zmq.asyncio
|
|
15
|
-
from uvicorn.importer import import_from_string
|
|
16
11
|
from watchfiles import watch
|
|
17
12
|
from watchfiles.filters import PythonFilter
|
|
18
13
|
from watchfiles.run import CombinedProcess, start_process
|
|
19
14
|
|
|
20
15
|
from mrok.conf import get_settings
|
|
21
|
-
from mrok.datastructures import Event, HTTPResponse, Meta, Status
|
|
22
|
-
from mrok.http.config import MrokBackendConfig
|
|
23
|
-
from mrok.http.lifespan import LifespanWrapper
|
|
24
|
-
from mrok.http.middlewares import CaptureMiddleware, MetricsMiddleware
|
|
25
|
-
from mrok.http.server import MrokServer
|
|
26
|
-
from mrok.http.types import ASGIApp
|
|
27
16
|
from mrok.logging import setup_logging
|
|
28
|
-
from mrok.
|
|
17
|
+
from mrok.proxy.worker import Worker
|
|
18
|
+
from mrok.types.proxy import ASGIApp
|
|
29
19
|
|
|
30
20
|
logger = logging.getLogger("mrok.agent")
|
|
31
21
|
|
|
@@ -52,7 +42,7 @@ def start_events_router(events_pub_port: int, events_sub_port: int):
|
|
|
52
42
|
try:
|
|
53
43
|
logger.info(f"Events router process started: {os.getpid()}")
|
|
54
44
|
zmq.proxy(frontend, backend)
|
|
55
|
-
except KeyboardInterrupt:
|
|
45
|
+
except KeyboardInterrupt: # pragma: no cover
|
|
56
46
|
pass
|
|
57
47
|
finally:
|
|
58
48
|
frontend.close()
|
|
@@ -62,78 +52,43 @@ def start_events_router(events_pub_port: int, events_sub_port: int):
|
|
|
62
52
|
|
|
63
53
|
def start_uvicorn_worker(
|
|
64
54
|
worker_id: str,
|
|
65
|
-
app: ASGIApp,
|
|
55
|
+
app: ASGIApp | str,
|
|
66
56
|
identity_file: str,
|
|
57
|
+
events_enabled: bool,
|
|
67
58
|
events_pub_port: int,
|
|
68
59
|
metrics_interval: float = 5.0,
|
|
69
60
|
):
|
|
70
61
|
import sys
|
|
71
62
|
|
|
72
63
|
sys.path.insert(0, os.getcwd())
|
|
73
|
-
if isinstance(app, str):
|
|
74
|
-
app = import_from_string(app)
|
|
75
64
|
|
|
76
|
-
|
|
77
|
-
|
|
78
|
-
|
|
79
|
-
ctx = zmq.asyncio.Context()
|
|
80
|
-
pub = ctx.socket(zmq.PUB)
|
|
81
|
-
pub.connect(f"tcp://localhost:{events_pub_port}")
|
|
82
|
-
metrics = WorkerMetricsCollector(worker_id)
|
|
83
|
-
|
|
84
|
-
task = None
|
|
85
|
-
|
|
86
|
-
async def status_sender():
|
|
87
|
-
while True:
|
|
88
|
-
snap = await metrics.snapshot()
|
|
89
|
-
event = Event(type="status", data=Status(meta=meta, metrics=snap))
|
|
90
|
-
await pub.send_string(event.model_dump_json())
|
|
91
|
-
await asyncio.sleep(metrics_interval)
|
|
92
|
-
|
|
93
|
-
async def on_startup(): # noqa
|
|
94
|
-
nonlocal task
|
|
95
|
-
task = asyncio.create_task(status_sender())
|
|
96
|
-
|
|
97
|
-
async def on_shutdown(): # noqa
|
|
98
|
-
if task:
|
|
99
|
-
task.cancel()
|
|
100
|
-
|
|
101
|
-
async def on_response_complete(response: HTTPResponse):
|
|
102
|
-
event = Event(type="response", data=response)
|
|
103
|
-
await pub.send_string(event.model_dump_json())
|
|
104
|
-
|
|
105
|
-
config = MrokBackendConfig(
|
|
106
|
-
LifespanWrapper(
|
|
107
|
-
MetricsMiddleware(
|
|
108
|
-
CaptureMiddleware(
|
|
109
|
-
app,
|
|
110
|
-
on_response_complete,
|
|
111
|
-
),
|
|
112
|
-
metrics,
|
|
113
|
-
),
|
|
114
|
-
on_startup=on_startup,
|
|
115
|
-
on_shutdown=on_shutdown,
|
|
116
|
-
),
|
|
65
|
+
worker = Worker(
|
|
66
|
+
worker_id,
|
|
67
|
+
app,
|
|
117
68
|
identity_file,
|
|
69
|
+
events_enabled=events_enabled,
|
|
70
|
+
event_publisher_port=events_pub_port,
|
|
71
|
+
metrics_interval=metrics_interval,
|
|
118
72
|
)
|
|
119
|
-
|
|
120
|
-
with contextlib.suppress(KeyboardInterrupt, asyncio.CancelledError):
|
|
121
|
-
server.run()
|
|
73
|
+
worker.run()
|
|
122
74
|
|
|
123
75
|
|
|
124
76
|
class MasterBase(ABC):
|
|
125
77
|
def __init__(
|
|
126
78
|
self,
|
|
127
79
|
identity_file: str,
|
|
128
|
-
|
|
129
|
-
|
|
130
|
-
|
|
131
|
-
|
|
80
|
+
*,
|
|
81
|
+
workers: int = 4,
|
|
82
|
+
reload: bool = False,
|
|
83
|
+
events_enabled: bool = True,
|
|
84
|
+
events_pub_port: int = 50000,
|
|
85
|
+
events_sub_port: int = 50001,
|
|
132
86
|
metrics_interval: float = 5.0,
|
|
133
87
|
):
|
|
134
88
|
self.identity_file = identity_file
|
|
135
89
|
self.workers = workers
|
|
136
90
|
self.reload = reload
|
|
91
|
+
self.events_enabled = events_enabled
|
|
137
92
|
self.events_pub_port = events_pub_port
|
|
138
93
|
self.events_sub_port = events_sub_port
|
|
139
94
|
self.metrics_interval = metrics_interval
|
|
@@ -154,7 +109,7 @@ class MasterBase(ABC):
|
|
|
154
109
|
|
|
155
110
|
@abstractmethod
|
|
156
111
|
def get_asgi_app(self):
|
|
157
|
-
|
|
112
|
+
raise NotImplementedError()
|
|
158
113
|
|
|
159
114
|
def setup_signals_handler(self):
|
|
160
115
|
for sig in (signal.SIGINT, signal.SIGTERM):
|
|
@@ -173,6 +128,7 @@ class MasterBase(ABC):
|
|
|
173
128
|
worker_id,
|
|
174
129
|
self.get_asgi_app(),
|
|
175
130
|
self.identity_file,
|
|
131
|
+
self.events_enabled,
|
|
176
132
|
self.events_pub_port,
|
|
177
133
|
self.metrics_interval,
|
|
178
134
|
),
|
|
@@ -181,18 +137,6 @@ class MasterBase(ABC):
|
|
|
181
137
|
logger.info(f"Worker {worker_id} [{p.pid}] started")
|
|
182
138
|
return p
|
|
183
139
|
|
|
184
|
-
def start(self):
|
|
185
|
-
self.start_events_router()
|
|
186
|
-
self.start_workers()
|
|
187
|
-
self.monitor_thread.start()
|
|
188
|
-
|
|
189
|
-
def stop(self):
|
|
190
|
-
if self.monitor_thread.is_alive():
|
|
191
|
-
logger.debug("Wait for monitor worker to exit")
|
|
192
|
-
self.monitor_thread.join(timeout=MONITOR_THREAD_JOIN_TIMEOUT)
|
|
193
|
-
self.stop_workers()
|
|
194
|
-
self.stop_events_router()
|
|
195
|
-
|
|
196
140
|
def start_events_router(self):
|
|
197
141
|
self.zmq_pubsub_router_process = start_process(
|
|
198
142
|
start_events_router,
|
|
@@ -204,21 +148,32 @@ class MasterBase(ABC):
|
|
|
204
148
|
None,
|
|
205
149
|
)
|
|
206
150
|
|
|
207
|
-
def stop_events_router(self):
|
|
208
|
-
self.zmq_pubsub_router_process.stop(sigint_timeout=5, sigkill_timeout=1)
|
|
209
|
-
|
|
210
151
|
def start_workers(self):
|
|
211
152
|
for i in range(self.workers):
|
|
212
153
|
worker_id = self.worker_identifiers[i]
|
|
213
154
|
p = self.start_worker(worker_id)
|
|
214
155
|
self.worker_processes[worker_id] = p
|
|
215
156
|
|
|
157
|
+
def start(self):
|
|
158
|
+
self.start_events_router()
|
|
159
|
+
self.start_workers()
|
|
160
|
+
self.monitor_thread.start()
|
|
161
|
+
|
|
162
|
+
def stop_events_router(self):
|
|
163
|
+
self.zmq_pubsub_router_process.stop(sigint_timeout=5, sigkill_timeout=1)
|
|
164
|
+
|
|
216
165
|
def stop_workers(self):
|
|
217
166
|
for process in self.worker_processes.values():
|
|
218
|
-
if process.is_alive():
|
|
167
|
+
if process.is_alive(): # pragma: no branch
|
|
219
168
|
process.stop(sigint_timeout=5, sigkill_timeout=1)
|
|
220
169
|
self.worker_processes.clear()
|
|
221
170
|
|
|
171
|
+
def stop(self):
|
|
172
|
+
if self.monitor_thread.is_alive(): # pragma: no branch
|
|
173
|
+
self.monitor_thread.join(timeout=MONITOR_THREAD_JOIN_TIMEOUT)
|
|
174
|
+
self.stop_workers()
|
|
175
|
+
self.stop_events_router()
|
|
176
|
+
|
|
222
177
|
def restart(self):
|
|
223
178
|
self.pause_event.set()
|
|
224
179
|
self.stop_workers()
|