mrok 0.6.0__py3-none-any.whl → 0.8.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mrok/agent/devtools/inspector/__main__.py +3 -23
- mrok/agent/devtools/inspector/app.py +407 -112
- mrok/agent/devtools/inspector/utils.py +149 -0
- mrok/cli/commands/admin/bootstrap.py +2 -2
- mrok/cli/commands/admin/register/extensions.py +7 -9
- mrok/cli/commands/admin/register/instances.py +13 -16
- mrok/cli/commands/admin/unregister/extensions.py +7 -11
- mrok/cli/commands/admin/unregister/instances.py +12 -12
- mrok/cli/commands/agent/run/asgi.py +1 -1
- mrok/cli/commands/frontend/run.py +1 -1
- mrok/cli/main.py +17 -1
- mrok/cli/utils.py +26 -0
- mrok/conf.py +15 -7
- mrok/constants.py +21 -0
- mrok/controller/app.py +12 -10
- mrok/controller/auth/__init__.py +11 -0
- mrok/controller/auth/backends.py +60 -0
- mrok/controller/auth/base.py +38 -0
- mrok/controller/auth/manager.py +31 -0
- mrok/controller/auth/registry.py +17 -0
- mrok/frontend/app.py +94 -26
- mrok/frontend/main.py +8 -5
- mrok/frontend/middleware.py +35 -0
- mrok/frontend/utils.py +83 -0
- mrok/logging.py +24 -22
- mrok/proxy/app.py +13 -5
- mrok/proxy/middleware.py +7 -8
- mrok/proxy/models.py +36 -10
- mrok/proxy/ziticorn.py +8 -17
- mrok/ziti/api.py +4 -4
- mrok/ziti/bootstrap.py +0 -5
- mrok/ziti/identities.py +11 -10
- mrok/ziti/services.py +6 -6
- {mrok-0.6.0.dist-info → mrok-0.8.0.dist-info}/METADATA +9 -3
- {mrok-0.6.0.dist-info → mrok-0.8.0.dist-info}/RECORD +38 -35
- mrok/agent/devtools/__main__.py +0 -34
- mrok/cli/commands/agent/utils.py +0 -5
- mrok/controller/auth.py +0 -87
- mrok/proxy/constants.py +0 -22
- mrok/proxy/utils.py +0 -90
- {mrok-0.6.0.dist-info → mrok-0.8.0.dist-info}/WHEEL +0 -0
- {mrok-0.6.0.dist-info → mrok-0.8.0.dist-info}/entry_points.txt +0 -0
- {mrok-0.6.0.dist-info → mrok-0.8.0.dist-info}/licenses/LICENSE.txt +0 -0
|
@@ -0,0 +1,31 @@
|
|
|
1
|
+
from dynaconf.utils.boxing import DynaBox
|
|
2
|
+
from fastapi import Request
|
|
3
|
+
|
|
4
|
+
from mrok.controller.auth.base import UNAUTHORIZED_EXCEPTION, AuthIdentity, BaseHTTPAuthBackend
|
|
5
|
+
from mrok.controller.auth.registry import get_authentication_backend
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class HTTPAuthManager:
|
|
9
|
+
def __init__(self, auth_settings: DynaBox):
|
|
10
|
+
self.auth_settings = auth_settings
|
|
11
|
+
self.active_backends: list[BaseHTTPAuthBackend] = []
|
|
12
|
+
self._setup_backends()
|
|
13
|
+
|
|
14
|
+
def _setup_backends(self):
|
|
15
|
+
enabled_keys = self.auth_settings.get("backends", [])
|
|
16
|
+
|
|
17
|
+
for key in enabled_keys:
|
|
18
|
+
backend_cls = get_authentication_backend(key)
|
|
19
|
+
if not backend_cls:
|
|
20
|
+
raise ValueError(f"Backend '{key}' is not registered.")
|
|
21
|
+
|
|
22
|
+
specific_config = self.auth_settings.get(key, {})
|
|
23
|
+
self.active_backends.append(backend_cls(specific_config))
|
|
24
|
+
|
|
25
|
+
async def __call__(self, request: Request) -> AuthIdentity:
|
|
26
|
+
for backend in self.active_backends:
|
|
27
|
+
identity = await backend(request)
|
|
28
|
+
if identity:
|
|
29
|
+
return identity
|
|
30
|
+
|
|
31
|
+
raise UNAUTHORIZED_EXCEPTION
|
|
@@ -0,0 +1,17 @@
|
|
|
1
|
+
from mrok.controller.auth.base import BaseHTTPAuthBackend
|
|
2
|
+
|
|
3
|
+
BACKEND_REGISTRY: dict[str, type[BaseHTTPAuthBackend]] = {}
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
def register_authentication_backend(name: str):
|
|
7
|
+
"""Decorator to register a backend class with a unique key."""
|
|
8
|
+
|
|
9
|
+
def decorator(cls: type[BaseHTTPAuthBackend]):
|
|
10
|
+
BACKEND_REGISTRY[name] = cls
|
|
11
|
+
return cls
|
|
12
|
+
|
|
13
|
+
return decorator
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
def get_authentication_backend(name: str) -> type[BaseHTTPAuthBackend] | None:
|
|
17
|
+
return BACKEND_REGISTRY.get(name)
|
mrok/frontend/app.py
CHANGED
|
@@ -1,14 +1,21 @@
|
|
|
1
|
-
import
|
|
1
|
+
from http import HTTPStatus
|
|
2
|
+
from pathlib import Path
|
|
3
|
+
from typing import Any
|
|
2
4
|
|
|
3
5
|
from httpcore import AsyncConnectionPool
|
|
6
|
+
from jinja2 import Environment, FileSystemLoader, select_autoescape
|
|
4
7
|
|
|
5
8
|
from mrok.conf import get_settings
|
|
9
|
+
from mrok.frontend.utils import get_target_name, parse_accept_header
|
|
6
10
|
from mrok.proxy.app import ProxyAppBase
|
|
7
11
|
from mrok.proxy.backend import AIOZitiNetworkBackend
|
|
8
12
|
from mrok.proxy.exceptions import InvalidTargetError
|
|
9
|
-
from mrok.types.proxy import Scope
|
|
13
|
+
from mrok.types.proxy import ASGISend, Scope
|
|
10
14
|
|
|
11
|
-
|
|
15
|
+
ERROR_TEMPLATE_FORMATS = {
|
|
16
|
+
"application/json": "json",
|
|
17
|
+
"text/html": "html",
|
|
18
|
+
}
|
|
12
19
|
|
|
13
20
|
|
|
14
21
|
class FrontendProxyApp(ProxyAppBase):
|
|
@@ -19,10 +26,11 @@ class FrontendProxyApp(ProxyAppBase):
|
|
|
19
26
|
max_connections: int | None = 10,
|
|
20
27
|
max_keepalive_connections: int | None = None,
|
|
21
28
|
keepalive_expiry: float | None = None,
|
|
22
|
-
retries=0,
|
|
29
|
+
retries: int = 0,
|
|
23
30
|
):
|
|
24
31
|
self._identity_file = identity_file
|
|
25
|
-
self.
|
|
32
|
+
self._jinja_env_cache: dict[Path, Environment] = {}
|
|
33
|
+
self._templates_by_error = get_settings().frontend.get("errors", {})
|
|
26
34
|
super().__init__(
|
|
27
35
|
max_connections=max_connections,
|
|
28
36
|
max_keepalive_connections=max_keepalive_connections,
|
|
@@ -46,30 +54,90 @@ class FrontendProxyApp(ProxyAppBase):
|
|
|
46
54
|
)
|
|
47
55
|
|
|
48
56
|
def get_upstream_base_url(self, scope: Scope) -> str:
|
|
49
|
-
target =
|
|
57
|
+
target = get_target_name(
|
|
50
58
|
{k.decode("latin1"): v.decode("latin1") for k, v in scope.get("headers", {})}
|
|
51
59
|
)
|
|
60
|
+
if not target:
|
|
61
|
+
raise InvalidTargetError()
|
|
62
|
+
|
|
52
63
|
return f"http://{target.lower()}"
|
|
53
64
|
|
|
54
|
-
def
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
|
|
65
|
+
async def send_error_response(
|
|
66
|
+
self,
|
|
67
|
+
scope: Scope,
|
|
68
|
+
send: ASGISend,
|
|
69
|
+
http_status: int,
|
|
70
|
+
body: str,
|
|
71
|
+
headers: list[tuple[bytes, bytes]] | None = None,
|
|
72
|
+
):
|
|
73
|
+
request_headers = {
|
|
74
|
+
k.decode("latin1"): v.decode("latin1") for k, v in scope.get("headers", {})
|
|
75
|
+
}
|
|
76
|
+
accept_header = request_headers.get("accept")
|
|
77
|
+
if not (accept_header and str(http_status) in self._templates_by_error):
|
|
78
|
+
return await super().send_error_response(scope, send, http_status, body)
|
|
61
79
|
|
|
62
|
-
|
|
63
|
-
header_value = headers.get(name, "")
|
|
64
|
-
if self._proxy_domain in header_value:
|
|
65
|
-
if ":" in header_value:
|
|
66
|
-
header_value, _ = header_value.split(":", 1)
|
|
67
|
-
return header_value[: -len(self._proxy_domain)]
|
|
80
|
+
available_templates = self._templates_by_error[str(http_status)]
|
|
68
81
|
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
|
|
82
|
+
media_types = parse_accept_header(accept_header)
|
|
83
|
+
for media_type in media_types:
|
|
84
|
+
template_format = ERROR_TEMPLATE_FORMATS.get(media_type)
|
|
85
|
+
if template_format and template_format in available_templates:
|
|
86
|
+
template_path = available_templates[template_format]
|
|
87
|
+
rendered = await self._render_error_template(
|
|
88
|
+
Path(template_path), scope, http_status, body
|
|
89
|
+
)
|
|
90
|
+
return await super().send_error_response(
|
|
91
|
+
scope,
|
|
92
|
+
send,
|
|
93
|
+
http_status,
|
|
94
|
+
rendered,
|
|
95
|
+
headers=[(b"content-type", media_type.encode("latin-1"))],
|
|
96
|
+
)
|
|
97
|
+
|
|
98
|
+
return await super().send_error_response(scope, send, http_status, body)
|
|
99
|
+
|
|
100
|
+
async def _render_error_template(
|
|
101
|
+
self, template_path: Path, scope: Scope, http_status: int, body: str
|
|
102
|
+
) -> str:
|
|
103
|
+
env = self._get_jinja_env(template_path)
|
|
104
|
+
template = env.get_template(template_path.name)
|
|
105
|
+
status_title = HTTPStatus(http_status).name.replace("_", " ").title()
|
|
106
|
+
context = {
|
|
107
|
+
"status": http_status,
|
|
108
|
+
"status_title": status_title,
|
|
109
|
+
"body": body,
|
|
110
|
+
"request": self._extract_request_context(scope),
|
|
111
|
+
}
|
|
112
|
+
|
|
113
|
+
return await template.render_async(context)
|
|
114
|
+
|
|
115
|
+
def _extract_request_context(self, scope: Scope) -> dict[str, Any]:
|
|
116
|
+
headers = {k.decode("latin-1"): v.decode("latin-1") for k, v in scope.get("headers", [])}
|
|
117
|
+
|
|
118
|
+
return {
|
|
119
|
+
"method": scope.get("method"),
|
|
120
|
+
"path": scope.get("path"),
|
|
121
|
+
"raw_path": scope.get("raw_path", b"").decode("latin-1"),
|
|
122
|
+
"query_string": scope.get("query_string", b"").decode("latin-1"),
|
|
123
|
+
"scheme": scope.get("scheme"),
|
|
124
|
+
"headers": headers,
|
|
125
|
+
"client": scope.get("client"),
|
|
126
|
+
"server": scope.get("server"),
|
|
127
|
+
"http_version": scope.get("http_version"),
|
|
128
|
+
}
|
|
129
|
+
|
|
130
|
+
def _get_jinja_env(self, template_path: Path) -> Environment:
|
|
131
|
+
template_dir = template_path.parent
|
|
132
|
+
|
|
133
|
+
if template_dir not in self._jinja_env_cache:
|
|
134
|
+
self._jinja_env_cache[template_dir] = Environment(
|
|
135
|
+
loader=FileSystemLoader(str(template_dir)),
|
|
136
|
+
autoescape=select_autoescape(
|
|
137
|
+
enabled_extensions=("html", "xml"),
|
|
138
|
+
default_for_string=False,
|
|
139
|
+
),
|
|
140
|
+
enable_async=True,
|
|
141
|
+
)
|
|
142
|
+
|
|
143
|
+
return self._jinja_env_cache[template_dir]
|
mrok/frontend/main.py
CHANGED
|
@@ -7,6 +7,7 @@ from uvicorn_worker import UvicornWorker
|
|
|
7
7
|
|
|
8
8
|
from mrok.conf import get_settings
|
|
9
9
|
from mrok.frontend.app import FrontendProxyApp
|
|
10
|
+
from mrok.frontend.middleware import HealthCheckMiddleware
|
|
10
11
|
from mrok.logging import get_logging_config
|
|
11
12
|
|
|
12
13
|
|
|
@@ -42,11 +43,13 @@ def run(
|
|
|
42
43
|
max_keepalive_connections: int | None,
|
|
43
44
|
keepalive_expiry: float | None,
|
|
44
45
|
):
|
|
45
|
-
app =
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
|
|
46
|
+
app = HealthCheckMiddleware(
|
|
47
|
+
FrontendProxyApp(
|
|
48
|
+
str(identity_file),
|
|
49
|
+
max_connections=max_connections,
|
|
50
|
+
max_keepalive_connections=max_keepalive_connections,
|
|
51
|
+
keepalive_expiry=keepalive_expiry,
|
|
52
|
+
)
|
|
50
53
|
)
|
|
51
54
|
|
|
52
55
|
options = {
|
|
@@ -0,0 +1,35 @@
|
|
|
1
|
+
import json
|
|
2
|
+
|
|
3
|
+
from mrok.frontend.utils import get_target_name
|
|
4
|
+
from mrok.types.proxy import ASGIApp, ASGIReceive, ASGISend, Scope
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
class HealthCheckMiddleware:
|
|
8
|
+
def __init__(self, app: ASGIApp):
|
|
9
|
+
self.app = app
|
|
10
|
+
|
|
11
|
+
async def __call__(self, scope: Scope, receive: ASGIReceive, send: ASGISend):
|
|
12
|
+
if scope["type"] == "http" and scope["path"] == "/healthcheck":
|
|
13
|
+
target = get_target_name(
|
|
14
|
+
{k.decode("latin1"): v.decode("latin1") for k, v in scope.get("headers", {})}
|
|
15
|
+
)
|
|
16
|
+
|
|
17
|
+
if not target:
|
|
18
|
+
await send(
|
|
19
|
+
{
|
|
20
|
+
"type": "http.response.start",
|
|
21
|
+
"status": 200,
|
|
22
|
+
"headers": [
|
|
23
|
+
[b"content-type", b"application/json"],
|
|
24
|
+
],
|
|
25
|
+
}
|
|
26
|
+
)
|
|
27
|
+
await send(
|
|
28
|
+
{
|
|
29
|
+
"type": "http.response.body",
|
|
30
|
+
"body": json.dumps({"status": "healthy"}).encode("utf-8"),
|
|
31
|
+
}
|
|
32
|
+
)
|
|
33
|
+
return
|
|
34
|
+
|
|
35
|
+
await self.app(scope, receive, send)
|
mrok/frontend/utils.py
ADDED
|
@@ -0,0 +1,83 @@
|
|
|
1
|
+
import re
|
|
2
|
+
|
|
3
|
+
from mrok.conf import get_settings
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
def parse_accept_header(accept: str | None) -> list[str]:
|
|
7
|
+
if not accept:
|
|
8
|
+
return ["*/*"]
|
|
9
|
+
|
|
10
|
+
result: list[tuple[str, float, int]] = []
|
|
11
|
+
|
|
12
|
+
for index, item in enumerate(accept.split(",")):
|
|
13
|
+
item = item.strip()
|
|
14
|
+
if not item:
|
|
15
|
+
continue
|
|
16
|
+
|
|
17
|
+
parts = [p.strip() for p in item.split(";")]
|
|
18
|
+
media_type = parts[0].lower()
|
|
19
|
+
|
|
20
|
+
q = 1.0
|
|
21
|
+
for param in parts[1:]:
|
|
22
|
+
if param.startswith("q="):
|
|
23
|
+
try:
|
|
24
|
+
q = float(param[2:])
|
|
25
|
+
except ValueError:
|
|
26
|
+
q = 0.0
|
|
27
|
+
|
|
28
|
+
result.append((media_type, q, index))
|
|
29
|
+
|
|
30
|
+
# Sort by:
|
|
31
|
+
# 1) q value (desc)
|
|
32
|
+
# 2) specificity (more specific first)
|
|
33
|
+
# 3) original order (stable)
|
|
34
|
+
result.sort(
|
|
35
|
+
key=lambda x: (
|
|
36
|
+
-x[1],
|
|
37
|
+
-_media_type_specificity(x[0]),
|
|
38
|
+
x[2],
|
|
39
|
+
)
|
|
40
|
+
)
|
|
41
|
+
|
|
42
|
+
return [media_type for media_type, _, _ in result]
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
def _media_type_specificity(media_type: str) -> int:
|
|
46
|
+
if media_type == "*/*":
|
|
47
|
+
return 0
|
|
48
|
+
if media_type.endswith("/*"):
|
|
49
|
+
return 1
|
|
50
|
+
return 2
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
def get_frontend_domain():
|
|
54
|
+
settings = get_settings()
|
|
55
|
+
return (
|
|
56
|
+
settings.frontend.domain
|
|
57
|
+
if settings.frontend.domain[0] == "."
|
|
58
|
+
else f".{settings.frontend.domain}"
|
|
59
|
+
)
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
def _get_target_from_header(headers: dict[str, str], name: str) -> str | None:
|
|
63
|
+
domain_name = get_frontend_domain()
|
|
64
|
+
header_value = headers.get(name, "")
|
|
65
|
+
if domain_name in header_value:
|
|
66
|
+
if ":" in header_value:
|
|
67
|
+
header_value, _ = header_value.split(":", 1)
|
|
68
|
+
return header_value[: -len(domain_name)]
|
|
69
|
+
|
|
70
|
+
|
|
71
|
+
def get_target_name(headers: dict[str, str]) -> str | None:
|
|
72
|
+
settings = get_settings()
|
|
73
|
+
|
|
74
|
+
target = _get_target_from_header(headers, "x-forwarded-host")
|
|
75
|
+
if not target:
|
|
76
|
+
target = _get_target_from_header(headers, "host")
|
|
77
|
+
|
|
78
|
+
if target and (
|
|
79
|
+
re.fullmatch(settings.identifiers.extension.regex, target)
|
|
80
|
+
or re.fullmatch(settings.identifiers.instance.regex, target)
|
|
81
|
+
):
|
|
82
|
+
return target
|
|
83
|
+
return None
|
mrok/logging.py
CHANGED
|
@@ -1,12 +1,14 @@
|
|
|
1
|
+
import logging
|
|
1
2
|
import logging.config
|
|
2
3
|
|
|
3
|
-
from rich.console import Console
|
|
4
|
-
from rich.logging import RichHandler
|
|
5
|
-
from textual_serve.server import LogHighlighter
|
|
6
|
-
|
|
7
4
|
from mrok.conf import Settings
|
|
8
5
|
|
|
9
6
|
|
|
7
|
+
class HealthCheckFilter(logging.Filter):
|
|
8
|
+
def filter(self, record):
|
|
9
|
+
return "/healthcheck" not in record.getMessage()
|
|
10
|
+
|
|
11
|
+
|
|
10
12
|
def get_logging_config(settings: Settings, cli_mode: bool = False) -> dict:
|
|
11
13
|
log_level = "DEBUG" if settings.logging.debug else "INFO"
|
|
12
14
|
handler = "rich" if settings.logging.rich else "console"
|
|
@@ -30,6 +32,11 @@ def get_logging_config(settings: Settings, cli_mode: bool = False) -> dict:
|
|
|
30
32
|
},
|
|
31
33
|
"plain": {"format": "%(message)s"},
|
|
32
34
|
},
|
|
35
|
+
"filters": {
|
|
36
|
+
"healthcheck_filter": {
|
|
37
|
+
"()": HealthCheckFilter,
|
|
38
|
+
}
|
|
39
|
+
},
|
|
33
40
|
"handlers": {
|
|
34
41
|
"console": {
|
|
35
42
|
"class": "logging.StreamHandler",
|
|
@@ -58,17 +65,30 @@ def get_logging_config(settings: Settings, cli_mode: bool = False) -> dict:
|
|
|
58
65
|
"handlers": [handler],
|
|
59
66
|
"level": log_level,
|
|
60
67
|
"propagate": False,
|
|
68
|
+
"filters": ["healthcheck_filter"],
|
|
61
69
|
},
|
|
62
70
|
"gunicorn.error": {
|
|
63
71
|
"handlers": [handler],
|
|
64
72
|
"level": log_level,
|
|
65
73
|
"propagate": False,
|
|
66
74
|
},
|
|
75
|
+
"uvicorn.access": {
|
|
76
|
+
"handlers": [handler],
|
|
77
|
+
"level": log_level,
|
|
78
|
+
"propagate": False,
|
|
79
|
+
"filters": ["healthcheck_filter"],
|
|
80
|
+
},
|
|
67
81
|
"mrok": {
|
|
68
82
|
"handlers": [mrok_handler],
|
|
69
83
|
"level": log_level,
|
|
70
84
|
"propagate": False,
|
|
71
85
|
},
|
|
86
|
+
"mrok.access": {
|
|
87
|
+
"handlers": [mrok_handler],
|
|
88
|
+
"level": log_level,
|
|
89
|
+
"propagate": False,
|
|
90
|
+
"filters": ["healthcheck_filter"],
|
|
91
|
+
},
|
|
72
92
|
},
|
|
73
93
|
}
|
|
74
94
|
|
|
@@ -78,21 +98,3 @@ def get_logging_config(settings: Settings, cli_mode: bool = False) -> dict:
|
|
|
78
98
|
def setup_logging(settings: Settings, cli_mode: bool = False) -> None:
|
|
79
99
|
logging_config = get_logging_config(settings, cli_mode)
|
|
80
100
|
logging.config.dictConfig(logging_config)
|
|
81
|
-
|
|
82
|
-
|
|
83
|
-
def setup_inspector_logging(console: Console) -> None:
|
|
84
|
-
logging.basicConfig(
|
|
85
|
-
level="WARNING",
|
|
86
|
-
format="%(message)s",
|
|
87
|
-
datefmt="%Y-%m-%d %H:%M:%S",
|
|
88
|
-
handlers=[
|
|
89
|
-
RichHandler(
|
|
90
|
-
show_path=False,
|
|
91
|
-
show_time=True,
|
|
92
|
-
rich_tracebacks=True,
|
|
93
|
-
tracebacks_show_locals=True,
|
|
94
|
-
highlighter=LogHighlighter(),
|
|
95
|
-
console=console,
|
|
96
|
-
)
|
|
97
|
-
],
|
|
98
|
-
)
|
mrok/proxy/app.py
CHANGED
|
@@ -57,7 +57,7 @@ class ProxyAppBase(abc.ABC):
|
|
|
57
57
|
return
|
|
58
58
|
|
|
59
59
|
if scope.get("type") != "http":
|
|
60
|
-
await self.
|
|
60
|
+
await self.send_error_response(scope, send, 500, "Unsupported")
|
|
61
61
|
return
|
|
62
62
|
|
|
63
63
|
try:
|
|
@@ -105,15 +105,23 @@ class ProxyAppBase(abc.ABC):
|
|
|
105
105
|
await response.aclose()
|
|
106
106
|
|
|
107
107
|
except ProxyError as pe:
|
|
108
|
-
await self.
|
|
108
|
+
await self.send_error_response(scope, send, pe.http_status, pe.message)
|
|
109
109
|
|
|
110
110
|
except Exception:
|
|
111
111
|
logger.exception("Unexpected error in forwarder")
|
|
112
|
-
await self.
|
|
112
|
+
await self.send_error_response(scope, send, 502, "Bad Gateway")
|
|
113
113
|
|
|
114
|
-
async def
|
|
114
|
+
async def send_error_response(
|
|
115
|
+
self,
|
|
116
|
+
scope: Scope,
|
|
117
|
+
send: ASGISend,
|
|
118
|
+
http_status: int,
|
|
119
|
+
body: str,
|
|
120
|
+
headers: list[tuple[bytes, bytes]] | None = None,
|
|
121
|
+
):
|
|
122
|
+
headers = headers or [(b"content-type", b"text/plain")]
|
|
115
123
|
try:
|
|
116
|
-
await send({"type": "http.response.start", "status": http_status, "headers":
|
|
124
|
+
await send({"type": "http.response.start", "status": http_status, "headers": headers})
|
|
117
125
|
await send({"type": "http.response.body", "body": body.encode()})
|
|
118
126
|
except Exception as e: # pragma: no cover
|
|
119
127
|
logger.error(f"Cannot send error response: {e}")
|
mrok/proxy/middleware.py
CHANGED
|
@@ -2,10 +2,8 @@ import asyncio
|
|
|
2
2
|
import logging
|
|
3
3
|
import time
|
|
4
4
|
|
|
5
|
-
from mrok.proxy.constants import MAX_REQUEST_BODY_BYTES, MAX_RESPONSE_BODY_BYTES
|
|
6
5
|
from mrok.proxy.metrics import MetricsCollector
|
|
7
6
|
from mrok.proxy.models import FixedSizeByteBuffer, HTTPHeaders, HTTPRequest, HTTPResponse
|
|
8
|
-
from mrok.proxy.utils import must_capture_request, must_capture_response
|
|
9
7
|
from mrok.types.proxy import (
|
|
10
8
|
ASGIApp,
|
|
11
9
|
ASGIReceive,
|
|
@@ -15,6 +13,9 @@ from mrok.types.proxy import (
|
|
|
15
13
|
Scope,
|
|
16
14
|
)
|
|
17
15
|
|
|
16
|
+
MAX_REQUEST_BODY_BYTES = 2 * 1024 * 1024
|
|
17
|
+
MAX_RESPONSE_BODY_BYTES = 5 * 1024 * 1024
|
|
18
|
+
|
|
18
19
|
logger = logging.getLogger("mrok.proxy")
|
|
19
20
|
|
|
20
21
|
|
|
@@ -42,7 +43,7 @@ class CaptureMiddleware:
|
|
|
42
43
|
state = {}
|
|
43
44
|
|
|
44
45
|
req_buf = FixedSizeByteBuffer(MAX_REQUEST_BODY_BYTES)
|
|
45
|
-
capture_req_body =
|
|
46
|
+
capture_req_body = method.upper() not in ("GET", "HEAD", "OPTIONS", "TRACE")
|
|
46
47
|
|
|
47
48
|
request = HTTPRequest(
|
|
48
49
|
method=method,
|
|
@@ -68,9 +69,7 @@ class CaptureMiddleware:
|
|
|
68
69
|
resp_headers = HTTPHeaders.from_asgi(msg.get("headers", []))
|
|
69
70
|
state["resp_headers_raw"] = resp_headers
|
|
70
71
|
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
if state["capture_resp_body"] and msg["type"] == "http.response.body":
|
|
72
|
+
if msg["type"] == "http.response.body":
|
|
74
73
|
body = msg.get("body", b"")
|
|
75
74
|
resp_buf.write(body)
|
|
76
75
|
|
|
@@ -91,8 +90,8 @@ class CaptureMiddleware:
|
|
|
91
90
|
status=state["status"] or 0,
|
|
92
91
|
headers=state["resp_headers_raw"],
|
|
93
92
|
duration=duration,
|
|
94
|
-
body=resp_buf.getvalue()
|
|
95
|
-
body_truncated=resp_buf.overflow
|
|
93
|
+
body=resp_buf.getvalue(),
|
|
94
|
+
body_truncated=resp_buf.overflow,
|
|
96
95
|
)
|
|
97
96
|
asyncio.create_task(self._on_response_complete(response))
|
|
98
97
|
|
mrok/proxy/models.py
CHANGED
|
@@ -1,13 +1,40 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
|
|
3
|
+
import base64
|
|
4
|
+
import binascii
|
|
3
5
|
import json
|
|
6
|
+
from collections.abc import Mapping
|
|
4
7
|
from pathlib import Path
|
|
5
|
-
from typing import Literal
|
|
8
|
+
from typing import Annotated, Any, Literal
|
|
6
9
|
|
|
7
10
|
from pydantic import BaseModel, ConfigDict, Field, field_validator
|
|
11
|
+
from pydantic.functional_serializers import PlainSerializer
|
|
12
|
+
from pydantic.functional_validators import PlainValidator
|
|
8
13
|
from pydantic_core import core_schema
|
|
9
14
|
|
|
10
15
|
|
|
16
|
+
def serialize_b64(v: bytes) -> str:
|
|
17
|
+
return base64.b64encode(v).decode("ascii")
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
def deserialize_b64(v):
|
|
21
|
+
if isinstance(v, bytes):
|
|
22
|
+
return v
|
|
23
|
+
if isinstance(v, str):
|
|
24
|
+
try:
|
|
25
|
+
return base64.b64decode(v, validate=True)
|
|
26
|
+
except binascii.Error as e: # pragma: no branch
|
|
27
|
+
raise ValueError("Invalid base64 data") from e
|
|
28
|
+
raise TypeError("Expected bytes or base64 string") # pragma: no cover
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
Base64Bytes = Annotated[
|
|
32
|
+
bytes,
|
|
33
|
+
PlainValidator(deserialize_b64),
|
|
34
|
+
PlainSerializer(serialize_b64, return_type=str),
|
|
35
|
+
]
|
|
36
|
+
|
|
37
|
+
|
|
11
38
|
class X509Credentials(BaseModel):
|
|
12
39
|
key: str
|
|
13
40
|
cert: str
|
|
@@ -16,14 +43,13 @@ class X509Credentials(BaseModel):
|
|
|
16
43
|
@field_validator("key", "cert", "ca", mode="before")
|
|
17
44
|
@classmethod
|
|
18
45
|
def strip_pem_prefix(cls, value: str) -> str:
|
|
19
|
-
if isinstance(value, str) and value.startswith("pem:"):
|
|
46
|
+
if isinstance(value, str) and value.startswith("pem:"): # pragma: no branch
|
|
20
47
|
return value[4:]
|
|
21
|
-
return value
|
|
48
|
+
return value # pragma: no cover
|
|
22
49
|
|
|
23
50
|
|
|
24
51
|
class ServiceMetadata(BaseModel):
|
|
25
52
|
model_config = ConfigDict(extra="ignore")
|
|
26
|
-
identity: str
|
|
27
53
|
extension: str
|
|
28
54
|
instance: str
|
|
29
55
|
domain: str | None = None
|
|
@@ -34,10 +60,10 @@ class Identity(BaseModel):
|
|
|
34
60
|
model_config = ConfigDict(extra="ignore")
|
|
35
61
|
zt_api: str = Field(validation_alias="ztAPI")
|
|
36
62
|
id: X509Credentials
|
|
63
|
+
mrok: ServiceMetadata
|
|
64
|
+
enable_ha: bool = Field(default=False, validation_alias="enableHa")
|
|
37
65
|
zt_apis: str | None = Field(default=None, validation_alias="ztAPIs")
|
|
38
66
|
config_types: str | None = Field(default=None, validation_alias="configTypes")
|
|
39
|
-
enable_ha: bool = Field(default=False, validation_alias="enableHa")
|
|
40
|
-
mrok: ServiceMetadata | None = None
|
|
41
67
|
|
|
42
68
|
@staticmethod
|
|
43
69
|
def load_from_file(path: str | Path) -> Identity:
|
|
@@ -77,7 +103,7 @@ class FixedSizeByteBuffer:
|
|
|
77
103
|
|
|
78
104
|
|
|
79
105
|
class HTTPHeaders(dict):
|
|
80
|
-
def __init__(self, initial=None):
|
|
106
|
+
def __init__(self, initial: Mapping[Any, Any] | None = None):
|
|
81
107
|
super().__init__()
|
|
82
108
|
if initial:
|
|
83
109
|
for k, v in initial.items():
|
|
@@ -131,9 +157,9 @@ class HTTPRequest(BaseModel):
|
|
|
131
157
|
method: str
|
|
132
158
|
url: str
|
|
133
159
|
headers: HTTPHeaders
|
|
134
|
-
query_string:
|
|
160
|
+
query_string: Base64Bytes | None = None
|
|
135
161
|
start_time: float
|
|
136
|
-
body:
|
|
162
|
+
body: Base64Bytes | None = None
|
|
137
163
|
body_truncated: bool | None = None
|
|
138
164
|
|
|
139
165
|
|
|
@@ -143,7 +169,7 @@ class HTTPResponse(BaseModel):
|
|
|
143
169
|
status: int
|
|
144
170
|
headers: HTTPHeaders
|
|
145
171
|
duration: float
|
|
146
|
-
body:
|
|
172
|
+
body: Base64Bytes | None = None
|
|
147
173
|
body_truncated: bool | None = None
|
|
148
174
|
|
|
149
175
|
|
mrok/proxy/ziticorn.py
CHANGED
|
@@ -1,4 +1,3 @@
|
|
|
1
|
-
import json
|
|
2
1
|
import logging
|
|
3
2
|
import socket
|
|
4
3
|
from collections.abc import Callable
|
|
@@ -10,6 +9,7 @@ from uvicorn import config, server
|
|
|
10
9
|
from uvicorn.lifespan.on import LifespanOn
|
|
11
10
|
from uvicorn.protocols.http.httptools_impl import HttpToolsProtocol as UvHttpToolsProtocol
|
|
12
11
|
|
|
12
|
+
from mrok.proxy.models import Identity
|
|
13
13
|
from mrok.types.proxy import ASGIApp
|
|
14
14
|
|
|
15
15
|
logger = logging.getLogger("mrok.proxy")
|
|
@@ -48,10 +48,8 @@ class BackendConfig(config.Config):
|
|
|
48
48
|
backlog: int = 2048,
|
|
49
49
|
):
|
|
50
50
|
self.identity_file = identity_file
|
|
51
|
+
self.identity = Identity.load_from_file(self.identity_file)
|
|
51
52
|
self.ziti_load_timeout_ms = ziti_load_timeout_ms
|
|
52
|
-
self.service_name, self.identity_name, self.instance_id = self.get_identity_info(
|
|
53
|
-
identity_file
|
|
54
|
-
)
|
|
55
53
|
super().__init__(
|
|
56
54
|
app,
|
|
57
55
|
loop="asyncio",
|
|
@@ -59,26 +57,19 @@ class BackendConfig(config.Config):
|
|
|
59
57
|
backlog=backlog,
|
|
60
58
|
)
|
|
61
59
|
|
|
62
|
-
def get_identity_info(self, identity_file: str | Path):
|
|
63
|
-
with open(identity_file) as f:
|
|
64
|
-
identity_data = json.load(f)
|
|
65
|
-
try:
|
|
66
|
-
identity_name = identity_data["mrok"]["identity"]
|
|
67
|
-
instance_id, service_name = identity_name.split(".", 1)
|
|
68
|
-
return service_name, identity_name, instance_id
|
|
69
|
-
except KeyError:
|
|
70
|
-
raise ValueError("Invalid identity file: identity file is not mrok compatible.")
|
|
71
|
-
|
|
72
60
|
def bind_socket(self) -> socket.socket:
|
|
73
|
-
logger.info(
|
|
61
|
+
logger.info(
|
|
62
|
+
"Connect to Ziti service "
|
|
63
|
+
f"'{self.identity.mrok.extension} ({self.identity.mrok.instance})'"
|
|
64
|
+
)
|
|
74
65
|
|
|
75
66
|
ctx, err = openziti.load(str(self.identity_file), timeout=self.ziti_load_timeout_ms)
|
|
76
67
|
if err != 0:
|
|
77
68
|
raise RuntimeError(f"Failed to load Ziti identity from {self.identity_file}: {err}")
|
|
78
69
|
|
|
79
|
-
sock = ctx.bind(self.
|
|
70
|
+
sock = ctx.bind(self.identity.mrok.extension)
|
|
80
71
|
sock.listen(self.backlog)
|
|
81
|
-
logger.info(f"listening on ziti service {self.
|
|
72
|
+
logger.info(f"listening on ziti service {self.identity.mrok.extension} for connections")
|
|
82
73
|
return sock
|
|
83
74
|
|
|
84
75
|
def configure_logging(self) -> None:
|