mrok 0.6.0__py3-none-any.whl → 0.8.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (43) hide show
  1. mrok/agent/devtools/inspector/__main__.py +3 -23
  2. mrok/agent/devtools/inspector/app.py +407 -112
  3. mrok/agent/devtools/inspector/utils.py +149 -0
  4. mrok/cli/commands/admin/bootstrap.py +2 -2
  5. mrok/cli/commands/admin/register/extensions.py +7 -9
  6. mrok/cli/commands/admin/register/instances.py +13 -16
  7. mrok/cli/commands/admin/unregister/extensions.py +7 -11
  8. mrok/cli/commands/admin/unregister/instances.py +12 -12
  9. mrok/cli/commands/agent/run/asgi.py +1 -1
  10. mrok/cli/commands/frontend/run.py +1 -1
  11. mrok/cli/main.py +17 -1
  12. mrok/cli/utils.py +26 -0
  13. mrok/conf.py +15 -7
  14. mrok/constants.py +21 -0
  15. mrok/controller/app.py +12 -10
  16. mrok/controller/auth/__init__.py +11 -0
  17. mrok/controller/auth/backends.py +60 -0
  18. mrok/controller/auth/base.py +38 -0
  19. mrok/controller/auth/manager.py +31 -0
  20. mrok/controller/auth/registry.py +17 -0
  21. mrok/frontend/app.py +94 -26
  22. mrok/frontend/main.py +8 -5
  23. mrok/frontend/middleware.py +35 -0
  24. mrok/frontend/utils.py +83 -0
  25. mrok/logging.py +24 -22
  26. mrok/proxy/app.py +13 -5
  27. mrok/proxy/middleware.py +7 -8
  28. mrok/proxy/models.py +36 -10
  29. mrok/proxy/ziticorn.py +8 -17
  30. mrok/ziti/api.py +4 -4
  31. mrok/ziti/bootstrap.py +0 -5
  32. mrok/ziti/identities.py +11 -10
  33. mrok/ziti/services.py +6 -6
  34. {mrok-0.6.0.dist-info → mrok-0.8.0.dist-info}/METADATA +9 -3
  35. {mrok-0.6.0.dist-info → mrok-0.8.0.dist-info}/RECORD +38 -35
  36. mrok/agent/devtools/__main__.py +0 -34
  37. mrok/cli/commands/agent/utils.py +0 -5
  38. mrok/controller/auth.py +0 -87
  39. mrok/proxy/constants.py +0 -22
  40. mrok/proxy/utils.py +0 -90
  41. {mrok-0.6.0.dist-info → mrok-0.8.0.dist-info}/WHEEL +0 -0
  42. {mrok-0.6.0.dist-info → mrok-0.8.0.dist-info}/entry_points.txt +0 -0
  43. {mrok-0.6.0.dist-info → mrok-0.8.0.dist-info}/licenses/LICENSE.txt +0 -0
@@ -0,0 +1,31 @@
1
+ from dynaconf.utils.boxing import DynaBox
2
+ from fastapi import Request
3
+
4
+ from mrok.controller.auth.base import UNAUTHORIZED_EXCEPTION, AuthIdentity, BaseHTTPAuthBackend
5
+ from mrok.controller.auth.registry import get_authentication_backend
6
+
7
+
8
+ class HTTPAuthManager:
9
+ def __init__(self, auth_settings: DynaBox):
10
+ self.auth_settings = auth_settings
11
+ self.active_backends: list[BaseHTTPAuthBackend] = []
12
+ self._setup_backends()
13
+
14
+ def _setup_backends(self):
15
+ enabled_keys = self.auth_settings.get("backends", [])
16
+
17
+ for key in enabled_keys:
18
+ backend_cls = get_authentication_backend(key)
19
+ if not backend_cls:
20
+ raise ValueError(f"Backend '{key}' is not registered.")
21
+
22
+ specific_config = self.auth_settings.get(key, {})
23
+ self.active_backends.append(backend_cls(specific_config))
24
+
25
+ async def __call__(self, request: Request) -> AuthIdentity:
26
+ for backend in self.active_backends:
27
+ identity = await backend(request)
28
+ if identity:
29
+ return identity
30
+
31
+ raise UNAUTHORIZED_EXCEPTION
@@ -0,0 +1,17 @@
1
+ from mrok.controller.auth.base import BaseHTTPAuthBackend
2
+
3
+ BACKEND_REGISTRY: dict[str, type[BaseHTTPAuthBackend]] = {}
4
+
5
+
6
+ def register_authentication_backend(name: str):
7
+ """Decorator to register a backend class with a unique key."""
8
+
9
+ def decorator(cls: type[BaseHTTPAuthBackend]):
10
+ BACKEND_REGISTRY[name] = cls
11
+ return cls
12
+
13
+ return decorator
14
+
15
+
16
+ def get_authentication_backend(name: str) -> type[BaseHTTPAuthBackend] | None:
17
+ return BACKEND_REGISTRY.get(name)
mrok/frontend/app.py CHANGED
@@ -1,14 +1,21 @@
1
- import re
1
+ from http import HTTPStatus
2
+ from pathlib import Path
3
+ from typing import Any
2
4
 
3
5
  from httpcore import AsyncConnectionPool
6
+ from jinja2 import Environment, FileSystemLoader, select_autoescape
4
7
 
5
8
  from mrok.conf import get_settings
9
+ from mrok.frontend.utils import get_target_name, parse_accept_header
6
10
  from mrok.proxy.app import ProxyAppBase
7
11
  from mrok.proxy.backend import AIOZitiNetworkBackend
8
12
  from mrok.proxy.exceptions import InvalidTargetError
9
- from mrok.types.proxy import Scope
13
+ from mrok.types.proxy import ASGISend, Scope
10
14
 
11
- RE_SUBDOMAIN = re.compile(r"(?i)^(?:EXT-\d{4}-\d{4}|INS-\d{4}-\d{4}-\d{4})$")
15
+ ERROR_TEMPLATE_FORMATS = {
16
+ "application/json": "json",
17
+ "text/html": "html",
18
+ }
12
19
 
13
20
 
14
21
  class FrontendProxyApp(ProxyAppBase):
@@ -19,10 +26,11 @@ class FrontendProxyApp(ProxyAppBase):
19
26
  max_connections: int | None = 10,
20
27
  max_keepalive_connections: int | None = None,
21
28
  keepalive_expiry: float | None = None,
22
- retries=0,
29
+ retries: int = 0,
23
30
  ):
24
31
  self._identity_file = identity_file
25
- self._proxy_domain = self._get_proxy_domain()
32
+ self._jinja_env_cache: dict[Path, Environment] = {}
33
+ self._templates_by_error = get_settings().frontend.get("errors", {})
26
34
  super().__init__(
27
35
  max_connections=max_connections,
28
36
  max_keepalive_connections=max_keepalive_connections,
@@ -46,30 +54,90 @@ class FrontendProxyApp(ProxyAppBase):
46
54
  )
47
55
 
48
56
  def get_upstream_base_url(self, scope: Scope) -> str:
49
- target = self._get_target_name(
57
+ target = get_target_name(
50
58
  {k.decode("latin1"): v.decode("latin1") for k, v in scope.get("headers", {})}
51
59
  )
60
+ if not target:
61
+ raise InvalidTargetError()
62
+
52
63
  return f"http://{target.lower()}"
53
64
 
54
- def _get_proxy_domain(self):
55
- settings = get_settings()
56
- return (
57
- settings.proxy.domain
58
- if settings.proxy.domain[0] == "."
59
- else f".{settings.proxy.domain}"
60
- )
65
+ async def send_error_response(
66
+ self,
67
+ scope: Scope,
68
+ send: ASGISend,
69
+ http_status: int,
70
+ body: str,
71
+ headers: list[tuple[bytes, bytes]] | None = None,
72
+ ):
73
+ request_headers = {
74
+ k.decode("latin1"): v.decode("latin1") for k, v in scope.get("headers", {})
75
+ }
76
+ accept_header = request_headers.get("accept")
77
+ if not (accept_header and str(http_status) in self._templates_by_error):
78
+ return await super().send_error_response(scope, send, http_status, body)
61
79
 
62
- def _get_target_from_header(self, headers: dict[str, str], name: str) -> str | None:
63
- header_value = headers.get(name, "")
64
- if self._proxy_domain in header_value:
65
- if ":" in header_value:
66
- header_value, _ = header_value.split(":", 1)
67
- return header_value[: -len(self._proxy_domain)]
80
+ available_templates = self._templates_by_error[str(http_status)]
68
81
 
69
- def _get_target_name(self, headers: dict[str, str]) -> str:
70
- target = self._get_target_from_header(headers, "x-forwarded-host")
71
- if not target:
72
- target = self._get_target_from_header(headers, "host")
73
- if not target or not RE_SUBDOMAIN.fullmatch(target):
74
- raise InvalidTargetError()
75
- return target
82
+ media_types = parse_accept_header(accept_header)
83
+ for media_type in media_types:
84
+ template_format = ERROR_TEMPLATE_FORMATS.get(media_type)
85
+ if template_format and template_format in available_templates:
86
+ template_path = available_templates[template_format]
87
+ rendered = await self._render_error_template(
88
+ Path(template_path), scope, http_status, body
89
+ )
90
+ return await super().send_error_response(
91
+ scope,
92
+ send,
93
+ http_status,
94
+ rendered,
95
+ headers=[(b"content-type", media_type.encode("latin-1"))],
96
+ )
97
+
98
+ return await super().send_error_response(scope, send, http_status, body)
99
+
100
+ async def _render_error_template(
101
+ self, template_path: Path, scope: Scope, http_status: int, body: str
102
+ ) -> str:
103
+ env = self._get_jinja_env(template_path)
104
+ template = env.get_template(template_path.name)
105
+ status_title = HTTPStatus(http_status).name.replace("_", " ").title()
106
+ context = {
107
+ "status": http_status,
108
+ "status_title": status_title,
109
+ "body": body,
110
+ "request": self._extract_request_context(scope),
111
+ }
112
+
113
+ return await template.render_async(context)
114
+
115
+ def _extract_request_context(self, scope: Scope) -> dict[str, Any]:
116
+ headers = {k.decode("latin-1"): v.decode("latin-1") for k, v in scope.get("headers", [])}
117
+
118
+ return {
119
+ "method": scope.get("method"),
120
+ "path": scope.get("path"),
121
+ "raw_path": scope.get("raw_path", b"").decode("latin-1"),
122
+ "query_string": scope.get("query_string", b"").decode("latin-1"),
123
+ "scheme": scope.get("scheme"),
124
+ "headers": headers,
125
+ "client": scope.get("client"),
126
+ "server": scope.get("server"),
127
+ "http_version": scope.get("http_version"),
128
+ }
129
+
130
+ def _get_jinja_env(self, template_path: Path) -> Environment:
131
+ template_dir = template_path.parent
132
+
133
+ if template_dir not in self._jinja_env_cache:
134
+ self._jinja_env_cache[template_dir] = Environment(
135
+ loader=FileSystemLoader(str(template_dir)),
136
+ autoescape=select_autoescape(
137
+ enabled_extensions=("html", "xml"),
138
+ default_for_string=False,
139
+ ),
140
+ enable_async=True,
141
+ )
142
+
143
+ return self._jinja_env_cache[template_dir]
mrok/frontend/main.py CHANGED
@@ -7,6 +7,7 @@ from uvicorn_worker import UvicornWorker
7
7
 
8
8
  from mrok.conf import get_settings
9
9
  from mrok.frontend.app import FrontendProxyApp
10
+ from mrok.frontend.middleware import HealthCheckMiddleware
10
11
  from mrok.logging import get_logging_config
11
12
 
12
13
 
@@ -42,11 +43,13 @@ def run(
42
43
  max_keepalive_connections: int | None,
43
44
  keepalive_expiry: float | None,
44
45
  ):
45
- app = FrontendProxyApp(
46
- str(identity_file),
47
- max_connections=max_connections,
48
- max_keepalive_connections=max_keepalive_connections,
49
- keepalive_expiry=keepalive_expiry,
46
+ app = HealthCheckMiddleware(
47
+ FrontendProxyApp(
48
+ str(identity_file),
49
+ max_connections=max_connections,
50
+ max_keepalive_connections=max_keepalive_connections,
51
+ keepalive_expiry=keepalive_expiry,
52
+ )
50
53
  )
51
54
 
52
55
  options = {
@@ -0,0 +1,35 @@
1
+ import json
2
+
3
+ from mrok.frontend.utils import get_target_name
4
+ from mrok.types.proxy import ASGIApp, ASGIReceive, ASGISend, Scope
5
+
6
+
7
+ class HealthCheckMiddleware:
8
+ def __init__(self, app: ASGIApp):
9
+ self.app = app
10
+
11
+ async def __call__(self, scope: Scope, receive: ASGIReceive, send: ASGISend):
12
+ if scope["type"] == "http" and scope["path"] == "/healthcheck":
13
+ target = get_target_name(
14
+ {k.decode("latin1"): v.decode("latin1") for k, v in scope.get("headers", {})}
15
+ )
16
+
17
+ if not target:
18
+ await send(
19
+ {
20
+ "type": "http.response.start",
21
+ "status": 200,
22
+ "headers": [
23
+ [b"content-type", b"application/json"],
24
+ ],
25
+ }
26
+ )
27
+ await send(
28
+ {
29
+ "type": "http.response.body",
30
+ "body": json.dumps({"status": "healthy"}).encode("utf-8"),
31
+ }
32
+ )
33
+ return
34
+
35
+ await self.app(scope, receive, send)
mrok/frontend/utils.py ADDED
@@ -0,0 +1,83 @@
1
+ import re
2
+
3
+ from mrok.conf import get_settings
4
+
5
+
6
+ def parse_accept_header(accept: str | None) -> list[str]:
7
+ if not accept:
8
+ return ["*/*"]
9
+
10
+ result: list[tuple[str, float, int]] = []
11
+
12
+ for index, item in enumerate(accept.split(",")):
13
+ item = item.strip()
14
+ if not item:
15
+ continue
16
+
17
+ parts = [p.strip() for p in item.split(";")]
18
+ media_type = parts[0].lower()
19
+
20
+ q = 1.0
21
+ for param in parts[1:]:
22
+ if param.startswith("q="):
23
+ try:
24
+ q = float(param[2:])
25
+ except ValueError:
26
+ q = 0.0
27
+
28
+ result.append((media_type, q, index))
29
+
30
+ # Sort by:
31
+ # 1) q value (desc)
32
+ # 2) specificity (more specific first)
33
+ # 3) original order (stable)
34
+ result.sort(
35
+ key=lambda x: (
36
+ -x[1],
37
+ -_media_type_specificity(x[0]),
38
+ x[2],
39
+ )
40
+ )
41
+
42
+ return [media_type for media_type, _, _ in result]
43
+
44
+
45
+ def _media_type_specificity(media_type: str) -> int:
46
+ if media_type == "*/*":
47
+ return 0
48
+ if media_type.endswith("/*"):
49
+ return 1
50
+ return 2
51
+
52
+
53
+ def get_frontend_domain():
54
+ settings = get_settings()
55
+ return (
56
+ settings.frontend.domain
57
+ if settings.frontend.domain[0] == "."
58
+ else f".{settings.frontend.domain}"
59
+ )
60
+
61
+
62
+ def _get_target_from_header(headers: dict[str, str], name: str) -> str | None:
63
+ domain_name = get_frontend_domain()
64
+ header_value = headers.get(name, "")
65
+ if domain_name in header_value:
66
+ if ":" in header_value:
67
+ header_value, _ = header_value.split(":", 1)
68
+ return header_value[: -len(domain_name)]
69
+
70
+
71
+ def get_target_name(headers: dict[str, str]) -> str | None:
72
+ settings = get_settings()
73
+
74
+ target = _get_target_from_header(headers, "x-forwarded-host")
75
+ if not target:
76
+ target = _get_target_from_header(headers, "host")
77
+
78
+ if target and (
79
+ re.fullmatch(settings.identifiers.extension.regex, target)
80
+ or re.fullmatch(settings.identifiers.instance.regex, target)
81
+ ):
82
+ return target
83
+ return None
mrok/logging.py CHANGED
@@ -1,12 +1,14 @@
1
+ import logging
1
2
  import logging.config
2
3
 
3
- from rich.console import Console
4
- from rich.logging import RichHandler
5
- from textual_serve.server import LogHighlighter
6
-
7
4
  from mrok.conf import Settings
8
5
 
9
6
 
7
+ class HealthCheckFilter(logging.Filter):
8
+ def filter(self, record):
9
+ return "/healthcheck" not in record.getMessage()
10
+
11
+
10
12
  def get_logging_config(settings: Settings, cli_mode: bool = False) -> dict:
11
13
  log_level = "DEBUG" if settings.logging.debug else "INFO"
12
14
  handler = "rich" if settings.logging.rich else "console"
@@ -30,6 +32,11 @@ def get_logging_config(settings: Settings, cli_mode: bool = False) -> dict:
30
32
  },
31
33
  "plain": {"format": "%(message)s"},
32
34
  },
35
+ "filters": {
36
+ "healthcheck_filter": {
37
+ "()": HealthCheckFilter,
38
+ }
39
+ },
33
40
  "handlers": {
34
41
  "console": {
35
42
  "class": "logging.StreamHandler",
@@ -58,17 +65,30 @@ def get_logging_config(settings: Settings, cli_mode: bool = False) -> dict:
58
65
  "handlers": [handler],
59
66
  "level": log_level,
60
67
  "propagate": False,
68
+ "filters": ["healthcheck_filter"],
61
69
  },
62
70
  "gunicorn.error": {
63
71
  "handlers": [handler],
64
72
  "level": log_level,
65
73
  "propagate": False,
66
74
  },
75
+ "uvicorn.access": {
76
+ "handlers": [handler],
77
+ "level": log_level,
78
+ "propagate": False,
79
+ "filters": ["healthcheck_filter"],
80
+ },
67
81
  "mrok": {
68
82
  "handlers": [mrok_handler],
69
83
  "level": log_level,
70
84
  "propagate": False,
71
85
  },
86
+ "mrok.access": {
87
+ "handlers": [mrok_handler],
88
+ "level": log_level,
89
+ "propagate": False,
90
+ "filters": ["healthcheck_filter"],
91
+ },
72
92
  },
73
93
  }
74
94
 
@@ -78,21 +98,3 @@ def get_logging_config(settings: Settings, cli_mode: bool = False) -> dict:
78
98
  def setup_logging(settings: Settings, cli_mode: bool = False) -> None:
79
99
  logging_config = get_logging_config(settings, cli_mode)
80
100
  logging.config.dictConfig(logging_config)
81
-
82
-
83
- def setup_inspector_logging(console: Console) -> None:
84
- logging.basicConfig(
85
- level="WARNING",
86
- format="%(message)s",
87
- datefmt="%Y-%m-%d %H:%M:%S",
88
- handlers=[
89
- RichHandler(
90
- show_path=False,
91
- show_time=True,
92
- rich_tracebacks=True,
93
- tracebacks_show_locals=True,
94
- highlighter=LogHighlighter(),
95
- console=console,
96
- )
97
- ],
98
- )
mrok/proxy/app.py CHANGED
@@ -57,7 +57,7 @@ class ProxyAppBase(abc.ABC):
57
57
  return
58
58
 
59
59
  if scope.get("type") != "http":
60
- await self._send_error(send, 500, "Unsupported")
60
+ await self.send_error_response(scope, send, 500, "Unsupported")
61
61
  return
62
62
 
63
63
  try:
@@ -105,15 +105,23 @@ class ProxyAppBase(abc.ABC):
105
105
  await response.aclose()
106
106
 
107
107
  except ProxyError as pe:
108
- await self._send_error(send, pe.http_status, pe.message)
108
+ await self.send_error_response(scope, send, pe.http_status, pe.message)
109
109
 
110
110
  except Exception:
111
111
  logger.exception("Unexpected error in forwarder")
112
- await self._send_error(send, 502, "Bad Gateway")
112
+ await self.send_error_response(scope, send, 502, "Bad Gateway")
113
113
 
114
- async def _send_error(self, send: ASGISend, http_status: int, body: str):
114
+ async def send_error_response(
115
+ self,
116
+ scope: Scope,
117
+ send: ASGISend,
118
+ http_status: int,
119
+ body: str,
120
+ headers: list[tuple[bytes, bytes]] | None = None,
121
+ ):
122
+ headers = headers or [(b"content-type", b"text/plain")]
115
123
  try:
116
- await send({"type": "http.response.start", "status": http_status, "headers": []})
124
+ await send({"type": "http.response.start", "status": http_status, "headers": headers})
117
125
  await send({"type": "http.response.body", "body": body.encode()})
118
126
  except Exception as e: # pragma: no cover
119
127
  logger.error(f"Cannot send error response: {e}")
mrok/proxy/middleware.py CHANGED
@@ -2,10 +2,8 @@ import asyncio
2
2
  import logging
3
3
  import time
4
4
 
5
- from mrok.proxy.constants import MAX_REQUEST_BODY_BYTES, MAX_RESPONSE_BODY_BYTES
6
5
  from mrok.proxy.metrics import MetricsCollector
7
6
  from mrok.proxy.models import FixedSizeByteBuffer, HTTPHeaders, HTTPRequest, HTTPResponse
8
- from mrok.proxy.utils import must_capture_request, must_capture_response
9
7
  from mrok.types.proxy import (
10
8
  ASGIApp,
11
9
  ASGIReceive,
@@ -15,6 +13,9 @@ from mrok.types.proxy import (
15
13
  Scope,
16
14
  )
17
15
 
16
+ MAX_REQUEST_BODY_BYTES = 2 * 1024 * 1024
17
+ MAX_RESPONSE_BODY_BYTES = 5 * 1024 * 1024
18
+
18
19
  logger = logging.getLogger("mrok.proxy")
19
20
 
20
21
 
@@ -42,7 +43,7 @@ class CaptureMiddleware:
42
43
  state = {}
43
44
 
44
45
  req_buf = FixedSizeByteBuffer(MAX_REQUEST_BODY_BYTES)
45
- capture_req_body = must_capture_request(method, req_headers)
46
+ capture_req_body = method.upper() not in ("GET", "HEAD", "OPTIONS", "TRACE")
46
47
 
47
48
  request = HTTPRequest(
48
49
  method=method,
@@ -68,9 +69,7 @@ class CaptureMiddleware:
68
69
  resp_headers = HTTPHeaders.from_asgi(msg.get("headers", []))
69
70
  state["resp_headers_raw"] = resp_headers
70
71
 
71
- state["capture_resp_body"] = must_capture_response(resp_headers)
72
-
73
- if state["capture_resp_body"] and msg["type"] == "http.response.body":
72
+ if msg["type"] == "http.response.body":
74
73
  body = msg.get("body", b"")
75
74
  resp_buf.write(body)
76
75
 
@@ -91,8 +90,8 @@ class CaptureMiddleware:
91
90
  status=state["status"] or 0,
92
91
  headers=state["resp_headers_raw"],
93
92
  duration=duration,
94
- body=resp_buf.getvalue() if state["capture_resp_body"] else None,
95
- body_truncated=resp_buf.overflow if state["capture_resp_body"] else None,
93
+ body=resp_buf.getvalue(),
94
+ body_truncated=resp_buf.overflow,
96
95
  )
97
96
  asyncio.create_task(self._on_response_complete(response))
98
97
 
mrok/proxy/models.py CHANGED
@@ -1,13 +1,40 @@
1
1
  from __future__ import annotations
2
2
 
3
+ import base64
4
+ import binascii
3
5
  import json
6
+ from collections.abc import Mapping
4
7
  from pathlib import Path
5
- from typing import Literal
8
+ from typing import Annotated, Any, Literal
6
9
 
7
10
  from pydantic import BaseModel, ConfigDict, Field, field_validator
11
+ from pydantic.functional_serializers import PlainSerializer
12
+ from pydantic.functional_validators import PlainValidator
8
13
  from pydantic_core import core_schema
9
14
 
10
15
 
16
+ def serialize_b64(v: bytes) -> str:
17
+ return base64.b64encode(v).decode("ascii")
18
+
19
+
20
+ def deserialize_b64(v):
21
+ if isinstance(v, bytes):
22
+ return v
23
+ if isinstance(v, str):
24
+ try:
25
+ return base64.b64decode(v, validate=True)
26
+ except binascii.Error as e: # pragma: no branch
27
+ raise ValueError("Invalid base64 data") from e
28
+ raise TypeError("Expected bytes or base64 string") # pragma: no cover
29
+
30
+
31
+ Base64Bytes = Annotated[
32
+ bytes,
33
+ PlainValidator(deserialize_b64),
34
+ PlainSerializer(serialize_b64, return_type=str),
35
+ ]
36
+
37
+
11
38
  class X509Credentials(BaseModel):
12
39
  key: str
13
40
  cert: str
@@ -16,14 +43,13 @@ class X509Credentials(BaseModel):
16
43
  @field_validator("key", "cert", "ca", mode="before")
17
44
  @classmethod
18
45
  def strip_pem_prefix(cls, value: str) -> str:
19
- if isinstance(value, str) and value.startswith("pem:"):
46
+ if isinstance(value, str) and value.startswith("pem:"): # pragma: no branch
20
47
  return value[4:]
21
- return value
48
+ return value # pragma: no cover
22
49
 
23
50
 
24
51
  class ServiceMetadata(BaseModel):
25
52
  model_config = ConfigDict(extra="ignore")
26
- identity: str
27
53
  extension: str
28
54
  instance: str
29
55
  domain: str | None = None
@@ -34,10 +60,10 @@ class Identity(BaseModel):
34
60
  model_config = ConfigDict(extra="ignore")
35
61
  zt_api: str = Field(validation_alias="ztAPI")
36
62
  id: X509Credentials
63
+ mrok: ServiceMetadata
64
+ enable_ha: bool = Field(default=False, validation_alias="enableHa")
37
65
  zt_apis: str | None = Field(default=None, validation_alias="ztAPIs")
38
66
  config_types: str | None = Field(default=None, validation_alias="configTypes")
39
- enable_ha: bool = Field(default=False, validation_alias="enableHa")
40
- mrok: ServiceMetadata | None = None
41
67
 
42
68
  @staticmethod
43
69
  def load_from_file(path: str | Path) -> Identity:
@@ -77,7 +103,7 @@ class FixedSizeByteBuffer:
77
103
 
78
104
 
79
105
  class HTTPHeaders(dict):
80
- def __init__(self, initial=None):
106
+ def __init__(self, initial: Mapping[Any, Any] | None = None):
81
107
  super().__init__()
82
108
  if initial:
83
109
  for k, v in initial.items():
@@ -131,9 +157,9 @@ class HTTPRequest(BaseModel):
131
157
  method: str
132
158
  url: str
133
159
  headers: HTTPHeaders
134
- query_string: bytes
160
+ query_string: Base64Bytes | None = None
135
161
  start_time: float
136
- body: bytes | None = None
162
+ body: Base64Bytes | None = None
137
163
  body_truncated: bool | None = None
138
164
 
139
165
 
@@ -143,7 +169,7 @@ class HTTPResponse(BaseModel):
143
169
  status: int
144
170
  headers: HTTPHeaders
145
171
  duration: float
146
- body: bytes | None = None
172
+ body: Base64Bytes | None = None
147
173
  body_truncated: bool | None = None
148
174
 
149
175
 
mrok/proxy/ziticorn.py CHANGED
@@ -1,4 +1,3 @@
1
- import json
2
1
  import logging
3
2
  import socket
4
3
  from collections.abc import Callable
@@ -10,6 +9,7 @@ from uvicorn import config, server
10
9
  from uvicorn.lifespan.on import LifespanOn
11
10
  from uvicorn.protocols.http.httptools_impl import HttpToolsProtocol as UvHttpToolsProtocol
12
11
 
12
+ from mrok.proxy.models import Identity
13
13
  from mrok.types.proxy import ASGIApp
14
14
 
15
15
  logger = logging.getLogger("mrok.proxy")
@@ -48,10 +48,8 @@ class BackendConfig(config.Config):
48
48
  backlog: int = 2048,
49
49
  ):
50
50
  self.identity_file = identity_file
51
+ self.identity = Identity.load_from_file(self.identity_file)
51
52
  self.ziti_load_timeout_ms = ziti_load_timeout_ms
52
- self.service_name, self.identity_name, self.instance_id = self.get_identity_info(
53
- identity_file
54
- )
55
53
  super().__init__(
56
54
  app,
57
55
  loop="asyncio",
@@ -59,26 +57,19 @@ class BackendConfig(config.Config):
59
57
  backlog=backlog,
60
58
  )
61
59
 
62
- def get_identity_info(self, identity_file: str | Path):
63
- with open(identity_file) as f:
64
- identity_data = json.load(f)
65
- try:
66
- identity_name = identity_data["mrok"]["identity"]
67
- instance_id, service_name = identity_name.split(".", 1)
68
- return service_name, identity_name, instance_id
69
- except KeyError:
70
- raise ValueError("Invalid identity file: identity file is not mrok compatible.")
71
-
72
60
  def bind_socket(self) -> socket.socket:
73
- logger.info(f"Connect to Ziti service '{self.service_name} ({self.instance_id})'")
61
+ logger.info(
62
+ "Connect to Ziti service "
63
+ f"'{self.identity.mrok.extension} ({self.identity.mrok.instance})'"
64
+ )
74
65
 
75
66
  ctx, err = openziti.load(str(self.identity_file), timeout=self.ziti_load_timeout_ms)
76
67
  if err != 0:
77
68
  raise RuntimeError(f"Failed to load Ziti identity from {self.identity_file}: {err}")
78
69
 
79
- sock = ctx.bind(self.service_name)
70
+ sock = ctx.bind(self.identity.mrok.extension)
80
71
  sock.listen(self.backlog)
81
- logger.info(f"listening on ziti service {self.service_name} for connections")
72
+ logger.info(f"listening on ziti service {self.identity.mrok.extension} for connections")
82
73
  return sock
83
74
 
84
75
  def configure_logging(self) -> None: