mrok 0.5.0__py3-none-any.whl → 0.7.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mrok/agent/devtools/inspector/__main__.py +2 -24
- mrok/agent/devtools/inspector/app.py +408 -113
- mrok/agent/devtools/inspector/utils.py +149 -0
- mrok/agent/sidecar/app.py +9 -9
- mrok/agent/sidecar/main.py +31 -5
- mrok/agent/ziticorn.py +8 -2
- mrok/cli/commands/admin/bootstrap.py +3 -2
- mrok/cli/commands/admin/utils.py +2 -2
- mrok/cli/commands/agent/run/sidecar.py +59 -1
- mrok/cli/commands/frontend/run.py +43 -1
- mrok/cli/main.py +17 -1
- mrok/constants.py +21 -0
- mrok/controller/schemas.py +2 -2
- mrok/frontend/app.py +8 -8
- mrok/frontend/main.py +9 -1
- mrok/logging.py +0 -22
- mrok/proxy/app.py +10 -9
- mrok/proxy/asgi.py +96 -0
- mrok/proxy/backend.py +5 -3
- mrok/proxy/event_publisher.py +66 -0
- mrok/proxy/master.py +18 -60
- mrok/proxy/metrics.py +2 -2
- mrok/proxy/{middlewares.py → middleware.py} +11 -42
- mrok/proxy/{datastructures.py → models.py} +43 -17
- mrok/proxy/{streams.py → stream.py} +24 -1
- mrok/proxy/worker.py +64 -0
- mrok/proxy/ziticorn.py +76 -0
- mrok/types/__init__.py +0 -0
- mrok/{proxy/types.py → types/proxy.py} +7 -2
- mrok/types/ziti.py +1 -0
- mrok/ziti/api.py +16 -19
- mrok/ziti/bootstrap.py +3 -7
- mrok/ziti/identities.py +15 -13
- mrok/ziti/services.py +3 -2
- {mrok-0.5.0.dist-info → mrok-0.7.0.dist-info}/METADATA +8 -2
- {mrok-0.5.0.dist-info → mrok-0.7.0.dist-info}/RECORD +39 -40
- mrok/agent/devtools/__main__.py +0 -34
- mrok/cli/commands/agent/utils.py +0 -5
- mrok/proxy/config.py +0 -62
- mrok/proxy/constants.py +0 -22
- mrok/proxy/lifespan.py +0 -10
- mrok/proxy/protocol.py +0 -11
- mrok/proxy/server.py +0 -14
- mrok/proxy/utils.py +0 -90
- {mrok-0.5.0.dist-info → mrok-0.7.0.dist-info}/WHEEL +0 -0
- {mrok-0.5.0.dist-info → mrok-0.7.0.dist-info}/entry_points.txt +0 -0
- {mrok-0.5.0.dist-info → mrok-0.7.0.dist-info}/licenses/LICENSE.txt +0 -0
mrok/proxy/asgi.py
ADDED
|
@@ -0,0 +1,96 @@
|
|
|
1
|
+
from collections.abc import Iterator
|
|
2
|
+
from contextlib import AsyncExitStack, asynccontextmanager
|
|
3
|
+
from typing import Any, ParamSpec, Protocol
|
|
4
|
+
|
|
5
|
+
from mrok.types.proxy import ASGIApp, ASGIReceive, ASGISend, Lifespan, Scope
|
|
6
|
+
|
|
7
|
+
P = ParamSpec("P")
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class ASGIMiddleware(Protocol[P]):
|
|
11
|
+
def __call__(
|
|
12
|
+
self, app: ASGIApp, /, *args: P.args, **kwargs: P.kwargs
|
|
13
|
+
) -> ASGIApp: ... # pragma: no cover
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class Middleware:
|
|
17
|
+
def __init__(self, cls: ASGIMiddleware[P], *args: P.args, **kwargs: P.kwargs) -> None:
|
|
18
|
+
self.cls = cls
|
|
19
|
+
self.args = args
|
|
20
|
+
self.kwargs = kwargs
|
|
21
|
+
|
|
22
|
+
def __iter__(self) -> Iterator[Any]:
|
|
23
|
+
as_tuple = (self.cls, self.args, self.kwargs)
|
|
24
|
+
return iter(as_tuple)
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
class ASGIAppWrapper:
|
|
28
|
+
def __init__(
|
|
29
|
+
self,
|
|
30
|
+
app: ASGIApp,
|
|
31
|
+
lifespan: Lifespan | None = None,
|
|
32
|
+
) -> None:
|
|
33
|
+
self.app = app
|
|
34
|
+
self.lifespan = lifespan
|
|
35
|
+
self.middlware: list[Middleware] = []
|
|
36
|
+
self.middleare_stack: ASGIApp | None = None
|
|
37
|
+
|
|
38
|
+
def add_middleware(self, cls: ASGIMiddleware[P], *args: P.args, **kwargs: P.kwargs):
|
|
39
|
+
self.middlware.insert(0, Middleware(cls, *args, **kwargs))
|
|
40
|
+
|
|
41
|
+
def build_middleware_stack(self):
|
|
42
|
+
app = self.app
|
|
43
|
+
for cls, args, kwargs in reversed(self.middlware):
|
|
44
|
+
app = cls(app, *args, **kwargs)
|
|
45
|
+
return app
|
|
46
|
+
|
|
47
|
+
def get_starlette_lifespan(self):
|
|
48
|
+
router = getattr(self.app, "router", None)
|
|
49
|
+
if router is None:
|
|
50
|
+
return None
|
|
51
|
+
return getattr(router, "lifespan_context", None)
|
|
52
|
+
|
|
53
|
+
@asynccontextmanager
|
|
54
|
+
async def merge_lifespan(self, app: ASGIApp):
|
|
55
|
+
async with AsyncExitStack() as stack:
|
|
56
|
+
state: dict[Any, Any] = {}
|
|
57
|
+
if self.lifespan is not None:
|
|
58
|
+
outer_state = await stack.enter_async_context(self.lifespan(app))
|
|
59
|
+
state.update(outer_state or {})
|
|
60
|
+
starlette_lifespan = self.get_starlette_lifespan()
|
|
61
|
+
if starlette_lifespan is not None:
|
|
62
|
+
inner_state = await stack.enter_async_context(starlette_lifespan(app))
|
|
63
|
+
state.update(inner_state or {})
|
|
64
|
+
yield state
|
|
65
|
+
|
|
66
|
+
async def handle_lifespan(self, scope: Scope, receive: ASGIReceive, send: ASGISend) -> None:
|
|
67
|
+
started = False
|
|
68
|
+
app: Any = scope.get("app")
|
|
69
|
+
await receive()
|
|
70
|
+
try:
|
|
71
|
+
async with self.merge_lifespan(app) as state:
|
|
72
|
+
if state:
|
|
73
|
+
if "state" not in scope:
|
|
74
|
+
raise RuntimeError('"state" is unsupported by the current ASGI Server.')
|
|
75
|
+
scope["state"].update(state)
|
|
76
|
+
await send({"type": "lifespan.startup.complete"})
|
|
77
|
+
started = True
|
|
78
|
+
await receive()
|
|
79
|
+
except Exception as e: # pragma: no cover
|
|
80
|
+
if started:
|
|
81
|
+
await send({"type": "lifespan.shutdown.failed", "message": str(e)})
|
|
82
|
+
else:
|
|
83
|
+
await send({"type": "lifespan.startup.failed", "message": str(e)})
|
|
84
|
+
raise
|
|
85
|
+
else:
|
|
86
|
+
await send({"type": "lifespan.shutdown.complete"})
|
|
87
|
+
|
|
88
|
+
async def __call__(self, scope: Scope, receive: ASGIReceive, send: ASGISend) -> None:
|
|
89
|
+
if self.middleare_stack is None: # pragma: no branch
|
|
90
|
+
self.middleware_stack = self.build_middleware_stack()
|
|
91
|
+
if scope["type"] == "lifespan":
|
|
92
|
+
scope["app"] = self
|
|
93
|
+
await self.handle_lifespan(scope, receive, send)
|
|
94
|
+
return
|
|
95
|
+
|
|
96
|
+
await self.middleware_stack(scope, receive, send)
|
mrok/proxy/backend.py
CHANGED
|
@@ -6,8 +6,8 @@ import openziti
|
|
|
6
6
|
from httpcore import SOCKET_OPTION, AsyncNetworkBackend, AsyncNetworkStream
|
|
7
7
|
from openziti.context import ZitiContext
|
|
8
8
|
|
|
9
|
-
from mrok.proxy.exceptions import TargetUnavailableError
|
|
10
|
-
from mrok.proxy.
|
|
9
|
+
from mrok.proxy.exceptions import InvalidTargetError, TargetUnavailableError
|
|
10
|
+
from mrok.proxy.stream import AIONetworkStream
|
|
11
11
|
|
|
12
12
|
|
|
13
13
|
class AIOZitiNetworkBackend(AsyncNetworkBackend):
|
|
@@ -37,7 +37,9 @@ class AIOZitiNetworkBackend(AsyncNetworkBackend):
|
|
|
37
37
|
reader, writer = await asyncio.open_connection(sock=sock)
|
|
38
38
|
return AIONetworkStream(reader, writer)
|
|
39
39
|
except Exception as e:
|
|
40
|
-
|
|
40
|
+
if e.args and e.args[0] == -24: # the service exists but is not available
|
|
41
|
+
raise TargetUnavailableError() from e
|
|
42
|
+
raise InvalidTargetError() from e
|
|
41
43
|
|
|
42
44
|
async def sleep(self, seconds: float) -> None:
|
|
43
45
|
await asyncio.sleep(seconds)
|
|
@@ -0,0 +1,66 @@
|
|
|
1
|
+
import asyncio
|
|
2
|
+
import contextlib
|
|
3
|
+
import logging
|
|
4
|
+
|
|
5
|
+
import zmq
|
|
6
|
+
import zmq.asyncio
|
|
7
|
+
|
|
8
|
+
from mrok.proxy.asgi import ASGIAppWrapper
|
|
9
|
+
from mrok.proxy.metrics import MetricsCollector
|
|
10
|
+
from mrok.proxy.middleware import CaptureMiddleware, MetricsMiddleware
|
|
11
|
+
from mrok.proxy.models import Event, HTTPResponse, ServiceMetadata, Status
|
|
12
|
+
from mrok.types.proxy import ASGIApp
|
|
13
|
+
|
|
14
|
+
logger = logging.getLogger("mrok.proxy")
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class EventPublisher:
|
|
18
|
+
def __init__(
|
|
19
|
+
self,
|
|
20
|
+
worker_id: str,
|
|
21
|
+
meta: ServiceMetadata | None = None,
|
|
22
|
+
event_publisher_port: int = 50000,
|
|
23
|
+
metrics_interval: float = 5.0,
|
|
24
|
+
):
|
|
25
|
+
self._worker_id = worker_id
|
|
26
|
+
self._meta = meta
|
|
27
|
+
self._metrics_interval = metrics_interval
|
|
28
|
+
self.publisher_port = event_publisher_port
|
|
29
|
+
self._zmq_ctx = zmq.asyncio.Context()
|
|
30
|
+
self._publisher = self._zmq_ctx.socket(zmq.PUB)
|
|
31
|
+
self._metrics_collector = MetricsCollector(self._worker_id)
|
|
32
|
+
self._publish_task = None
|
|
33
|
+
|
|
34
|
+
async def on_startup(self):
|
|
35
|
+
self._publisher.connect(f"tcp://localhost:{self.publisher_port}")
|
|
36
|
+
self._publish_task = asyncio.create_task(self.publish_metrics_event())
|
|
37
|
+
logger.info(f"Events publishing for worker {self._worker_id} started")
|
|
38
|
+
|
|
39
|
+
async def on_shutdown(self):
|
|
40
|
+
self._publish_task.cancel()
|
|
41
|
+
with contextlib.suppress(asyncio.CancelledError):
|
|
42
|
+
await self._publish_task
|
|
43
|
+
self._publisher.close()
|
|
44
|
+
self._zmq_ctx.term()
|
|
45
|
+
logger.info(f"Events publishing for worker {self._worker_id} stopped")
|
|
46
|
+
|
|
47
|
+
async def publish_metrics_event(self):
|
|
48
|
+
while True:
|
|
49
|
+
snap = await self._metrics_collector.snapshot()
|
|
50
|
+
event = Event(type="status", data=Status(meta=self._meta, metrics=snap))
|
|
51
|
+
await self._publisher.send_string(event.model_dump_json())
|
|
52
|
+
await asyncio.sleep(self._metrics_interval)
|
|
53
|
+
|
|
54
|
+
async def publish_response_event(self, response: HTTPResponse):
|
|
55
|
+
event = Event(type="response", data=response)
|
|
56
|
+
await self._publisher.send_string(event.model_dump_json()) # type: ignore[attr-defined]
|
|
57
|
+
|
|
58
|
+
def setup_middleware(self, app: ASGIAppWrapper):
|
|
59
|
+
app.add_middleware(CaptureMiddleware, self.publish_response_event)
|
|
60
|
+
app.add_middleware(MetricsMiddleware, self._metrics_collector) # type: ignore
|
|
61
|
+
|
|
62
|
+
@contextlib.asynccontextmanager
|
|
63
|
+
async def lifespan(self, app: ASGIApp):
|
|
64
|
+
await self.on_startup() # type: ignore
|
|
65
|
+
yield
|
|
66
|
+
await self.on_shutdown() # type: ignore
|
mrok/proxy/master.py
CHANGED
|
@@ -1,5 +1,3 @@
|
|
|
1
|
-
import asyncio
|
|
2
|
-
import contextlib
|
|
3
1
|
import logging
|
|
4
2
|
import os
|
|
5
3
|
import signal
|
|
@@ -10,20 +8,14 @@ from pathlib import Path
|
|
|
10
8
|
from uuid import uuid4
|
|
11
9
|
|
|
12
10
|
import zmq
|
|
13
|
-
import zmq.asyncio
|
|
14
|
-
from uvicorn.importer import import_from_string
|
|
15
11
|
from watchfiles import watch
|
|
16
12
|
from watchfiles.filters import PythonFilter
|
|
17
13
|
from watchfiles.run import CombinedProcess, start_process
|
|
18
14
|
|
|
19
15
|
from mrok.conf import get_settings
|
|
20
16
|
from mrok.logging import setup_logging
|
|
21
|
-
from mrok.proxy.
|
|
22
|
-
from mrok.proxy
|
|
23
|
-
from mrok.proxy.metrics import WorkerMetricsCollector
|
|
24
|
-
from mrok.proxy.middlewares import CaptureMiddleware, LifespanMiddleware, MetricsMiddleware
|
|
25
|
-
from mrok.proxy.server import MrokServer
|
|
26
|
-
from mrok.proxy.types import ASGIApp
|
|
17
|
+
from mrok.proxy.worker import Worker
|
|
18
|
+
from mrok.types.proxy import ASGIApp
|
|
27
19
|
|
|
28
20
|
logger = logging.getLogger("mrok.agent")
|
|
29
21
|
|
|
@@ -62,76 +54,41 @@ def start_uvicorn_worker(
|
|
|
62
54
|
worker_id: str,
|
|
63
55
|
app: ASGIApp | str,
|
|
64
56
|
identity_file: str,
|
|
57
|
+
events_enabled: bool,
|
|
65
58
|
events_pub_port: int,
|
|
66
59
|
metrics_interval: float = 5.0,
|
|
67
60
|
):
|
|
68
61
|
import sys
|
|
69
62
|
|
|
70
63
|
sys.path.insert(0, os.getcwd())
|
|
71
|
-
asgi_app = app if not isinstance(app, str) else import_from_string(app)
|
|
72
64
|
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
pub = ctx.socket(zmq.PUB)
|
|
77
|
-
pub.connect(f"tcp://localhost:{events_pub_port}")
|
|
78
|
-
metrics = WorkerMetricsCollector(worker_id)
|
|
79
|
-
|
|
80
|
-
task = None
|
|
81
|
-
|
|
82
|
-
async def status_sender(): # pragma: no cover
|
|
83
|
-
while True:
|
|
84
|
-
snap = await metrics.snapshot()
|
|
85
|
-
event = Event(type="status", data=Status(meta=identity.mrok, metrics=snap))
|
|
86
|
-
await pub.send_string(event.model_dump_json())
|
|
87
|
-
await asyncio.sleep(metrics_interval)
|
|
88
|
-
|
|
89
|
-
async def on_startup(): # pragma: no cover
|
|
90
|
-
nonlocal task
|
|
91
|
-
await asyncio.sleep(0)
|
|
92
|
-
task = asyncio.create_task(status_sender())
|
|
93
|
-
|
|
94
|
-
async def on_shutdown(): # pragma: no cover
|
|
95
|
-
await asyncio.sleep(0)
|
|
96
|
-
if task:
|
|
97
|
-
task.cancel()
|
|
98
|
-
|
|
99
|
-
async def on_response_complete(response: HTTPResponse): # pragma: no cover
|
|
100
|
-
event = Event(type="response", data=response)
|
|
101
|
-
await pub.send_string(event.model_dump_json())
|
|
102
|
-
|
|
103
|
-
config = MrokBackendConfig(
|
|
104
|
-
LifespanMiddleware(
|
|
105
|
-
MetricsMiddleware(
|
|
106
|
-
CaptureMiddleware(
|
|
107
|
-
asgi_app,
|
|
108
|
-
on_response_complete,
|
|
109
|
-
),
|
|
110
|
-
metrics,
|
|
111
|
-
),
|
|
112
|
-
on_startup=on_startup,
|
|
113
|
-
on_shutdown=on_shutdown,
|
|
114
|
-
),
|
|
65
|
+
worker = Worker(
|
|
66
|
+
worker_id,
|
|
67
|
+
app,
|
|
115
68
|
identity_file,
|
|
69
|
+
events_enabled=events_enabled,
|
|
70
|
+
event_publisher_port=events_pub_port,
|
|
71
|
+
metrics_interval=metrics_interval,
|
|
116
72
|
)
|
|
117
|
-
|
|
118
|
-
with contextlib.suppress(KeyboardInterrupt, asyncio.CancelledError):
|
|
119
|
-
server.run()
|
|
73
|
+
worker.run()
|
|
120
74
|
|
|
121
75
|
|
|
122
76
|
class MasterBase(ABC):
|
|
123
77
|
def __init__(
|
|
124
78
|
self,
|
|
125
79
|
identity_file: str,
|
|
126
|
-
|
|
127
|
-
|
|
128
|
-
|
|
129
|
-
|
|
80
|
+
*,
|
|
81
|
+
workers: int = 4,
|
|
82
|
+
reload: bool = False,
|
|
83
|
+
events_enabled: bool = True,
|
|
84
|
+
events_pub_port: int = 50000,
|
|
85
|
+
events_sub_port: int = 50001,
|
|
130
86
|
metrics_interval: float = 5.0,
|
|
131
87
|
):
|
|
132
88
|
self.identity_file = identity_file
|
|
133
89
|
self.workers = workers
|
|
134
90
|
self.reload = reload
|
|
91
|
+
self.events_enabled = events_enabled
|
|
135
92
|
self.events_pub_port = events_pub_port
|
|
136
93
|
self.events_sub_port = events_sub_port
|
|
137
94
|
self.metrics_interval = metrics_interval
|
|
@@ -171,6 +128,7 @@ class MasterBase(ABC):
|
|
|
171
128
|
worker_id,
|
|
172
129
|
self.get_asgi_app(),
|
|
173
130
|
self.identity_file,
|
|
131
|
+
self.events_enabled,
|
|
174
132
|
self.events_pub_port,
|
|
175
133
|
self.metrics_interval,
|
|
176
134
|
),
|
mrok/proxy/metrics.py
CHANGED
|
@@ -8,7 +8,7 @@ import time
|
|
|
8
8
|
import psutil
|
|
9
9
|
from hdrh.histogram import HdrHistogram
|
|
10
10
|
|
|
11
|
-
from mrok.proxy.
|
|
11
|
+
from mrok.proxy.models import (
|
|
12
12
|
DataTransferMetrics,
|
|
13
13
|
ProcessMetrics,
|
|
14
14
|
RequestsMetrics,
|
|
@@ -49,7 +49,7 @@ async def get_process_metrics(interval: float = 0.1) -> ProcessMetrics:
|
|
|
49
49
|
return await asyncio.to_thread(_collect_process_usage, interval)
|
|
50
50
|
|
|
51
51
|
|
|
52
|
-
class
|
|
52
|
+
class MetricsCollector:
|
|
53
53
|
def __init__(self, worker_id: str, lowest=1, highest=60000, sigfigs=3):
|
|
54
54
|
self.worker_id = worker_id
|
|
55
55
|
self.total_requests = 0
|
|
@@ -2,19 +2,19 @@ import asyncio
|
|
|
2
2
|
import logging
|
|
3
3
|
import time
|
|
4
4
|
|
|
5
|
-
from mrok.proxy.
|
|
6
|
-
from mrok.proxy.
|
|
7
|
-
from mrok.proxy
|
|
8
|
-
from mrok.proxy.types import (
|
|
5
|
+
from mrok.proxy.metrics import MetricsCollector
|
|
6
|
+
from mrok.proxy.models import FixedSizeByteBuffer, HTTPHeaders, HTTPRequest, HTTPResponse
|
|
7
|
+
from mrok.types.proxy import (
|
|
9
8
|
ASGIApp,
|
|
10
9
|
ASGIReceive,
|
|
11
10
|
ASGISend,
|
|
12
|
-
LifespanCallback,
|
|
13
11
|
Message,
|
|
14
12
|
ResponseCompleteCallback,
|
|
15
13
|
Scope,
|
|
16
14
|
)
|
|
17
|
-
|
|
15
|
+
|
|
16
|
+
MAX_REQUEST_BODY_BYTES = 2 * 1024 * 1024
|
|
17
|
+
MAX_RESPONSE_BODY_BYTES = 5 * 1024 * 1024
|
|
18
18
|
|
|
19
19
|
logger = logging.getLogger("mrok.proxy")
|
|
20
20
|
|
|
@@ -43,7 +43,7 @@ class CaptureMiddleware:
|
|
|
43
43
|
state = {}
|
|
44
44
|
|
|
45
45
|
req_buf = FixedSizeByteBuffer(MAX_REQUEST_BODY_BYTES)
|
|
46
|
-
capture_req_body =
|
|
46
|
+
capture_req_body = method.upper() not in ("GET", "HEAD", "OPTIONS", "TRACE")
|
|
47
47
|
|
|
48
48
|
request = HTTPRequest(
|
|
49
49
|
method=method,
|
|
@@ -69,9 +69,7 @@ class CaptureMiddleware:
|
|
|
69
69
|
resp_headers = HTTPHeaders.from_asgi(msg.get("headers", []))
|
|
70
70
|
state["resp_headers_raw"] = resp_headers
|
|
71
71
|
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
if state["capture_resp_body"] and msg["type"] == "http.response.body":
|
|
72
|
+
if msg["type"] == "http.response.body":
|
|
75
73
|
body = msg.get("body", b"")
|
|
76
74
|
resp_buf.write(body)
|
|
77
75
|
|
|
@@ -92,14 +90,14 @@ class CaptureMiddleware:
|
|
|
92
90
|
status=state["status"] or 0,
|
|
93
91
|
headers=state["resp_headers_raw"],
|
|
94
92
|
duration=duration,
|
|
95
|
-
body=resp_buf.getvalue()
|
|
96
|
-
body_truncated=resp_buf.overflow
|
|
93
|
+
body=resp_buf.getvalue(),
|
|
94
|
+
body_truncated=resp_buf.overflow,
|
|
97
95
|
)
|
|
98
96
|
asyncio.create_task(self._on_response_complete(response))
|
|
99
97
|
|
|
100
98
|
|
|
101
99
|
class MetricsMiddleware:
|
|
102
|
-
def __init__(self, app: ASGIApp, metrics:
|
|
100
|
+
def __init__(self, app: ASGIApp, metrics: MetricsCollector):
|
|
103
101
|
self.app = app
|
|
104
102
|
self.metrics = metrics
|
|
105
103
|
|
|
@@ -133,32 +131,3 @@ class MetricsMiddleware:
|
|
|
133
131
|
await self.app(scope, wrapped_receive, wrapped_send)
|
|
134
132
|
finally:
|
|
135
133
|
await self.metrics.on_request_end(start_time, status_code)
|
|
136
|
-
|
|
137
|
-
|
|
138
|
-
class LifespanMiddleware:
|
|
139
|
-
def __init__(
|
|
140
|
-
self,
|
|
141
|
-
app,
|
|
142
|
-
on_startup: LifespanCallback | None = None,
|
|
143
|
-
on_shutdown: LifespanCallback | None = None,
|
|
144
|
-
):
|
|
145
|
-
self.app = app
|
|
146
|
-
self.on_startup = on_startup
|
|
147
|
-
self.on_shutdown = on_shutdown
|
|
148
|
-
|
|
149
|
-
async def __call__(self, scope, receive, send):
|
|
150
|
-
if scope["type"] == "lifespan":
|
|
151
|
-
while True:
|
|
152
|
-
event = await receive()
|
|
153
|
-
if event["type"] == "lifespan.startup":
|
|
154
|
-
if self.on_startup: # pragma: no branch
|
|
155
|
-
await self.on_startup()
|
|
156
|
-
await send({"type": "lifespan.startup.complete"})
|
|
157
|
-
|
|
158
|
-
elif event["type"] == "lifespan.shutdown":
|
|
159
|
-
if self.on_shutdown: # pragma: no branch
|
|
160
|
-
await self.on_shutdown()
|
|
161
|
-
await send({"type": "lifespan.shutdown.complete"})
|
|
162
|
-
break
|
|
163
|
-
else:
|
|
164
|
-
await self.app(scope, receive, send)
|
|
@@ -1,14 +1,41 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
|
|
3
|
+
import base64
|
|
4
|
+
import binascii
|
|
3
5
|
import json
|
|
6
|
+
from collections.abc import Mapping
|
|
4
7
|
from pathlib import Path
|
|
5
|
-
from typing import Literal
|
|
8
|
+
from typing import Annotated, Any, Literal
|
|
6
9
|
|
|
7
10
|
from pydantic import BaseModel, ConfigDict, Field, field_validator
|
|
11
|
+
from pydantic.functional_serializers import PlainSerializer
|
|
12
|
+
from pydantic.functional_validators import PlainValidator
|
|
8
13
|
from pydantic_core import core_schema
|
|
9
14
|
|
|
10
15
|
|
|
11
|
-
|
|
16
|
+
def serialize_b64(v: bytes) -> str:
|
|
17
|
+
return base64.b64encode(v).decode("ascii")
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
def deserialize_b64(v):
|
|
21
|
+
if isinstance(v, bytes):
|
|
22
|
+
return v
|
|
23
|
+
if isinstance(v, str):
|
|
24
|
+
try:
|
|
25
|
+
return base64.b64decode(v, validate=True)
|
|
26
|
+
except binascii.Error as e: # pragma: no branch
|
|
27
|
+
raise ValueError("Invalid base64 data") from e
|
|
28
|
+
raise TypeError("Expected bytes or base64 string") # pragma: no cover
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
Base64Bytes = Annotated[
|
|
32
|
+
bytes,
|
|
33
|
+
PlainValidator(deserialize_b64),
|
|
34
|
+
PlainSerializer(serialize_b64, return_type=str),
|
|
35
|
+
]
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
class X509Credentials(BaseModel):
|
|
12
39
|
key: str
|
|
13
40
|
cert: str
|
|
14
41
|
ca: str
|
|
@@ -16,35 +43,34 @@ class ZitiId(BaseModel):
|
|
|
16
43
|
@field_validator("key", "cert", "ca", mode="before")
|
|
17
44
|
@classmethod
|
|
18
45
|
def strip_pem_prefix(cls, value: str) -> str:
|
|
19
|
-
if isinstance(value, str) and value.startswith("pem:"):
|
|
46
|
+
if isinstance(value, str) and value.startswith("pem:"): # pragma: no branch
|
|
20
47
|
return value[4:]
|
|
21
|
-
return value
|
|
48
|
+
return value # pragma: no cover
|
|
22
49
|
|
|
23
50
|
|
|
24
|
-
class
|
|
51
|
+
class ServiceMetadata(BaseModel):
|
|
25
52
|
model_config = ConfigDict(extra="ignore")
|
|
26
|
-
identity: str
|
|
27
53
|
extension: str
|
|
28
54
|
instance: str
|
|
29
55
|
domain: str | None = None
|
|
30
56
|
tags: dict[str, str | bool | None] | None = None
|
|
31
57
|
|
|
32
58
|
|
|
33
|
-
class
|
|
59
|
+
class Identity(BaseModel):
|
|
34
60
|
model_config = ConfigDict(extra="ignore")
|
|
35
61
|
zt_api: str = Field(validation_alias="ztAPI")
|
|
36
|
-
id:
|
|
62
|
+
id: X509Credentials
|
|
63
|
+
mrok: ServiceMetadata
|
|
64
|
+
enable_ha: bool = Field(default=False, validation_alias="enableHa")
|
|
37
65
|
zt_apis: str | None = Field(default=None, validation_alias="ztAPIs")
|
|
38
66
|
config_types: str | None = Field(default=None, validation_alias="configTypes")
|
|
39
|
-
enable_ha: bool = Field(default=False, validation_alias="enableHa")
|
|
40
|
-
mrok: ZitiMrokMeta | None = None
|
|
41
67
|
|
|
42
68
|
@staticmethod
|
|
43
|
-
def load_from_file(path: str | Path) ->
|
|
69
|
+
def load_from_file(path: str | Path) -> Identity:
|
|
44
70
|
path = Path(path)
|
|
45
71
|
with path.open("r", encoding="utf-8") as f:
|
|
46
72
|
data = json.load(f)
|
|
47
|
-
return
|
|
73
|
+
return Identity.model_validate(data)
|
|
48
74
|
|
|
49
75
|
|
|
50
76
|
class FixedSizeByteBuffer:
|
|
@@ -77,7 +103,7 @@ class FixedSizeByteBuffer:
|
|
|
77
103
|
|
|
78
104
|
|
|
79
105
|
class HTTPHeaders(dict):
|
|
80
|
-
def __init__(self, initial=None):
|
|
106
|
+
def __init__(self, initial: Mapping[Any, Any] | None = None):
|
|
81
107
|
super().__init__()
|
|
82
108
|
if initial:
|
|
83
109
|
for k, v in initial.items():
|
|
@@ -131,9 +157,9 @@ class HTTPRequest(BaseModel):
|
|
|
131
157
|
method: str
|
|
132
158
|
url: str
|
|
133
159
|
headers: HTTPHeaders
|
|
134
|
-
query_string:
|
|
160
|
+
query_string: Base64Bytes | None = None
|
|
135
161
|
start_time: float
|
|
136
|
-
body:
|
|
162
|
+
body: Base64Bytes | None = None
|
|
137
163
|
body_truncated: bool | None = None
|
|
138
164
|
|
|
139
165
|
|
|
@@ -143,7 +169,7 @@ class HTTPResponse(BaseModel):
|
|
|
143
169
|
status: int
|
|
144
170
|
headers: HTTPHeaders
|
|
145
171
|
duration: float
|
|
146
|
-
body:
|
|
172
|
+
body: Base64Bytes | None = None
|
|
147
173
|
body_truncated: bool | None = None
|
|
148
174
|
|
|
149
175
|
|
|
@@ -183,7 +209,7 @@ class WorkerMetrics(BaseModel):
|
|
|
183
209
|
|
|
184
210
|
class Status(BaseModel):
|
|
185
211
|
type: Literal["status"] = "status"
|
|
186
|
-
meta:
|
|
212
|
+
meta: ServiceMetadata
|
|
187
213
|
metrics: WorkerMetrics
|
|
188
214
|
|
|
189
215
|
|
|
@@ -1,8 +1,24 @@
|
|
|
1
1
|
import asyncio
|
|
2
|
+
import select
|
|
3
|
+
import sys
|
|
4
|
+
from typing import Any
|
|
2
5
|
|
|
3
6
|
from httpcore import AsyncNetworkStream
|
|
4
7
|
|
|
5
|
-
from mrok.proxy
|
|
8
|
+
from mrok.types.proxy import ASGIReceive
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
def is_readable(sock): # pragma: no cover
|
|
12
|
+
# Stolen from
|
|
13
|
+
# https://github.com/python-trio/trio/blob/20ee2b1b7376db637435d80e266212a35837ddcc/trio/_socket.py#L471C1-L478C31
|
|
14
|
+
|
|
15
|
+
# use select.select on Windows, and select.poll everywhere else
|
|
16
|
+
if sys.platform == "win32":
|
|
17
|
+
rready, _, _ = select.select([sock], [], [], 0)
|
|
18
|
+
return bool(rready)
|
|
19
|
+
p = select.poll()
|
|
20
|
+
p.register(sock, select.POLLIN)
|
|
21
|
+
return bool(p.poll(0))
|
|
6
22
|
|
|
7
23
|
|
|
8
24
|
class AIONetworkStream(AsyncNetworkStream):
|
|
@@ -21,6 +37,13 @@ class AIONetworkStream(AsyncNetworkStream):
|
|
|
21
37
|
self._writer.close()
|
|
22
38
|
await self._writer.wait_closed()
|
|
23
39
|
|
|
40
|
+
def get_extra_info(self, info: str) -> Any:
|
|
41
|
+
transport = self._writer.transport
|
|
42
|
+
if info == "is_readable":
|
|
43
|
+
sock = transport.get_extra_info("socket")
|
|
44
|
+
return is_readable(sock)
|
|
45
|
+
return transport.get_extra_info(info)
|
|
46
|
+
|
|
24
47
|
|
|
25
48
|
class ASGIRequestBodyStream:
|
|
26
49
|
def __init__(self, receive: ASGIReceive):
|
mrok/proxy/worker.py
ADDED
|
@@ -0,0 +1,64 @@
|
|
|
1
|
+
import asyncio
|
|
2
|
+
import contextlib
|
|
3
|
+
import logging
|
|
4
|
+
from pathlib import Path
|
|
5
|
+
|
|
6
|
+
from uvicorn.importer import import_from_string
|
|
7
|
+
|
|
8
|
+
from mrok.conf import get_settings
|
|
9
|
+
from mrok.logging import setup_logging
|
|
10
|
+
from mrok.proxy.asgi import ASGIAppWrapper
|
|
11
|
+
from mrok.proxy.event_publisher import EventPublisher
|
|
12
|
+
from mrok.proxy.models import Identity
|
|
13
|
+
from mrok.proxy.ziticorn import BackendConfig, Server
|
|
14
|
+
from mrok.types.proxy import ASGIApp
|
|
15
|
+
|
|
16
|
+
logger = logging.getLogger("mrok.proxy")
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class Worker:
|
|
20
|
+
def __init__(
|
|
21
|
+
self,
|
|
22
|
+
worker_id: str,
|
|
23
|
+
app: ASGIApp | str,
|
|
24
|
+
identity_file: str | Path,
|
|
25
|
+
*,
|
|
26
|
+
events_enabled: bool = True,
|
|
27
|
+
event_publisher_port: int = 50000,
|
|
28
|
+
metrics_interval: float = 5.0,
|
|
29
|
+
):
|
|
30
|
+
self._worker_id = worker_id
|
|
31
|
+
self._identity_file = identity_file
|
|
32
|
+
self._identity = Identity.load_from_file(self._identity_file)
|
|
33
|
+
self._app = app
|
|
34
|
+
|
|
35
|
+
self._events_enabled = events_enabled
|
|
36
|
+
self._event_publisher = (
|
|
37
|
+
EventPublisher(
|
|
38
|
+
worker_id=worker_id,
|
|
39
|
+
meta=self._identity.mrok,
|
|
40
|
+
event_publisher_port=event_publisher_port,
|
|
41
|
+
metrics_interval=metrics_interval,
|
|
42
|
+
)
|
|
43
|
+
if events_enabled
|
|
44
|
+
else None
|
|
45
|
+
)
|
|
46
|
+
|
|
47
|
+
def setup_app(self):
|
|
48
|
+
app = ASGIAppWrapper(
|
|
49
|
+
self._app if not isinstance(self._app, str) else import_from_string(self._app),
|
|
50
|
+
lifespan=self._event_publisher.lifespan if self._events_enabled else None,
|
|
51
|
+
)
|
|
52
|
+
|
|
53
|
+
if self._events_enabled:
|
|
54
|
+
self._event_publisher.setup_middleware(app)
|
|
55
|
+
return app
|
|
56
|
+
|
|
57
|
+
def run(self):
|
|
58
|
+
setup_logging(get_settings())
|
|
59
|
+
app = self.setup_app()
|
|
60
|
+
|
|
61
|
+
config = BackendConfig(app, self._identity_file)
|
|
62
|
+
server = Server(config)
|
|
63
|
+
with contextlib.suppress(KeyboardInterrupt, asyncio.CancelledError):
|
|
64
|
+
server.run()
|