mrok 0.4.5__py3-none-any.whl → 0.5.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (39) hide show
  1. mrok/agent/devtools/inspector/app.py +2 -2
  2. mrok/agent/sidecar/app.py +64 -16
  3. mrok/agent/sidecar/main.py +5 -5
  4. mrok/agent/ziticorn.py +2 -2
  5. mrok/cli/commands/__init__.py +2 -2
  6. mrok/cli/commands/admin/register/extensions.py +2 -4
  7. mrok/cli/commands/admin/register/instances.py +11 -5
  8. mrok/cli/commands/{proxy → frontend}/__init__.py +1 -1
  9. mrok/cli/commands/{proxy → frontend}/run.py +4 -4
  10. mrok/constants.py +4 -0
  11. mrok/controller/openapi/examples.py +13 -0
  12. mrok/frontend/__init__.py +3 -0
  13. mrok/frontend/app.py +75 -0
  14. mrok/{proxy → frontend}/main.py +4 -10
  15. mrok/proxy/__init__.py +0 -3
  16. mrok/proxy/app.py +160 -57
  17. mrok/proxy/backend.py +43 -0
  18. mrok/{http → proxy}/config.py +3 -3
  19. mrok/{datastructures.py → proxy/datastructures.py} +43 -10
  20. mrok/proxy/exceptions.py +22 -0
  21. mrok/proxy/lifespan.py +10 -0
  22. mrok/{master.py → proxy/master.py} +35 -38
  23. mrok/{metrics.py → proxy/metrics.py} +37 -49
  24. mrok/{http → proxy}/middlewares.py +47 -26
  25. mrok/proxy/streams.py +45 -0
  26. mrok/proxy/types.py +15 -0
  27. mrok/{http → proxy}/utils.py +1 -1
  28. {mrok-0.4.5.dist-info → mrok-0.5.0.dist-info}/METADATA +2 -5
  29. {mrok-0.4.5.dist-info → mrok-0.5.0.dist-info}/RECORD +35 -31
  30. mrok/http/__init__.py +0 -0
  31. mrok/http/forwarder.py +0 -338
  32. mrok/http/lifespan.py +0 -39
  33. mrok/http/types.py +0 -43
  34. /mrok/{http → proxy}/constants.py +0 -0
  35. /mrok/{http → proxy}/protocol.py +0 -0
  36. /mrok/{http → proxy}/server.py +0 -0
  37. {mrok-0.4.5.dist-info → mrok-0.5.0.dist-info}/WHEEL +0 -0
  38. {mrok-0.4.5.dist-info → mrok-0.5.0.dist-info}/entry_points.txt +0 -0
  39. {mrok-0.4.5.dist-info → mrok-0.5.0.dist-info}/licenses/LICENSE.txt +0 -0
mrok/proxy/backend.py ADDED
@@ -0,0 +1,43 @@
1
+ import asyncio
2
+ from collections.abc import Iterable
3
+ from pathlib import Path
4
+
5
+ import openziti
6
+ from httpcore import SOCKET_OPTION, AsyncNetworkBackend, AsyncNetworkStream
7
+ from openziti.context import ZitiContext
8
+
9
+ from mrok.proxy.exceptions import TargetUnavailableError
10
+ from mrok.proxy.streams import AIONetworkStream
11
+
12
+
13
+ class AIOZitiNetworkBackend(AsyncNetworkBackend):
14
+ def __init__(self, identity_file: str | Path) -> None:
15
+ self._identity_file = identity_file
16
+ self._ziti_ctx: ZitiContext | None = None
17
+
18
+ def _get_ziti_ctx(self) -> ZitiContext:
19
+ if self._ziti_ctx is None:
20
+ ctx, err = openziti.load(str(self._identity_file), timeout=10_000)
21
+ if err != 0:
22
+ raise Exception(f"Cannot create a Ziti context from the identity file: {err}")
23
+ self._ziti_ctx = ctx
24
+ return self._ziti_ctx
25
+
26
+ async def connect_tcp(
27
+ self,
28
+ host: str,
29
+ port: int,
30
+ timeout: float | None = None,
31
+ local_address: str | None = None,
32
+ socket_options: Iterable[SOCKET_OPTION] | None = None,
33
+ ) -> AsyncNetworkStream:
34
+ ctx = self._get_ziti_ctx()
35
+ try:
36
+ sock = ctx.connect(host)
37
+ reader, writer = await asyncio.open_connection(sock=sock)
38
+ return AIONetworkStream(reader, writer)
39
+ except Exception as e:
40
+ raise TargetUnavailableError() from e
41
+
42
+ async def sleep(self, seconds: float) -> None:
43
+ await asyncio.sleep(seconds)
@@ -8,12 +8,12 @@ from typing import Any
8
8
  import openziti
9
9
  from uvicorn import config
10
10
 
11
- from mrok.http.protocol import MrokHttpToolsProtocol
12
- from mrok.http.types import ASGIApp
11
+ from mrok.proxy.protocol import MrokHttpToolsProtocol
12
+ from mrok.proxy.types import ASGIApp
13
13
 
14
14
  logger = logging.getLogger("mrok.proxy")
15
15
 
16
- config.LIFESPAN["auto"] = "mrok.http.lifespan:MrokLifespan"
16
+ config.LIFESPAN["auto"] = "mrok.proxy.lifespan:MrokLifespan"
17
17
 
18
18
 
19
19
  class MrokBackendConfig(config.Config):
@@ -1,11 +1,52 @@
1
1
  from __future__ import annotations
2
2
 
3
+ import json
4
+ from pathlib import Path
3
5
  from typing import Literal
4
6
 
5
- from pydantic import BaseModel, Field
7
+ from pydantic import BaseModel, ConfigDict, Field, field_validator
6
8
  from pydantic_core import core_schema
7
9
 
8
10
 
11
+ class ZitiId(BaseModel):
12
+ key: str
13
+ cert: str
14
+ ca: str
15
+
16
+ @field_validator("key", "cert", "ca", mode="before")
17
+ @classmethod
18
+ def strip_pem_prefix(cls, value: str) -> str:
19
+ if isinstance(value, str) and value.startswith("pem:"):
20
+ return value[4:]
21
+ return value
22
+
23
+
24
+ class ZitiMrokMeta(BaseModel):
25
+ model_config = ConfigDict(extra="ignore")
26
+ identity: str
27
+ extension: str
28
+ instance: str
29
+ domain: str | None = None
30
+ tags: dict[str, str | bool | None] | None = None
31
+
32
+
33
+ class ZitiIdentity(BaseModel):
34
+ model_config = ConfigDict(extra="ignore")
35
+ zt_api: str = Field(validation_alias="ztAPI")
36
+ id: ZitiId
37
+ zt_apis: str | None = Field(default=None, validation_alias="ztAPIs")
38
+ config_types: str | None = Field(default=None, validation_alias="configTypes")
39
+ enable_ha: bool = Field(default=False, validation_alias="enableHa")
40
+ mrok: ZitiMrokMeta | None = None
41
+
42
+ @staticmethod
43
+ def load_from_file(path: str | Path) -> ZitiIdentity:
44
+ path = Path(path)
45
+ with path.open("r", encoding="utf-8") as f:
46
+ data = json.load(f)
47
+ return ZitiIdentity.model_validate(data)
48
+
49
+
9
50
  class FixedSizeByteBuffer:
10
51
  def __init__(self, max_size: int):
11
52
  self._max_size = max_size
@@ -140,17 +181,9 @@ class WorkerMetrics(BaseModel):
140
181
  process: ProcessMetrics
141
182
 
142
183
 
143
- class Meta(BaseModel):
144
- identity: str
145
- extension: str
146
- instance: str
147
- domain: str
148
- tags: dict[str, str] | None = None
149
-
150
-
151
184
  class Status(BaseModel):
152
185
  type: Literal["status"] = "status"
153
- meta: Meta
186
+ meta: ZitiMrokMeta
154
187
  metrics: WorkerMetrics
155
188
 
156
189
 
@@ -0,0 +1,22 @@
1
+ from http import HTTPStatus
2
+
3
+ from httpcore import ConnectError
4
+
5
+
6
+ class ProxyError(Exception):
7
+ def __init__(self, http_status: HTTPStatus, message: str) -> None:
8
+ self.http_status: HTTPStatus = http_status
9
+ self.message: str = message
10
+
11
+
12
+ class InvalidTargetError(ProxyError):
13
+ def __init__(self):
14
+ super().__init__(HTTPStatus.BAD_GATEWAY, "Bad Gateway: invalid target extension.")
15
+
16
+
17
+ class TargetUnavailableError(ProxyError, ConnectError):
18
+ def __init__(self):
19
+ super().__init__(
20
+ HTTPStatus.SERVICE_UNAVAILABLE,
21
+ "Service Unavailable: the target extension is unavailable.",
22
+ )
mrok/proxy/lifespan.py ADDED
@@ -0,0 +1,10 @@
1
+ import logging
2
+
3
+ from uvicorn.config import Config
4
+ from uvicorn.lifespan.on import LifespanOn
5
+
6
+
7
+ class MrokLifespan(LifespanOn):
8
+ def __init__(self, config: Config) -> None:
9
+ super().__init__(config)
10
+ self.logger = logging.getLogger("mrok.proxy")
@@ -1,6 +1,5 @@
1
1
  import asyncio
2
2
  import contextlib
3
- import json
4
3
  import logging
5
4
  import os
6
5
  import signal
@@ -18,14 +17,13 @@ from watchfiles.filters import PythonFilter
18
17
  from watchfiles.run import CombinedProcess, start_process
19
18
 
20
19
  from mrok.conf import get_settings
21
- from mrok.datastructures import Event, HTTPResponse, Meta, Status
22
- from mrok.http.config import MrokBackendConfig
23
- from mrok.http.lifespan import LifespanWrapper
24
- from mrok.http.middlewares import CaptureMiddleware, MetricsMiddleware
25
- from mrok.http.server import MrokServer
26
- from mrok.http.types import ASGIApp
27
20
  from mrok.logging import setup_logging
28
- from mrok.metrics import WorkerMetricsCollector
21
+ from mrok.proxy.config import MrokBackendConfig
22
+ from mrok.proxy.datastructures import Event, HTTPResponse, Status, ZitiIdentity
23
+ from mrok.proxy.metrics import WorkerMetricsCollector
24
+ from mrok.proxy.middlewares import CaptureMiddleware, LifespanMiddleware, MetricsMiddleware
25
+ from mrok.proxy.server import MrokServer
26
+ from mrok.proxy.types import ASGIApp
29
27
 
30
28
  logger = logging.getLogger("mrok.agent")
31
29
 
@@ -52,7 +50,7 @@ def start_events_router(events_pub_port: int, events_sub_port: int):
52
50
  try:
53
51
  logger.info(f"Events router process started: {os.getpid()}")
54
52
  zmq.proxy(frontend, backend)
55
- except KeyboardInterrupt:
53
+ except KeyboardInterrupt: # pragma: no cover
56
54
  pass
57
55
  finally:
58
56
  frontend.close()
@@ -62,7 +60,7 @@ def start_events_router(events_pub_port: int, events_sub_port: int):
62
60
 
63
61
  def start_uvicorn_worker(
64
62
  worker_id: str,
65
- app: ASGIApp,
63
+ app: ASGIApp | str,
66
64
  identity_file: str,
67
65
  events_pub_port: int,
68
66
  metrics_interval: float = 5.0,
@@ -70,12 +68,10 @@ def start_uvicorn_worker(
70
68
  import sys
71
69
 
72
70
  sys.path.insert(0, os.getcwd())
73
- if isinstance(app, str):
74
- app = import_from_string(app)
71
+ asgi_app = app if not isinstance(app, str) else import_from_string(app)
75
72
 
76
73
  setup_logging(get_settings())
77
- identity = json.load(open(identity_file))
78
- meta = Meta(**identity["mrok"])
74
+ identity = ZitiIdentity.load_from_file(identity_file)
79
75
  ctx = zmq.asyncio.Context()
80
76
  pub = ctx.socket(zmq.PUB)
81
77
  pub.connect(f"tcp://localhost:{events_pub_port}")
@@ -83,30 +79,32 @@ def start_uvicorn_worker(
83
79
 
84
80
  task = None
85
81
 
86
- async def status_sender():
82
+ async def status_sender(): # pragma: no cover
87
83
  while True:
88
84
  snap = await metrics.snapshot()
89
- event = Event(type="status", data=Status(meta=meta, metrics=snap))
85
+ event = Event(type="status", data=Status(meta=identity.mrok, metrics=snap))
90
86
  await pub.send_string(event.model_dump_json())
91
87
  await asyncio.sleep(metrics_interval)
92
88
 
93
- async def on_startup(): # noqa
89
+ async def on_startup(): # pragma: no cover
94
90
  nonlocal task
91
+ await asyncio.sleep(0)
95
92
  task = asyncio.create_task(status_sender())
96
93
 
97
- async def on_shutdown(): # noqa
94
+ async def on_shutdown(): # pragma: no cover
95
+ await asyncio.sleep(0)
98
96
  if task:
99
97
  task.cancel()
100
98
 
101
- async def on_response_complete(response: HTTPResponse):
99
+ async def on_response_complete(response: HTTPResponse): # pragma: no cover
102
100
  event = Event(type="response", data=response)
103
101
  await pub.send_string(event.model_dump_json())
104
102
 
105
103
  config = MrokBackendConfig(
106
- LifespanWrapper(
104
+ LifespanMiddleware(
107
105
  MetricsMiddleware(
108
106
  CaptureMiddleware(
109
- app,
107
+ asgi_app,
110
108
  on_response_complete,
111
109
  ),
112
110
  metrics,
@@ -154,7 +152,7 @@ class MasterBase(ABC):
154
152
 
155
153
  @abstractmethod
156
154
  def get_asgi_app(self):
157
- pass
155
+ raise NotImplementedError()
158
156
 
159
157
  def setup_signals_handler(self):
160
158
  for sig in (signal.SIGINT, signal.SIGTERM):
@@ -181,18 +179,6 @@ class MasterBase(ABC):
181
179
  logger.info(f"Worker {worker_id} [{p.pid}] started")
182
180
  return p
183
181
 
184
- def start(self):
185
- self.start_events_router()
186
- self.start_workers()
187
- self.monitor_thread.start()
188
-
189
- def stop(self):
190
- if self.monitor_thread.is_alive():
191
- logger.debug("Wait for monitor worker to exit")
192
- self.monitor_thread.join(timeout=MONITOR_THREAD_JOIN_TIMEOUT)
193
- self.stop_workers()
194
- self.stop_events_router()
195
-
196
182
  def start_events_router(self):
197
183
  self.zmq_pubsub_router_process = start_process(
198
184
  start_events_router,
@@ -204,21 +190,32 @@ class MasterBase(ABC):
204
190
  None,
205
191
  )
206
192
 
207
- def stop_events_router(self):
208
- self.zmq_pubsub_router_process.stop(sigint_timeout=5, sigkill_timeout=1)
209
-
210
193
  def start_workers(self):
211
194
  for i in range(self.workers):
212
195
  worker_id = self.worker_identifiers[i]
213
196
  p = self.start_worker(worker_id)
214
197
  self.worker_processes[worker_id] = p
215
198
 
199
+ def start(self):
200
+ self.start_events_router()
201
+ self.start_workers()
202
+ self.monitor_thread.start()
203
+
204
+ def stop_events_router(self):
205
+ self.zmq_pubsub_router_process.stop(sigint_timeout=5, sigkill_timeout=1)
206
+
216
207
  def stop_workers(self):
217
208
  for process in self.worker_processes.values():
218
- if process.is_alive():
209
+ if process.is_alive(): # pragma: no branch
219
210
  process.stop(sigint_timeout=5, sigkill_timeout=1)
220
211
  self.worker_processes.clear()
221
212
 
213
+ def stop(self):
214
+ if self.monitor_thread.is_alive(): # pragma: no branch
215
+ self.monitor_thread.join(timeout=MONITOR_THREAD_JOIN_TIMEOUT)
216
+ self.stop_workers()
217
+ self.stop_events_router()
218
+
222
219
  def restart(self):
223
220
  self.pause_event.set()
224
221
  self.stop_workers()
@@ -8,7 +8,7 @@ import time
8
8
  import psutil
9
9
  from hdrh.histogram import HdrHistogram
10
10
 
11
- from mrok.datastructures import (
11
+ from mrok.proxy.datastructures import (
12
12
  DataTransferMetrics,
13
13
  ProcessMetrics,
14
14
  RequestsMetrics,
@@ -20,11 +20,7 @@ logger = logging.getLogger("mrok.proxy")
20
20
 
21
21
 
22
22
  def _collect_process_usage(interval: float) -> ProcessMetrics:
23
- try:
24
- proc = psutil.Process(os.getpid())
25
- except psutil.NoSuchProcess:
26
- return ProcessMetrics(cpu=0.0, mem=0.0)
27
-
23
+ proc = psutil.Process(os.getpid())
28
24
  total_cpu = 0.0
29
25
  total_mem = 0.0
30
26
 
@@ -33,29 +29,28 @@ def _collect_process_usage(interval: float) -> ProcessMetrics:
33
29
  except Exception:
34
30
  total_cpu = 0.0
35
31
 
36
- if interval and interval > 0:
32
+ if interval and interval > 0: # pragma: no branch
37
33
  time.sleep(interval)
38
34
 
39
35
  try:
40
36
  total_cpu = proc.cpu_percent(None)
41
- except Exception:
37
+ except Exception: # pragma: no cover
42
38
  total_cpu = 0.0
43
39
 
44
40
  try:
45
41
  total_mem = proc.memory_percent()
46
- except Exception:
42
+ except Exception: # pragma: no cover
47
43
  total_mem = 0.0
48
44
 
49
45
  return ProcessMetrics(cpu=total_cpu, mem=total_mem)
50
46
 
51
47
 
52
- async def get_process_and_children_usage(interval: float = 0.1) -> ProcessMetrics:
48
+ async def get_process_metrics(interval: float = 0.1) -> ProcessMetrics:
53
49
  return await asyncio.to_thread(_collect_process_usage, interval)
54
50
 
55
51
 
56
52
  class WorkerMetricsCollector:
57
53
  def __init__(self, worker_id: str, lowest=1, highest=60000, sigfigs=3):
58
- # Request-level counters
59
54
  self.worker_id = worker_id
60
55
  self.total_requests = 0
61
56
  self.successful_requests = 0
@@ -63,14 +58,11 @@ class WorkerMetricsCollector:
63
58
  self.bytes_in = 0
64
59
  self.bytes_out = 0
65
60
 
66
- # RPS
67
61
  self._tick_last = time.time()
68
62
  self._tick_requests = 0
69
63
 
70
- # latency histogram
71
64
  self.hist = HdrHistogram(lowest, highest, sigfigs)
72
65
 
73
- # async lock
74
66
  self._lock = asyncio.Lock()
75
67
 
76
68
  async def on_request_start(self, scope):
@@ -102,38 +94,34 @@ class WorkerMetricsCollector:
102
94
  self.hist.record_value(elapsed_ms)
103
95
 
104
96
  async def snapshot(self) -> WorkerMetrics:
105
- try:
106
- async with self._lock:
107
- now = time.time()
108
- delta = now - self._tick_last
109
- rps = int(self._tick_requests / delta) if delta > 0 else 0
110
- data = WorkerMetrics(
111
- worker_id=self.worker_id,
112
- process=await get_process_and_children_usage(),
113
- requests=RequestsMetrics(
114
- rps=rps,
115
- total=self.total_requests,
116
- successful=self.successful_requests,
117
- failed=self.failed_requests,
118
- ),
119
- data_transfer=DataTransferMetrics(
120
- bytes_in=self.bytes_in,
121
- bytes_out=self.bytes_out,
122
- ),
123
- response_time=ResponseTimeMetrics(
124
- avg=self.hist.get_mean_value(),
125
- min=self.hist.get_min_value(),
126
- max=self.hist.get_max_value(),
127
- p50=self.hist.get_value_at_percentile(50),
128
- p90=self.hist.get_value_at_percentile(90),
129
- p99=self.hist.get_value_at_percentile(99),
130
- ),
131
- )
132
-
133
- self._tick_last = now
134
- self._tick_requests = 0
135
-
136
- return data
137
- except Exception:
138
- logger.exception("Exception calculating snapshot")
139
- raise
97
+ async with self._lock:
98
+ now = time.time()
99
+ delta = now - self._tick_last
100
+ rps = int(self._tick_requests / delta) if delta > 0 else 0
101
+ data = WorkerMetrics(
102
+ worker_id=self.worker_id,
103
+ process=await get_process_metrics(),
104
+ requests=RequestsMetrics(
105
+ rps=rps,
106
+ total=self.total_requests,
107
+ successful=self.successful_requests,
108
+ failed=self.failed_requests,
109
+ ),
110
+ data_transfer=DataTransferMetrics(
111
+ bytes_in=self.bytes_in,
112
+ bytes_out=self.bytes_out,
113
+ ),
114
+ response_time=ResponseTimeMetrics(
115
+ avg=self.hist.get_mean_value(),
116
+ min=self.hist.get_min_value(),
117
+ max=self.hist.get_max_value(),
118
+ p50=self.hist.get_value_at_percentile(50),
119
+ p90=self.hist.get_value_at_percentile(90),
120
+ p99=self.hist.get_value_at_percentile(99),
121
+ ),
122
+ )
123
+
124
+ self._tick_last = now
125
+ self._tick_requests = 0
126
+
127
+ return data
@@ -1,20 +1,23 @@
1
1
  import asyncio
2
- import inspect
3
2
  import logging
4
3
  import time
5
- from collections.abc import Callable, Coroutine
6
- from typing import Any
7
4
 
8
- from mrok.datastructures import FixedSizeByteBuffer, HTTPHeaders, HTTPRequest, HTTPResponse
9
- from mrok.http.constants import MAX_REQUEST_BODY_BYTES, MAX_RESPONSE_BODY_BYTES
10
- from mrok.http.types import ASGIApp, ASGIReceive, ASGISend, Message, Scope
11
- from mrok.http.utils import must_capture_request, must_capture_response
12
- from mrok.metrics import WorkerMetricsCollector
5
+ from mrok.proxy.constants import MAX_REQUEST_BODY_BYTES, MAX_RESPONSE_BODY_BYTES
6
+ from mrok.proxy.datastructures import FixedSizeByteBuffer, HTTPHeaders, HTTPRequest, HTTPResponse
7
+ from mrok.proxy.metrics import WorkerMetricsCollector
8
+ from mrok.proxy.types import (
9
+ ASGIApp,
10
+ ASGIReceive,
11
+ ASGISend,
12
+ LifespanCallback,
13
+ Message,
14
+ ResponseCompleteCallback,
15
+ Scope,
16
+ )
17
+ from mrok.proxy.utils import must_capture_request, must_capture_response
13
18
 
14
19
  logger = logging.getLogger("mrok.proxy")
15
20
 
16
- ResponseCompleteCallback = Callable[[HTTPResponse], Coroutine[Any, Any, None] | None]
17
-
18
21
 
19
22
  class CaptureMiddleware:
20
23
  def __init__(
@@ -92,22 +95,11 @@ class CaptureMiddleware:
92
95
  body=resp_buf.getvalue() if state["capture_resp_body"] else None,
93
96
  body_truncated=resp_buf.overflow if state["capture_resp_body"] else None,
94
97
  )
95
- await asyncio.create_task(self.handle_callback(response))
96
-
97
- async def handle_callback(self, response: HTTPResponse):
98
- try:
99
- if inspect.iscoroutinefunction(self._on_response_complete):
100
- await self._on_response_complete(response)
101
- else:
102
- await asyncio.get_running_loop().run_in_executor(
103
- None, self._on_response_complete, response
104
- )
105
- except Exception:
106
- logger.exception("Error invoking callback")
98
+ asyncio.create_task(self._on_response_complete(response))
107
99
 
108
100
 
109
101
  class MetricsMiddleware:
110
- def __init__(self, app, metrics: WorkerMetricsCollector):
102
+ def __init__(self, app: ASGIApp, metrics: WorkerMetricsCollector):
111
103
  self.app = app
112
104
  self.metrics = metrics
113
105
 
@@ -116,11 +108,11 @@ class MetricsMiddleware:
116
108
  return await self.app(scope, receive, send)
117
109
 
118
110
  start_time = await self.metrics.on_request_start(scope)
119
- status_code = 500 # default if app errors early
111
+ status_code = 500
120
112
 
121
113
  async def wrapped_receive():
122
114
  msg = await receive()
123
- if msg["type"] == "http.request" and msg.get("body"):
115
+ if msg["type"] == "http.request" and msg.get("body"): # pragma: no branch
124
116
  await self.metrics.on_request_body(len(msg["body"]))
125
117
  return msg
126
118
 
@@ -131,7 +123,7 @@ class MetricsMiddleware:
131
123
  status_code = msg["status"]
132
124
  await self.metrics.on_response_start(status_code)
133
125
 
134
- elif msg["type"] == "http.response.body":
126
+ elif msg["type"] == "http.response.body": # pragma: no branch
135
127
  body = msg.get("body", b"")
136
128
  await self.metrics.on_response_chunk(len(body))
137
129
 
@@ -141,3 +133,32 @@ class MetricsMiddleware:
141
133
  await self.app(scope, wrapped_receive, wrapped_send)
142
134
  finally:
143
135
  await self.metrics.on_request_end(start_time, status_code)
136
+
137
+
138
+ class LifespanMiddleware:
139
+ def __init__(
140
+ self,
141
+ app,
142
+ on_startup: LifespanCallback | None = None,
143
+ on_shutdown: LifespanCallback | None = None,
144
+ ):
145
+ self.app = app
146
+ self.on_startup = on_startup
147
+ self.on_shutdown = on_shutdown
148
+
149
+ async def __call__(self, scope, receive, send):
150
+ if scope["type"] == "lifespan":
151
+ while True:
152
+ event = await receive()
153
+ if event["type"] == "lifespan.startup":
154
+ if self.on_startup: # pragma: no branch
155
+ await self.on_startup()
156
+ await send({"type": "lifespan.startup.complete"})
157
+
158
+ elif event["type"] == "lifespan.shutdown":
159
+ if self.on_shutdown: # pragma: no branch
160
+ await self.on_shutdown()
161
+ await send({"type": "lifespan.shutdown.complete"})
162
+ break
163
+ else:
164
+ await self.app(scope, receive, send)
mrok/proxy/streams.py ADDED
@@ -0,0 +1,45 @@
1
+ import asyncio
2
+
3
+ from httpcore import AsyncNetworkStream
4
+
5
+ from mrok.proxy.types import ASGIReceive
6
+
7
+
8
+ class AIONetworkStream(AsyncNetworkStream):
9
+ def __init__(self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter) -> None:
10
+ self._reader = reader
11
+ self._writer = writer
12
+
13
+ async def read(self, n: int, timeout: float | None = None) -> bytes:
14
+ return await asyncio.wait_for(self._reader.read(n), timeout)
15
+
16
+ async def write(self, data: bytes, timeout: float | None = None) -> None:
17
+ self._writer.write(data)
18
+ await asyncio.wait_for(self._writer.drain(), timeout)
19
+
20
+ async def aclose(self) -> None:
21
+ self._writer.close()
22
+ await self._writer.wait_closed()
23
+
24
+
25
+ class ASGIRequestBodyStream:
26
+ def __init__(self, receive: ASGIReceive):
27
+ self._receive = receive
28
+ self._more_body = True
29
+
30
+ def __aiter__(self):
31
+ return self
32
+
33
+ async def __anext__(self) -> bytes:
34
+ if not self._more_body:
35
+ raise StopAsyncIteration
36
+
37
+ msg = await self._receive()
38
+ if msg["type"] == "http.request":
39
+ chunk = msg.get("body", b"")
40
+ self._more_body = msg.get("more_body", False)
41
+ return chunk
42
+ elif msg["type"] == "http.disconnect":
43
+ raise Exception("Client disconnected.")
44
+
45
+ raise Exception("Unexpected asgi message.")
mrok/proxy/types.py ADDED
@@ -0,0 +1,15 @@
1
+ from __future__ import annotations
2
+
3
+ from collections.abc import Awaitable, Callable, Coroutine, MutableMapping
4
+ from typing import Any, Never
5
+
6
+ from mrok.proxy.datastructures import HTTPResponse
7
+
8
+ Scope = MutableMapping[str, Any]
9
+ Message = MutableMapping[str, Any]
10
+
11
+ ASGIReceive = Callable[[], Awaitable[Message]]
12
+ ASGISend = Callable[[Message], Awaitable[None]]
13
+ ASGIApp = Callable[[Scope, ASGIReceive, ASGISend], Awaitable[None]]
14
+ LifespanCallback = Callable[[], Awaitable[None]]
15
+ ResponseCompleteCallback = Callable[[HTTPResponse], Coroutine[Any, Any, Never]]
@@ -1,6 +1,6 @@
1
1
  from collections.abc import Mapping
2
2
 
3
- from mrok.http.constants import (
3
+ from mrok.proxy.constants import (
4
4
  BINARY_CONTENT_TYPES,
5
5
  BINARY_PREFIXES,
6
6
  MAX_REQUEST_BODY_BYTES,