mrok 0.3.0__py3-none-any.whl → 0.4.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (47) hide show
  1. mrok/agent/devtools/__init__.py +0 -0
  2. mrok/agent/devtools/__main__.py +34 -0
  3. mrok/agent/devtools/inspector/__init__.py +0 -0
  4. mrok/agent/devtools/inspector/__main__.py +25 -0
  5. mrok/agent/devtools/inspector/app.py +556 -0
  6. mrok/agent/devtools/inspector/server.py +18 -0
  7. mrok/agent/sidecar/app.py +9 -10
  8. mrok/agent/sidecar/main.py +35 -16
  9. mrok/agent/ziticorn.py +27 -18
  10. mrok/cli/commands/__init__.py +2 -1
  11. mrok/cli/commands/agent/__init__.py +2 -0
  12. mrok/cli/commands/agent/dev/__init__.py +7 -0
  13. mrok/cli/commands/agent/dev/console.py +25 -0
  14. mrok/cli/commands/agent/dev/web.py +37 -0
  15. mrok/cli/commands/agent/run/asgi.py +35 -16
  16. mrok/cli/commands/agent/run/sidecar.py +29 -13
  17. mrok/cli/commands/agent/utils.py +5 -0
  18. mrok/cli/commands/controller/run.py +1 -5
  19. mrok/cli/commands/proxy/__init__.py +6 -0
  20. mrok/cli/commands/proxy/run.py +49 -0
  21. mrok/cli/utils.py +5 -0
  22. mrok/conf.py +6 -0
  23. mrok/controller/auth.py +2 -2
  24. mrok/datastructures.py +159 -0
  25. mrok/http/config.py +3 -6
  26. mrok/http/constants.py +22 -0
  27. mrok/http/forwarder.py +62 -23
  28. mrok/http/lifespan.py +29 -0
  29. mrok/http/middlewares.py +143 -0
  30. mrok/http/types.py +43 -0
  31. mrok/http/utils.py +90 -0
  32. mrok/logging.py +22 -0
  33. mrok/master.py +272 -0
  34. mrok/metrics.py +139 -0
  35. mrok/proxy/__init__.py +3 -0
  36. mrok/proxy/app.py +77 -0
  37. mrok/proxy/dataclasses.py +12 -0
  38. mrok/proxy/main.py +58 -0
  39. mrok/proxy/streams.py +124 -0
  40. mrok/proxy/types.py +12 -0
  41. mrok/proxy/ziti.py +173 -0
  42. {mrok-0.3.0.dist-info → mrok-0.4.1.dist-info}/METADATA +7 -1
  43. {mrok-0.3.0.dist-info → mrok-0.4.1.dist-info}/RECORD +46 -20
  44. mrok/http/master.py +0 -132
  45. {mrok-0.3.0.dist-info → mrok-0.4.1.dist-info}/WHEEL +0 -0
  46. {mrok-0.3.0.dist-info → mrok-0.4.1.dist-info}/entry_points.txt +0 -0
  47. {mrok-0.3.0.dist-info → mrok-0.4.1.dist-info}/licenses/LICENSE.txt +0 -0
mrok/master.py ADDED
@@ -0,0 +1,272 @@
1
+ import asyncio
2
+ import contextlib
3
+ import json
4
+ import logging
5
+ import os
6
+ import signal
7
+ import threading
8
+ import time
9
+ from abc import ABC, abstractmethod
10
+ from pathlib import Path
11
+ from uuid import uuid4
12
+
13
+ import zmq
14
+ import zmq.asyncio
15
+ from uvicorn.importer import import_from_string
16
+ from watchfiles import watch
17
+ from watchfiles.filters import PythonFilter
18
+ from watchfiles.run import CombinedProcess, start_process
19
+
20
+ from mrok.conf import get_settings
21
+ from mrok.datastructures import Event, HTTPResponse, Meta, Status
22
+ from mrok.http.config import MrokBackendConfig
23
+ from mrok.http.lifespan import LifespanWrapper
24
+ from mrok.http.middlewares import CaptureMiddleware, MetricsMiddleware
25
+ from mrok.http.server import MrokServer
26
+ from mrok.http.types import ASGIApp
27
+ from mrok.logging import setup_logging
28
+ from mrok.metrics import WorkerMetricsCollector
29
+
30
+ logger = logging.getLogger("mrok.agent")
31
+
32
+ MONITOR_THREAD_JOIN_TIMEOUT = 5
33
+ MONITOR_THREAD_CHECK_DELAY = 1
34
+ MONITOR_THREAD_ERROR_DELAY = 3
35
+
36
+
37
+ def print_path(path):
38
+ try:
39
+ return f'"{path.relative_to(Path.cwd())}"'
40
+ except ValueError:
41
+ return f'"{path}"'
42
+
43
+
44
+ def start_events_router(events_pub_port: int, events_sub_port: int):
45
+ setup_logging(get_settings())
46
+ context = zmq.Context()
47
+ frontend = context.socket(zmq.XSUB)
48
+ frontend.bind(f"tcp://localhost:{events_pub_port}")
49
+ backend = context.socket(zmq.XPUB)
50
+ backend.bind(f"tcp://localhost:{events_sub_port}")
51
+
52
+ try:
53
+ logger.info(f"Events router process started: {os.getpid()}")
54
+ zmq.proxy(frontend, backend)
55
+ except KeyboardInterrupt:
56
+ pass
57
+ finally:
58
+ frontend.close()
59
+ backend.close()
60
+ context.term()
61
+
62
+
63
+ def start_uvicorn_worker(
64
+ worker_id: str,
65
+ app: ASGIApp,
66
+ identity_file: str,
67
+ events_pub_port: int,
68
+ metrics_interval: float = 5.0,
69
+ ):
70
+ import sys
71
+
72
+ sys.path.insert(0, os.getcwd())
73
+ if isinstance(app, str):
74
+ app = import_from_string(app)
75
+
76
+ setup_logging(get_settings())
77
+ identity = json.load(open(identity_file))
78
+ meta = Meta(**identity["mrok"])
79
+ ctx = zmq.asyncio.Context()
80
+ pub = ctx.socket(zmq.PUB)
81
+ pub.connect(f"tcp://localhost:{events_pub_port}")
82
+ metrics = WorkerMetricsCollector(worker_id)
83
+
84
+ task = None
85
+
86
+ async def status_sender():
87
+ while True:
88
+ snap = await metrics.snapshot()
89
+ event = Event(type="status", data=Status(meta=meta, metrics=snap))
90
+ await pub.send_string(event.model_dump_json())
91
+ await asyncio.sleep(metrics_interval)
92
+
93
+ async def on_startup(): # noqa
94
+ nonlocal task
95
+ task = asyncio.create_task(status_sender())
96
+
97
+ async def on_shutdown(): # noqa
98
+ if task:
99
+ task.cancel()
100
+
101
+ async def on_response_complete(response: HTTPResponse):
102
+ event = Event(type="response", data=response)
103
+ await pub.send_string(event.model_dump_json())
104
+
105
+ config = MrokBackendConfig(
106
+ LifespanWrapper(
107
+ MetricsMiddleware(
108
+ CaptureMiddleware(
109
+ app,
110
+ on_response_complete,
111
+ ),
112
+ metrics,
113
+ ),
114
+ on_startup=on_startup,
115
+ on_shutdown=on_shutdown,
116
+ ),
117
+ identity_file,
118
+ )
119
+ server = MrokServer(config)
120
+ with contextlib.suppress(KeyboardInterrupt, asyncio.CancelledError):
121
+ server.run()
122
+
123
+
124
+ class MasterBase(ABC):
125
+ def __init__(
126
+ self,
127
+ identity_file: str,
128
+ workers: int,
129
+ reload: bool,
130
+ events_pub_port: int,
131
+ events_sub_port: int,
132
+ metrics_interval: float = 5.0,
133
+ ):
134
+ self.identity_file = identity_file
135
+ self.workers = workers
136
+ self.reload = reload
137
+ self.events_pub_port = events_pub_port
138
+ self.events_sub_port = events_sub_port
139
+ self.metrics_interval = metrics_interval
140
+ self.worker_identifiers = [str(uuid4()) for _ in range(workers)]
141
+ self.worker_processes: dict[str, CombinedProcess] = {}
142
+ self.zmq_pubsub_router_process = None
143
+ self.monitor_thread = threading.Thread(target=self.monitor_workers, daemon=True)
144
+ self.stop_event = threading.Event()
145
+ self.pause_event = threading.Event()
146
+ self.watch_filter = PythonFilter(ignore_paths=None)
147
+ self.watcher = watch(
148
+ Path.cwd(),
149
+ watch_filter=self.watch_filter,
150
+ stop_event=self.stop_event,
151
+ yield_on_timeout=True,
152
+ )
153
+ self.setup_signals_handler()
154
+
155
+ @abstractmethod
156
+ def get_asgi_app(self):
157
+ pass
158
+
159
+ def setup_signals_handler(self):
160
+ for sig in (signal.SIGINT, signal.SIGTERM):
161
+ signal.signal(sig, self.handle_signal)
162
+
163
+ def handle_signal(self, *args, **kwargs):
164
+ self.stop_event.set()
165
+
166
+ def start_worker(self, worker_id: str):
167
+ """Start a single worker process"""
168
+
169
+ p = start_process(
170
+ start_uvicorn_worker,
171
+ "function",
172
+ (
173
+ worker_id,
174
+ self.get_asgi_app(),
175
+ self.identity_file,
176
+ self.events_pub_port,
177
+ self.metrics_interval,
178
+ ),
179
+ None,
180
+ )
181
+ logger.info(f"Worker {worker_id} [{p.pid}] started")
182
+ return p
183
+
184
+ def start(self):
185
+ self.start_events_router()
186
+ self.start_workers()
187
+ self.monitor_thread.start()
188
+
189
+ def stop(self):
190
+ if self.monitor_thread.is_alive():
191
+ logger.debug("Wait for monitor worker to exit")
192
+ self.monitor_thread.join(timeout=MONITOR_THREAD_JOIN_TIMEOUT)
193
+ self.stop_workers()
194
+ self.stop_events_router()
195
+
196
+ def start_events_router(self):
197
+ self.zmq_pubsub_router_process = start_process(
198
+ start_events_router,
199
+ "function",
200
+ (
201
+ self.events_pub_port,
202
+ self.events_sub_port,
203
+ ),
204
+ None,
205
+ )
206
+
207
+ def stop_events_router(self):
208
+ self.zmq_pubsub_router_process.stop(sigint_timeout=5, sigkill_timeout=1)
209
+
210
+ def start_workers(self):
211
+ for i in range(self.workers):
212
+ worker_id = self.worker_identifiers[i]
213
+ p = self.start_worker(worker_id)
214
+ self.worker_processes[worker_id] = p
215
+
216
+ def stop_workers(self):
217
+ for process in self.worker_processes.values():
218
+ if process.is_alive():
219
+ process.stop(sigint_timeout=5, sigkill_timeout=1)
220
+ self.worker_processes.clear()
221
+
222
+ def restart(self):
223
+ self.pause_event.set()
224
+ self.stop_workers()
225
+ self.start_workers()
226
+ self.pause_event.clear()
227
+
228
+ def monitor_workers(self):
229
+ while not self.stop_event.is_set():
230
+ try:
231
+ self.pause_event.wait()
232
+ for worker_id, process in self.worker_processes.items():
233
+ if not process.is_alive():
234
+ logger.warning(f"Worker {worker_id} [{process.pid}] died unexpectedly")
235
+ process.stop(sigint_timeout=1, sigkill_timeout=1)
236
+ new_process = self.start_worker(worker_id)
237
+ self.worker_processes[worker_id] = new_process
238
+ logger.info(
239
+ f"Restarted worker {worker_id} [{process.pid}] -> [{new_process.pid}]"
240
+ )
241
+
242
+ time.sleep(MONITOR_THREAD_CHECK_DELAY)
243
+
244
+ except Exception as e:
245
+ logger.error(f"Error in worker monitoring: {e}")
246
+ time.sleep(MONITOR_THREAD_ERROR_DELAY)
247
+
248
+ def __iter__(self):
249
+ return self
250
+
251
+ def __next__(self):
252
+ changes = next(self.watcher)
253
+ if changes:
254
+ return list({Path(change[1]) for change in changes})
255
+ return None
256
+
257
+ def run(self):
258
+ setup_logging(get_settings())
259
+ logger.info(f"Master process started: {os.getpid()}")
260
+ self.start()
261
+ try:
262
+ if self.reload:
263
+ for files_changed in self:
264
+ if files_changed:
265
+ logger.warning(
266
+ f"{', '.join(map(print_path, files_changed))} changed, reloading...",
267
+ )
268
+ self.restart()
269
+ else:
270
+ self.stop_event.wait()
271
+ finally:
272
+ self.stop()
mrok/metrics.py ADDED
@@ -0,0 +1,139 @@
1
+ from __future__ import annotations
2
+
3
+ import asyncio
4
+ import logging
5
+ import os
6
+ import time
7
+
8
+ import psutil
9
+ from hdrh.histogram import HdrHistogram
10
+
11
+ from mrok.datastructures import (
12
+ DataTransferMetrics,
13
+ ProcessMetrics,
14
+ RequestsMetrics,
15
+ ResponseTimeMetrics,
16
+ WorkerMetrics,
17
+ )
18
+
19
+ logger = logging.getLogger("mrok.proxy")
20
+
21
+
22
+ def _collect_process_usage(interval: float) -> ProcessMetrics:
23
+ try:
24
+ proc = psutil.Process(os.getpid())
25
+ except psutil.NoSuchProcess:
26
+ return ProcessMetrics(cpu=0.0, mem=0.0)
27
+
28
+ total_cpu = 0.0
29
+ total_mem = 0.0
30
+
31
+ try:
32
+ total_cpu = proc.cpu_percent(None)
33
+ except Exception:
34
+ total_cpu = 0.0
35
+
36
+ if interval and interval > 0:
37
+ time.sleep(interval)
38
+
39
+ try:
40
+ total_cpu = proc.cpu_percent(None)
41
+ except Exception:
42
+ total_cpu = 0.0
43
+
44
+ try:
45
+ total_mem = proc.memory_percent()
46
+ except Exception:
47
+ total_mem = 0.0
48
+
49
+ return ProcessMetrics(cpu=total_cpu, mem=total_mem)
50
+
51
+
52
+ async def get_process_and_children_usage(interval: float = 0.1) -> ProcessMetrics:
53
+ return await asyncio.to_thread(_collect_process_usage, interval)
54
+
55
+
56
+ class WorkerMetricsCollector:
57
+ def __init__(self, worker_id: str, lowest=1, highest=60000, sigfigs=3):
58
+ # Request-level counters
59
+ self.worker_id = worker_id
60
+ self.total_requests = 0
61
+ self.successful_requests = 0
62
+ self.failed_requests = 0
63
+ self.bytes_in = 0
64
+ self.bytes_out = 0
65
+
66
+ # RPS
67
+ self._tick_last = time.time()
68
+ self._tick_requests = 0
69
+
70
+ # latency histogram
71
+ self.hist = HdrHistogram(lowest, highest, sigfigs)
72
+
73
+ # async lock
74
+ self._lock = asyncio.Lock()
75
+
76
+ async def on_request_start(self, scope):
77
+ return time.perf_counter()
78
+
79
+ async def on_request_body(self, length):
80
+ async with self._lock:
81
+ self.bytes_in += length
82
+
83
+ async def on_response_start(self, status_code):
84
+ pass # reserved
85
+
86
+ async def on_response_chunk(self, length):
87
+ async with self._lock:
88
+ self.bytes_out += length
89
+
90
+ async def on_request_end(self, start_time, status_code):
91
+ elapsed_ms = (time.perf_counter() - start_time) * 1000
92
+
93
+ async with self._lock:
94
+ self.total_requests += 1
95
+ self._tick_requests += 1
96
+
97
+ if status_code < 500:
98
+ self.successful_requests += 1
99
+ else:
100
+ self.failed_requests += 1
101
+
102
+ self.hist.record_value(elapsed_ms)
103
+
104
+ async def snapshot(self) -> WorkerMetrics:
105
+ try:
106
+ async with self._lock:
107
+ now = time.time()
108
+ delta = now - self._tick_last
109
+ rps = int(self._tick_requests / delta) if delta > 0 else 0
110
+ data = WorkerMetrics(
111
+ worker_id=self.worker_id,
112
+ process=await get_process_and_children_usage(),
113
+ requests=RequestsMetrics(
114
+ rps=rps,
115
+ total=self.total_requests,
116
+ successful=self.successful_requests,
117
+ failed=self.failed_requests,
118
+ ),
119
+ data_transfer=DataTransferMetrics(
120
+ bytes_in=self.bytes_in,
121
+ bytes_out=self.bytes_out,
122
+ ),
123
+ response_time=ResponseTimeMetrics(
124
+ avg=self.hist.get_mean_value(),
125
+ min=self.hist.get_min_value(),
126
+ max=self.hist.get_max_value(),
127
+ p50=self.hist.get_value_at_percentile(50),
128
+ p90=self.hist.get_value_at_percentile(90),
129
+ p99=self.hist.get_value_at_percentile(99),
130
+ ),
131
+ )
132
+
133
+ self._tick_last = now
134
+ self._tick_requests = 0
135
+
136
+ return data
137
+ except Exception:
138
+ logger.exception("Exception calculating snapshot")
139
+ raise
mrok/proxy/__init__.py ADDED
@@ -0,0 +1,3 @@
1
+ from mrok.proxy.main import run
2
+
3
+ __all__ = ["run"]
mrok/proxy/app.py ADDED
@@ -0,0 +1,77 @@
1
+ import logging
2
+ import os
3
+ from pathlib import Path
4
+
5
+ from mrok.conf import get_settings
6
+ from mrok.http.forwarder import ForwardAppBase
7
+ from mrok.http.types import Scope, StreamReader, StreamWriter
8
+ from mrok.logging import setup_logging
9
+ from mrok.proxy.ziti import ZitiConnectionManager
10
+
11
+ logger = logging.getLogger("mrok.proxy")
12
+
13
+
14
+ class ProxyError(Exception):
15
+ pass
16
+
17
+
18
+ class ProxyApp(ForwardAppBase):
19
+ def __init__(
20
+ self,
21
+ identity_file: str | Path,
22
+ *,
23
+ read_chunk_size: int = 65536,
24
+ ziti_connection_ttl_seconds: float = 60,
25
+ ziti_conn_cache_purge_interval_seconds: float = 10,
26
+ ) -> None:
27
+ super().__init__(read_chunk_size=read_chunk_size)
28
+ self._identity_file = identity_file
29
+ settings = get_settings()
30
+ self._proxy_wildcard_domain = (
31
+ settings.proxy.domain
32
+ if settings.proxy.domain[0] == "."
33
+ else f".{settings.proxy.domain}"
34
+ )
35
+ self._conn_manager = ZitiConnectionManager(
36
+ identity_file,
37
+ ttl_seconds=ziti_connection_ttl_seconds,
38
+ purge_interval=ziti_conn_cache_purge_interval_seconds,
39
+ )
40
+
41
+ def get_target_from_header(self, name: str, headers: dict[str, str]) -> str:
42
+ header_value = headers.get(name)
43
+ if not header_value:
44
+ raise ProxyError(
45
+ f"Header {name} not found!",
46
+ )
47
+ if ":" in header_value:
48
+ header_value, _ = header_value.split(":", 1)
49
+ if not header_value.endswith(self._proxy_wildcard_domain):
50
+ raise ProxyError(f"Unexpected value for {name} header: `{header_value}`.")
51
+
52
+ return header_value[: -len(self._proxy_wildcard_domain)]
53
+
54
+ def get_target_name(self, headers: dict[str, str]) -> str:
55
+ try:
56
+ return self.get_target_from_header("x-forwared-for", headers)
57
+ except ProxyError as pe:
58
+ logger.warning(pe)
59
+ return self.get_target_from_header("host", headers)
60
+
61
+ async def startup(self):
62
+ setup_logging(get_settings())
63
+ await self._conn_manager.start()
64
+ logger.info(f"Proxy app startup completed: {os.getpid()}")
65
+
66
+ async def shutdown(self):
67
+ await self._conn_manager.stop()
68
+ logger.info(f"Proxy app shutdown completed: {os.getpid()}")
69
+
70
+ async def select_backend(
71
+ self,
72
+ scope: Scope,
73
+ headers: dict[str, str],
74
+ ) -> tuple[StreamReader, StreamWriter] | tuple[None, None]:
75
+ target_name = self.get_target_name(headers)
76
+
77
+ return await self._conn_manager.get(target_name)
@@ -0,0 +1,12 @@
1
+ from __future__ import annotations
2
+
3
+ from dataclasses import dataclass
4
+
5
+ from mrok.http.types import StreamReader, StreamWriter
6
+
7
+
8
+ @dataclass
9
+ class CachedStreamEntry:
10
+ reader: StreamReader
11
+ writer: StreamWriter
12
+ last_access: float
mrok/proxy/main.py ADDED
@@ -0,0 +1,58 @@
1
+ from collections.abc import Callable
2
+ from pathlib import Path
3
+ from typing import Any
4
+
5
+ from gunicorn.app.base import BaseApplication
6
+ from uvicorn_worker import UvicornWorker
7
+
8
+ from mrok.conf import get_settings
9
+ from mrok.http.lifespan import LifespanWrapper
10
+ from mrok.logging import get_logging_config
11
+ from mrok.proxy.app import ProxyApp
12
+
13
+
14
+ class MrokUvicornWorker(UvicornWorker):
15
+ CONFIG_KWARGS: dict[str, Any] = {"loop": "asyncio", "http": "auto", "lifespan": "on"}
16
+
17
+
18
+ class StandaloneApplication(BaseApplication): # pragma: no cover
19
+ def __init__(self, application: Callable, options: dict[str, Any] | None = None):
20
+ self.options = options or {}
21
+ self.application = application
22
+ super().__init__()
23
+
24
+ def load_config(self):
25
+ config = {
26
+ key: value
27
+ for key, value in self.options.items()
28
+ if key in self.cfg.settings and value is not None
29
+ }
30
+ for key, value in config.items():
31
+ self.cfg.set(key.lower(), value)
32
+
33
+ def load(self):
34
+ return self.application
35
+
36
+
37
+ def run(
38
+ identity_file: str | Path,
39
+ host: str,
40
+ port: int,
41
+ workers: int,
42
+ ):
43
+ proxy_app = ProxyApp(identity_file)
44
+
45
+ asgi_app = LifespanWrapper(
46
+ proxy_app,
47
+ proxy_app.startup,
48
+ proxy_app.shutdown,
49
+ )
50
+ options = {
51
+ "bind": f"{host}:{port}",
52
+ "workers": workers,
53
+ "worker_class": "mrok.proxy.main.MrokUvicornWorker",
54
+ "logconfig_dict": get_logging_config(get_settings()),
55
+ "reload": False,
56
+ }
57
+
58
+ StandaloneApplication(asgi_app, options).run()
mrok/proxy/streams.py ADDED
@@ -0,0 +1,124 @@
1
+ import asyncio
2
+
3
+ from mrok.proxy.types import ConnectionCache, ConnectionKey
4
+
5
+
6
+ class CachedStreamReader:
7
+ def __init__(
8
+ self,
9
+ reader: asyncio.StreamReader,
10
+ key: ConnectionKey,
11
+ manager: ConnectionCache,
12
+ ):
13
+ self._reader = reader
14
+ self._key = key
15
+ self._manager = manager
16
+
17
+ async def read(self, n: int = -1) -> bytes:
18
+ try:
19
+ return await self._reader.read(n)
20
+ except (
21
+ asyncio.CancelledError,
22
+ asyncio.IncompleteReadError,
23
+ asyncio.LimitOverrunError,
24
+ BrokenPipeError,
25
+ ConnectionAbortedError,
26
+ ConnectionResetError,
27
+ RuntimeError,
28
+ TimeoutError,
29
+ UnicodeDecodeError,
30
+ ):
31
+ asyncio.create_task(self._manager.invalidate(self._key))
32
+ raise
33
+
34
+ async def readexactly(self, n: int) -> bytes:
35
+ try:
36
+ return await self._reader.readexactly(n)
37
+ except (
38
+ asyncio.CancelledError,
39
+ asyncio.IncompleteReadError,
40
+ asyncio.LimitOverrunError,
41
+ BrokenPipeError,
42
+ ConnectionAbortedError,
43
+ ConnectionResetError,
44
+ RuntimeError,
45
+ TimeoutError,
46
+ UnicodeDecodeError,
47
+ ):
48
+ asyncio.create_task(self._manager.invalidate(self._key))
49
+ raise
50
+
51
+ async def readline(self) -> bytes:
52
+ try:
53
+ return await self._reader.readline()
54
+ except (
55
+ asyncio.CancelledError,
56
+ asyncio.IncompleteReadError,
57
+ asyncio.LimitOverrunError,
58
+ BrokenPipeError,
59
+ ConnectionAbortedError,
60
+ ConnectionResetError,
61
+ RuntimeError,
62
+ TimeoutError,
63
+ UnicodeDecodeError,
64
+ ):
65
+ asyncio.create_task(self._manager.invalidate(self._key))
66
+ raise
67
+
68
+ def at_eof(self) -> bool:
69
+ return self._reader.at_eof()
70
+
71
+ @property
72
+ def underlying(self) -> asyncio.StreamReader:
73
+ return self._reader
74
+
75
+
76
+ class CachedStreamWriter:
77
+ def __init__(
78
+ self,
79
+ writer: asyncio.StreamWriter,
80
+ key: ConnectionKey,
81
+ manager: ConnectionCache,
82
+ ):
83
+ self._writer = writer
84
+ self._key = key
85
+ self._manager = manager
86
+
87
+ def write(self, data: bytes) -> None:
88
+ try:
89
+ return self._writer.write(data)
90
+ except (RuntimeError, TypeError):
91
+ asyncio.create_task(self._manager.invalidate(self._key))
92
+ raise
93
+
94
+ async def drain(self) -> None:
95
+ try:
96
+ return await self._writer.drain()
97
+ except (
98
+ asyncio.CancelledError,
99
+ BrokenPipeError,
100
+ ConnectionAbortedError,
101
+ ConnectionResetError,
102
+ RuntimeError,
103
+ TimeoutError,
104
+ ):
105
+ asyncio.create_task(self._manager.invalidate(self._key))
106
+ raise
107
+
108
+ def close(self) -> None:
109
+ return self._writer.close()
110
+
111
+ async def wait_closed(self) -> None:
112
+ try:
113
+ return await self._writer.wait_closed()
114
+ except (ConnectionResetError, BrokenPipeError):
115
+ asyncio.create_task(self._manager.invalidate(self._key))
116
+ raise
117
+
118
+ @property
119
+ def transport(self):
120
+ return self._writer.transport
121
+
122
+ @property
123
+ def underlying(self) -> asyncio.StreamWriter:
124
+ return self._writer
mrok/proxy/types.py ADDED
@@ -0,0 +1,12 @@
1
+ from __future__ import annotations
2
+
3
+ from typing import Protocol
4
+
5
+ from mrok.http.types import StreamReader, StreamWriter
6
+
7
+ ConnectionKey = tuple[str, str | None]
8
+ CachedStream = tuple[StreamReader, StreamWriter]
9
+
10
+
11
+ class ConnectionCache(Protocol):
12
+ async def invalidate(self, key: ConnectionKey) -> None: ...