mrok 0.3.0__py3-none-any.whl → 0.4.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mrok/agent/devtools/__init__.py +0 -0
- mrok/agent/devtools/__main__.py +34 -0
- mrok/agent/devtools/inspector/__init__.py +0 -0
- mrok/agent/devtools/inspector/__main__.py +25 -0
- mrok/agent/devtools/inspector/app.py +556 -0
- mrok/agent/devtools/inspector/server.py +18 -0
- mrok/agent/sidecar/app.py +9 -10
- mrok/agent/sidecar/main.py +35 -16
- mrok/agent/ziticorn.py +27 -18
- mrok/cli/commands/__init__.py +2 -1
- mrok/cli/commands/agent/__init__.py +2 -0
- mrok/cli/commands/agent/dev/__init__.py +7 -0
- mrok/cli/commands/agent/dev/console.py +25 -0
- mrok/cli/commands/agent/dev/web.py +37 -0
- mrok/cli/commands/agent/run/asgi.py +35 -16
- mrok/cli/commands/agent/run/sidecar.py +29 -13
- mrok/cli/commands/agent/utils.py +5 -0
- mrok/cli/commands/controller/run.py +1 -5
- mrok/cli/commands/proxy/__init__.py +6 -0
- mrok/cli/commands/proxy/run.py +49 -0
- mrok/cli/utils.py +5 -0
- mrok/conf.py +6 -0
- mrok/controller/auth.py +2 -2
- mrok/datastructures.py +159 -0
- mrok/http/config.py +3 -6
- mrok/http/constants.py +22 -0
- mrok/http/forwarder.py +62 -23
- mrok/http/lifespan.py +29 -0
- mrok/http/middlewares.py +143 -0
- mrok/http/types.py +43 -0
- mrok/http/utils.py +90 -0
- mrok/logging.py +22 -0
- mrok/master.py +269 -0
- mrok/metrics.py +139 -0
- mrok/proxy/__init__.py +3 -0
- mrok/proxy/app.py +73 -0
- mrok/proxy/dataclasses.py +12 -0
- mrok/proxy/main.py +58 -0
- mrok/proxy/streams.py +124 -0
- mrok/proxy/types.py +12 -0
- mrok/proxy/ziti.py +173 -0
- {mrok-0.3.0.dist-info → mrok-0.4.0.dist-info}/METADATA +7 -1
- {mrok-0.3.0.dist-info → mrok-0.4.0.dist-info}/RECORD +46 -20
- mrok/http/master.py +0 -132
- {mrok-0.3.0.dist-info → mrok-0.4.0.dist-info}/WHEEL +0 -0
- {mrok-0.3.0.dist-info → mrok-0.4.0.dist-info}/entry_points.txt +0 -0
- {mrok-0.3.0.dist-info → mrok-0.4.0.dist-info}/licenses/LICENSE.txt +0 -0
mrok/master.py
ADDED
|
@@ -0,0 +1,269 @@
|
|
|
1
|
+
import asyncio
|
|
2
|
+
import contextlib
|
|
3
|
+
import json
|
|
4
|
+
import logging
|
|
5
|
+
import os
|
|
6
|
+
import signal
|
|
7
|
+
import threading
|
|
8
|
+
import time
|
|
9
|
+
from abc import ABC, abstractmethod
|
|
10
|
+
from pathlib import Path
|
|
11
|
+
from uuid import uuid4
|
|
12
|
+
|
|
13
|
+
import zmq
|
|
14
|
+
import zmq.asyncio
|
|
15
|
+
from watchfiles import watch
|
|
16
|
+
from watchfiles.filters import PythonFilter
|
|
17
|
+
from watchfiles.run import CombinedProcess, start_process
|
|
18
|
+
|
|
19
|
+
from mrok.conf import get_settings
|
|
20
|
+
from mrok.datastructures import Event, HTTPResponse, Meta, Status
|
|
21
|
+
from mrok.http.config import MrokBackendConfig
|
|
22
|
+
from mrok.http.lifespan import LifespanWrapper
|
|
23
|
+
from mrok.http.middlewares import CaptureMiddleware, MetricsMiddleware
|
|
24
|
+
from mrok.http.server import MrokServer
|
|
25
|
+
from mrok.http.types import ASGIApp
|
|
26
|
+
from mrok.logging import setup_logging
|
|
27
|
+
from mrok.metrics import WorkerMetricsCollector
|
|
28
|
+
|
|
29
|
+
logger = logging.getLogger("mrok.agent")
|
|
30
|
+
|
|
31
|
+
MONITOR_THREAD_JOIN_TIMEOUT = 5
|
|
32
|
+
MONITOR_THREAD_CHECK_DELAY = 1
|
|
33
|
+
MONITOR_THREAD_ERROR_DELAY = 3
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
def print_path(path):
|
|
37
|
+
try:
|
|
38
|
+
return f'"{path.relative_to(Path.cwd())}"'
|
|
39
|
+
except ValueError:
|
|
40
|
+
return f'"{path}"'
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
def start_events_router(events_pub_port: int, events_sub_port: int):
|
|
44
|
+
setup_logging(get_settings())
|
|
45
|
+
context = zmq.Context()
|
|
46
|
+
frontend = context.socket(zmq.XSUB)
|
|
47
|
+
frontend.bind(f"tcp://localhost:{events_pub_port}")
|
|
48
|
+
backend = context.socket(zmq.XPUB)
|
|
49
|
+
backend.bind(f"tcp://localhost:{events_sub_port}")
|
|
50
|
+
|
|
51
|
+
try:
|
|
52
|
+
logger.info(f"Events router process started: {os.getpid()}")
|
|
53
|
+
zmq.proxy(frontend, backend)
|
|
54
|
+
except KeyboardInterrupt:
|
|
55
|
+
pass
|
|
56
|
+
finally:
|
|
57
|
+
frontend.close()
|
|
58
|
+
backend.close()
|
|
59
|
+
context.term()
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
def start_uvicorn_worker(
|
|
63
|
+
worker_id: str,
|
|
64
|
+
app: ASGIApp,
|
|
65
|
+
identity_file: str,
|
|
66
|
+
events_pub_port: int,
|
|
67
|
+
metrics_interval: float = 5.0,
|
|
68
|
+
):
|
|
69
|
+
import sys
|
|
70
|
+
|
|
71
|
+
sys.path.insert(0, os.getcwd())
|
|
72
|
+
setup_logging(get_settings())
|
|
73
|
+
identity = json.load(open(identity_file))
|
|
74
|
+
meta = Meta(**identity["mrok"])
|
|
75
|
+
ctx = zmq.asyncio.Context()
|
|
76
|
+
pub = ctx.socket(zmq.PUB)
|
|
77
|
+
pub.connect(f"tcp://localhost:{events_pub_port}")
|
|
78
|
+
metrics = WorkerMetricsCollector(worker_id)
|
|
79
|
+
|
|
80
|
+
task = None
|
|
81
|
+
|
|
82
|
+
async def status_sender():
|
|
83
|
+
while True:
|
|
84
|
+
snap = await metrics.snapshot()
|
|
85
|
+
logger.info(f"New metrics snapshot taken: {snap}")
|
|
86
|
+
event = Event(type="status", data=Status(meta=meta, metrics=snap))
|
|
87
|
+
await pub.send_string(event.model_dump_json())
|
|
88
|
+
await asyncio.sleep(metrics_interval)
|
|
89
|
+
|
|
90
|
+
async def on_startup(): # noqa
|
|
91
|
+
nonlocal task
|
|
92
|
+
task = asyncio.create_task(status_sender())
|
|
93
|
+
|
|
94
|
+
async def on_shutdown(): # noqa
|
|
95
|
+
if task:
|
|
96
|
+
task.cancel()
|
|
97
|
+
|
|
98
|
+
async def on_response_complete(response: HTTPResponse):
|
|
99
|
+
event = Event(type="response", data=response)
|
|
100
|
+
await pub.send_string(event.model_dump_json())
|
|
101
|
+
|
|
102
|
+
config = MrokBackendConfig(
|
|
103
|
+
LifespanWrapper(
|
|
104
|
+
MetricsMiddleware(
|
|
105
|
+
CaptureMiddleware(
|
|
106
|
+
app,
|
|
107
|
+
on_response_complete,
|
|
108
|
+
),
|
|
109
|
+
metrics,
|
|
110
|
+
),
|
|
111
|
+
on_startup=on_startup,
|
|
112
|
+
on_shutdown=on_shutdown,
|
|
113
|
+
),
|
|
114
|
+
identity_file,
|
|
115
|
+
)
|
|
116
|
+
server = MrokServer(config)
|
|
117
|
+
with contextlib.suppress(KeyboardInterrupt, asyncio.CancelledError):
|
|
118
|
+
server.run()
|
|
119
|
+
|
|
120
|
+
|
|
121
|
+
class MasterBase(ABC):
|
|
122
|
+
def __init__(
|
|
123
|
+
self,
|
|
124
|
+
identity_file: str,
|
|
125
|
+
workers: int,
|
|
126
|
+
reload: bool,
|
|
127
|
+
events_pub_port: int,
|
|
128
|
+
events_sub_port: int,
|
|
129
|
+
metrics_interval: float = 5.0,
|
|
130
|
+
):
|
|
131
|
+
self.identity_file = identity_file
|
|
132
|
+
self.workers = workers
|
|
133
|
+
self.reload = reload
|
|
134
|
+
self.events_pub_port = events_pub_port
|
|
135
|
+
self.events_sub_port = events_sub_port
|
|
136
|
+
self.metrics_interval = metrics_interval
|
|
137
|
+
self.worker_identifiers = [str(uuid4()) for _ in range(workers)]
|
|
138
|
+
self.worker_processes: dict[str, CombinedProcess] = {}
|
|
139
|
+
self.zmq_pubsub_router_process = None
|
|
140
|
+
self.monitor_thread = threading.Thread(target=self.monitor_workers, daemon=True)
|
|
141
|
+
self.stop_event = threading.Event()
|
|
142
|
+
self.pause_event = threading.Event()
|
|
143
|
+
self.watch_filter = PythonFilter(ignore_paths=None)
|
|
144
|
+
self.watcher = watch(
|
|
145
|
+
Path.cwd(),
|
|
146
|
+
watch_filter=self.watch_filter,
|
|
147
|
+
stop_event=self.stop_event,
|
|
148
|
+
yield_on_timeout=True,
|
|
149
|
+
)
|
|
150
|
+
self.setup_signals_handler()
|
|
151
|
+
|
|
152
|
+
@abstractmethod
|
|
153
|
+
def get_asgi_app(self):
|
|
154
|
+
pass
|
|
155
|
+
|
|
156
|
+
def setup_signals_handler(self):
|
|
157
|
+
for sig in (signal.SIGINT, signal.SIGTERM):
|
|
158
|
+
signal.signal(sig, self.handle_signal)
|
|
159
|
+
|
|
160
|
+
def handle_signal(self, *args, **kwargs):
|
|
161
|
+
self.stop_event.set()
|
|
162
|
+
|
|
163
|
+
def start_worker(self, worker_id: str):
|
|
164
|
+
"""Start a single worker process"""
|
|
165
|
+
|
|
166
|
+
p = start_process(
|
|
167
|
+
start_uvicorn_worker,
|
|
168
|
+
"function",
|
|
169
|
+
(
|
|
170
|
+
worker_id,
|
|
171
|
+
self.get_asgi_app(),
|
|
172
|
+
self.identity_file,
|
|
173
|
+
self.events_pub_port,
|
|
174
|
+
self.metrics_interval,
|
|
175
|
+
),
|
|
176
|
+
None,
|
|
177
|
+
)
|
|
178
|
+
logger.info(f"Worker {worker_id} [{p.pid}] started")
|
|
179
|
+
return p
|
|
180
|
+
|
|
181
|
+
def start(self):
|
|
182
|
+
self.start_events_router()
|
|
183
|
+
self.start_workers()
|
|
184
|
+
self.monitor_thread.start()
|
|
185
|
+
|
|
186
|
+
def stop(self):
|
|
187
|
+
if self.monitor_thread.is_alive():
|
|
188
|
+
logger.debug("Wait for monitor worker to exit")
|
|
189
|
+
self.monitor_thread.join(timeout=MONITOR_THREAD_JOIN_TIMEOUT)
|
|
190
|
+
self.stop_workers()
|
|
191
|
+
self.stop_events_router()
|
|
192
|
+
|
|
193
|
+
def start_events_router(self):
|
|
194
|
+
self.zmq_pubsub_router_process = start_process(
|
|
195
|
+
start_events_router,
|
|
196
|
+
"function",
|
|
197
|
+
(
|
|
198
|
+
self.events_pub_port,
|
|
199
|
+
self.events_sub_port,
|
|
200
|
+
),
|
|
201
|
+
None,
|
|
202
|
+
)
|
|
203
|
+
|
|
204
|
+
def stop_events_router(self):
|
|
205
|
+
self.zmq_pubsub_router_process.stop(sigint_timeout=5, sigkill_timeout=1)
|
|
206
|
+
|
|
207
|
+
def start_workers(self):
|
|
208
|
+
for i in range(self.workers):
|
|
209
|
+
worker_id = self.worker_identifiers[i]
|
|
210
|
+
p = self.start_worker(worker_id)
|
|
211
|
+
self.worker_processes[worker_id] = p
|
|
212
|
+
|
|
213
|
+
def stop_workers(self):
|
|
214
|
+
for process in self.worker_processes.values():
|
|
215
|
+
if process.is_alive():
|
|
216
|
+
process.stop(sigint_timeout=5, sigkill_timeout=1)
|
|
217
|
+
self.worker_processes.clear()
|
|
218
|
+
|
|
219
|
+
def restart(self):
|
|
220
|
+
self.pause_event.set()
|
|
221
|
+
self.stop_workers()
|
|
222
|
+
self.start_workers()
|
|
223
|
+
self.pause_event.clear()
|
|
224
|
+
|
|
225
|
+
def monitor_workers(self):
|
|
226
|
+
while not self.stop_event.is_set():
|
|
227
|
+
try:
|
|
228
|
+
self.pause_event.wait()
|
|
229
|
+
for worker_id, process in self.worker_processes.items():
|
|
230
|
+
if not process.is_alive():
|
|
231
|
+
logger.warning(f"Worker {worker_id} [{process.pid}] died unexpectedly")
|
|
232
|
+
process.stop(sigint_timeout=1, sigkill_timeout=1)
|
|
233
|
+
new_process = self.start_worker(worker_id)
|
|
234
|
+
self.worker_processes[worker_id] = new_process
|
|
235
|
+
logger.info(
|
|
236
|
+
f"Restarted worker {worker_id} [{process.pid}] -> [{new_process.pid}]"
|
|
237
|
+
)
|
|
238
|
+
|
|
239
|
+
time.sleep(MONITOR_THREAD_CHECK_DELAY)
|
|
240
|
+
|
|
241
|
+
except Exception as e:
|
|
242
|
+
logger.error(f"Error in worker monitoring: {e}")
|
|
243
|
+
time.sleep(MONITOR_THREAD_ERROR_DELAY)
|
|
244
|
+
|
|
245
|
+
def __iter__(self):
|
|
246
|
+
return self
|
|
247
|
+
|
|
248
|
+
def __next__(self):
|
|
249
|
+
changes = next(self.watcher)
|
|
250
|
+
if changes:
|
|
251
|
+
return list({Path(change[1]) for change in changes})
|
|
252
|
+
return None
|
|
253
|
+
|
|
254
|
+
def run(self):
|
|
255
|
+
setup_logging(get_settings())
|
|
256
|
+
logger.info(f"Master process started: {os.getpid()}")
|
|
257
|
+
self.start()
|
|
258
|
+
try:
|
|
259
|
+
if self.reload:
|
|
260
|
+
for files_changed in self:
|
|
261
|
+
if files_changed:
|
|
262
|
+
logger.warning(
|
|
263
|
+
f"{', '.join(map(print_path, files_changed))} changed, reloading...",
|
|
264
|
+
)
|
|
265
|
+
self.restart()
|
|
266
|
+
else:
|
|
267
|
+
self.stop_event.wait()
|
|
268
|
+
finally:
|
|
269
|
+
self.stop()
|
mrok/metrics.py
ADDED
|
@@ -0,0 +1,139 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import asyncio
|
|
4
|
+
import logging
|
|
5
|
+
import os
|
|
6
|
+
import time
|
|
7
|
+
|
|
8
|
+
import psutil
|
|
9
|
+
from hdrh.histogram import HdrHistogram
|
|
10
|
+
|
|
11
|
+
from mrok.datastructures import (
|
|
12
|
+
DataTransferMetrics,
|
|
13
|
+
ProcessMetrics,
|
|
14
|
+
RequestsMetrics,
|
|
15
|
+
ResponseTimeMetrics,
|
|
16
|
+
WorkerMetrics,
|
|
17
|
+
)
|
|
18
|
+
|
|
19
|
+
logger = logging.getLogger("mrok.proxy")
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
def _collect_process_usage(interval: float) -> ProcessMetrics:
|
|
23
|
+
try:
|
|
24
|
+
proc = psutil.Process(os.getpid())
|
|
25
|
+
except psutil.NoSuchProcess:
|
|
26
|
+
return ProcessMetrics(cpu=0.0, mem=0.0)
|
|
27
|
+
|
|
28
|
+
total_cpu = 0.0
|
|
29
|
+
total_mem = 0.0
|
|
30
|
+
|
|
31
|
+
try:
|
|
32
|
+
total_cpu = proc.cpu_percent(None)
|
|
33
|
+
except Exception:
|
|
34
|
+
total_cpu = 0.0
|
|
35
|
+
|
|
36
|
+
if interval and interval > 0:
|
|
37
|
+
time.sleep(interval)
|
|
38
|
+
|
|
39
|
+
try:
|
|
40
|
+
total_cpu = proc.cpu_percent(None)
|
|
41
|
+
except Exception:
|
|
42
|
+
total_cpu = 0.0
|
|
43
|
+
|
|
44
|
+
try:
|
|
45
|
+
total_mem = proc.memory_percent()
|
|
46
|
+
except Exception:
|
|
47
|
+
total_mem = 0.0
|
|
48
|
+
|
|
49
|
+
return ProcessMetrics(cpu=total_cpu, mem=total_mem)
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
async def get_process_and_children_usage(interval: float = 0.1) -> ProcessMetrics:
|
|
53
|
+
return await asyncio.to_thread(_collect_process_usage, interval)
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
class WorkerMetricsCollector:
|
|
57
|
+
def __init__(self, worker_id: str, lowest=1, highest=60000, sigfigs=3):
|
|
58
|
+
# Request-level counters
|
|
59
|
+
self.worker_id = worker_id
|
|
60
|
+
self.total_requests = 0
|
|
61
|
+
self.successful_requests = 0
|
|
62
|
+
self.failed_requests = 0
|
|
63
|
+
self.bytes_in = 0
|
|
64
|
+
self.bytes_out = 0
|
|
65
|
+
|
|
66
|
+
# RPS
|
|
67
|
+
self._tick_last = time.time()
|
|
68
|
+
self._tick_requests = 0
|
|
69
|
+
|
|
70
|
+
# latency histogram
|
|
71
|
+
self.hist = HdrHistogram(lowest, highest, sigfigs)
|
|
72
|
+
|
|
73
|
+
# async lock
|
|
74
|
+
self._lock = asyncio.Lock()
|
|
75
|
+
|
|
76
|
+
async def on_request_start(self, scope):
|
|
77
|
+
return time.perf_counter()
|
|
78
|
+
|
|
79
|
+
async def on_request_body(self, length):
|
|
80
|
+
async with self._lock:
|
|
81
|
+
self.bytes_in += length
|
|
82
|
+
|
|
83
|
+
async def on_response_start(self, status_code):
|
|
84
|
+
pass # reserved
|
|
85
|
+
|
|
86
|
+
async def on_response_chunk(self, length):
|
|
87
|
+
async with self._lock:
|
|
88
|
+
self.bytes_out += length
|
|
89
|
+
|
|
90
|
+
async def on_request_end(self, start_time, status_code):
|
|
91
|
+
elapsed_ms = (time.perf_counter() - start_time) * 1000
|
|
92
|
+
|
|
93
|
+
async with self._lock:
|
|
94
|
+
self.total_requests += 1
|
|
95
|
+
self._tick_requests += 1
|
|
96
|
+
|
|
97
|
+
if status_code < 500:
|
|
98
|
+
self.successful_requests += 1
|
|
99
|
+
else:
|
|
100
|
+
self.failed_requests += 1
|
|
101
|
+
|
|
102
|
+
self.hist.record_value(elapsed_ms)
|
|
103
|
+
|
|
104
|
+
async def snapshot(self) -> WorkerMetrics:
|
|
105
|
+
try:
|
|
106
|
+
async with self._lock:
|
|
107
|
+
now = time.time()
|
|
108
|
+
delta = now - self._tick_last
|
|
109
|
+
rps = int(self._tick_requests / delta) if delta > 0 else 0
|
|
110
|
+
data = WorkerMetrics(
|
|
111
|
+
worker_id=self.worker_id,
|
|
112
|
+
process=await get_process_and_children_usage(),
|
|
113
|
+
requests=RequestsMetrics(
|
|
114
|
+
rps=rps,
|
|
115
|
+
total=self.total_requests,
|
|
116
|
+
successful=self.successful_requests,
|
|
117
|
+
failed=self.failed_requests,
|
|
118
|
+
),
|
|
119
|
+
data_transfer=DataTransferMetrics(
|
|
120
|
+
bytes_in=self.bytes_in,
|
|
121
|
+
bytes_out=self.bytes_out,
|
|
122
|
+
),
|
|
123
|
+
response_time=ResponseTimeMetrics(
|
|
124
|
+
avg=self.hist.get_mean_value(),
|
|
125
|
+
min=self.hist.get_min_value(),
|
|
126
|
+
max=self.hist.get_max_value(),
|
|
127
|
+
p50=self.hist.get_value_at_percentile(50),
|
|
128
|
+
p90=self.hist.get_value_at_percentile(90),
|
|
129
|
+
p99=self.hist.get_value_at_percentile(99),
|
|
130
|
+
),
|
|
131
|
+
)
|
|
132
|
+
|
|
133
|
+
self._tick_last = now
|
|
134
|
+
self._tick_requests = 0
|
|
135
|
+
|
|
136
|
+
return data
|
|
137
|
+
except Exception:
|
|
138
|
+
logger.exception("Exception calculating snapshot")
|
|
139
|
+
raise
|
mrok/proxy/__init__.py
ADDED
mrok/proxy/app.py
ADDED
|
@@ -0,0 +1,73 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
import os
|
|
3
|
+
from pathlib import Path
|
|
4
|
+
|
|
5
|
+
from mrok.conf import get_settings
|
|
6
|
+
from mrok.http.forwarder import ForwardAppBase
|
|
7
|
+
from mrok.http.types import Scope, StreamReader, StreamWriter
|
|
8
|
+
from mrok.logging import setup_logging
|
|
9
|
+
from mrok.proxy.ziti import ZitiConnectionManager
|
|
10
|
+
|
|
11
|
+
logger = logging.getLogger("mrok.proxy")
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class ProxyError(Exception):
|
|
15
|
+
pass
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class ProxyApp(ForwardAppBase):
|
|
19
|
+
def __init__(
|
|
20
|
+
self,
|
|
21
|
+
identity_file: str | Path,
|
|
22
|
+
*,
|
|
23
|
+
read_chunk_size: int = 65536,
|
|
24
|
+
ziti_connection_ttl_seconds: float = 60,
|
|
25
|
+
ziti_conn_cache_purge_interval_seconds: float = 10,
|
|
26
|
+
) -> None:
|
|
27
|
+
super().__init__(read_chunk_size=read_chunk_size)
|
|
28
|
+
self._identity_file = identity_file
|
|
29
|
+
settings = get_settings()
|
|
30
|
+
self._proxy_wildcard_domain = (
|
|
31
|
+
settings.proxy.domain
|
|
32
|
+
if settings.proxy.domain[0] == "."
|
|
33
|
+
else f".{settings.proxy.domain}"
|
|
34
|
+
)
|
|
35
|
+
self._conn_manager = ZitiConnectionManager(
|
|
36
|
+
identity_file,
|
|
37
|
+
ttl_seconds=ziti_connection_ttl_seconds,
|
|
38
|
+
purge_interval=ziti_conn_cache_purge_interval_seconds,
|
|
39
|
+
)
|
|
40
|
+
|
|
41
|
+
def get_target_name(self, headers: dict[str, str]) -> str:
|
|
42
|
+
header_value = headers.get("x-forwarded-for", headers.get("host"))
|
|
43
|
+
if not header_value:
|
|
44
|
+
raise ProxyError(
|
|
45
|
+
"Cannot determine the target OpenZiti service/terminator name, "
|
|
46
|
+
"neither Host nor X-Forwarded-For headers have been sent in the request.",
|
|
47
|
+
)
|
|
48
|
+
if ":" in header_value:
|
|
49
|
+
header_value, _ = header_value.split(":", 1)
|
|
50
|
+
if not header_value.endswith(self._proxy_wildcard_domain):
|
|
51
|
+
raise ProxyError(
|
|
52
|
+
f"Unexpected value for Host or X-Forwarded-For header: `{header_value}`."
|
|
53
|
+
)
|
|
54
|
+
|
|
55
|
+
return header_value[: -len(self._proxy_wildcard_domain)]
|
|
56
|
+
|
|
57
|
+
async def startup(self):
|
|
58
|
+
setup_logging(get_settings())
|
|
59
|
+
await self._conn_manager.start()
|
|
60
|
+
logger.info(f"Proxy app startup completed: {os.getpid()}")
|
|
61
|
+
|
|
62
|
+
async def shutdown(self):
|
|
63
|
+
await self._conn_manager.stop()
|
|
64
|
+
logger.info(f"Proxy app shutdown completed: {os.getpid()}")
|
|
65
|
+
|
|
66
|
+
async def select_backend(
|
|
67
|
+
self,
|
|
68
|
+
scope: Scope,
|
|
69
|
+
headers: dict[str, str],
|
|
70
|
+
) -> tuple[StreamReader, StreamWriter] | tuple[None, None]:
|
|
71
|
+
target_name = self.get_target_name(headers)
|
|
72
|
+
|
|
73
|
+
return await self._conn_manager.get(target_name)
|
mrok/proxy/main.py
ADDED
|
@@ -0,0 +1,58 @@
|
|
|
1
|
+
from collections.abc import Callable
|
|
2
|
+
from pathlib import Path
|
|
3
|
+
from typing import Any
|
|
4
|
+
|
|
5
|
+
from gunicorn.app.base import BaseApplication
|
|
6
|
+
from uvicorn_worker import UvicornWorker
|
|
7
|
+
|
|
8
|
+
from mrok.conf import get_settings
|
|
9
|
+
from mrok.http.lifespan import LifespanWrapper
|
|
10
|
+
from mrok.logging import get_logging_config
|
|
11
|
+
from mrok.proxy.app import ProxyApp
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class MrokUvicornWorker(UvicornWorker):
|
|
15
|
+
CONFIG_KWARGS: dict[str, Any] = {"loop": "asyncio", "http": "auto", "lifespan": "on"}
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class StandaloneApplication(BaseApplication): # pragma: no cover
|
|
19
|
+
def __init__(self, application: Callable, options: dict[str, Any] | None = None):
|
|
20
|
+
self.options = options or {}
|
|
21
|
+
self.application = application
|
|
22
|
+
super().__init__()
|
|
23
|
+
|
|
24
|
+
def load_config(self):
|
|
25
|
+
config = {
|
|
26
|
+
key: value
|
|
27
|
+
for key, value in self.options.items()
|
|
28
|
+
if key in self.cfg.settings and value is not None
|
|
29
|
+
}
|
|
30
|
+
for key, value in config.items():
|
|
31
|
+
self.cfg.set(key.lower(), value)
|
|
32
|
+
|
|
33
|
+
def load(self):
|
|
34
|
+
return self.application
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
def run(
|
|
38
|
+
identity_file: str | Path,
|
|
39
|
+
host: str,
|
|
40
|
+
port: int,
|
|
41
|
+
workers: int,
|
|
42
|
+
):
|
|
43
|
+
proxy_app = ProxyApp(identity_file)
|
|
44
|
+
|
|
45
|
+
asgi_app = LifespanWrapper(
|
|
46
|
+
proxy_app,
|
|
47
|
+
proxy_app.startup,
|
|
48
|
+
proxy_app.shutdown,
|
|
49
|
+
)
|
|
50
|
+
options = {
|
|
51
|
+
"bind": f"{host}:{port}",
|
|
52
|
+
"workers": workers,
|
|
53
|
+
"worker_class": "mrok.proxy.main.MrokUvicornWorker",
|
|
54
|
+
"logconfig_dict": get_logging_config(get_settings()),
|
|
55
|
+
"reload": False,
|
|
56
|
+
}
|
|
57
|
+
|
|
58
|
+
StandaloneApplication(asgi_app, options).run()
|
mrok/proxy/streams.py
ADDED
|
@@ -0,0 +1,124 @@
|
|
|
1
|
+
import asyncio
|
|
2
|
+
|
|
3
|
+
from mrok.proxy.types import ConnectionCache, ConnectionKey
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
class CachedStreamReader:
|
|
7
|
+
def __init__(
|
|
8
|
+
self,
|
|
9
|
+
reader: asyncio.StreamReader,
|
|
10
|
+
key: ConnectionKey,
|
|
11
|
+
manager: ConnectionCache,
|
|
12
|
+
):
|
|
13
|
+
self._reader = reader
|
|
14
|
+
self._key = key
|
|
15
|
+
self._manager = manager
|
|
16
|
+
|
|
17
|
+
async def read(self, n: int = -1) -> bytes:
|
|
18
|
+
try:
|
|
19
|
+
return await self._reader.read(n)
|
|
20
|
+
except (
|
|
21
|
+
asyncio.CancelledError,
|
|
22
|
+
asyncio.IncompleteReadError,
|
|
23
|
+
asyncio.LimitOverrunError,
|
|
24
|
+
BrokenPipeError,
|
|
25
|
+
ConnectionAbortedError,
|
|
26
|
+
ConnectionResetError,
|
|
27
|
+
RuntimeError,
|
|
28
|
+
TimeoutError,
|
|
29
|
+
UnicodeDecodeError,
|
|
30
|
+
):
|
|
31
|
+
asyncio.create_task(self._manager.invalidate(self._key))
|
|
32
|
+
raise
|
|
33
|
+
|
|
34
|
+
async def readexactly(self, n: int) -> bytes:
|
|
35
|
+
try:
|
|
36
|
+
return await self._reader.readexactly(n)
|
|
37
|
+
except (
|
|
38
|
+
asyncio.CancelledError,
|
|
39
|
+
asyncio.IncompleteReadError,
|
|
40
|
+
asyncio.LimitOverrunError,
|
|
41
|
+
BrokenPipeError,
|
|
42
|
+
ConnectionAbortedError,
|
|
43
|
+
ConnectionResetError,
|
|
44
|
+
RuntimeError,
|
|
45
|
+
TimeoutError,
|
|
46
|
+
UnicodeDecodeError,
|
|
47
|
+
):
|
|
48
|
+
asyncio.create_task(self._manager.invalidate(self._key))
|
|
49
|
+
raise
|
|
50
|
+
|
|
51
|
+
async def readline(self) -> bytes:
|
|
52
|
+
try:
|
|
53
|
+
return await self._reader.readline()
|
|
54
|
+
except (
|
|
55
|
+
asyncio.CancelledError,
|
|
56
|
+
asyncio.IncompleteReadError,
|
|
57
|
+
asyncio.LimitOverrunError,
|
|
58
|
+
BrokenPipeError,
|
|
59
|
+
ConnectionAbortedError,
|
|
60
|
+
ConnectionResetError,
|
|
61
|
+
RuntimeError,
|
|
62
|
+
TimeoutError,
|
|
63
|
+
UnicodeDecodeError,
|
|
64
|
+
):
|
|
65
|
+
asyncio.create_task(self._manager.invalidate(self._key))
|
|
66
|
+
raise
|
|
67
|
+
|
|
68
|
+
def at_eof(self) -> bool:
|
|
69
|
+
return self._reader.at_eof()
|
|
70
|
+
|
|
71
|
+
@property
|
|
72
|
+
def underlying(self) -> asyncio.StreamReader:
|
|
73
|
+
return self._reader
|
|
74
|
+
|
|
75
|
+
|
|
76
|
+
class CachedStreamWriter:
|
|
77
|
+
def __init__(
|
|
78
|
+
self,
|
|
79
|
+
writer: asyncio.StreamWriter,
|
|
80
|
+
key: ConnectionKey,
|
|
81
|
+
manager: ConnectionCache,
|
|
82
|
+
):
|
|
83
|
+
self._writer = writer
|
|
84
|
+
self._key = key
|
|
85
|
+
self._manager = manager
|
|
86
|
+
|
|
87
|
+
def write(self, data: bytes) -> None:
|
|
88
|
+
try:
|
|
89
|
+
return self._writer.write(data)
|
|
90
|
+
except (RuntimeError, TypeError):
|
|
91
|
+
asyncio.create_task(self._manager.invalidate(self._key))
|
|
92
|
+
raise
|
|
93
|
+
|
|
94
|
+
async def drain(self) -> None:
|
|
95
|
+
try:
|
|
96
|
+
return await self._writer.drain()
|
|
97
|
+
except (
|
|
98
|
+
asyncio.CancelledError,
|
|
99
|
+
BrokenPipeError,
|
|
100
|
+
ConnectionAbortedError,
|
|
101
|
+
ConnectionResetError,
|
|
102
|
+
RuntimeError,
|
|
103
|
+
TimeoutError,
|
|
104
|
+
):
|
|
105
|
+
asyncio.create_task(self._manager.invalidate(self._key))
|
|
106
|
+
raise
|
|
107
|
+
|
|
108
|
+
def close(self) -> None:
|
|
109
|
+
return self._writer.close()
|
|
110
|
+
|
|
111
|
+
async def wait_closed(self) -> None:
|
|
112
|
+
try:
|
|
113
|
+
return await self._writer.wait_closed()
|
|
114
|
+
except (ConnectionResetError, BrokenPipeError):
|
|
115
|
+
asyncio.create_task(self._manager.invalidate(self._key))
|
|
116
|
+
raise
|
|
117
|
+
|
|
118
|
+
@property
|
|
119
|
+
def transport(self):
|
|
120
|
+
return self._writer.transport
|
|
121
|
+
|
|
122
|
+
@property
|
|
123
|
+
def underlying(self) -> asyncio.StreamWriter:
|
|
124
|
+
return self._writer
|
mrok/proxy/types.py
ADDED
|
@@ -0,0 +1,12 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from typing import Protocol
|
|
4
|
+
|
|
5
|
+
from mrok.http.types import StreamReader, StreamWriter
|
|
6
|
+
|
|
7
|
+
ConnectionKey = tuple[str, str | None]
|
|
8
|
+
CachedStream = tuple[StreamReader, StreamWriter]
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class ConnectionCache(Protocol):
|
|
12
|
+
async def invalidate(self, key: ConnectionKey) -> None: ...
|