mrok 0.2.3__py3-none-any.whl → 0.4.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mrok/agent/devtools/__init__.py +0 -0
- mrok/agent/devtools/__main__.py +34 -0
- mrok/agent/devtools/inspector/__init__.py +0 -0
- mrok/agent/devtools/inspector/__main__.py +25 -0
- mrok/agent/devtools/inspector/app.py +556 -0
- mrok/agent/devtools/inspector/server.py +18 -0
- mrok/agent/sidecar/app.py +9 -10
- mrok/agent/sidecar/main.py +35 -16
- mrok/agent/ziticorn.py +27 -18
- mrok/cli/commands/__init__.py +2 -1
- mrok/cli/commands/admin/list/instances.py +24 -4
- mrok/cli/commands/admin/register/extensions.py +2 -2
- mrok/cli/commands/admin/register/instances.py +3 -3
- mrok/cli/commands/admin/unregister/extensions.py +2 -2
- mrok/cli/commands/admin/unregister/instances.py +2 -2
- mrok/cli/commands/agent/__init__.py +2 -0
- mrok/cli/commands/agent/dev/__init__.py +7 -0
- mrok/cli/commands/agent/dev/console.py +25 -0
- mrok/cli/commands/agent/dev/web.py +37 -0
- mrok/cli/commands/agent/run/asgi.py +35 -16
- mrok/cli/commands/agent/run/sidecar.py +29 -13
- mrok/cli/commands/agent/utils.py +5 -0
- mrok/cli/commands/controller/run.py +1 -5
- mrok/cli/commands/proxy/__init__.py +6 -0
- mrok/cli/commands/proxy/run.py +49 -0
- mrok/cli/utils.py +5 -0
- mrok/conf.py +6 -0
- mrok/controller/auth.py +2 -2
- mrok/controller/routes/extensions.py +9 -7
- mrok/datastructures.py +159 -0
- mrok/http/config.py +3 -6
- mrok/http/constants.py +22 -0
- mrok/http/forwarder.py +62 -23
- mrok/http/lifespan.py +29 -0
- mrok/http/middlewares.py +143 -0
- mrok/http/types.py +43 -0
- mrok/http/utils.py +90 -0
- mrok/logging.py +22 -0
- mrok/master.py +269 -0
- mrok/metrics.py +139 -0
- mrok/proxy/__init__.py +3 -0
- mrok/proxy/app.py +73 -0
- mrok/proxy/dataclasses.py +12 -0
- mrok/proxy/main.py +58 -0
- mrok/proxy/streams.py +124 -0
- mrok/proxy/types.py +12 -0
- mrok/proxy/ziti.py +173 -0
- mrok/ziti/identities.py +50 -20
- mrok/ziti/services.py +8 -8
- {mrok-0.2.3.dist-info → mrok-0.4.0.dist-info}/METADATA +7 -1
- mrok-0.4.0.dist-info/RECORD +92 -0
- {mrok-0.2.3.dist-info → mrok-0.4.0.dist-info}/WHEEL +1 -1
- mrok/http/master.py +0 -132
- mrok-0.2.3.dist-info/RECORD +0 -66
- {mrok-0.2.3.dist-info → mrok-0.4.0.dist-info}/entry_points.txt +0 -0
- {mrok-0.2.3.dist-info → mrok-0.4.0.dist-info}/licenses/LICENSE.txt +0 -0
|
@@ -14,8 +14,8 @@ from mrok.ziti.errors import (
|
|
|
14
14
|
ServiceAlreadyRegisteredError,
|
|
15
15
|
ServiceNotFoundError,
|
|
16
16
|
)
|
|
17
|
-
from mrok.ziti.identities import
|
|
18
|
-
from mrok.ziti.services import
|
|
17
|
+
from mrok.ziti.identities import register_identity, unregister_identity
|
|
18
|
+
from mrok.ziti.services import register_service, unregister_service
|
|
19
19
|
|
|
20
20
|
logger = logging.getLogger("mrok.controller")
|
|
21
21
|
|
|
@@ -83,7 +83,7 @@ async def create_extension(
|
|
|
83
83
|
],
|
|
84
84
|
):
|
|
85
85
|
try:
|
|
86
|
-
service = await
|
|
86
|
+
service = await register_service(settings, mgmt_api, data.extension.id, data.tags)
|
|
87
87
|
return ExtensionRead(
|
|
88
88
|
id=service["id"],
|
|
89
89
|
name=service["name"],
|
|
@@ -149,7 +149,7 @@ async def delete_instance_by_id_or_extension_id(
|
|
|
149
149
|
id_or_extension_id: str,
|
|
150
150
|
):
|
|
151
151
|
try:
|
|
152
|
-
await
|
|
152
|
+
await unregister_service(settings, mgmt_api, id_or_extension_id)
|
|
153
153
|
except ServiceNotFoundError:
|
|
154
154
|
raise HTTPException(
|
|
155
155
|
status_code=status.HTTP_404_NOT_FOUND,
|
|
@@ -203,6 +203,7 @@ async def get_extensions(
|
|
|
203
203
|
tags=["Instances"],
|
|
204
204
|
)
|
|
205
205
|
async def create_extension_instances(
|
|
206
|
+
settings: AppSettings,
|
|
206
207
|
mgmt_api: ZitiManagementAPI,
|
|
207
208
|
client_api: ZitiClientAPI,
|
|
208
209
|
id_or_extension_id: str,
|
|
@@ -223,8 +224,8 @@ async def create_extension_instances(
|
|
|
223
224
|
],
|
|
224
225
|
):
|
|
225
226
|
service = await fetch_extension_or_404(mgmt_api, id_or_extension_id)
|
|
226
|
-
identity, identity_file = await
|
|
227
|
-
mgmt_api, client_api, service["name"], data.instance.id, data.tags
|
|
227
|
+
identity, identity_file = await register_identity(
|
|
228
|
+
settings, mgmt_api, client_api, service["name"], data.instance.id, data.tags
|
|
228
229
|
)
|
|
229
230
|
return InstanceRead(
|
|
230
231
|
id=identity["id"],
|
|
@@ -299,10 +300,11 @@ async def get_instance_by_id_or_instance_id(
|
|
|
299
300
|
tags=["Instances"],
|
|
300
301
|
)
|
|
301
302
|
async def delete_instance_by_id_or_instance_id(
|
|
303
|
+
settings: AppSettings,
|
|
302
304
|
mgmt_api: ZitiManagementAPI,
|
|
303
305
|
id_or_extension_id: str,
|
|
304
306
|
id_or_instance_id: str,
|
|
305
307
|
):
|
|
306
308
|
identity = await fetch_instance_or_404(mgmt_api, id_or_extension_id, id_or_instance_id)
|
|
307
309
|
instance_id, extension_id = identity["name"].split(".")
|
|
308
|
-
await
|
|
310
|
+
await unregister_identity(settings, mgmt_api, extension_id, instance_id)
|
mrok/datastructures.py
ADDED
|
@@ -0,0 +1,159 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from typing import Literal
|
|
4
|
+
|
|
5
|
+
from pydantic import BaseModel, Field
|
|
6
|
+
from pydantic_core import core_schema
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class FixedSizeByteBuffer:
|
|
10
|
+
def __init__(self, max_size: int):
|
|
11
|
+
self._max_size = max_size
|
|
12
|
+
self._buf = bytearray()
|
|
13
|
+
self.overflow = False
|
|
14
|
+
|
|
15
|
+
def write(self, data: bytes) -> None:
|
|
16
|
+
if not data:
|
|
17
|
+
return
|
|
18
|
+
|
|
19
|
+
remaining = self._max_size - len(self._buf)
|
|
20
|
+
if remaining <= 0:
|
|
21
|
+
self.overflow = True
|
|
22
|
+
return
|
|
23
|
+
|
|
24
|
+
if len(data) > remaining:
|
|
25
|
+
self._buf.extend(data[:remaining])
|
|
26
|
+
self.overflow = True
|
|
27
|
+
else:
|
|
28
|
+
self._buf.extend(data)
|
|
29
|
+
|
|
30
|
+
def getvalue(self) -> bytes:
|
|
31
|
+
return bytes(self._buf)
|
|
32
|
+
|
|
33
|
+
def clear(self) -> None:
|
|
34
|
+
self._buf.clear()
|
|
35
|
+
self.overflow = False
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
class HTTPHeaders(dict):
|
|
39
|
+
def __init__(self, initial=None):
|
|
40
|
+
super().__init__()
|
|
41
|
+
if initial:
|
|
42
|
+
for k, v in initial.items():
|
|
43
|
+
super().__setitem__(str(k).lower(), str(v))
|
|
44
|
+
|
|
45
|
+
def __getitem__(self, key: str) -> str:
|
|
46
|
+
return super().__getitem__(key.lower())
|
|
47
|
+
|
|
48
|
+
def __setitem__(self, key: str, value: str) -> None:
|
|
49
|
+
super().__setitem__(str(key).lower(), str(value))
|
|
50
|
+
|
|
51
|
+
def __delitem__(self, key: str) -> None:
|
|
52
|
+
super().__delitem__(key.lower())
|
|
53
|
+
|
|
54
|
+
def get(self, k, default=None):
|
|
55
|
+
return super().get(k.lower(), default)
|
|
56
|
+
|
|
57
|
+
@classmethod
|
|
58
|
+
def from_asgi(cls, items: list[tuple[bytes, bytes]]) -> HTTPHeaders:
|
|
59
|
+
d = {k.decode("latin-1"): v.decode("latin-1") for k, v in items}
|
|
60
|
+
return cls(d)
|
|
61
|
+
|
|
62
|
+
@classmethod
|
|
63
|
+
def __get_pydantic_core_schema__(cls, source, handler):
|
|
64
|
+
"""Provide a pydantic-core schema so Pydantic treats this as a mapping of str->str.
|
|
65
|
+
|
|
66
|
+
We generate the schema for `dict[str, str]` using the provided handler and wrap
|
|
67
|
+
it with a validator that converts the validated dict into `HTTPHeaders`.
|
|
68
|
+
"""
|
|
69
|
+
# handler may be a callable or an object with `generate_schema`; handle both
|
|
70
|
+
try:
|
|
71
|
+
dict_schema = handler.generate_schema(dict[str, str])
|
|
72
|
+
except AttributeError:
|
|
73
|
+
dict_schema = handler(dict[str, str])
|
|
74
|
+
|
|
75
|
+
def _wrap(v, validator):
|
|
76
|
+
# `validator` will validate input according to `dict_schema` and return a dict
|
|
77
|
+
validated = validator(input_value=v)
|
|
78
|
+
if isinstance(validated, HTTPHeaders):
|
|
79
|
+
return validated
|
|
80
|
+
return cls(validated)
|
|
81
|
+
|
|
82
|
+
return core_schema.no_info_wrap_validator_function(
|
|
83
|
+
_wrap,
|
|
84
|
+
dict_schema,
|
|
85
|
+
serialization=core_schema.plain_serializer_function_ser_schema(lambda v: dict(v)),
|
|
86
|
+
)
|
|
87
|
+
|
|
88
|
+
|
|
89
|
+
class HTTPRequest(BaseModel):
|
|
90
|
+
method: str
|
|
91
|
+
url: str
|
|
92
|
+
headers: HTTPHeaders
|
|
93
|
+
query_string: bytes
|
|
94
|
+
start_time: float
|
|
95
|
+
body: bytes | None = None
|
|
96
|
+
body_truncated: bool | None = None
|
|
97
|
+
|
|
98
|
+
|
|
99
|
+
class HTTPResponse(BaseModel):
|
|
100
|
+
type: Literal["response"] = "response"
|
|
101
|
+
request: HTTPRequest
|
|
102
|
+
status: int
|
|
103
|
+
headers: HTTPHeaders
|
|
104
|
+
duration: float
|
|
105
|
+
body: bytes | None = None
|
|
106
|
+
body_truncated: bool | None = None
|
|
107
|
+
|
|
108
|
+
|
|
109
|
+
class ProcessMetrics(BaseModel):
|
|
110
|
+
cpu: float
|
|
111
|
+
mem: float
|
|
112
|
+
|
|
113
|
+
|
|
114
|
+
class DataTransferMetrics(BaseModel):
|
|
115
|
+
bytes_in: int
|
|
116
|
+
bytes_out: int
|
|
117
|
+
|
|
118
|
+
|
|
119
|
+
class RequestsMetrics(BaseModel):
|
|
120
|
+
rps: int
|
|
121
|
+
total: int
|
|
122
|
+
successful: int
|
|
123
|
+
failed: int
|
|
124
|
+
|
|
125
|
+
|
|
126
|
+
class ResponseTimeMetrics(BaseModel):
|
|
127
|
+
avg: float
|
|
128
|
+
min: int
|
|
129
|
+
max: int
|
|
130
|
+
p50: int
|
|
131
|
+
p90: int
|
|
132
|
+
p99: int
|
|
133
|
+
|
|
134
|
+
|
|
135
|
+
class WorkerMetrics(BaseModel):
|
|
136
|
+
worker_id: str
|
|
137
|
+
data_transfer: DataTransferMetrics
|
|
138
|
+
requests: RequestsMetrics
|
|
139
|
+
response_time: ResponseTimeMetrics
|
|
140
|
+
process: ProcessMetrics
|
|
141
|
+
|
|
142
|
+
|
|
143
|
+
class Meta(BaseModel):
|
|
144
|
+
identity: str
|
|
145
|
+
extension: str
|
|
146
|
+
instance: str
|
|
147
|
+
domain: str
|
|
148
|
+
tags: dict[str, str] | None = None
|
|
149
|
+
|
|
150
|
+
|
|
151
|
+
class Status(BaseModel):
|
|
152
|
+
type: Literal["status"] = "status"
|
|
153
|
+
meta: Meta
|
|
154
|
+
metrics: WorkerMetrics
|
|
155
|
+
|
|
156
|
+
|
|
157
|
+
class Event(BaseModel):
|
|
158
|
+
type: Literal["status", "response"]
|
|
159
|
+
data: Status | HTTPResponse = Field(discriminator="type")
|
mrok/http/config.py
CHANGED
|
@@ -8,21 +8,18 @@ from typing import Any
|
|
|
8
8
|
import openziti
|
|
9
9
|
from uvicorn import config
|
|
10
10
|
|
|
11
|
-
from mrok.conf import get_settings
|
|
12
11
|
from mrok.http.protocol import MrokHttpToolsProtocol
|
|
13
|
-
from mrok.
|
|
12
|
+
from mrok.http.types import ASGIApp
|
|
14
13
|
|
|
15
14
|
logger = logging.getLogger("mrok.proxy")
|
|
16
15
|
|
|
17
16
|
config.LIFESPAN["auto"] = "mrok.http.lifespan:MrokLifespan"
|
|
18
17
|
|
|
19
|
-
ASGIApplication = config.ASGIApplication
|
|
20
|
-
|
|
21
18
|
|
|
22
19
|
class MrokBackendConfig(config.Config):
|
|
23
20
|
def __init__(
|
|
24
21
|
self,
|
|
25
|
-
app:
|
|
22
|
+
app: ASGIApp | Callable[..., Any] | str,
|
|
26
23
|
identity_file: str | Path,
|
|
27
24
|
ziti_load_timeout_ms: int = 5000,
|
|
28
25
|
backlog: int = 2048,
|
|
@@ -62,4 +59,4 @@ class MrokBackendConfig(config.Config):
|
|
|
62
59
|
return sock
|
|
63
60
|
|
|
64
61
|
def configure_logging(self) -> None:
|
|
65
|
-
|
|
62
|
+
return
|
mrok/http/constants.py
ADDED
|
@@ -0,0 +1,22 @@
|
|
|
1
|
+
MAX_REQUEST_BODY_BYTES = 2 * 1024 * 1024
|
|
2
|
+
MAX_RESPONSE_BODY_BYTES = 5 * 1024 * 1024
|
|
3
|
+
|
|
4
|
+
BINARY_CONTENT_TYPES = {
|
|
5
|
+
"application/octet-stream",
|
|
6
|
+
"application/pdf",
|
|
7
|
+
}
|
|
8
|
+
|
|
9
|
+
BINARY_PREFIXES = (
|
|
10
|
+
"image/",
|
|
11
|
+
"video/",
|
|
12
|
+
"audio/",
|
|
13
|
+
)
|
|
14
|
+
|
|
15
|
+
TEXTUAL_CONTENT_TYPES = {
|
|
16
|
+
"application/json",
|
|
17
|
+
"application/xml",
|
|
18
|
+
"application/javascript",
|
|
19
|
+
"application/x-www-form-urlencoded",
|
|
20
|
+
}
|
|
21
|
+
|
|
22
|
+
TEXTUAL_PREFIXES = ("text/",)
|
mrok/http/forwarder.py
CHANGED
|
@@ -1,14 +1,10 @@
|
|
|
1
1
|
import abc
|
|
2
2
|
import asyncio
|
|
3
3
|
import logging
|
|
4
|
-
from collections.abc import Awaitable, Callable
|
|
5
|
-
from typing import Any
|
|
6
4
|
|
|
7
|
-
|
|
5
|
+
from mrok.http.types import ASGIReceive, ASGISend, Scope, StreamReader, StreamWriter
|
|
8
6
|
|
|
9
|
-
|
|
10
|
-
ASGIReceive = Callable[[], Awaitable[dict[str, Any]]]
|
|
11
|
-
ASGISend = Callable[[dict[str, Any]], Awaitable[None]]
|
|
7
|
+
logger = logging.getLogger("mrok.proxy")
|
|
12
8
|
|
|
13
9
|
|
|
14
10
|
class BackendNotFoundError(Exception):
|
|
@@ -24,24 +20,71 @@ class ForwardAppBase(abc.ABC):
|
|
|
24
20
|
and streaming logic (requests and responses).
|
|
25
21
|
"""
|
|
26
22
|
|
|
27
|
-
def __init__(
|
|
28
|
-
|
|
23
|
+
def __init__(
|
|
24
|
+
self,
|
|
25
|
+
read_chunk_size: int = 65536,
|
|
26
|
+
lifespan_timeout: float = 10.0,
|
|
27
|
+
) -> None:
|
|
29
28
|
self._read_chunk_size: int = int(read_chunk_size)
|
|
29
|
+
self._lifespan_timeout = lifespan_timeout
|
|
30
|
+
|
|
31
|
+
async def handle_lifespan(self, receive: ASGIReceive, send: ASGISend) -> None:
|
|
32
|
+
while True:
|
|
33
|
+
event = await receive()
|
|
34
|
+
etype = event.get("type")
|
|
35
|
+
|
|
36
|
+
if etype == "lifespan.startup":
|
|
37
|
+
try:
|
|
38
|
+
await asyncio.wait_for(self.startup(), self._lifespan_timeout)
|
|
39
|
+
except TimeoutError:
|
|
40
|
+
logger.exception("Lifespan startup timed out")
|
|
41
|
+
await send({"type": "lifespan.startup.failed", "message": "startup timeout"})
|
|
42
|
+
continue
|
|
43
|
+
except Exception as e:
|
|
44
|
+
logger.exception("Exception during lifespan startup")
|
|
45
|
+
await send({"type": "lifespan.startup.failed", "message": str(e)})
|
|
46
|
+
continue
|
|
47
|
+
await send({"type": "lifespan.startup.complete"})
|
|
48
|
+
|
|
49
|
+
elif etype == "lifespan.shutdown":
|
|
50
|
+
try:
|
|
51
|
+
await asyncio.wait_for(self.shutdown(), self._lifespan_timeout)
|
|
52
|
+
except TimeoutError:
|
|
53
|
+
logger.exception("Lifespan shutdown timed out")
|
|
54
|
+
await send({"type": "lifespan.shutdown.failed", "message": "shutdown timeout"})
|
|
55
|
+
return
|
|
56
|
+
except Exception as exc:
|
|
57
|
+
logger.exception("Exception during lifespan shutdown")
|
|
58
|
+
await send({"type": "lifespan.shutdown.failed", "message": str(exc)})
|
|
59
|
+
return
|
|
60
|
+
await send({"type": "lifespan.shutdown.complete"})
|
|
61
|
+
return
|
|
30
62
|
|
|
31
63
|
@abc.abstractmethod
|
|
32
64
|
async def select_backend(
|
|
33
65
|
self,
|
|
34
66
|
scope: Scope,
|
|
35
67
|
headers: dict[str, str],
|
|
36
|
-
) -> tuple[
|
|
68
|
+
) -> tuple[StreamReader, StreamWriter] | tuple[None, None]:
|
|
37
69
|
"""Return (reader, writer) connected to the target backend."""
|
|
38
70
|
|
|
71
|
+
async def startup(self):
|
|
72
|
+
return
|
|
73
|
+
|
|
74
|
+
async def shutdown(self):
|
|
75
|
+
return
|
|
76
|
+
|
|
39
77
|
async def __call__(self, scope: Scope, receive: ASGIReceive, send: ASGISend) -> None:
|
|
40
78
|
"""ASGI callable entry point.
|
|
41
79
|
|
|
42
80
|
Delegates to smaller helper methods for readability. Subclasses only
|
|
43
81
|
need to implement backend selection.
|
|
44
82
|
"""
|
|
83
|
+
scope_type = scope.get("type")
|
|
84
|
+
if scope_type == "lifespan":
|
|
85
|
+
await self.handle_lifespan(receive, send)
|
|
86
|
+
return
|
|
87
|
+
|
|
45
88
|
if scope.get("type") != "http":
|
|
46
89
|
await send({"type": "http.response.start", "status": 500, "headers": []})
|
|
47
90
|
await send({"type": "http.response.body", "body": b"Unsupported"})
|
|
@@ -116,7 +159,7 @@ class ForwardAppBase(abc.ABC):
|
|
|
116
159
|
|
|
117
160
|
async def write_request_line_and_headers(
|
|
118
161
|
self,
|
|
119
|
-
writer:
|
|
162
|
+
writer: StreamWriter,
|
|
120
163
|
method: str,
|
|
121
164
|
path_qs: str,
|
|
122
165
|
headers: list[tuple[bytes, bytes]],
|
|
@@ -130,7 +173,7 @@ class ForwardAppBase(abc.ABC):
|
|
|
130
173
|
await writer.drain()
|
|
131
174
|
|
|
132
175
|
async def stream_request_body(
|
|
133
|
-
self, receive: ASGIReceive, writer:
|
|
176
|
+
self, receive: ASGIReceive, writer: StreamWriter, use_chunked: bool
|
|
134
177
|
) -> None:
|
|
135
178
|
if use_chunked:
|
|
136
179
|
await self.stream_request_chunked(receive, writer)
|
|
@@ -138,9 +181,7 @@ class ForwardAppBase(abc.ABC):
|
|
|
138
181
|
|
|
139
182
|
await self.stream_request_until_end(receive, writer)
|
|
140
183
|
|
|
141
|
-
async def stream_request_chunked(
|
|
142
|
-
self, receive: ASGIReceive, writer: asyncio.StreamWriter
|
|
143
|
-
) -> None:
|
|
184
|
+
async def stream_request_chunked(self, receive: ASGIReceive, writer: StreamWriter) -> None:
|
|
144
185
|
"""Send request body to backend using HTTP/1.1 chunked encoding."""
|
|
145
186
|
while True:
|
|
146
187
|
event = await receive()
|
|
@@ -160,9 +201,7 @@ class ForwardAppBase(abc.ABC):
|
|
|
160
201
|
writer.write(b"0\r\n\r\n")
|
|
161
202
|
await writer.drain()
|
|
162
203
|
|
|
163
|
-
async def stream_request_until_end(
|
|
164
|
-
self, receive: ASGIReceive, writer: asyncio.StreamWriter
|
|
165
|
-
) -> None:
|
|
204
|
+
async def stream_request_until_end(self, receive: ASGIReceive, writer: StreamWriter) -> None:
|
|
166
205
|
"""Send request body to backend when content length/transfer-encoding
|
|
167
206
|
already provided (no chunking).
|
|
168
207
|
"""
|
|
@@ -180,7 +219,7 @@ class ForwardAppBase(abc.ABC):
|
|
|
180
219
|
return
|
|
181
220
|
|
|
182
221
|
async def read_status_and_headers(
|
|
183
|
-
self, reader:
|
|
222
|
+
self, reader: StreamReader, first_line: bytes
|
|
184
223
|
) -> tuple[int, list[tuple[bytes, bytes]], dict[bytes, bytes]]:
|
|
185
224
|
parts = first_line.decode(errors="ignore").split(" ", 2)
|
|
186
225
|
status = int(parts[1]) if len(parts) >= 2 and parts[1].isdigit() else 502
|
|
@@ -217,14 +256,14 @@ class ForwardAppBase(abc.ABC):
|
|
|
217
256
|
except Exception:
|
|
218
257
|
return None
|
|
219
258
|
|
|
220
|
-
async def drain_trailers(self, reader:
|
|
259
|
+
async def drain_trailers(self, reader: StreamReader) -> None:
|
|
221
260
|
"""Consume trailer header lines until an empty line is encountered."""
|
|
222
261
|
while True:
|
|
223
262
|
trailer = await reader.readline()
|
|
224
263
|
if trailer in (b"\r\n", b"\n", b""):
|
|
225
264
|
break
|
|
226
265
|
|
|
227
|
-
async def stream_response_chunked(self, reader:
|
|
266
|
+
async def stream_response_chunked(self, reader: StreamReader, send: ASGISend) -> None:
|
|
228
267
|
"""Read chunked-encoded response from reader, decode and forward to ASGI send."""
|
|
229
268
|
while True:
|
|
230
269
|
size_line = await reader.readline()
|
|
@@ -253,7 +292,7 @@ class ForwardAppBase(abc.ABC):
|
|
|
253
292
|
await send({"type": "http.response.body", "body": b"", "more_body": False})
|
|
254
293
|
|
|
255
294
|
async def stream_response_with_content_length(
|
|
256
|
-
self, reader:
|
|
295
|
+
self, reader: StreamReader, send: ASGISend, content_length: int
|
|
257
296
|
) -> None:
|
|
258
297
|
"""Read exactly content_length bytes and forward to ASGI send events."""
|
|
259
298
|
remaining = content_length
|
|
@@ -272,7 +311,7 @@ class ForwardAppBase(abc.ABC):
|
|
|
272
311
|
if not sent_final:
|
|
273
312
|
await send({"type": "http.response.body", "body": b"", "more_body": False})
|
|
274
313
|
|
|
275
|
-
async def stream_response_until_eof(self, reader:
|
|
314
|
+
async def stream_response_until_eof(self, reader: StreamReader, send: ASGISend) -> None:
|
|
276
315
|
"""Read from reader until EOF and forward chunks to ASGI send events."""
|
|
277
316
|
while True:
|
|
278
317
|
chunk = await reader.read(self._read_chunk_size)
|
|
@@ -282,7 +321,7 @@ class ForwardAppBase(abc.ABC):
|
|
|
282
321
|
await send({"type": "http.response.body", "body": b"", "more_body": False})
|
|
283
322
|
|
|
284
323
|
async def stream_response_body(
|
|
285
|
-
self, reader:
|
|
324
|
+
self, reader: StreamReader, send: ASGISend, raw_headers: dict[bytes, bytes]
|
|
286
325
|
) -> None:
|
|
287
326
|
te = raw_headers.get(b"transfer-encoding", b"").lower()
|
|
288
327
|
cl = raw_headers.get(b"content-length")
|
mrok/http/lifespan.py
CHANGED
|
@@ -1,10 +1,39 @@
|
|
|
1
1
|
import logging
|
|
2
|
+
from collections.abc import Awaitable, Callable
|
|
2
3
|
|
|
3
4
|
from uvicorn.config import Config
|
|
4
5
|
from uvicorn.lifespan.on import LifespanOn
|
|
5
6
|
|
|
7
|
+
AsyncCallback = Callable[[], Awaitable[None]]
|
|
8
|
+
|
|
6
9
|
|
|
7
10
|
class MrokLifespan(LifespanOn):
|
|
8
11
|
def __init__(self, config: Config) -> None:
|
|
9
12
|
super().__init__(config)
|
|
10
13
|
self.logger = logging.getLogger("mrok.proxy")
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class LifespanWrapper:
|
|
17
|
+
def __init__(
|
|
18
|
+
self, app, on_startup: AsyncCallback | None = None, on_shutdown: AsyncCallback | None = None
|
|
19
|
+
):
|
|
20
|
+
self.app = app
|
|
21
|
+
self.on_startup = on_startup
|
|
22
|
+
self.on_shutdown = on_shutdown
|
|
23
|
+
|
|
24
|
+
async def __call__(self, scope, receive, send):
|
|
25
|
+
if scope["type"] == "lifespan":
|
|
26
|
+
while True:
|
|
27
|
+
event = await receive()
|
|
28
|
+
if event["type"] == "lifespan.startup":
|
|
29
|
+
if self.on_startup:
|
|
30
|
+
await self.on_startup()
|
|
31
|
+
await send({"type": "lifespan.startup.complete"})
|
|
32
|
+
|
|
33
|
+
elif event["type"] == "lifespan.shutdown":
|
|
34
|
+
if self.on_shutdown:
|
|
35
|
+
await self.on_shutdown()
|
|
36
|
+
await send({"type": "lifespan.shutdown.complete"})
|
|
37
|
+
break
|
|
38
|
+
else:
|
|
39
|
+
await self.app(scope, receive, send)
|
mrok/http/middlewares.py
ADDED
|
@@ -0,0 +1,143 @@
|
|
|
1
|
+
import asyncio
|
|
2
|
+
import inspect
|
|
3
|
+
import logging
|
|
4
|
+
import time
|
|
5
|
+
from collections.abc import Callable, Coroutine
|
|
6
|
+
from typing import Any
|
|
7
|
+
|
|
8
|
+
from mrok.datastructures import FixedSizeByteBuffer, HTTPHeaders, HTTPRequest, HTTPResponse
|
|
9
|
+
from mrok.http.constants import MAX_REQUEST_BODY_BYTES, MAX_RESPONSE_BODY_BYTES
|
|
10
|
+
from mrok.http.types import ASGIApp, ASGIReceive, ASGISend, Message, Scope
|
|
11
|
+
from mrok.http.utils import must_capture_request, must_capture_response
|
|
12
|
+
from mrok.metrics import WorkerMetricsCollector
|
|
13
|
+
|
|
14
|
+
logger = logging.getLogger("mrok.proxy")
|
|
15
|
+
|
|
16
|
+
ResponseCompleteCallback = Callable[[HTTPResponse], Coroutine[Any, Any, None] | None]
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class CaptureMiddleware:
|
|
20
|
+
def __init__(
|
|
21
|
+
self,
|
|
22
|
+
app: ASGIApp,
|
|
23
|
+
on_response_complete: ResponseCompleteCallback,
|
|
24
|
+
):
|
|
25
|
+
self.app = app
|
|
26
|
+
self._on_response_complete = on_response_complete
|
|
27
|
+
|
|
28
|
+
async def __call__(self, scope: Scope, receive: ASGIReceive, send: ASGISend):
|
|
29
|
+
if scope["type"] != "http":
|
|
30
|
+
await self.app(scope, receive, send)
|
|
31
|
+
return
|
|
32
|
+
|
|
33
|
+
start_time = time.time()
|
|
34
|
+
method = scope["method"]
|
|
35
|
+
path = scope["path"]
|
|
36
|
+
query_string = scope.get("query_string", b"")
|
|
37
|
+
req_headers_raw = scope.get("headers", [])
|
|
38
|
+
req_headers = HTTPHeaders.from_asgi(req_headers_raw)
|
|
39
|
+
|
|
40
|
+
state = {}
|
|
41
|
+
|
|
42
|
+
req_buf = FixedSizeByteBuffer(MAX_REQUEST_BODY_BYTES)
|
|
43
|
+
capture_req_body = must_capture_request(method, req_headers)
|
|
44
|
+
|
|
45
|
+
request = HTTPRequest(
|
|
46
|
+
method=method,
|
|
47
|
+
url=path,
|
|
48
|
+
headers=req_headers,
|
|
49
|
+
query_string=query_string,
|
|
50
|
+
start_time=start_time,
|
|
51
|
+
)
|
|
52
|
+
|
|
53
|
+
# Response capture
|
|
54
|
+
resp_buf = FixedSizeByteBuffer(MAX_RESPONSE_BODY_BYTES)
|
|
55
|
+
|
|
56
|
+
async def receive_wrapper() -> Message:
|
|
57
|
+
msg = await receive()
|
|
58
|
+
if capture_req_body and msg["type"] == "http.request":
|
|
59
|
+
body = msg.get("body", b"")
|
|
60
|
+
req_buf.write(body)
|
|
61
|
+
return msg
|
|
62
|
+
|
|
63
|
+
async def send_wrapper(msg: Message):
|
|
64
|
+
if msg["type"] == "http.response.start":
|
|
65
|
+
state["status"] = msg["status"]
|
|
66
|
+
resp_headers = HTTPHeaders.from_asgi(msg.get("headers", []))
|
|
67
|
+
state["resp_headers_raw"] = resp_headers
|
|
68
|
+
|
|
69
|
+
state["capture_resp_body"] = must_capture_response(resp_headers)
|
|
70
|
+
|
|
71
|
+
if state["capture_resp_body"] and msg["type"] == "http.response.body":
|
|
72
|
+
body = msg.get("body", b"")
|
|
73
|
+
resp_buf.write(body)
|
|
74
|
+
|
|
75
|
+
await send(msg)
|
|
76
|
+
|
|
77
|
+
await self.app(scope, receive_wrapper, send_wrapper)
|
|
78
|
+
|
|
79
|
+
# Finalise request
|
|
80
|
+
request.body = req_buf.getvalue() if capture_req_body else None
|
|
81
|
+
request.body_truncated = req_buf.overflow if capture_req_body else None
|
|
82
|
+
|
|
83
|
+
# Finalise response
|
|
84
|
+
end_time = time.time()
|
|
85
|
+
duration = end_time - start_time
|
|
86
|
+
|
|
87
|
+
response = HTTPResponse(
|
|
88
|
+
request=request,
|
|
89
|
+
status=state["status"] or 0,
|
|
90
|
+
headers=state["resp_headers_raw"],
|
|
91
|
+
duration=duration,
|
|
92
|
+
body=resp_buf.getvalue() if state["capture_resp_body"] else None,
|
|
93
|
+
body_truncated=resp_buf.overflow if state["capture_resp_body"] else None,
|
|
94
|
+
)
|
|
95
|
+
await asyncio.create_task(self.handle_callback(response))
|
|
96
|
+
|
|
97
|
+
async def handle_callback(self, response: HTTPResponse):
|
|
98
|
+
try:
|
|
99
|
+
if inspect.iscoroutinefunction(self._on_response_complete):
|
|
100
|
+
await self._on_response_complete(response)
|
|
101
|
+
else:
|
|
102
|
+
await asyncio.get_running_loop().run_in_executor(
|
|
103
|
+
None, self._on_response_complete, response
|
|
104
|
+
)
|
|
105
|
+
except Exception:
|
|
106
|
+
logger.exception("Error invoking callback")
|
|
107
|
+
|
|
108
|
+
|
|
109
|
+
class MetricsMiddleware:
|
|
110
|
+
def __init__(self, app, metrics: WorkerMetricsCollector):
|
|
111
|
+
self.app = app
|
|
112
|
+
self.metrics = metrics
|
|
113
|
+
|
|
114
|
+
async def __call__(self, scope, receive, send):
|
|
115
|
+
if scope["type"] != "http":
|
|
116
|
+
return await self.app(scope, receive, send)
|
|
117
|
+
|
|
118
|
+
start_time = await self.metrics.on_request_start(scope)
|
|
119
|
+
status_code = 500 # default if app errors early
|
|
120
|
+
|
|
121
|
+
async def wrapped_receive():
|
|
122
|
+
msg = await receive()
|
|
123
|
+
if msg["type"] == "http.request" and msg.get("body"):
|
|
124
|
+
await self.metrics.on_request_body(len(msg["body"]))
|
|
125
|
+
return msg
|
|
126
|
+
|
|
127
|
+
async def wrapped_send(msg):
|
|
128
|
+
nonlocal status_code
|
|
129
|
+
|
|
130
|
+
if msg["type"] == "http.response.start":
|
|
131
|
+
status_code = msg["status"]
|
|
132
|
+
await self.metrics.on_response_start(status_code)
|
|
133
|
+
|
|
134
|
+
elif msg["type"] == "http.response.body":
|
|
135
|
+
body = msg.get("body", b"")
|
|
136
|
+
await self.metrics.on_response_chunk(len(body))
|
|
137
|
+
|
|
138
|
+
return await send(msg)
|
|
139
|
+
|
|
140
|
+
try:
|
|
141
|
+
await self.app(scope, wrapped_receive, wrapped_send)
|
|
142
|
+
finally:
|
|
143
|
+
await self.metrics.on_request_end(start_time, status_code)
|
mrok/http/types.py
ADDED
|
@@ -0,0 +1,43 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import asyncio
|
|
4
|
+
from collections.abc import Awaitable, Callable, MutableMapping
|
|
5
|
+
from typing import Any, Protocol
|
|
6
|
+
|
|
7
|
+
from mrok.datastructures import HTTPRequest, HTTPResponse
|
|
8
|
+
|
|
9
|
+
Scope = MutableMapping[str, Any]
|
|
10
|
+
Message = MutableMapping[str, Any]
|
|
11
|
+
|
|
12
|
+
ASGIReceive = Callable[[], Awaitable[Message]]
|
|
13
|
+
ASGISend = Callable[[Message], Awaitable[None]]
|
|
14
|
+
ASGIApp = Callable[[Scope, ASGIReceive, ASGISend], Awaitable[None]]
|
|
15
|
+
RequestCompleteCallback = Callable[[HTTPRequest], Awaitable | None]
|
|
16
|
+
ResponseCompleteCallback = Callable[[HTTPResponse], Awaitable | None]
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class StreamReaderWrapper(Protocol):
|
|
20
|
+
async def read(self, n: int = -1) -> bytes: ...
|
|
21
|
+
async def readexactly(self, n: int) -> bytes: ...
|
|
22
|
+
async def readline(self) -> bytes: ...
|
|
23
|
+
def at_eof(self) -> bool: ...
|
|
24
|
+
|
|
25
|
+
@property
|
|
26
|
+
def underlying(self) -> asyncio.StreamReader: ...
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
class StreamWriterWrapper(Protocol):
|
|
30
|
+
def write(self, data: bytes) -> None: ...
|
|
31
|
+
async def drain(self) -> None: ...
|
|
32
|
+
def close(self) -> None: ...
|
|
33
|
+
async def wait_closed(self) -> None: ...
|
|
34
|
+
|
|
35
|
+
@property
|
|
36
|
+
def transport(self): ...
|
|
37
|
+
|
|
38
|
+
@property
|
|
39
|
+
def underlying(self) -> asyncio.StreamWriter: ...
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
StreamReader = StreamReaderWrapper | asyncio.StreamReader
|
|
43
|
+
StreamWriter = StreamWriterWrapper | asyncio.StreamWriter
|