mrok 0.3.0__py3-none-any.whl → 0.4.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (47) hide show
  1. mrok/agent/devtools/__init__.py +0 -0
  2. mrok/agent/devtools/__main__.py +34 -0
  3. mrok/agent/devtools/inspector/__init__.py +0 -0
  4. mrok/agent/devtools/inspector/__main__.py +25 -0
  5. mrok/agent/devtools/inspector/app.py +556 -0
  6. mrok/agent/devtools/inspector/server.py +18 -0
  7. mrok/agent/sidecar/app.py +9 -10
  8. mrok/agent/sidecar/main.py +35 -16
  9. mrok/agent/ziticorn.py +27 -18
  10. mrok/cli/commands/__init__.py +2 -1
  11. mrok/cli/commands/agent/__init__.py +2 -0
  12. mrok/cli/commands/agent/dev/__init__.py +7 -0
  13. mrok/cli/commands/agent/dev/console.py +25 -0
  14. mrok/cli/commands/agent/dev/web.py +37 -0
  15. mrok/cli/commands/agent/run/asgi.py +35 -16
  16. mrok/cli/commands/agent/run/sidecar.py +29 -13
  17. mrok/cli/commands/agent/utils.py +5 -0
  18. mrok/cli/commands/controller/run.py +1 -5
  19. mrok/cli/commands/proxy/__init__.py +6 -0
  20. mrok/cli/commands/proxy/run.py +49 -0
  21. mrok/cli/utils.py +5 -0
  22. mrok/conf.py +6 -0
  23. mrok/controller/auth.py +2 -2
  24. mrok/datastructures.py +159 -0
  25. mrok/http/config.py +3 -6
  26. mrok/http/constants.py +22 -0
  27. mrok/http/forwarder.py +62 -23
  28. mrok/http/lifespan.py +29 -0
  29. mrok/http/middlewares.py +143 -0
  30. mrok/http/types.py +43 -0
  31. mrok/http/utils.py +90 -0
  32. mrok/logging.py +22 -0
  33. mrok/master.py +269 -0
  34. mrok/metrics.py +139 -0
  35. mrok/proxy/__init__.py +3 -0
  36. mrok/proxy/app.py +73 -0
  37. mrok/proxy/dataclasses.py +12 -0
  38. mrok/proxy/main.py +58 -0
  39. mrok/proxy/streams.py +124 -0
  40. mrok/proxy/types.py +12 -0
  41. mrok/proxy/ziti.py +173 -0
  42. {mrok-0.3.0.dist-info → mrok-0.4.0.dist-info}/METADATA +7 -1
  43. {mrok-0.3.0.dist-info → mrok-0.4.0.dist-info}/RECORD +46 -20
  44. mrok/http/master.py +0 -132
  45. {mrok-0.3.0.dist-info → mrok-0.4.0.dist-info}/WHEEL +0 -0
  46. {mrok-0.3.0.dist-info → mrok-0.4.0.dist-info}/entry_points.txt +0 -0
  47. {mrok-0.3.0.dist-info → mrok-0.4.0.dist-info}/licenses/LICENSE.txt +0 -0
mrok/master.py ADDED
@@ -0,0 +1,269 @@
1
+ import asyncio
2
+ import contextlib
3
+ import json
4
+ import logging
5
+ import os
6
+ import signal
7
+ import threading
8
+ import time
9
+ from abc import ABC, abstractmethod
10
+ from pathlib import Path
11
+ from uuid import uuid4
12
+
13
+ import zmq
14
+ import zmq.asyncio
15
+ from watchfiles import watch
16
+ from watchfiles.filters import PythonFilter
17
+ from watchfiles.run import CombinedProcess, start_process
18
+
19
+ from mrok.conf import get_settings
20
+ from mrok.datastructures import Event, HTTPResponse, Meta, Status
21
+ from mrok.http.config import MrokBackendConfig
22
+ from mrok.http.lifespan import LifespanWrapper
23
+ from mrok.http.middlewares import CaptureMiddleware, MetricsMiddleware
24
+ from mrok.http.server import MrokServer
25
+ from mrok.http.types import ASGIApp
26
+ from mrok.logging import setup_logging
27
+ from mrok.metrics import WorkerMetricsCollector
28
+
29
+ logger = logging.getLogger("mrok.agent")
30
+
31
+ MONITOR_THREAD_JOIN_TIMEOUT = 5
32
+ MONITOR_THREAD_CHECK_DELAY = 1
33
+ MONITOR_THREAD_ERROR_DELAY = 3
34
+
35
+
36
+ def print_path(path):
37
+ try:
38
+ return f'"{path.relative_to(Path.cwd())}"'
39
+ except ValueError:
40
+ return f'"{path}"'
41
+
42
+
43
+ def start_events_router(events_pub_port: int, events_sub_port: int):
44
+ setup_logging(get_settings())
45
+ context = zmq.Context()
46
+ frontend = context.socket(zmq.XSUB)
47
+ frontend.bind(f"tcp://localhost:{events_pub_port}")
48
+ backend = context.socket(zmq.XPUB)
49
+ backend.bind(f"tcp://localhost:{events_sub_port}")
50
+
51
+ try:
52
+ logger.info(f"Events router process started: {os.getpid()}")
53
+ zmq.proxy(frontend, backend)
54
+ except KeyboardInterrupt:
55
+ pass
56
+ finally:
57
+ frontend.close()
58
+ backend.close()
59
+ context.term()
60
+
61
+
62
+ def start_uvicorn_worker(
63
+ worker_id: str,
64
+ app: ASGIApp,
65
+ identity_file: str,
66
+ events_pub_port: int,
67
+ metrics_interval: float = 5.0,
68
+ ):
69
+ import sys
70
+
71
+ sys.path.insert(0, os.getcwd())
72
+ setup_logging(get_settings())
73
+ identity = json.load(open(identity_file))
74
+ meta = Meta(**identity["mrok"])
75
+ ctx = zmq.asyncio.Context()
76
+ pub = ctx.socket(zmq.PUB)
77
+ pub.connect(f"tcp://localhost:{events_pub_port}")
78
+ metrics = WorkerMetricsCollector(worker_id)
79
+
80
+ task = None
81
+
82
+ async def status_sender():
83
+ while True:
84
+ snap = await metrics.snapshot()
85
+ logger.info(f"New metrics snapshot taken: {snap}")
86
+ event = Event(type="status", data=Status(meta=meta, metrics=snap))
87
+ await pub.send_string(event.model_dump_json())
88
+ await asyncio.sleep(metrics_interval)
89
+
90
+ async def on_startup(): # noqa
91
+ nonlocal task
92
+ task = asyncio.create_task(status_sender())
93
+
94
+ async def on_shutdown(): # noqa
95
+ if task:
96
+ task.cancel()
97
+
98
+ async def on_response_complete(response: HTTPResponse):
99
+ event = Event(type="response", data=response)
100
+ await pub.send_string(event.model_dump_json())
101
+
102
+ config = MrokBackendConfig(
103
+ LifespanWrapper(
104
+ MetricsMiddleware(
105
+ CaptureMiddleware(
106
+ app,
107
+ on_response_complete,
108
+ ),
109
+ metrics,
110
+ ),
111
+ on_startup=on_startup,
112
+ on_shutdown=on_shutdown,
113
+ ),
114
+ identity_file,
115
+ )
116
+ server = MrokServer(config)
117
+ with contextlib.suppress(KeyboardInterrupt, asyncio.CancelledError):
118
+ server.run()
119
+
120
+
121
+ class MasterBase(ABC):
122
+ def __init__(
123
+ self,
124
+ identity_file: str,
125
+ workers: int,
126
+ reload: bool,
127
+ events_pub_port: int,
128
+ events_sub_port: int,
129
+ metrics_interval: float = 5.0,
130
+ ):
131
+ self.identity_file = identity_file
132
+ self.workers = workers
133
+ self.reload = reload
134
+ self.events_pub_port = events_pub_port
135
+ self.events_sub_port = events_sub_port
136
+ self.metrics_interval = metrics_interval
137
+ self.worker_identifiers = [str(uuid4()) for _ in range(workers)]
138
+ self.worker_processes: dict[str, CombinedProcess] = {}
139
+ self.zmq_pubsub_router_process = None
140
+ self.monitor_thread = threading.Thread(target=self.monitor_workers, daemon=True)
141
+ self.stop_event = threading.Event()
142
+ self.pause_event = threading.Event()
143
+ self.watch_filter = PythonFilter(ignore_paths=None)
144
+ self.watcher = watch(
145
+ Path.cwd(),
146
+ watch_filter=self.watch_filter,
147
+ stop_event=self.stop_event,
148
+ yield_on_timeout=True,
149
+ )
150
+ self.setup_signals_handler()
151
+
152
+ @abstractmethod
153
+ def get_asgi_app(self):
154
+ pass
155
+
156
+ def setup_signals_handler(self):
157
+ for sig in (signal.SIGINT, signal.SIGTERM):
158
+ signal.signal(sig, self.handle_signal)
159
+
160
+ def handle_signal(self, *args, **kwargs):
161
+ self.stop_event.set()
162
+
163
+ def start_worker(self, worker_id: str):
164
+ """Start a single worker process"""
165
+
166
+ p = start_process(
167
+ start_uvicorn_worker,
168
+ "function",
169
+ (
170
+ worker_id,
171
+ self.get_asgi_app(),
172
+ self.identity_file,
173
+ self.events_pub_port,
174
+ self.metrics_interval,
175
+ ),
176
+ None,
177
+ )
178
+ logger.info(f"Worker {worker_id} [{p.pid}] started")
179
+ return p
180
+
181
+ def start(self):
182
+ self.start_events_router()
183
+ self.start_workers()
184
+ self.monitor_thread.start()
185
+
186
+ def stop(self):
187
+ if self.monitor_thread.is_alive():
188
+ logger.debug("Wait for monitor worker to exit")
189
+ self.monitor_thread.join(timeout=MONITOR_THREAD_JOIN_TIMEOUT)
190
+ self.stop_workers()
191
+ self.stop_events_router()
192
+
193
+ def start_events_router(self):
194
+ self.zmq_pubsub_router_process = start_process(
195
+ start_events_router,
196
+ "function",
197
+ (
198
+ self.events_pub_port,
199
+ self.events_sub_port,
200
+ ),
201
+ None,
202
+ )
203
+
204
+ def stop_events_router(self):
205
+ self.zmq_pubsub_router_process.stop(sigint_timeout=5, sigkill_timeout=1)
206
+
207
+ def start_workers(self):
208
+ for i in range(self.workers):
209
+ worker_id = self.worker_identifiers[i]
210
+ p = self.start_worker(worker_id)
211
+ self.worker_processes[worker_id] = p
212
+
213
+ def stop_workers(self):
214
+ for process in self.worker_processes.values():
215
+ if process.is_alive():
216
+ process.stop(sigint_timeout=5, sigkill_timeout=1)
217
+ self.worker_processes.clear()
218
+
219
+ def restart(self):
220
+ self.pause_event.set()
221
+ self.stop_workers()
222
+ self.start_workers()
223
+ self.pause_event.clear()
224
+
225
+ def monitor_workers(self):
226
+ while not self.stop_event.is_set():
227
+ try:
228
+ self.pause_event.wait()
229
+ for worker_id, process in self.worker_processes.items():
230
+ if not process.is_alive():
231
+ logger.warning(f"Worker {worker_id} [{process.pid}] died unexpectedly")
232
+ process.stop(sigint_timeout=1, sigkill_timeout=1)
233
+ new_process = self.start_worker(worker_id)
234
+ self.worker_processes[worker_id] = new_process
235
+ logger.info(
236
+ f"Restarted worker {worker_id} [{process.pid}] -> [{new_process.pid}]"
237
+ )
238
+
239
+ time.sleep(MONITOR_THREAD_CHECK_DELAY)
240
+
241
+ except Exception as e:
242
+ logger.error(f"Error in worker monitoring: {e}")
243
+ time.sleep(MONITOR_THREAD_ERROR_DELAY)
244
+
245
+ def __iter__(self):
246
+ return self
247
+
248
+ def __next__(self):
249
+ changes = next(self.watcher)
250
+ if changes:
251
+ return list({Path(change[1]) for change in changes})
252
+ return None
253
+
254
+ def run(self):
255
+ setup_logging(get_settings())
256
+ logger.info(f"Master process started: {os.getpid()}")
257
+ self.start()
258
+ try:
259
+ if self.reload:
260
+ for files_changed in self:
261
+ if files_changed:
262
+ logger.warning(
263
+ f"{', '.join(map(print_path, files_changed))} changed, reloading...",
264
+ )
265
+ self.restart()
266
+ else:
267
+ self.stop_event.wait()
268
+ finally:
269
+ self.stop()
mrok/metrics.py ADDED
@@ -0,0 +1,139 @@
1
+ from __future__ import annotations
2
+
3
+ import asyncio
4
+ import logging
5
+ import os
6
+ import time
7
+
8
+ import psutil
9
+ from hdrh.histogram import HdrHistogram
10
+
11
+ from mrok.datastructures import (
12
+ DataTransferMetrics,
13
+ ProcessMetrics,
14
+ RequestsMetrics,
15
+ ResponseTimeMetrics,
16
+ WorkerMetrics,
17
+ )
18
+
19
+ logger = logging.getLogger("mrok.proxy")
20
+
21
+
22
+ def _collect_process_usage(interval: float) -> ProcessMetrics:
23
+ try:
24
+ proc = psutil.Process(os.getpid())
25
+ except psutil.NoSuchProcess:
26
+ return ProcessMetrics(cpu=0.0, mem=0.0)
27
+
28
+ total_cpu = 0.0
29
+ total_mem = 0.0
30
+
31
+ try:
32
+ total_cpu = proc.cpu_percent(None)
33
+ except Exception:
34
+ total_cpu = 0.0
35
+
36
+ if interval and interval > 0:
37
+ time.sleep(interval)
38
+
39
+ try:
40
+ total_cpu = proc.cpu_percent(None)
41
+ except Exception:
42
+ total_cpu = 0.0
43
+
44
+ try:
45
+ total_mem = proc.memory_percent()
46
+ except Exception:
47
+ total_mem = 0.0
48
+
49
+ return ProcessMetrics(cpu=total_cpu, mem=total_mem)
50
+
51
+
52
+ async def get_process_and_children_usage(interval: float = 0.1) -> ProcessMetrics:
53
+ return await asyncio.to_thread(_collect_process_usage, interval)
54
+
55
+
56
+ class WorkerMetricsCollector:
57
+ def __init__(self, worker_id: str, lowest=1, highest=60000, sigfigs=3):
58
+ # Request-level counters
59
+ self.worker_id = worker_id
60
+ self.total_requests = 0
61
+ self.successful_requests = 0
62
+ self.failed_requests = 0
63
+ self.bytes_in = 0
64
+ self.bytes_out = 0
65
+
66
+ # RPS
67
+ self._tick_last = time.time()
68
+ self._tick_requests = 0
69
+
70
+ # latency histogram
71
+ self.hist = HdrHistogram(lowest, highest, sigfigs)
72
+
73
+ # async lock
74
+ self._lock = asyncio.Lock()
75
+
76
+ async def on_request_start(self, scope):
77
+ return time.perf_counter()
78
+
79
+ async def on_request_body(self, length):
80
+ async with self._lock:
81
+ self.bytes_in += length
82
+
83
+ async def on_response_start(self, status_code):
84
+ pass # reserved
85
+
86
+ async def on_response_chunk(self, length):
87
+ async with self._lock:
88
+ self.bytes_out += length
89
+
90
+ async def on_request_end(self, start_time, status_code):
91
+ elapsed_ms = (time.perf_counter() - start_time) * 1000
92
+
93
+ async with self._lock:
94
+ self.total_requests += 1
95
+ self._tick_requests += 1
96
+
97
+ if status_code < 500:
98
+ self.successful_requests += 1
99
+ else:
100
+ self.failed_requests += 1
101
+
102
+ self.hist.record_value(elapsed_ms)
103
+
104
+ async def snapshot(self) -> WorkerMetrics:
105
+ try:
106
+ async with self._lock:
107
+ now = time.time()
108
+ delta = now - self._tick_last
109
+ rps = int(self._tick_requests / delta) if delta > 0 else 0
110
+ data = WorkerMetrics(
111
+ worker_id=self.worker_id,
112
+ process=await get_process_and_children_usage(),
113
+ requests=RequestsMetrics(
114
+ rps=rps,
115
+ total=self.total_requests,
116
+ successful=self.successful_requests,
117
+ failed=self.failed_requests,
118
+ ),
119
+ data_transfer=DataTransferMetrics(
120
+ bytes_in=self.bytes_in,
121
+ bytes_out=self.bytes_out,
122
+ ),
123
+ response_time=ResponseTimeMetrics(
124
+ avg=self.hist.get_mean_value(),
125
+ min=self.hist.get_min_value(),
126
+ max=self.hist.get_max_value(),
127
+ p50=self.hist.get_value_at_percentile(50),
128
+ p90=self.hist.get_value_at_percentile(90),
129
+ p99=self.hist.get_value_at_percentile(99),
130
+ ),
131
+ )
132
+
133
+ self._tick_last = now
134
+ self._tick_requests = 0
135
+
136
+ return data
137
+ except Exception:
138
+ logger.exception("Exception calculating snapshot")
139
+ raise
mrok/proxy/__init__.py ADDED
@@ -0,0 +1,3 @@
1
+ from mrok.proxy.main import run
2
+
3
+ __all__ = ["run"]
mrok/proxy/app.py ADDED
@@ -0,0 +1,73 @@
1
+ import logging
2
+ import os
3
+ from pathlib import Path
4
+
5
+ from mrok.conf import get_settings
6
+ from mrok.http.forwarder import ForwardAppBase
7
+ from mrok.http.types import Scope, StreamReader, StreamWriter
8
+ from mrok.logging import setup_logging
9
+ from mrok.proxy.ziti import ZitiConnectionManager
10
+
11
+ logger = logging.getLogger("mrok.proxy")
12
+
13
+
14
+ class ProxyError(Exception):
15
+ pass
16
+
17
+
18
+ class ProxyApp(ForwardAppBase):
19
+ def __init__(
20
+ self,
21
+ identity_file: str | Path,
22
+ *,
23
+ read_chunk_size: int = 65536,
24
+ ziti_connection_ttl_seconds: float = 60,
25
+ ziti_conn_cache_purge_interval_seconds: float = 10,
26
+ ) -> None:
27
+ super().__init__(read_chunk_size=read_chunk_size)
28
+ self._identity_file = identity_file
29
+ settings = get_settings()
30
+ self._proxy_wildcard_domain = (
31
+ settings.proxy.domain
32
+ if settings.proxy.domain[0] == "."
33
+ else f".{settings.proxy.domain}"
34
+ )
35
+ self._conn_manager = ZitiConnectionManager(
36
+ identity_file,
37
+ ttl_seconds=ziti_connection_ttl_seconds,
38
+ purge_interval=ziti_conn_cache_purge_interval_seconds,
39
+ )
40
+
41
+ def get_target_name(self, headers: dict[str, str]) -> str:
42
+ header_value = headers.get("x-forwarded-for", headers.get("host"))
43
+ if not header_value:
44
+ raise ProxyError(
45
+ "Cannot determine the target OpenZiti service/terminator name, "
46
+ "neither Host nor X-Forwarded-For headers have been sent in the request.",
47
+ )
48
+ if ":" in header_value:
49
+ header_value, _ = header_value.split(":", 1)
50
+ if not header_value.endswith(self._proxy_wildcard_domain):
51
+ raise ProxyError(
52
+ f"Unexpected value for Host or X-Forwarded-For header: `{header_value}`."
53
+ )
54
+
55
+ return header_value[: -len(self._proxy_wildcard_domain)]
56
+
57
+ async def startup(self):
58
+ setup_logging(get_settings())
59
+ await self._conn_manager.start()
60
+ logger.info(f"Proxy app startup completed: {os.getpid()}")
61
+
62
+ async def shutdown(self):
63
+ await self._conn_manager.stop()
64
+ logger.info(f"Proxy app shutdown completed: {os.getpid()}")
65
+
66
+ async def select_backend(
67
+ self,
68
+ scope: Scope,
69
+ headers: dict[str, str],
70
+ ) -> tuple[StreamReader, StreamWriter] | tuple[None, None]:
71
+ target_name = self.get_target_name(headers)
72
+
73
+ return await self._conn_manager.get(target_name)
@@ -0,0 +1,12 @@
1
+ from __future__ import annotations
2
+
3
+ from dataclasses import dataclass
4
+
5
+ from mrok.http.types import StreamReader, StreamWriter
6
+
7
+
8
+ @dataclass
9
+ class CachedStreamEntry:
10
+ reader: StreamReader
11
+ writer: StreamWriter
12
+ last_access: float
mrok/proxy/main.py ADDED
@@ -0,0 +1,58 @@
1
+ from collections.abc import Callable
2
+ from pathlib import Path
3
+ from typing import Any
4
+
5
+ from gunicorn.app.base import BaseApplication
6
+ from uvicorn_worker import UvicornWorker
7
+
8
+ from mrok.conf import get_settings
9
+ from mrok.http.lifespan import LifespanWrapper
10
+ from mrok.logging import get_logging_config
11
+ from mrok.proxy.app import ProxyApp
12
+
13
+
14
+ class MrokUvicornWorker(UvicornWorker):
15
+ CONFIG_KWARGS: dict[str, Any] = {"loop": "asyncio", "http": "auto", "lifespan": "on"}
16
+
17
+
18
+ class StandaloneApplication(BaseApplication): # pragma: no cover
19
+ def __init__(self, application: Callable, options: dict[str, Any] | None = None):
20
+ self.options = options or {}
21
+ self.application = application
22
+ super().__init__()
23
+
24
+ def load_config(self):
25
+ config = {
26
+ key: value
27
+ for key, value in self.options.items()
28
+ if key in self.cfg.settings and value is not None
29
+ }
30
+ for key, value in config.items():
31
+ self.cfg.set(key.lower(), value)
32
+
33
+ def load(self):
34
+ return self.application
35
+
36
+
37
+ def run(
38
+ identity_file: str | Path,
39
+ host: str,
40
+ port: int,
41
+ workers: int,
42
+ ):
43
+ proxy_app = ProxyApp(identity_file)
44
+
45
+ asgi_app = LifespanWrapper(
46
+ proxy_app,
47
+ proxy_app.startup,
48
+ proxy_app.shutdown,
49
+ )
50
+ options = {
51
+ "bind": f"{host}:{port}",
52
+ "workers": workers,
53
+ "worker_class": "mrok.proxy.main.MrokUvicornWorker",
54
+ "logconfig_dict": get_logging_config(get_settings()),
55
+ "reload": False,
56
+ }
57
+
58
+ StandaloneApplication(asgi_app, options).run()
mrok/proxy/streams.py ADDED
@@ -0,0 +1,124 @@
1
+ import asyncio
2
+
3
+ from mrok.proxy.types import ConnectionCache, ConnectionKey
4
+
5
+
6
+ class CachedStreamReader:
7
+ def __init__(
8
+ self,
9
+ reader: asyncio.StreamReader,
10
+ key: ConnectionKey,
11
+ manager: ConnectionCache,
12
+ ):
13
+ self._reader = reader
14
+ self._key = key
15
+ self._manager = manager
16
+
17
+ async def read(self, n: int = -1) -> bytes:
18
+ try:
19
+ return await self._reader.read(n)
20
+ except (
21
+ asyncio.CancelledError,
22
+ asyncio.IncompleteReadError,
23
+ asyncio.LimitOverrunError,
24
+ BrokenPipeError,
25
+ ConnectionAbortedError,
26
+ ConnectionResetError,
27
+ RuntimeError,
28
+ TimeoutError,
29
+ UnicodeDecodeError,
30
+ ):
31
+ asyncio.create_task(self._manager.invalidate(self._key))
32
+ raise
33
+
34
+ async def readexactly(self, n: int) -> bytes:
35
+ try:
36
+ return await self._reader.readexactly(n)
37
+ except (
38
+ asyncio.CancelledError,
39
+ asyncio.IncompleteReadError,
40
+ asyncio.LimitOverrunError,
41
+ BrokenPipeError,
42
+ ConnectionAbortedError,
43
+ ConnectionResetError,
44
+ RuntimeError,
45
+ TimeoutError,
46
+ UnicodeDecodeError,
47
+ ):
48
+ asyncio.create_task(self._manager.invalidate(self._key))
49
+ raise
50
+
51
+ async def readline(self) -> bytes:
52
+ try:
53
+ return await self._reader.readline()
54
+ except (
55
+ asyncio.CancelledError,
56
+ asyncio.IncompleteReadError,
57
+ asyncio.LimitOverrunError,
58
+ BrokenPipeError,
59
+ ConnectionAbortedError,
60
+ ConnectionResetError,
61
+ RuntimeError,
62
+ TimeoutError,
63
+ UnicodeDecodeError,
64
+ ):
65
+ asyncio.create_task(self._manager.invalidate(self._key))
66
+ raise
67
+
68
+ def at_eof(self) -> bool:
69
+ return self._reader.at_eof()
70
+
71
+ @property
72
+ def underlying(self) -> asyncio.StreamReader:
73
+ return self._reader
74
+
75
+
76
+ class CachedStreamWriter:
77
+ def __init__(
78
+ self,
79
+ writer: asyncio.StreamWriter,
80
+ key: ConnectionKey,
81
+ manager: ConnectionCache,
82
+ ):
83
+ self._writer = writer
84
+ self._key = key
85
+ self._manager = manager
86
+
87
+ def write(self, data: bytes) -> None:
88
+ try:
89
+ return self._writer.write(data)
90
+ except (RuntimeError, TypeError):
91
+ asyncio.create_task(self._manager.invalidate(self._key))
92
+ raise
93
+
94
+ async def drain(self) -> None:
95
+ try:
96
+ return await self._writer.drain()
97
+ except (
98
+ asyncio.CancelledError,
99
+ BrokenPipeError,
100
+ ConnectionAbortedError,
101
+ ConnectionResetError,
102
+ RuntimeError,
103
+ TimeoutError,
104
+ ):
105
+ asyncio.create_task(self._manager.invalidate(self._key))
106
+ raise
107
+
108
+ def close(self) -> None:
109
+ return self._writer.close()
110
+
111
+ async def wait_closed(self) -> None:
112
+ try:
113
+ return await self._writer.wait_closed()
114
+ except (ConnectionResetError, BrokenPipeError):
115
+ asyncio.create_task(self._manager.invalidate(self._key))
116
+ raise
117
+
118
+ @property
119
+ def transport(self):
120
+ return self._writer.transport
121
+
122
+ @property
123
+ def underlying(self) -> asyncio.StreamWriter:
124
+ return self._writer
mrok/proxy/types.py ADDED
@@ -0,0 +1,12 @@
1
+ from __future__ import annotations
2
+
3
+ from typing import Protocol
4
+
5
+ from mrok.http.types import StreamReader, StreamWriter
6
+
7
+ ConnectionKey = tuple[str, str | None]
8
+ CachedStream = tuple[StreamReader, StreamWriter]
9
+
10
+
11
+ class ConnectionCache(Protocol):
12
+ async def invalidate(self, key: ConnectionKey) -> None: ...