mrok 0.4.4__py3-none-any.whl → 0.4.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
mrok/agent/sidecar/app.py CHANGED
@@ -1,9 +1,12 @@
1
1
  import asyncio
2
2
  import logging
3
+ from collections.abc import AsyncGenerator
4
+ from contextlib import asynccontextmanager
3
5
  from pathlib import Path
4
6
 
5
7
  from mrok.http.forwarder import ForwardAppBase
6
- from mrok.http.types import Scope, StreamReader, StreamWriter
8
+ from mrok.http.pool import ConnectionPool
9
+ from mrok.http.types import Scope, StreamPair
7
10
 
8
11
  logger = logging.getLogger("mrok.agent")
9
12
 
@@ -18,12 +21,31 @@ class ForwardApp(ForwardAppBase):
18
21
  read_chunk_size=read_chunk_size,
19
22
  )
20
23
  self._target_address = target_address
24
+ self._pool = ConnectionPool(
25
+ pool_name=str(self._target_address),
26
+ factory=self.connect,
27
+ initial_connections=5,
28
+ max_size=100,
29
+ idle_timeout=20.0,
30
+ reaper_interval=5.0,
31
+ )
32
+
33
+ async def connect(self) -> StreamPair:
34
+ if isinstance(self._target_address, tuple):
35
+ return await asyncio.open_connection(*self._target_address)
36
+ return await asyncio.open_unix_connection(str(self._target_address))
21
37
 
38
+ async def startup(self):
39
+ await self._pool.start()
40
+
41
+ async def shutdown(self):
42
+ await self._pool.stop()
43
+
44
+ @asynccontextmanager
22
45
  async def select_backend(
23
46
  self,
24
47
  scope: Scope,
25
48
  headers: dict[str, str],
26
- ) -> tuple[StreamReader, StreamWriter] | tuple[None, None]:
27
- if isinstance(self._target_address, tuple):
28
- return await asyncio.open_connection(*self._target_address)
29
- return await asyncio.open_unix_connection(str(self._target_address))
49
+ ) -> AsyncGenerator[StreamPair, None]:
50
+ async with self._pool.acquire() as (reader, writer):
51
+ yield reader, writer
@@ -1,5 +1,4 @@
1
1
  import asyncio
2
- import re
3
2
  from typing import Annotated
4
3
 
5
4
  import typer
@@ -7,11 +6,10 @@ from rich import print
7
6
 
8
7
  from mrok.cli.commands.admin.utils import parse_tags
9
8
  from mrok.conf import Settings
9
+ from mrok.constants import RE_EXTENSION_ID
10
10
  from mrok.ziti.api import ZitiManagementAPI
11
11
  from mrok.ziti.services import register_service
12
12
 
13
- RE_EXTENSION_ID = re.compile(r"(?i)EXT-\d{4}-\d{4}")
14
-
15
13
 
16
14
  async def do_register(settings: Settings, extension_id: str, tags: list[str] | None):
17
15
  async with ZitiManagementAPI(settings) as api:
@@ -20,7 +18,7 @@ async def do_register(settings: Settings, extension_id: str, tags: list[str] | N
20
18
 
21
19
  def validate_extension_id(extension_id: str) -> str:
22
20
  if not RE_EXTENSION_ID.fullmatch(extension_id):
23
- raise typer.BadParameter("ext_id must match EXT-xxxx-yyyy (case-insensitive)")
21
+ raise typer.BadParameter("it must match EXT-xxxx-yyyy (case-insensitive)")
24
22
  return extension_id
25
23
 
26
24
 
@@ -1,6 +1,5 @@
1
1
  import asyncio
2
2
  import json
3
- import re
4
3
  from pathlib import Path
5
4
  from typing import Annotated
6
5
 
@@ -8,11 +7,10 @@ import typer
8
7
 
9
8
  from mrok.cli.commands.admin.utils import parse_tags
10
9
  from mrok.conf import Settings
10
+ from mrok.constants import RE_EXTENSION_ID, RE_INSTANCE_ID
11
11
  from mrok.ziti.api import ZitiClientAPI, ZitiManagementAPI
12
12
  from mrok.ziti.identities import register_identity
13
13
 
14
- RE_EXTENSION_ID = re.compile(r"(?i)EXT-\d{4}-\d{4}")
15
-
16
14
 
17
15
  async def do_register(
18
16
  settings: Settings, extension_id: str, instance_id: str, tags: list[str] | None
@@ -25,10 +23,16 @@ async def do_register(
25
23
 
26
24
  def validate_extension_id(extension_id: str):
27
25
  if not RE_EXTENSION_ID.fullmatch(extension_id):
28
- raise typer.BadParameter("ext_id must match EXT-xxxx-yyyy (case-insensitive)")
26
+ raise typer.BadParameter("it must match EXT-xxxx-yyyy (case-insensitive)")
29
27
  return extension_id
30
28
 
31
29
 
30
+ def validate_instance_id(instance_id: str):
31
+ if not RE_INSTANCE_ID.fullmatch(instance_id):
32
+ raise typer.BadParameter("it must match INS-xxxx-yyyy-zzzz (case-insensitive)")
33
+ return instance_id
34
+
35
+
32
36
  def register(app: typer.Typer) -> None:
33
37
  @app.command("instance")
34
38
  def register_instance(
@@ -36,7 +40,9 @@ def register(app: typer.Typer) -> None:
36
40
  extension_id: str = typer.Argument(
37
41
  ..., callback=validate_extension_id, help="Extension ID in format EXT-xxxx-yyyy"
38
42
  ),
39
- instance_id: str = typer.Argument(..., help="Instance ID"),
43
+ instance_id: str = typer.Argument(
44
+ ..., callback=validate_instance_id, help="Instance ID in format INS-xxxx-yyyy-zzzz"
45
+ ),
40
46
  output: Path = typer.Argument(
41
47
  ...,
42
48
  file_okay=True,
mrok/constants.py ADDED
@@ -0,0 +1,6 @@
1
+ import re
2
+
3
+ RE_EXTENSION_ID = re.compile(r"(?i)EXT-\d{4}-\d{4}")
4
+ RE_INSTANCE_ID = re.compile(r"(?i)INS-\d{4}-\d{4}-\d{4}")
5
+
6
+ RE_SUBDOMAIN = re.compile(r"(?i)^(?:EXT-\d{4}-\d{4}|INS-\d{4}-\d{4}-\d{4})$")
mrok/http/forwarder.py CHANGED
@@ -1,14 +1,27 @@
1
1
  import abc
2
2
  import asyncio
3
3
  import logging
4
+ from contextlib import AbstractAsyncContextManager
4
5
 
5
- from mrok.http.types import ASGIReceive, ASGISend, Scope, StreamReader, StreamWriter
6
+ from mrok.http.types import ASGIReceive, ASGISend, Scope, StreamPair
6
7
 
7
8
  logger = logging.getLogger("mrok.proxy")
8
9
 
9
10
 
10
- class BackendNotFoundError(Exception):
11
- pass
11
+ class BackendSelectionError(Exception):
12
+ def __init__(self, status: int = 500, message: str = "Internal Server Error"):
13
+ self.status = status
14
+ self.message = message
15
+
16
+
17
+ class InvalidBackendError(BackendSelectionError):
18
+ def __init__(self):
19
+ super().__init__(status=502, message="Bad Gateway")
20
+
21
+
22
+ class BackendUnavailableError(BackendSelectionError):
23
+ def __init__(self):
24
+ super().__init__(status=503, message="Service Unavailable")
12
25
 
13
26
 
14
27
  class ForwardAppBase(abc.ABC):
@@ -60,20 +73,18 @@ class ForwardAppBase(abc.ABC):
60
73
  await send({"type": "lifespan.shutdown.complete"})
61
74
  return
62
75
 
63
- @abc.abstractmethod
64
- async def select_backend(
65
- self,
66
- scope: Scope,
67
- headers: dict[str, str],
68
- ) -> tuple[StreamReader, StreamWriter] | tuple[None, None]:
69
- """Return (reader, writer) connected to the target backend."""
70
-
71
76
  async def startup(self):
72
77
  return
73
78
 
74
79
  async def shutdown(self):
75
80
  return
76
81
 
82
+ @abc.abstractmethod
83
+ def select_backend(
84
+ self, scope: Scope, headers: dict[str, str]
85
+ ) -> AbstractAsyncContextManager[StreamPair]:
86
+ raise NotImplementedError()
87
+
77
88
  async def __call__(self, scope: Scope, receive: ASGIReceive, send: ASGISend) -> None:
78
89
  """ASGI callable entry point.
79
90
 
@@ -95,37 +106,40 @@ class ForwardAppBase(abc.ABC):
95
106
 
96
107
  headers = list(scope.get("headers", []))
97
108
  headers = self.ensure_host_header(headers, scope)
98
- reader, writer = await self.select_backend(
99
- scope, {k[0].decode().lower(): k[1].decode() for k in headers}
100
- )
101
-
102
- if not (reader and writer):
103
- await send({"type": "http.response.start", "status": 502, "headers": []})
104
- await send({"type": "http.response.body", "body": b"Bad Gateway"})
105
- return
106
-
107
- use_chunked = self.ensure_chunked_if_needed(headers)
109
+ try:
110
+ async with self.select_backend(
111
+ scope, {k[0].decode().lower(): k[1].decode() for k in headers}
112
+ ) as (reader, writer):
113
+ if not (reader and writer):
114
+ await send({"type": "http.response.start", "status": 502, "headers": []})
115
+ await send({"type": "http.response.body", "body": b"Bad Gateway"})
116
+ return
108
117
 
109
- await self.write_request_line_and_headers(writer, method, path_qs, headers)
118
+ use_chunked = self.ensure_chunked_if_needed(headers)
110
119
 
111
- await self.stream_request_body(receive, writer, use_chunked)
120
+ await self.write_request_line_and_headers(writer, method, path_qs, headers)
112
121
 
113
- status_line = await reader.readline()
114
- if not status_line:
115
- await send({"type": "http.response.start", "status": 502, "headers": []})
116
- await send({"type": "http.response.body", "body": b"Bad Gateway"})
117
- writer.close()
118
- await writer.wait_closed()
119
- return
122
+ await self.stream_request_body(receive, writer, use_chunked)
120
123
 
121
- status, headers_out, raw_headers = await self.read_status_and_headers(reader, status_line)
124
+ status_line = await reader.readline()
125
+ if not status_line:
126
+ await send({"type": "http.response.start", "status": 502, "headers": []})
127
+ await send({"type": "http.response.body", "body": b"Bad Gateway"})
128
+ return
122
129
 
123
- await send({"type": "http.response.start", "status": status, "headers": headers_out})
130
+ status, headers_out, raw_headers = await self.read_status_and_headers(
131
+ reader, status_line
132
+ )
124
133
 
125
- await self.stream_response_body(reader, send, raw_headers)
134
+ await send(
135
+ {"type": "http.response.start", "status": status, "headers": headers_out}
136
+ )
126
137
 
127
- writer.close()
128
- await writer.wait_closed()
138
+ await self.stream_response_body(reader, send, raw_headers)
139
+ except BackendSelectionError as bse:
140
+ await send({"type": "http.response.start", "status": bse.status, "headers": []})
141
+ await send({"type": "http.response.body", "body": bse.message.encode()})
142
+ return
129
143
 
130
144
  def format_path(self, scope: Scope) -> str:
131
145
  raw_path = scope.get("raw_path")
@@ -159,7 +173,7 @@ class ForwardAppBase(abc.ABC):
159
173
 
160
174
  async def write_request_line_and_headers(
161
175
  self,
162
- writer: StreamWriter,
176
+ writer: asyncio.StreamWriter,
163
177
  method: str,
164
178
  path_qs: str,
165
179
  headers: list[tuple[bytes, bytes]],
@@ -173,7 +187,7 @@ class ForwardAppBase(abc.ABC):
173
187
  await writer.drain()
174
188
 
175
189
  async def stream_request_body(
176
- self, receive: ASGIReceive, writer: StreamWriter, use_chunked: bool
190
+ self, receive: ASGIReceive, writer: asyncio.StreamWriter, use_chunked: bool
177
191
  ) -> None:
178
192
  if use_chunked:
179
193
  await self.stream_request_chunked(receive, writer)
@@ -181,7 +195,9 @@ class ForwardAppBase(abc.ABC):
181
195
 
182
196
  await self.stream_request_until_end(receive, writer)
183
197
 
184
- async def stream_request_chunked(self, receive: ASGIReceive, writer: StreamWriter) -> None:
198
+ async def stream_request_chunked(
199
+ self, receive: ASGIReceive, writer: asyncio.StreamWriter
200
+ ) -> None:
185
201
  """Send request body to backend using HTTP/1.1 chunked encoding."""
186
202
  while True:
187
203
  event = await receive()
@@ -195,13 +211,14 @@ class ForwardAppBase(abc.ABC):
195
211
  if not event.get("more_body", False):
196
212
  break
197
213
  elif event["type"] == "http.disconnect":
198
- writer.close()
199
214
  return
200
215
 
201
216
  writer.write(b"0\r\n\r\n")
202
217
  await writer.drain()
203
218
 
204
- async def stream_request_until_end(self, receive: ASGIReceive, writer: StreamWriter) -> None:
219
+ async def stream_request_until_end(
220
+ self, receive: ASGIReceive, writer: asyncio.StreamWriter
221
+ ) -> None:
205
222
  """Send request body to backend when content length/transfer-encoding
206
223
  already provided (no chunking).
207
224
  """
@@ -215,11 +232,10 @@ class ForwardAppBase(abc.ABC):
215
232
  if not event.get("more_body", False):
216
233
  break
217
234
  elif event["type"] == "http.disconnect":
218
- writer.close()
219
235
  return
220
236
 
221
237
  async def read_status_and_headers(
222
- self, reader: StreamReader, first_line: bytes
238
+ self, reader: asyncio.StreamReader, first_line: bytes
223
239
  ) -> tuple[int, list[tuple[bytes, bytes]], dict[bytes, bytes]]:
224
240
  parts = first_line.decode(errors="ignore").split(" ", 2)
225
241
  status = int(parts[1]) if len(parts) >= 2 and parts[1].isdigit() else 502
@@ -256,14 +272,14 @@ class ForwardAppBase(abc.ABC):
256
272
  except Exception:
257
273
  return None
258
274
 
259
- async def drain_trailers(self, reader: StreamReader) -> None:
275
+ async def drain_trailers(self, reader: asyncio.StreamReader) -> None:
260
276
  """Consume trailer header lines until an empty line is encountered."""
261
277
  while True:
262
278
  trailer = await reader.readline()
263
279
  if trailer in (b"\r\n", b"\n", b""):
264
280
  break
265
281
 
266
- async def stream_response_chunked(self, reader: StreamReader, send: ASGISend) -> None:
282
+ async def stream_response_chunked(self, reader: asyncio.StreamReader, send: ASGISend) -> None:
267
283
  """Read chunked-encoded response from reader, decode and forward to ASGI send."""
268
284
  while True:
269
285
  size_line = await reader.readline()
@@ -292,7 +308,7 @@ class ForwardAppBase(abc.ABC):
292
308
  await send({"type": "http.response.body", "body": b"", "more_body": False})
293
309
 
294
310
  async def stream_response_with_content_length(
295
- self, reader: StreamReader, send: ASGISend, content_length: int
311
+ self, reader: asyncio.StreamReader, send: ASGISend, content_length: int
296
312
  ) -> None:
297
313
  """Read exactly content_length bytes and forward to ASGI send events."""
298
314
  remaining = content_length
@@ -311,7 +327,7 @@ class ForwardAppBase(abc.ABC):
311
327
  if not sent_final:
312
328
  await send({"type": "http.response.body", "body": b"", "more_body": False})
313
329
 
314
- async def stream_response_until_eof(self, reader: StreamReader, send: ASGISend) -> None:
330
+ async def stream_response_until_eof(self, reader: asyncio.StreamReader, send: ASGISend) -> None:
315
331
  """Read from reader until EOF and forward chunks to ASGI send events."""
316
332
  while True:
317
333
  chunk = await reader.read(self._read_chunk_size)
@@ -321,7 +337,7 @@ class ForwardAppBase(abc.ABC):
321
337
  await send({"type": "http.response.body", "body": b"", "more_body": False})
322
338
 
323
339
  async def stream_response_body(
324
- self, reader: StreamReader, send: ASGISend, raw_headers: dict[bytes, bytes]
340
+ self, reader: asyncio.StreamReader, send: ASGISend, raw_headers: dict[bytes, bytes]
325
341
  ) -> None:
326
342
  te = raw_headers.get(b"transfer-encoding", b"").lower()
327
343
  cl = raw_headers.get(b"content-length")
mrok/http/pool.py ADDED
@@ -0,0 +1,239 @@
1
+ import asyncio
2
+ import contextlib
3
+ import logging
4
+ import time
5
+ from collections.abc import AsyncGenerator, Awaitable, Callable, Coroutine
6
+ from contextlib import asynccontextmanager
7
+ from typing import Any
8
+
9
+ from cachetools import TTLCache
10
+
11
+ from mrok.http.types import StreamPair
12
+
13
+ PoolItem = tuple[asyncio.StreamReader, asyncio.StreamWriter, float]
14
+
15
+ logger = logging.getLogger("mrok.proxy")
16
+
17
+
18
+ class ConnectionPool:
19
+ def __init__(
20
+ self,
21
+ pool_name: str,
22
+ factory: Callable[[], Awaitable[StreamPair]],
23
+ *,
24
+ initial_connections: int = 0,
25
+ max_size: int = 10,
26
+ idle_timeout: float = 30.0,
27
+ reaper_interval: float = 5.0,
28
+ ) -> None:
29
+ if initial_connections < 0:
30
+ raise ValueError("initial_connections must be >= 0")
31
+ if max_size < 1:
32
+ raise ValueError("max_size must be >= 1")
33
+ if initial_connections > max_size:
34
+ raise ValueError("initial_connections cannot exceed max_size")
35
+ self.pool_name = pool_name
36
+ self.factory = factory
37
+ self.initial_connections = initial_connections
38
+ self.max_size = max_size
39
+ self.idle_timeout = idle_timeout
40
+ self.reaper_interval = reaper_interval
41
+
42
+ self._pool: list[PoolItem] = []
43
+ self._in_use = 0
44
+ self._lock = asyncio.Lock()
45
+ self._cond = asyncio.Condition()
46
+ self._stop_event = asyncio.Event()
47
+
48
+ self._started = False
49
+ self._reaper_task: asyncio.Task | None = None
50
+
51
+ async def start(self) -> None:
52
+ if self._started:
53
+ return
54
+ self._reaper_task = asyncio.create_task(self._reaper())
55
+ await self._prewarm()
56
+ self._started = True
57
+
58
+ async def stop(self) -> None:
59
+ self._stop_event.set()
60
+ if self._reaper_task is not None:
61
+ self._reaper_task.cancel()
62
+ with contextlib.suppress(Exception):
63
+ await self._reaper_task
64
+
65
+ to_close: list[asyncio.StreamWriter] = []
66
+ async with self._lock:
67
+ to_close = [writer for _, writer, _ in self._pool]
68
+ self._pool.clear()
69
+ for w in to_close:
70
+ with contextlib.suppress(Exception):
71
+ w.close()
72
+ await w.wait_closed()
73
+
74
+ async with self._cond:
75
+ self._cond.notify_all()
76
+
77
+ @asynccontextmanager
78
+ async def acquire(self) -> AsyncGenerator[StreamPair]:
79
+ if not self._started:
80
+ await self.start()
81
+ reader, writer = await self._acquire()
82
+ logger.info(
83
+ f"Acquire stats for pool {self.pool_name}: "
84
+ f"in_use={self._in_use}, size={len(self._pool)}"
85
+ )
86
+ try:
87
+ yield (reader, writer)
88
+ finally:
89
+ await self._release(reader, writer)
90
+
91
+ async def _prewarm(self) -> None:
92
+ conns: list[PoolItem] = []
93
+ needed = max(0, self.initial_connections - (self._in_use + len(self._pool)))
94
+ for _ in range(needed):
95
+ reader, writer = await self.factory()
96
+ conns.append((reader, writer, time.time()))
97
+ if conns:
98
+ async with self._lock:
99
+ self._pool.extend(conns)
100
+ # notify any waiters
101
+ async with self._cond:
102
+ self._cond.notify_all()
103
+
104
+ async def _acquire(self) -> StreamPair: # type: ignore
105
+ to_close: list[asyncio.StreamWriter] = []
106
+ create_new = False
107
+ while True:
108
+ need_prewarm = False
109
+ async with self._cond:
110
+ now = time.time()
111
+ if not self._pool:
112
+ need_prewarm = True
113
+ while self._pool:
114
+ reader, writer, ts = self._pool.pop()
115
+ if now - ts <= self.idle_timeout and not writer.is_closing():
116
+ self._in_use += 1
117
+ return reader, writer
118
+ to_close.append(writer)
119
+
120
+ total = self._in_use + len(self._pool)
121
+ if total < self.max_size:
122
+ self._in_use += 1
123
+ create_new = True
124
+ break
125
+ await self._cond.wait()
126
+
127
+ if need_prewarm:
128
+ await self._prewarm()
129
+ continue
130
+
131
+ for w in to_close:
132
+ with contextlib.suppress(Exception):
133
+ w.close()
134
+ await w.wait_closed()
135
+
136
+ if create_new:
137
+ try:
138
+ reader, writer = await self.factory()
139
+ except Exception:
140
+ async with self._cond:
141
+ if self._in_use > 0:
142
+ self._in_use -= 1
143
+ self._cond.notify()
144
+ raise
145
+ return reader, writer
146
+
147
+ async def _release(self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter) -> None:
148
+ async with self._cond:
149
+ if self._in_use > 0:
150
+ self._in_use -= 1
151
+
152
+ if not writer.is_closing():
153
+ self._pool.append((reader, writer, time.time()))
154
+
155
+ self._cond.notify()
156
+ logger.info(
157
+ f"Release stats for pool {self.pool_name}: "
158
+ f"in_use={self._in_use}, size={len(self._pool)}"
159
+ )
160
+
161
+ async def _reaper(self) -> None:
162
+ try:
163
+ while not self._stop_event.is_set():
164
+ await asyncio.sleep(self.reaper_interval)
165
+ to_close: list[asyncio.StreamWriter] = []
166
+ now = time.time()
167
+ async with self._lock:
168
+ new_pool: list[PoolItem] = []
169
+ for reader, writer, ts in self._pool:
170
+ if now - ts > self.idle_timeout or writer.is_closing():
171
+ to_close.append(writer)
172
+ else:
173
+ new_pool.append((reader, writer, ts))
174
+ self._pool = new_pool
175
+
176
+ for w in to_close:
177
+ with contextlib.suppress(Exception):
178
+ w.close()
179
+ await w.wait_closed()
180
+
181
+ async with self._cond:
182
+ self._cond.notify_all()
183
+ except asyncio.CancelledError:
184
+ pass
185
+
186
+
187
+ class SlidingTTLCache(TTLCache):
188
+ def __init__(
189
+ self,
190
+ *,
191
+ maxsize: float,
192
+ ttl: float,
193
+ on_evict: Callable[[Any], Coroutine[Any, Any, None]] | None,
194
+ ) -> None:
195
+ super().__init__(maxsize=maxsize, ttl=ttl)
196
+ self.on_evict = on_evict
197
+
198
+ def __getitem__(self, key: Any) -> Any:
199
+ value = super().__getitem__(key)
200
+ super().__setitem__(key, value)
201
+ return value
202
+
203
+ def popitem(self) -> Any:
204
+ key, value = super().popitem()
205
+ if self.on_evict:
206
+ asyncio.create_task(self.on_evict(value))
207
+ return key, value
208
+
209
+
210
+ class PoolManager:
211
+ def __init__(
212
+ self,
213
+ pool_factory: Callable[[Any], Awaitable[ConnectionPool]],
214
+ idle_timeout: int = 300,
215
+ ):
216
+ self.factory = pool_factory
217
+ self.cache = SlidingTTLCache(
218
+ maxsize=float("inf"),
219
+ ttl=idle_timeout,
220
+ on_evict=self._close_pool,
221
+ )
222
+
223
+ async def _close_pool(self, pool: ConnectionPool):
224
+ with contextlib.suppress(Exception):
225
+ await pool.stop()
226
+
227
+ async def get_pool(self, key) -> ConnectionPool:
228
+ try:
229
+ return self.cache[key]
230
+ except KeyError:
231
+ pool = await self.factory(key)
232
+ await pool.start()
233
+ self.cache[key] = pool
234
+ return pool
235
+
236
+ async def shutdown(self) -> None:
237
+ pools = list(self.cache.values())
238
+ for pool in pools:
239
+ await self._close_pool(pool)
mrok/http/types.py CHANGED
@@ -2,7 +2,7 @@ from __future__ import annotations
2
2
 
3
3
  import asyncio
4
4
  from collections.abc import Awaitable, Callable, MutableMapping
5
- from typing import Any, Protocol
5
+ from typing import Any
6
6
 
7
7
  from mrok.datastructures import HTTPRequest, HTTPResponse
8
8
 
@@ -15,29 +15,4 @@ ASGIApp = Callable[[Scope, ASGIReceive, ASGISend], Awaitable[None]]
15
15
  RequestCompleteCallback = Callable[[HTTPRequest], Awaitable | None]
16
16
  ResponseCompleteCallback = Callable[[HTTPResponse], Awaitable | None]
17
17
 
18
-
19
- class StreamReaderWrapper(Protocol):
20
- async def read(self, n: int = -1) -> bytes: ...
21
- async def readexactly(self, n: int) -> bytes: ...
22
- async def readline(self) -> bytes: ...
23
- def at_eof(self) -> bool: ...
24
-
25
- @property
26
- def underlying(self) -> asyncio.StreamReader: ...
27
-
28
-
29
- class StreamWriterWrapper(Protocol):
30
- def write(self, data: bytes) -> None: ...
31
- async def drain(self) -> None: ...
32
- def close(self) -> None: ...
33
- async def wait_closed(self) -> None: ...
34
-
35
- @property
36
- def transport(self): ...
37
-
38
- @property
39
- def underlying(self) -> asyncio.StreamWriter: ...
40
-
41
-
42
- StreamReader = StreamReaderWrapper | asyncio.StreamReader
43
- StreamWriter = StreamWriterWrapper | asyncio.StreamWriter
18
+ StreamPair = tuple[asyncio.StreamReader, asyncio.StreamWriter]
mrok/proxy/app.py CHANGED
@@ -1,12 +1,18 @@
1
1
  import asyncio
2
2
  import logging
3
+ from collections.abc import AsyncGenerator
4
+ from contextlib import asynccontextmanager
3
5
  from pathlib import Path
4
6
 
7
+ import openziti
8
+ from openziti.context import ZitiContext
9
+
5
10
  from mrok.conf import get_settings
6
- from mrok.http.forwarder import ForwardAppBase
7
- from mrok.http.types import Scope, StreamReader, StreamWriter
11
+ from mrok.constants import RE_SUBDOMAIN
12
+ from mrok.http.forwarder import BackendUnavailableError, ForwardAppBase, InvalidBackendError
13
+ from mrok.http.pool import ConnectionPool, PoolManager
14
+ from mrok.http.types import Scope, StreamPair
8
15
  from mrok.logging import setup_logging
9
- from mrok.proxy.ziti import ZitiSocketCache
10
16
 
11
17
  logger = logging.getLogger("mrok.proxy")
12
18
 
@@ -30,7 +36,8 @@ class ProxyApp(ForwardAppBase):
30
36
  if settings.proxy.domain[0] == "."
31
37
  else f".{settings.proxy.domain}"
32
38
  )
33
- self._ziti_socket_cache = ZitiSocketCache(self._identity_file)
39
+ self._ziti_ctx: ZitiContext | None = None
40
+ self._pool_manager = PoolManager(self.build_connection_pool)
34
41
 
35
42
  def get_target_from_header(self, headers: dict[str, str], name: str) -> str | None:
36
43
  header_value = headers.get(name, "")
@@ -47,18 +54,48 @@ class ProxyApp(ForwardAppBase):
47
54
  raise ProxyError("Neither Host nor X-Forwarded-Host contain a valid target name")
48
55
  return target
49
56
 
57
+ def _get_ziti_ctx(self) -> ZitiContext:
58
+ if self._ziti_ctx is None:
59
+ ctx, err = openziti.load(str(self._identity_file), timeout=10_000)
60
+ if err != 0:
61
+ raise Exception(f"Cannot create a Ziti context from the identity file: {err}")
62
+ self._ziti_ctx = ctx
63
+ return self._ziti_ctx
64
+
50
65
  async def startup(self):
51
66
  setup_logging(get_settings())
67
+ self._get_ziti_ctx()
52
68
 
53
69
  async def shutdown(self):
54
- await self._ziti_socket_cache.stop()
70
+ await self._pool_manager.shutdown()
71
+
72
+ async def build_connection_pool(self, key: str) -> ConnectionPool:
73
+ async def connect():
74
+ sock = self._get_ziti_ctx().connect(key)
75
+ reader, writer = await asyncio.open_connection(sock=sock)
76
+ return reader, writer
77
+
78
+ return ConnectionPool(
79
+ pool_name=key,
80
+ factory=connect,
81
+ initial_connections=5,
82
+ max_size=100,
83
+ idle_timeout=20.0,
84
+ reaper_interval=5.0,
85
+ )
55
86
 
87
+ @asynccontextmanager
56
88
  async def select_backend(
57
89
  self,
58
90
  scope: Scope,
59
91
  headers: dict[str, str],
60
- ) -> tuple[StreamReader, StreamWriter] | tuple[None, None]:
92
+ ) -> AsyncGenerator[StreamPair]:
61
93
  target_name = self.get_target_name(headers)
62
- sock = await self._ziti_socket_cache.get_or_create(target_name)
63
- reader, writer = await asyncio.open_connection(sock=sock)
64
- return reader, writer
94
+ if not target_name or not RE_SUBDOMAIN.fullmatch(target_name):
95
+ raise InvalidBackendError()
96
+ pool = await self._pool_manager.get_pool(target_name)
97
+ try:
98
+ async with pool.acquire() as (reader, writer):
99
+ yield reader, writer
100
+ except Exception:
101
+ raise BackendUnavailableError()
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: mrok
3
- Version: 0.4.4
3
+ Version: 0.4.6
4
4
  Summary: MPT Extensions OpenZiti Orchestrator
5
5
  Author: SoftwareOne AG
6
6
  License: Apache License
@@ -206,8 +206,8 @@ License: Apache License
206
206
  limitations under the License.
207
207
  License-File: LICENSE.txt
208
208
  Requires-Python: <4,>=3.12
209
- Requires-Dist: aiocache<0.13.0,>=0.12.3
210
209
  Requires-Dist: asn1crypto<2.0.0,>=1.5.1
210
+ Requires-Dist: cachetools<7.0.0,>=6.2.2
211
211
  Requires-Dist: cryptography<46.0.0,>=45.0.7
212
212
  Requires-Dist: dynaconf<4.0.0,>=3.2.11
213
213
  Requires-Dist: fastapi-pagination<0.15.0,>=0.14.1
@@ -1,5 +1,6 @@
1
1
  mrok/__init__.py,sha256=D1PUs3KtMCqG4bFLceVNG62L3RN53NS95uSCNXpgvzs,181
2
2
  mrok/conf.py,sha256=_5Z-A5LyojQeY8J7W8C0QidsmrPl99r9qKYEoMf4kcI,840
3
+ mrok/constants.py,sha256=65OlmploxfND686E4mt9LR9MqYn8I5k-L0H-R5KsLG8,201
3
4
  mrok/datastructures.py,sha256=gp8KF2JoNOxIRzYStVZLKL_XVDbcIVSIDnmpQo4FNt0,4067
4
5
  mrok/errors.py,sha256=ruNMDFr2_0ezCGXuCG1OswCEv-bHOIzMMd02J_0ABcs,37
5
6
  mrok/logging.py,sha256=ZMWn0w4fJ-F_g-L37H_GM14BSXAIF2mFF_ougX5S7mg,2856
@@ -14,7 +15,7 @@ mrok/agent/devtools/inspector/__main__.py,sha256=HeYcRf1bjXPji2LKMPCcTU61afrRH2P
14
15
  mrok/agent/devtools/inspector/app.py,sha256=_pzxemMqIunE5EdMq5amjqpOGsMWIOw17GgiCtRAi6Q,16464
15
16
  mrok/agent/devtools/inspector/server.py,sha256=C4uD6_1psSHMjJLUDCMPGvKdQYKaEwYTw27NAbwuuA0,636
16
17
  mrok/agent/sidecar/__init__.py,sha256=DrjJGhqFyxsVODW06KI20Wpr6HsD2lD6qFCKUXc7GIE,59
17
- mrok/agent/sidecar/app.py,sha256=1p6qWkXVq78zcJ2dhCYlw8CqfwPsgEtu07Lp5csK3Iw,874
18
+ mrok/agent/sidecar/app.py,sha256=YOQLwPPqcElbF2kU15bcw-ePzZM09eVJQZ6Z5NYg9u8,1509
18
19
  mrok/agent/sidecar/main.py,sha256=h31wynUCcFmRckvqLHtH97w1QgMv4fzcmYjhRPUobxY,1076
19
20
  mrok/cli/__init__.py,sha256=mtFEa8IeS1x6Gm4dUYoSnAxyEzNqbUVSmWxtuZUMR84,61
20
21
  mrok/cli/main.py,sha256=DFcYPwDskXi8SKAgEsuP4GMFzaniIf_6bZaSDWvYKDk,2724
@@ -28,8 +29,8 @@ mrok/cli/commands/admin/list/__init__.py,sha256=kjCMcpn1gopcrQaaHxfFh8Kyngldepnl
28
29
  mrok/cli/commands/admin/list/extensions.py,sha256=16fhDB5ucL8su2WQnSaQ1E6MhgC4vkP9-nuHAcPpzyE,4405
29
30
  mrok/cli/commands/admin/list/instances.py,sha256=kaqeyidwUxgYqfaHXqp2m76rm5h2ErBsYyZcNeaBRwY,5912
30
31
  mrok/cli/commands/admin/register/__init__.py,sha256=5Jb_bc2L47MEpQIrOcquzduTFWQ01Jd1U1MpqaR-Ekw,209
31
- mrok/cli/commands/admin/register/extensions.py,sha256=p1qX5gSQX1IGpOQjO2MJzbc09v1ebdFuPo94QzJErKk,1485
32
- mrok/cli/commands/admin/register/instances.py,sha256=XB6uAchc7Rm8uAu7o3-oHaN_rS8CCIBf0QKWZGW86fI,1940
32
+ mrok/cli/commands/admin/register/extensions.py,sha256=dxciVA_S31rZSm0A7lkecn2mI9TMlWDhcJTgwgNXbM4,1460
33
+ mrok/cli/commands/admin/register/instances.py,sha256=raF57jPUTryWdvNqGCosth1C-8jjv9IbA0UuNbDel3A,2220
33
34
  mrok/cli/commands/admin/unregister/__init__.py,sha256=-GjjCPX1pISbWmJK6GpKO3ijGsDQb21URjU1hNu99O4,215
34
35
  mrok/cli/commands/admin/unregister/extensions.py,sha256=GR3Iwzeksk_R0GkgmCSG7iHRcUrI7ABqDi25Gbes64Y,1016
35
36
  mrok/cli/commands/admin/unregister/instances.py,sha256=-28wL8pTXTWHVHtw93y8-dqi-Dlf0OZOnlBCKOyGo80,1138
@@ -63,17 +64,17 @@ mrok/controller/routes/instances.py,sha256=v-fn_F6JHbDZ4YUNCIZzClgHp6aC1Eu5HB7k7
63
64
  mrok/http/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
64
65
  mrok/http/config.py,sha256=k73-4hBo6jag1RpyZagJLtTCL6EQoebZaX8Vv-CMN_k,2050
65
66
  mrok/http/constants.py,sha256=ao5gI2HFBWmrdd2Yc6XFK_RGaHk-omxI4AqvfIiGes8,409
66
- mrok/http/forwarder.py,sha256=DakD9hrrCWAzB1B_4SgQaQaEHcDYLLI9WaYs5F0O36I,12977
67
+ mrok/http/forwarder.py,sha256=vAf2nh6Fmr07JdRJkK4dPHKJilP9PnsYZcroqsnilB8,13751
67
68
  mrok/http/lifespan.py,sha256=UdbOqjWZsHzJJjX0CTd2hY96Jpk5QWtdHJEzPG6Z4hQ,1288
68
69
  mrok/http/middlewares.py,sha256=SGo4EwhTId2uJx1aMuqGbNy7MXgZlDEdZI0buzBYVv0,5011
70
+ mrok/http/pool.py,sha256=Q-pRwgYPusqEKQCwZsRQ2mnGaDfyWknWpvydUu5KtEU,7696
69
71
  mrok/http/protocol.py,sha256=ap8jbLUvgbAH81ZJZCBkQiYR7mkV_eL3rpfwEkoE8sU,392
70
72
  mrok/http/server.py,sha256=Mj7C85fc-DXp-WTBWaOd7ag808oliLmFBH5bf-G2FHg,370
71
- mrok/http/types.py,sha256=XpNrvbfpANKvmjOBYtLF1FmDHoJF3z_MIMQHXoJlvmE,1302
73
+ mrok/http/types.py,sha256=A82zloEqW8KdKahdNrbW5fhlJNUo2enLNRVMWIJTatA,632
72
74
  mrok/http/utils.py,sha256=sOixYu3R9-nNoMFYdifrreYvcFRIHYVtb6AAmtVzaLE,2125
73
75
  mrok/proxy/__init__.py,sha256=vWXyImroqM1Eq8e_oFPBup8VJ3reyp8SVjFTbLzRkI8,51
74
- mrok/proxy/app.py,sha256=yulfBdTdxesVxF1h2lli_5zjd5wP-jTx17FRdbkaV7A,2163
76
+ mrok/proxy/app.py,sha256=VvMRmYLwsItjCcecy6ccrkk564LnArIermHTRVDxh9U,3469
75
77
  mrok/proxy/main.py,sha256=ZXpticE6J4FABaslDB_8J5qklPsf3e7xIFSZmcPAAjQ,1588
76
- mrok/proxy/ziti.py,sha256=rKgIXpOvtBeVopZkQlNUZa3Fdci9jgiog_i6egb17ps,3318
77
78
  mrok/ziti/__init__.py,sha256=20OWMiexRhOovZOX19zlX87-V78QyWnEnSZfyAftUdE,263
78
79
  mrok/ziti/api.py,sha256=KvGiT9d4oSgC3JbFWLDQyuHcLX2HuZJoJ8nHmWtCDkY,16154
79
80
  mrok/ziti/bootstrap.py,sha256=QIDhlkIxPW2QRuumFq2D1WDbD003P5f3z24pAUsyeBI,2696
@@ -82,8 +83,8 @@ mrok/ziti/errors.py,sha256=yYCbVDwktnR0AYduqtynIjo73K3HOhIrwA_vQimvEd4,368
82
83
  mrok/ziti/identities.py,sha256=1BcwfqAJHMBhc3vRaf0aLaIkoHskj5Xe2Lsq2lO9Vs8,6735
83
84
  mrok/ziti/pki.py,sha256=o2tySqHC8-7bvFuI2Tqxg9vX6H6ZSxWxfP_9x29e19M,1954
84
85
  mrok/ziti/services.py,sha256=zR1PEBYwXVou20iJK4euh0ZZFAo9UB8PZk8f6SDmiUE,3194
85
- mrok-0.4.4.dist-info/METADATA,sha256=nzjalRGet1yhkJf1L4t022A-NTDG-xQ9a5cWZfbDkdg,15836
86
- mrok-0.4.4.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
87
- mrok-0.4.4.dist-info/entry_points.txt,sha256=tloXwvU1uJicBJR2h-8HoVclPgwJWDwuREMHN8Zq-nU,38
88
- mrok-0.4.4.dist-info/licenses/LICENSE.txt,sha256=6PaICaoA3yNsZKLv5G6OKqSfLSoX7MakYqTDgJoTCBs,11346
89
- mrok-0.4.4.dist-info/RECORD,,
86
+ mrok-0.4.6.dist-info/METADATA,sha256=Io64noW9WGLw9asC4xjeuLS7Wh8bFefufJmTjUK8Syo,15836
87
+ mrok-0.4.6.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
88
+ mrok-0.4.6.dist-info/entry_points.txt,sha256=tloXwvU1uJicBJR2h-8HoVclPgwJWDwuREMHN8Zq-nU,38
89
+ mrok-0.4.6.dist-info/licenses/LICENSE.txt,sha256=6PaICaoA3yNsZKLv5G6OKqSfLSoX7MakYqTDgJoTCBs,11346
90
+ mrok-0.4.6.dist-info/RECORD,,
mrok/proxy/ziti.py DELETED
@@ -1,102 +0,0 @@
1
- import asyncio
2
- import contextlib
3
- import logging
4
- from asyncio import Task
5
- from pathlib import Path
6
-
7
- import openziti
8
- from aiocache import Cache
9
- from openziti.context import ZitiContext
10
- from openziti.zitisock import ZitiSocket
11
-
12
- logger = logging.getLogger("mrok.proxy")
13
-
14
-
15
- class ZitiSocketCache:
16
- def __init__(
17
- self,
18
- identity_file: str | Path,
19
- ziti_ctx_timeout_ms: int = 10_000,
20
- ttl_seconds: float = 60.0,
21
- cleanup_interval: float = 10.0,
22
- ) -> None:
23
- self._identity_file = identity_file
24
- self._ziti_ctx_timeout_ms = ziti_ctx_timeout_ms
25
- self._ttl_seconds = ttl_seconds
26
- self._cleanup_interval = cleanup_interval
27
-
28
- self._ziti_ctx: ZitiContext | None = None
29
- self._cache = Cache(Cache.MEMORY)
30
- self._active_sockets: dict[str, ZitiSocket] = {}
31
- self._cleanup_task: Task | None = None
32
-
33
- def _get_ziti_ctx(self) -> ZitiContext:
34
- if self._ziti_ctx is None:
35
- ctx, err = openziti.load(str(self._identity_file), timeout=self._ziti_ctx_timeout_ms)
36
- if err != 0:
37
- raise Exception(f"Cannot create a Ziti context from the identity file: {err}")
38
- self._ziti_ctx = ctx
39
- return self._ziti_ctx
40
-
41
- async def _create_socket(self, key: str):
42
- return self._get_ziti_ctx().connect(key)
43
-
44
- async def get_or_create(self, key: str):
45
- sock = await self._cache.get(key)
46
-
47
- if sock:
48
- await self._cache.expire(key, self._ttl_seconds)
49
- self._active_sockets[key] = sock
50
- logger.debug(f"Ziti socket found for service {key}")
51
- return sock
52
-
53
- sock = await self._create_socket(key)
54
- await self._cache.set(key, sock, self._ttl_seconds)
55
- self._active_sockets[key] = sock
56
- logger.info(f"New Ziti socket created for service {key}")
57
- return sock
58
-
59
- # async def invalidate(self, key: str):
60
- # sock = await self._cache.get(key)
61
- # if sock:
62
- # await self._close_socket(sock)
63
-
64
- # await self._cache.delete(key)
65
- # self._active_sockets.pop(key, None)
66
-
67
- async def start(self):
68
- self._cleanup_task = asyncio.create_task(self._periodic_cleanup())
69
- # Warmup ziti context
70
- self._get_ziti_ctx()
71
-
72
- async def stop(self):
73
- self._cleanup_task.cancel()
74
- with contextlib.suppress(Exception):
75
- await self._cleanup_task
76
-
77
- for sock in list(self._active_sockets.values()):
78
- await self._close_socket(sock)
79
-
80
- self._active_sockets.clear()
81
- await self._cache.clear()
82
-
83
- @staticmethod
84
- async def _close_socket(sock: ZitiSocket):
85
- with contextlib.suppress(Exception):
86
- sock.close()
87
-
88
- async def _periodic_cleanup(self):
89
- try:
90
- while True:
91
- await asyncio.sleep(self._cleanup_interval)
92
- await self._cleanup_once()
93
- except asyncio.CancelledError:
94
- return
95
-
96
- async def _cleanup_once(self):
97
- expired = {key for key in self._active_sockets.keys() if not self._cache.exists(key)}
98
- for key in expired:
99
- logger.debug(f"Cleaning up expired socket connection {key}")
100
- sock = self._active_sockets.pop(key, None)
101
- if sock:
102
- await self._close_socket(sock)
File without changes