mrok 0.5.0__py3-none-any.whl → 0.7.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (47) hide show
  1. mrok/agent/devtools/inspector/__main__.py +2 -24
  2. mrok/agent/devtools/inspector/app.py +408 -113
  3. mrok/agent/devtools/inspector/utils.py +149 -0
  4. mrok/agent/sidecar/app.py +9 -9
  5. mrok/agent/sidecar/main.py +31 -5
  6. mrok/agent/ziticorn.py +8 -2
  7. mrok/cli/commands/admin/bootstrap.py +3 -2
  8. mrok/cli/commands/admin/utils.py +2 -2
  9. mrok/cli/commands/agent/run/sidecar.py +59 -1
  10. mrok/cli/commands/frontend/run.py +43 -1
  11. mrok/cli/main.py +17 -1
  12. mrok/constants.py +21 -0
  13. mrok/controller/schemas.py +2 -2
  14. mrok/frontend/app.py +8 -8
  15. mrok/frontend/main.py +9 -1
  16. mrok/logging.py +0 -22
  17. mrok/proxy/app.py +10 -9
  18. mrok/proxy/asgi.py +96 -0
  19. mrok/proxy/backend.py +5 -3
  20. mrok/proxy/event_publisher.py +66 -0
  21. mrok/proxy/master.py +18 -60
  22. mrok/proxy/metrics.py +2 -2
  23. mrok/proxy/{middlewares.py → middleware.py} +11 -42
  24. mrok/proxy/{datastructures.py → models.py} +43 -17
  25. mrok/proxy/{streams.py → stream.py} +24 -1
  26. mrok/proxy/worker.py +64 -0
  27. mrok/proxy/ziticorn.py +76 -0
  28. mrok/types/__init__.py +0 -0
  29. mrok/{proxy/types.py → types/proxy.py} +7 -2
  30. mrok/types/ziti.py +1 -0
  31. mrok/ziti/api.py +16 -19
  32. mrok/ziti/bootstrap.py +3 -7
  33. mrok/ziti/identities.py +15 -13
  34. mrok/ziti/services.py +3 -2
  35. {mrok-0.5.0.dist-info → mrok-0.7.0.dist-info}/METADATA +8 -2
  36. {mrok-0.5.0.dist-info → mrok-0.7.0.dist-info}/RECORD +39 -40
  37. mrok/agent/devtools/__main__.py +0 -34
  38. mrok/cli/commands/agent/utils.py +0 -5
  39. mrok/proxy/config.py +0 -62
  40. mrok/proxy/constants.py +0 -22
  41. mrok/proxy/lifespan.py +0 -10
  42. mrok/proxy/protocol.py +0 -11
  43. mrok/proxy/server.py +0 -14
  44. mrok/proxy/utils.py +0 -90
  45. {mrok-0.5.0.dist-info → mrok-0.7.0.dist-info}/WHEEL +0 -0
  46. {mrok-0.5.0.dist-info → mrok-0.7.0.dist-info}/entry_points.txt +0 -0
  47. {mrok-0.5.0.dist-info → mrok-0.7.0.dist-info}/licenses/LICENSE.txt +0 -0
mrok/proxy/asgi.py ADDED
@@ -0,0 +1,96 @@
1
+ from collections.abc import Iterator
2
+ from contextlib import AsyncExitStack, asynccontextmanager
3
+ from typing import Any, ParamSpec, Protocol
4
+
5
+ from mrok.types.proxy import ASGIApp, ASGIReceive, ASGISend, Lifespan, Scope
6
+
7
+ P = ParamSpec("P")
8
+
9
+
10
+ class ASGIMiddleware(Protocol[P]):
11
+ def __call__(
12
+ self, app: ASGIApp, /, *args: P.args, **kwargs: P.kwargs
13
+ ) -> ASGIApp: ... # pragma: no cover
14
+
15
+
16
+ class Middleware:
17
+ def __init__(self, cls: ASGIMiddleware[P], *args: P.args, **kwargs: P.kwargs) -> None:
18
+ self.cls = cls
19
+ self.args = args
20
+ self.kwargs = kwargs
21
+
22
+ def __iter__(self) -> Iterator[Any]:
23
+ as_tuple = (self.cls, self.args, self.kwargs)
24
+ return iter(as_tuple)
25
+
26
+
27
+ class ASGIAppWrapper:
28
+ def __init__(
29
+ self,
30
+ app: ASGIApp,
31
+ lifespan: Lifespan | None = None,
32
+ ) -> None:
33
+ self.app = app
34
+ self.lifespan = lifespan
35
+ self.middlware: list[Middleware] = []
36
+ self.middleare_stack: ASGIApp | None = None
37
+
38
+ def add_middleware(self, cls: ASGIMiddleware[P], *args: P.args, **kwargs: P.kwargs):
39
+ self.middlware.insert(0, Middleware(cls, *args, **kwargs))
40
+
41
+ def build_middleware_stack(self):
42
+ app = self.app
43
+ for cls, args, kwargs in reversed(self.middlware):
44
+ app = cls(app, *args, **kwargs)
45
+ return app
46
+
47
+ def get_starlette_lifespan(self):
48
+ router = getattr(self.app, "router", None)
49
+ if router is None:
50
+ return None
51
+ return getattr(router, "lifespan_context", None)
52
+
53
+ @asynccontextmanager
54
+ async def merge_lifespan(self, app: ASGIApp):
55
+ async with AsyncExitStack() as stack:
56
+ state: dict[Any, Any] = {}
57
+ if self.lifespan is not None:
58
+ outer_state = await stack.enter_async_context(self.lifespan(app))
59
+ state.update(outer_state or {})
60
+ starlette_lifespan = self.get_starlette_lifespan()
61
+ if starlette_lifespan is not None:
62
+ inner_state = await stack.enter_async_context(starlette_lifespan(app))
63
+ state.update(inner_state or {})
64
+ yield state
65
+
66
+ async def handle_lifespan(self, scope: Scope, receive: ASGIReceive, send: ASGISend) -> None:
67
+ started = False
68
+ app: Any = scope.get("app")
69
+ await receive()
70
+ try:
71
+ async with self.merge_lifespan(app) as state:
72
+ if state:
73
+ if "state" not in scope:
74
+ raise RuntimeError('"state" is unsupported by the current ASGI Server.')
75
+ scope["state"].update(state)
76
+ await send({"type": "lifespan.startup.complete"})
77
+ started = True
78
+ await receive()
79
+ except Exception as e: # pragma: no cover
80
+ if started:
81
+ await send({"type": "lifespan.shutdown.failed", "message": str(e)})
82
+ else:
83
+ await send({"type": "lifespan.startup.failed", "message": str(e)})
84
+ raise
85
+ else:
86
+ await send({"type": "lifespan.shutdown.complete"})
87
+
88
+ async def __call__(self, scope: Scope, receive: ASGIReceive, send: ASGISend) -> None:
89
+ if self.middleare_stack is None: # pragma: no branch
90
+ self.middleware_stack = self.build_middleware_stack()
91
+ if scope["type"] == "lifespan":
92
+ scope["app"] = self
93
+ await self.handle_lifespan(scope, receive, send)
94
+ return
95
+
96
+ await self.middleware_stack(scope, receive, send)
mrok/proxy/backend.py CHANGED
@@ -6,8 +6,8 @@ import openziti
6
6
  from httpcore import SOCKET_OPTION, AsyncNetworkBackend, AsyncNetworkStream
7
7
  from openziti.context import ZitiContext
8
8
 
9
- from mrok.proxy.exceptions import TargetUnavailableError
10
- from mrok.proxy.streams import AIONetworkStream
9
+ from mrok.proxy.exceptions import InvalidTargetError, TargetUnavailableError
10
+ from mrok.proxy.stream import AIONetworkStream
11
11
 
12
12
 
13
13
  class AIOZitiNetworkBackend(AsyncNetworkBackend):
@@ -37,7 +37,9 @@ class AIOZitiNetworkBackend(AsyncNetworkBackend):
37
37
  reader, writer = await asyncio.open_connection(sock=sock)
38
38
  return AIONetworkStream(reader, writer)
39
39
  except Exception as e:
40
- raise TargetUnavailableError() from e
40
+ if e.args and e.args[0] == -24: # the service exists but is not available
41
+ raise TargetUnavailableError() from e
42
+ raise InvalidTargetError() from e
41
43
 
42
44
  async def sleep(self, seconds: float) -> None:
43
45
  await asyncio.sleep(seconds)
@@ -0,0 +1,66 @@
1
+ import asyncio
2
+ import contextlib
3
+ import logging
4
+
5
+ import zmq
6
+ import zmq.asyncio
7
+
8
+ from mrok.proxy.asgi import ASGIAppWrapper
9
+ from mrok.proxy.metrics import MetricsCollector
10
+ from mrok.proxy.middleware import CaptureMiddleware, MetricsMiddleware
11
+ from mrok.proxy.models import Event, HTTPResponse, ServiceMetadata, Status
12
+ from mrok.types.proxy import ASGIApp
13
+
14
+ logger = logging.getLogger("mrok.proxy")
15
+
16
+
17
+ class EventPublisher:
18
+ def __init__(
19
+ self,
20
+ worker_id: str,
21
+ meta: ServiceMetadata | None = None,
22
+ event_publisher_port: int = 50000,
23
+ metrics_interval: float = 5.0,
24
+ ):
25
+ self._worker_id = worker_id
26
+ self._meta = meta
27
+ self._metrics_interval = metrics_interval
28
+ self.publisher_port = event_publisher_port
29
+ self._zmq_ctx = zmq.asyncio.Context()
30
+ self._publisher = self._zmq_ctx.socket(zmq.PUB)
31
+ self._metrics_collector = MetricsCollector(self._worker_id)
32
+ self._publish_task = None
33
+
34
+ async def on_startup(self):
35
+ self._publisher.connect(f"tcp://localhost:{self.publisher_port}")
36
+ self._publish_task = asyncio.create_task(self.publish_metrics_event())
37
+ logger.info(f"Events publishing for worker {self._worker_id} started")
38
+
39
+ async def on_shutdown(self):
40
+ self._publish_task.cancel()
41
+ with contextlib.suppress(asyncio.CancelledError):
42
+ await self._publish_task
43
+ self._publisher.close()
44
+ self._zmq_ctx.term()
45
+ logger.info(f"Events publishing for worker {self._worker_id} stopped")
46
+
47
+ async def publish_metrics_event(self):
48
+ while True:
49
+ snap = await self._metrics_collector.snapshot()
50
+ event = Event(type="status", data=Status(meta=self._meta, metrics=snap))
51
+ await self._publisher.send_string(event.model_dump_json())
52
+ await asyncio.sleep(self._metrics_interval)
53
+
54
+ async def publish_response_event(self, response: HTTPResponse):
55
+ event = Event(type="response", data=response)
56
+ await self._publisher.send_string(event.model_dump_json()) # type: ignore[attr-defined]
57
+
58
+ def setup_middleware(self, app: ASGIAppWrapper):
59
+ app.add_middleware(CaptureMiddleware, self.publish_response_event)
60
+ app.add_middleware(MetricsMiddleware, self._metrics_collector) # type: ignore
61
+
62
+ @contextlib.asynccontextmanager
63
+ async def lifespan(self, app: ASGIApp):
64
+ await self.on_startup() # type: ignore
65
+ yield
66
+ await self.on_shutdown() # type: ignore
mrok/proxy/master.py CHANGED
@@ -1,5 +1,3 @@
1
- import asyncio
2
- import contextlib
3
1
  import logging
4
2
  import os
5
3
  import signal
@@ -10,20 +8,14 @@ from pathlib import Path
10
8
  from uuid import uuid4
11
9
 
12
10
  import zmq
13
- import zmq.asyncio
14
- from uvicorn.importer import import_from_string
15
11
  from watchfiles import watch
16
12
  from watchfiles.filters import PythonFilter
17
13
  from watchfiles.run import CombinedProcess, start_process
18
14
 
19
15
  from mrok.conf import get_settings
20
16
  from mrok.logging import setup_logging
21
- from mrok.proxy.config import MrokBackendConfig
22
- from mrok.proxy.datastructures import Event, HTTPResponse, Status, ZitiIdentity
23
- from mrok.proxy.metrics import WorkerMetricsCollector
24
- from mrok.proxy.middlewares import CaptureMiddleware, LifespanMiddleware, MetricsMiddleware
25
- from mrok.proxy.server import MrokServer
26
- from mrok.proxy.types import ASGIApp
17
+ from mrok.proxy.worker import Worker
18
+ from mrok.types.proxy import ASGIApp
27
19
 
28
20
  logger = logging.getLogger("mrok.agent")
29
21
 
@@ -62,76 +54,41 @@ def start_uvicorn_worker(
62
54
  worker_id: str,
63
55
  app: ASGIApp | str,
64
56
  identity_file: str,
57
+ events_enabled: bool,
65
58
  events_pub_port: int,
66
59
  metrics_interval: float = 5.0,
67
60
  ):
68
61
  import sys
69
62
 
70
63
  sys.path.insert(0, os.getcwd())
71
- asgi_app = app if not isinstance(app, str) else import_from_string(app)
72
64
 
73
- setup_logging(get_settings())
74
- identity = ZitiIdentity.load_from_file(identity_file)
75
- ctx = zmq.asyncio.Context()
76
- pub = ctx.socket(zmq.PUB)
77
- pub.connect(f"tcp://localhost:{events_pub_port}")
78
- metrics = WorkerMetricsCollector(worker_id)
79
-
80
- task = None
81
-
82
- async def status_sender(): # pragma: no cover
83
- while True:
84
- snap = await metrics.snapshot()
85
- event = Event(type="status", data=Status(meta=identity.mrok, metrics=snap))
86
- await pub.send_string(event.model_dump_json())
87
- await asyncio.sleep(metrics_interval)
88
-
89
- async def on_startup(): # pragma: no cover
90
- nonlocal task
91
- await asyncio.sleep(0)
92
- task = asyncio.create_task(status_sender())
93
-
94
- async def on_shutdown(): # pragma: no cover
95
- await asyncio.sleep(0)
96
- if task:
97
- task.cancel()
98
-
99
- async def on_response_complete(response: HTTPResponse): # pragma: no cover
100
- event = Event(type="response", data=response)
101
- await pub.send_string(event.model_dump_json())
102
-
103
- config = MrokBackendConfig(
104
- LifespanMiddleware(
105
- MetricsMiddleware(
106
- CaptureMiddleware(
107
- asgi_app,
108
- on_response_complete,
109
- ),
110
- metrics,
111
- ),
112
- on_startup=on_startup,
113
- on_shutdown=on_shutdown,
114
- ),
65
+ worker = Worker(
66
+ worker_id,
67
+ app,
115
68
  identity_file,
69
+ events_enabled=events_enabled,
70
+ event_publisher_port=events_pub_port,
71
+ metrics_interval=metrics_interval,
116
72
  )
117
- server = MrokServer(config)
118
- with contextlib.suppress(KeyboardInterrupt, asyncio.CancelledError):
119
- server.run()
73
+ worker.run()
120
74
 
121
75
 
122
76
  class MasterBase(ABC):
123
77
  def __init__(
124
78
  self,
125
79
  identity_file: str,
126
- workers: int,
127
- reload: bool,
128
- events_pub_port: int,
129
- events_sub_port: int,
80
+ *,
81
+ workers: int = 4,
82
+ reload: bool = False,
83
+ events_enabled: bool = True,
84
+ events_pub_port: int = 50000,
85
+ events_sub_port: int = 50001,
130
86
  metrics_interval: float = 5.0,
131
87
  ):
132
88
  self.identity_file = identity_file
133
89
  self.workers = workers
134
90
  self.reload = reload
91
+ self.events_enabled = events_enabled
135
92
  self.events_pub_port = events_pub_port
136
93
  self.events_sub_port = events_sub_port
137
94
  self.metrics_interval = metrics_interval
@@ -171,6 +128,7 @@ class MasterBase(ABC):
171
128
  worker_id,
172
129
  self.get_asgi_app(),
173
130
  self.identity_file,
131
+ self.events_enabled,
174
132
  self.events_pub_port,
175
133
  self.metrics_interval,
176
134
  ),
mrok/proxy/metrics.py CHANGED
@@ -8,7 +8,7 @@ import time
8
8
  import psutil
9
9
  from hdrh.histogram import HdrHistogram
10
10
 
11
- from mrok.proxy.datastructures import (
11
+ from mrok.proxy.models import (
12
12
  DataTransferMetrics,
13
13
  ProcessMetrics,
14
14
  RequestsMetrics,
@@ -49,7 +49,7 @@ async def get_process_metrics(interval: float = 0.1) -> ProcessMetrics:
49
49
  return await asyncio.to_thread(_collect_process_usage, interval)
50
50
 
51
51
 
52
- class WorkerMetricsCollector:
52
+ class MetricsCollector:
53
53
  def __init__(self, worker_id: str, lowest=1, highest=60000, sigfigs=3):
54
54
  self.worker_id = worker_id
55
55
  self.total_requests = 0
@@ -2,19 +2,19 @@ import asyncio
2
2
  import logging
3
3
  import time
4
4
 
5
- from mrok.proxy.constants import MAX_REQUEST_BODY_BYTES, MAX_RESPONSE_BODY_BYTES
6
- from mrok.proxy.datastructures import FixedSizeByteBuffer, HTTPHeaders, HTTPRequest, HTTPResponse
7
- from mrok.proxy.metrics import WorkerMetricsCollector
8
- from mrok.proxy.types import (
5
+ from mrok.proxy.metrics import MetricsCollector
6
+ from mrok.proxy.models import FixedSizeByteBuffer, HTTPHeaders, HTTPRequest, HTTPResponse
7
+ from mrok.types.proxy import (
9
8
  ASGIApp,
10
9
  ASGIReceive,
11
10
  ASGISend,
12
- LifespanCallback,
13
11
  Message,
14
12
  ResponseCompleteCallback,
15
13
  Scope,
16
14
  )
17
- from mrok.proxy.utils import must_capture_request, must_capture_response
15
+
16
+ MAX_REQUEST_BODY_BYTES = 2 * 1024 * 1024
17
+ MAX_RESPONSE_BODY_BYTES = 5 * 1024 * 1024
18
18
 
19
19
  logger = logging.getLogger("mrok.proxy")
20
20
 
@@ -43,7 +43,7 @@ class CaptureMiddleware:
43
43
  state = {}
44
44
 
45
45
  req_buf = FixedSizeByteBuffer(MAX_REQUEST_BODY_BYTES)
46
- capture_req_body = must_capture_request(method, req_headers)
46
+ capture_req_body = method.upper() not in ("GET", "HEAD", "OPTIONS", "TRACE")
47
47
 
48
48
  request = HTTPRequest(
49
49
  method=method,
@@ -69,9 +69,7 @@ class CaptureMiddleware:
69
69
  resp_headers = HTTPHeaders.from_asgi(msg.get("headers", []))
70
70
  state["resp_headers_raw"] = resp_headers
71
71
 
72
- state["capture_resp_body"] = must_capture_response(resp_headers)
73
-
74
- if state["capture_resp_body"] and msg["type"] == "http.response.body":
72
+ if msg["type"] == "http.response.body":
75
73
  body = msg.get("body", b"")
76
74
  resp_buf.write(body)
77
75
 
@@ -92,14 +90,14 @@ class CaptureMiddleware:
92
90
  status=state["status"] or 0,
93
91
  headers=state["resp_headers_raw"],
94
92
  duration=duration,
95
- body=resp_buf.getvalue() if state["capture_resp_body"] else None,
96
- body_truncated=resp_buf.overflow if state["capture_resp_body"] else None,
93
+ body=resp_buf.getvalue(),
94
+ body_truncated=resp_buf.overflow,
97
95
  )
98
96
  asyncio.create_task(self._on_response_complete(response))
99
97
 
100
98
 
101
99
  class MetricsMiddleware:
102
- def __init__(self, app: ASGIApp, metrics: WorkerMetricsCollector):
100
+ def __init__(self, app: ASGIApp, metrics: MetricsCollector):
103
101
  self.app = app
104
102
  self.metrics = metrics
105
103
 
@@ -133,32 +131,3 @@ class MetricsMiddleware:
133
131
  await self.app(scope, wrapped_receive, wrapped_send)
134
132
  finally:
135
133
  await self.metrics.on_request_end(start_time, status_code)
136
-
137
-
138
- class LifespanMiddleware:
139
- def __init__(
140
- self,
141
- app,
142
- on_startup: LifespanCallback | None = None,
143
- on_shutdown: LifespanCallback | None = None,
144
- ):
145
- self.app = app
146
- self.on_startup = on_startup
147
- self.on_shutdown = on_shutdown
148
-
149
- async def __call__(self, scope, receive, send):
150
- if scope["type"] == "lifespan":
151
- while True:
152
- event = await receive()
153
- if event["type"] == "lifespan.startup":
154
- if self.on_startup: # pragma: no branch
155
- await self.on_startup()
156
- await send({"type": "lifespan.startup.complete"})
157
-
158
- elif event["type"] == "lifespan.shutdown":
159
- if self.on_shutdown: # pragma: no branch
160
- await self.on_shutdown()
161
- await send({"type": "lifespan.shutdown.complete"})
162
- break
163
- else:
164
- await self.app(scope, receive, send)
@@ -1,14 +1,41 @@
1
1
  from __future__ import annotations
2
2
 
3
+ import base64
4
+ import binascii
3
5
  import json
6
+ from collections.abc import Mapping
4
7
  from pathlib import Path
5
- from typing import Literal
8
+ from typing import Annotated, Any, Literal
6
9
 
7
10
  from pydantic import BaseModel, ConfigDict, Field, field_validator
11
+ from pydantic.functional_serializers import PlainSerializer
12
+ from pydantic.functional_validators import PlainValidator
8
13
  from pydantic_core import core_schema
9
14
 
10
15
 
11
- class ZitiId(BaseModel):
16
+ def serialize_b64(v: bytes) -> str:
17
+ return base64.b64encode(v).decode("ascii")
18
+
19
+
20
+ def deserialize_b64(v):
21
+ if isinstance(v, bytes):
22
+ return v
23
+ if isinstance(v, str):
24
+ try:
25
+ return base64.b64decode(v, validate=True)
26
+ except binascii.Error as e: # pragma: no branch
27
+ raise ValueError("Invalid base64 data") from e
28
+ raise TypeError("Expected bytes or base64 string") # pragma: no cover
29
+
30
+
31
+ Base64Bytes = Annotated[
32
+ bytes,
33
+ PlainValidator(deserialize_b64),
34
+ PlainSerializer(serialize_b64, return_type=str),
35
+ ]
36
+
37
+
38
+ class X509Credentials(BaseModel):
12
39
  key: str
13
40
  cert: str
14
41
  ca: str
@@ -16,35 +43,34 @@ class ZitiId(BaseModel):
16
43
  @field_validator("key", "cert", "ca", mode="before")
17
44
  @classmethod
18
45
  def strip_pem_prefix(cls, value: str) -> str:
19
- if isinstance(value, str) and value.startswith("pem:"):
46
+ if isinstance(value, str) and value.startswith("pem:"): # pragma: no branch
20
47
  return value[4:]
21
- return value
48
+ return value # pragma: no cover
22
49
 
23
50
 
24
- class ZitiMrokMeta(BaseModel):
51
+ class ServiceMetadata(BaseModel):
25
52
  model_config = ConfigDict(extra="ignore")
26
- identity: str
27
53
  extension: str
28
54
  instance: str
29
55
  domain: str | None = None
30
56
  tags: dict[str, str | bool | None] | None = None
31
57
 
32
58
 
33
- class ZitiIdentity(BaseModel):
59
+ class Identity(BaseModel):
34
60
  model_config = ConfigDict(extra="ignore")
35
61
  zt_api: str = Field(validation_alias="ztAPI")
36
- id: ZitiId
62
+ id: X509Credentials
63
+ mrok: ServiceMetadata
64
+ enable_ha: bool = Field(default=False, validation_alias="enableHa")
37
65
  zt_apis: str | None = Field(default=None, validation_alias="ztAPIs")
38
66
  config_types: str | None = Field(default=None, validation_alias="configTypes")
39
- enable_ha: bool = Field(default=False, validation_alias="enableHa")
40
- mrok: ZitiMrokMeta | None = None
41
67
 
42
68
  @staticmethod
43
- def load_from_file(path: str | Path) -> ZitiIdentity:
69
+ def load_from_file(path: str | Path) -> Identity:
44
70
  path = Path(path)
45
71
  with path.open("r", encoding="utf-8") as f:
46
72
  data = json.load(f)
47
- return ZitiIdentity.model_validate(data)
73
+ return Identity.model_validate(data)
48
74
 
49
75
 
50
76
  class FixedSizeByteBuffer:
@@ -77,7 +103,7 @@ class FixedSizeByteBuffer:
77
103
 
78
104
 
79
105
  class HTTPHeaders(dict):
80
- def __init__(self, initial=None):
106
+ def __init__(self, initial: Mapping[Any, Any] | None = None):
81
107
  super().__init__()
82
108
  if initial:
83
109
  for k, v in initial.items():
@@ -131,9 +157,9 @@ class HTTPRequest(BaseModel):
131
157
  method: str
132
158
  url: str
133
159
  headers: HTTPHeaders
134
- query_string: bytes
160
+ query_string: Base64Bytes | None = None
135
161
  start_time: float
136
- body: bytes | None = None
162
+ body: Base64Bytes | None = None
137
163
  body_truncated: bool | None = None
138
164
 
139
165
 
@@ -143,7 +169,7 @@ class HTTPResponse(BaseModel):
143
169
  status: int
144
170
  headers: HTTPHeaders
145
171
  duration: float
146
- body: bytes | None = None
172
+ body: Base64Bytes | None = None
147
173
  body_truncated: bool | None = None
148
174
 
149
175
 
@@ -183,7 +209,7 @@ class WorkerMetrics(BaseModel):
183
209
 
184
210
  class Status(BaseModel):
185
211
  type: Literal["status"] = "status"
186
- meta: ZitiMrokMeta
212
+ meta: ServiceMetadata
187
213
  metrics: WorkerMetrics
188
214
 
189
215
 
@@ -1,8 +1,24 @@
1
1
  import asyncio
2
+ import select
3
+ import sys
4
+ from typing import Any
2
5
 
3
6
  from httpcore import AsyncNetworkStream
4
7
 
5
- from mrok.proxy.types import ASGIReceive
8
+ from mrok.types.proxy import ASGIReceive
9
+
10
+
11
+ def is_readable(sock): # pragma: no cover
12
+ # Stolen from
13
+ # https://github.com/python-trio/trio/blob/20ee2b1b7376db637435d80e266212a35837ddcc/trio/_socket.py#L471C1-L478C31
14
+
15
+ # use select.select on Windows, and select.poll everywhere else
16
+ if sys.platform == "win32":
17
+ rready, _, _ = select.select([sock], [], [], 0)
18
+ return bool(rready)
19
+ p = select.poll()
20
+ p.register(sock, select.POLLIN)
21
+ return bool(p.poll(0))
6
22
 
7
23
 
8
24
  class AIONetworkStream(AsyncNetworkStream):
@@ -21,6 +37,13 @@ class AIONetworkStream(AsyncNetworkStream):
21
37
  self._writer.close()
22
38
  await self._writer.wait_closed()
23
39
 
40
+ def get_extra_info(self, info: str) -> Any:
41
+ transport = self._writer.transport
42
+ if info == "is_readable":
43
+ sock = transport.get_extra_info("socket")
44
+ return is_readable(sock)
45
+ return transport.get_extra_info(info)
46
+
24
47
 
25
48
  class ASGIRequestBodyStream:
26
49
  def __init__(self, receive: ASGIReceive):
mrok/proxy/worker.py ADDED
@@ -0,0 +1,64 @@
1
+ import asyncio
2
+ import contextlib
3
+ import logging
4
+ from pathlib import Path
5
+
6
+ from uvicorn.importer import import_from_string
7
+
8
+ from mrok.conf import get_settings
9
+ from mrok.logging import setup_logging
10
+ from mrok.proxy.asgi import ASGIAppWrapper
11
+ from mrok.proxy.event_publisher import EventPublisher
12
+ from mrok.proxy.models import Identity
13
+ from mrok.proxy.ziticorn import BackendConfig, Server
14
+ from mrok.types.proxy import ASGIApp
15
+
16
+ logger = logging.getLogger("mrok.proxy")
17
+
18
+
19
+ class Worker:
20
+ def __init__(
21
+ self,
22
+ worker_id: str,
23
+ app: ASGIApp | str,
24
+ identity_file: str | Path,
25
+ *,
26
+ events_enabled: bool = True,
27
+ event_publisher_port: int = 50000,
28
+ metrics_interval: float = 5.0,
29
+ ):
30
+ self._worker_id = worker_id
31
+ self._identity_file = identity_file
32
+ self._identity = Identity.load_from_file(self._identity_file)
33
+ self._app = app
34
+
35
+ self._events_enabled = events_enabled
36
+ self._event_publisher = (
37
+ EventPublisher(
38
+ worker_id=worker_id,
39
+ meta=self._identity.mrok,
40
+ event_publisher_port=event_publisher_port,
41
+ metrics_interval=metrics_interval,
42
+ )
43
+ if events_enabled
44
+ else None
45
+ )
46
+
47
+ def setup_app(self):
48
+ app = ASGIAppWrapper(
49
+ self._app if not isinstance(self._app, str) else import_from_string(self._app),
50
+ lifespan=self._event_publisher.lifespan if self._events_enabled else None,
51
+ )
52
+
53
+ if self._events_enabled:
54
+ self._event_publisher.setup_middleware(app)
55
+ return app
56
+
57
+ def run(self):
58
+ setup_logging(get_settings())
59
+ app = self.setup_app()
60
+
61
+ config = BackendConfig(app, self._identity_file)
62
+ server = Server(config)
63
+ with contextlib.suppress(KeyboardInterrupt, asyncio.CancelledError):
64
+ server.run()