mrok 0.2.3__py3-none-any.whl → 0.4.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (56) hide show
  1. mrok/agent/devtools/__init__.py +0 -0
  2. mrok/agent/devtools/__main__.py +34 -0
  3. mrok/agent/devtools/inspector/__init__.py +0 -0
  4. mrok/agent/devtools/inspector/__main__.py +25 -0
  5. mrok/agent/devtools/inspector/app.py +556 -0
  6. mrok/agent/devtools/inspector/server.py +18 -0
  7. mrok/agent/sidecar/app.py +9 -10
  8. mrok/agent/sidecar/main.py +35 -16
  9. mrok/agent/ziticorn.py +27 -18
  10. mrok/cli/commands/__init__.py +2 -1
  11. mrok/cli/commands/admin/list/instances.py +24 -4
  12. mrok/cli/commands/admin/register/extensions.py +2 -2
  13. mrok/cli/commands/admin/register/instances.py +3 -3
  14. mrok/cli/commands/admin/unregister/extensions.py +2 -2
  15. mrok/cli/commands/admin/unregister/instances.py +2 -2
  16. mrok/cli/commands/agent/__init__.py +2 -0
  17. mrok/cli/commands/agent/dev/__init__.py +7 -0
  18. mrok/cli/commands/agent/dev/console.py +25 -0
  19. mrok/cli/commands/agent/dev/web.py +37 -0
  20. mrok/cli/commands/agent/run/asgi.py +35 -16
  21. mrok/cli/commands/agent/run/sidecar.py +29 -13
  22. mrok/cli/commands/agent/utils.py +5 -0
  23. mrok/cli/commands/controller/run.py +1 -5
  24. mrok/cli/commands/proxy/__init__.py +6 -0
  25. mrok/cli/commands/proxy/run.py +49 -0
  26. mrok/cli/utils.py +5 -0
  27. mrok/conf.py +6 -0
  28. mrok/controller/auth.py +2 -2
  29. mrok/controller/routes/extensions.py +9 -7
  30. mrok/datastructures.py +159 -0
  31. mrok/http/config.py +3 -6
  32. mrok/http/constants.py +22 -0
  33. mrok/http/forwarder.py +62 -23
  34. mrok/http/lifespan.py +29 -0
  35. mrok/http/middlewares.py +143 -0
  36. mrok/http/types.py +43 -0
  37. mrok/http/utils.py +90 -0
  38. mrok/logging.py +22 -0
  39. mrok/master.py +269 -0
  40. mrok/metrics.py +139 -0
  41. mrok/proxy/__init__.py +3 -0
  42. mrok/proxy/app.py +73 -0
  43. mrok/proxy/dataclasses.py +12 -0
  44. mrok/proxy/main.py +58 -0
  45. mrok/proxy/streams.py +124 -0
  46. mrok/proxy/types.py +12 -0
  47. mrok/proxy/ziti.py +173 -0
  48. mrok/ziti/identities.py +50 -20
  49. mrok/ziti/services.py +8 -8
  50. {mrok-0.2.3.dist-info → mrok-0.4.0.dist-info}/METADATA +7 -1
  51. mrok-0.4.0.dist-info/RECORD +92 -0
  52. {mrok-0.2.3.dist-info → mrok-0.4.0.dist-info}/WHEEL +1 -1
  53. mrok/http/master.py +0 -132
  54. mrok-0.2.3.dist-info/RECORD +0 -66
  55. {mrok-0.2.3.dist-info → mrok-0.4.0.dist-info}/entry_points.txt +0 -0
  56. {mrok-0.2.3.dist-info → mrok-0.4.0.dist-info}/licenses/LICENSE.txt +0 -0
@@ -14,8 +14,8 @@ from mrok.ziti.errors import (
14
14
  ServiceAlreadyRegisteredError,
15
15
  ServiceNotFoundError,
16
16
  )
17
- from mrok.ziti.identities import register_instance, unregister_instance
18
- from mrok.ziti.services import register_extension, unregister_extension
17
+ from mrok.ziti.identities import register_identity, unregister_identity
18
+ from mrok.ziti.services import register_service, unregister_service
19
19
 
20
20
  logger = logging.getLogger("mrok.controller")
21
21
 
@@ -83,7 +83,7 @@ async def create_extension(
83
83
  ],
84
84
  ):
85
85
  try:
86
- service = await register_extension(settings, mgmt_api, data.extension.id, data.tags)
86
+ service = await register_service(settings, mgmt_api, data.extension.id, data.tags)
87
87
  return ExtensionRead(
88
88
  id=service["id"],
89
89
  name=service["name"],
@@ -149,7 +149,7 @@ async def delete_instance_by_id_or_extension_id(
149
149
  id_or_extension_id: str,
150
150
  ):
151
151
  try:
152
- await unregister_extension(settings, mgmt_api, id_or_extension_id)
152
+ await unregister_service(settings, mgmt_api, id_or_extension_id)
153
153
  except ServiceNotFoundError:
154
154
  raise HTTPException(
155
155
  status_code=status.HTTP_404_NOT_FOUND,
@@ -203,6 +203,7 @@ async def get_extensions(
203
203
  tags=["Instances"],
204
204
  )
205
205
  async def create_extension_instances(
206
+ settings: AppSettings,
206
207
  mgmt_api: ZitiManagementAPI,
207
208
  client_api: ZitiClientAPI,
208
209
  id_or_extension_id: str,
@@ -223,8 +224,8 @@ async def create_extension_instances(
223
224
  ],
224
225
  ):
225
226
  service = await fetch_extension_or_404(mgmt_api, id_or_extension_id)
226
- identity, identity_file = await register_instance(
227
- mgmt_api, client_api, service["name"], data.instance.id, data.tags
227
+ identity, identity_file = await register_identity(
228
+ settings, mgmt_api, client_api, service["name"], data.instance.id, data.tags
228
229
  )
229
230
  return InstanceRead(
230
231
  id=identity["id"],
@@ -299,10 +300,11 @@ async def get_instance_by_id_or_instance_id(
299
300
  tags=["Instances"],
300
301
  )
301
302
  async def delete_instance_by_id_or_instance_id(
303
+ settings: AppSettings,
302
304
  mgmt_api: ZitiManagementAPI,
303
305
  id_or_extension_id: str,
304
306
  id_or_instance_id: str,
305
307
  ):
306
308
  identity = await fetch_instance_or_404(mgmt_api, id_or_extension_id, id_or_instance_id)
307
309
  instance_id, extension_id = identity["name"].split(".")
308
- await unregister_instance(mgmt_api, extension_id, instance_id)
310
+ await unregister_identity(settings, mgmt_api, extension_id, instance_id)
mrok/datastructures.py ADDED
@@ -0,0 +1,159 @@
1
+ from __future__ import annotations
2
+
3
+ from typing import Literal
4
+
5
+ from pydantic import BaseModel, Field
6
+ from pydantic_core import core_schema
7
+
8
+
9
+ class FixedSizeByteBuffer:
10
+ def __init__(self, max_size: int):
11
+ self._max_size = max_size
12
+ self._buf = bytearray()
13
+ self.overflow = False
14
+
15
+ def write(self, data: bytes) -> None:
16
+ if not data:
17
+ return
18
+
19
+ remaining = self._max_size - len(self._buf)
20
+ if remaining <= 0:
21
+ self.overflow = True
22
+ return
23
+
24
+ if len(data) > remaining:
25
+ self._buf.extend(data[:remaining])
26
+ self.overflow = True
27
+ else:
28
+ self._buf.extend(data)
29
+
30
+ def getvalue(self) -> bytes:
31
+ return bytes(self._buf)
32
+
33
+ def clear(self) -> None:
34
+ self._buf.clear()
35
+ self.overflow = False
36
+
37
+
38
+ class HTTPHeaders(dict):
39
+ def __init__(self, initial=None):
40
+ super().__init__()
41
+ if initial:
42
+ for k, v in initial.items():
43
+ super().__setitem__(str(k).lower(), str(v))
44
+
45
+ def __getitem__(self, key: str) -> str:
46
+ return super().__getitem__(key.lower())
47
+
48
+ def __setitem__(self, key: str, value: str) -> None:
49
+ super().__setitem__(str(key).lower(), str(value))
50
+
51
+ def __delitem__(self, key: str) -> None:
52
+ super().__delitem__(key.lower())
53
+
54
+ def get(self, k, default=None):
55
+ return super().get(k.lower(), default)
56
+
57
+ @classmethod
58
+ def from_asgi(cls, items: list[tuple[bytes, bytes]]) -> HTTPHeaders:
59
+ d = {k.decode("latin-1"): v.decode("latin-1") for k, v in items}
60
+ return cls(d)
61
+
62
+ @classmethod
63
+ def __get_pydantic_core_schema__(cls, source, handler):
64
+ """Provide a pydantic-core schema so Pydantic treats this as a mapping of str->str.
65
+
66
+ We generate the schema for `dict[str, str]` using the provided handler and wrap
67
+ it with a validator that converts the validated dict into `HTTPHeaders`.
68
+ """
69
+ # handler may be a callable or an object with `generate_schema`; handle both
70
+ try:
71
+ dict_schema = handler.generate_schema(dict[str, str])
72
+ except AttributeError:
73
+ dict_schema = handler(dict[str, str])
74
+
75
+ def _wrap(v, validator):
76
+ # `validator` will validate input according to `dict_schema` and return a dict
77
+ validated = validator(input_value=v)
78
+ if isinstance(validated, HTTPHeaders):
79
+ return validated
80
+ return cls(validated)
81
+
82
+ return core_schema.no_info_wrap_validator_function(
83
+ _wrap,
84
+ dict_schema,
85
+ serialization=core_schema.plain_serializer_function_ser_schema(lambda v: dict(v)),
86
+ )
87
+
88
+
89
+ class HTTPRequest(BaseModel):
90
+ method: str
91
+ url: str
92
+ headers: HTTPHeaders
93
+ query_string: bytes
94
+ start_time: float
95
+ body: bytes | None = None
96
+ body_truncated: bool | None = None
97
+
98
+
99
+ class HTTPResponse(BaseModel):
100
+ type: Literal["response"] = "response"
101
+ request: HTTPRequest
102
+ status: int
103
+ headers: HTTPHeaders
104
+ duration: float
105
+ body: bytes | None = None
106
+ body_truncated: bool | None = None
107
+
108
+
109
+ class ProcessMetrics(BaseModel):
110
+ cpu: float
111
+ mem: float
112
+
113
+
114
+ class DataTransferMetrics(BaseModel):
115
+ bytes_in: int
116
+ bytes_out: int
117
+
118
+
119
+ class RequestsMetrics(BaseModel):
120
+ rps: int
121
+ total: int
122
+ successful: int
123
+ failed: int
124
+
125
+
126
+ class ResponseTimeMetrics(BaseModel):
127
+ avg: float
128
+ min: int
129
+ max: int
130
+ p50: int
131
+ p90: int
132
+ p99: int
133
+
134
+
135
+ class WorkerMetrics(BaseModel):
136
+ worker_id: str
137
+ data_transfer: DataTransferMetrics
138
+ requests: RequestsMetrics
139
+ response_time: ResponseTimeMetrics
140
+ process: ProcessMetrics
141
+
142
+
143
+ class Meta(BaseModel):
144
+ identity: str
145
+ extension: str
146
+ instance: str
147
+ domain: str
148
+ tags: dict[str, str] | None = None
149
+
150
+
151
+ class Status(BaseModel):
152
+ type: Literal["status"] = "status"
153
+ meta: Meta
154
+ metrics: WorkerMetrics
155
+
156
+
157
+ class Event(BaseModel):
158
+ type: Literal["status", "response"]
159
+ data: Status | HTTPResponse = Field(discriminator="type")
mrok/http/config.py CHANGED
@@ -8,21 +8,18 @@ from typing import Any
8
8
  import openziti
9
9
  from uvicorn import config
10
10
 
11
- from mrok.conf import get_settings
12
11
  from mrok.http.protocol import MrokHttpToolsProtocol
13
- from mrok.logging import setup_logging
12
+ from mrok.http.types import ASGIApp
14
13
 
15
14
  logger = logging.getLogger("mrok.proxy")
16
15
 
17
16
  config.LIFESPAN["auto"] = "mrok.http.lifespan:MrokLifespan"
18
17
 
19
- ASGIApplication = config.ASGIApplication
20
-
21
18
 
22
19
  class MrokBackendConfig(config.Config):
23
20
  def __init__(
24
21
  self,
25
- app: ASGIApplication | Callable[..., Any] | str,
22
+ app: ASGIApp | Callable[..., Any] | str,
26
23
  identity_file: str | Path,
27
24
  ziti_load_timeout_ms: int = 5000,
28
25
  backlog: int = 2048,
@@ -62,4 +59,4 @@ class MrokBackendConfig(config.Config):
62
59
  return sock
63
60
 
64
61
  def configure_logging(self) -> None:
65
- setup_logging(get_settings())
62
+ return
mrok/http/constants.py ADDED
@@ -0,0 +1,22 @@
1
+ MAX_REQUEST_BODY_BYTES = 2 * 1024 * 1024
2
+ MAX_RESPONSE_BODY_BYTES = 5 * 1024 * 1024
3
+
4
+ BINARY_CONTENT_TYPES = {
5
+ "application/octet-stream",
6
+ "application/pdf",
7
+ }
8
+
9
+ BINARY_PREFIXES = (
10
+ "image/",
11
+ "video/",
12
+ "audio/",
13
+ )
14
+
15
+ TEXTUAL_CONTENT_TYPES = {
16
+ "application/json",
17
+ "application/xml",
18
+ "application/javascript",
19
+ "application/x-www-form-urlencoded",
20
+ }
21
+
22
+ TEXTUAL_PREFIXES = ("text/",)
mrok/http/forwarder.py CHANGED
@@ -1,14 +1,10 @@
1
1
  import abc
2
2
  import asyncio
3
3
  import logging
4
- from collections.abc import Awaitable, Callable
5
- from typing import Any
6
4
 
7
- logger = logging.getLogger("mrok.proxy")
5
+ from mrok.http.types import ASGIReceive, ASGISend, Scope, StreamReader, StreamWriter
8
6
 
9
- Scope = dict[str, Any]
10
- ASGIReceive = Callable[[], Awaitable[dict[str, Any]]]
11
- ASGISend = Callable[[dict[str, Any]], Awaitable[None]]
7
+ logger = logging.getLogger("mrok.proxy")
12
8
 
13
9
 
14
10
  class BackendNotFoundError(Exception):
@@ -24,24 +20,71 @@ class ForwardAppBase(abc.ABC):
24
20
  and streaming logic (requests and responses).
25
21
  """
26
22
 
27
- def __init__(self, read_chunk_size: int = 65536) -> None:
28
- # number of bytes to read per iteration when streaming bodies
23
+ def __init__(
24
+ self,
25
+ read_chunk_size: int = 65536,
26
+ lifespan_timeout: float = 10.0,
27
+ ) -> None:
29
28
  self._read_chunk_size: int = int(read_chunk_size)
29
+ self._lifespan_timeout = lifespan_timeout
30
+
31
+ async def handle_lifespan(self, receive: ASGIReceive, send: ASGISend) -> None:
32
+ while True:
33
+ event = await receive()
34
+ etype = event.get("type")
35
+
36
+ if etype == "lifespan.startup":
37
+ try:
38
+ await asyncio.wait_for(self.startup(), self._lifespan_timeout)
39
+ except TimeoutError:
40
+ logger.exception("Lifespan startup timed out")
41
+ await send({"type": "lifespan.startup.failed", "message": "startup timeout"})
42
+ continue
43
+ except Exception as e:
44
+ logger.exception("Exception during lifespan startup")
45
+ await send({"type": "lifespan.startup.failed", "message": str(e)})
46
+ continue
47
+ await send({"type": "lifespan.startup.complete"})
48
+
49
+ elif etype == "lifespan.shutdown":
50
+ try:
51
+ await asyncio.wait_for(self.shutdown(), self._lifespan_timeout)
52
+ except TimeoutError:
53
+ logger.exception("Lifespan shutdown timed out")
54
+ await send({"type": "lifespan.shutdown.failed", "message": "shutdown timeout"})
55
+ return
56
+ except Exception as exc:
57
+ logger.exception("Exception during lifespan shutdown")
58
+ await send({"type": "lifespan.shutdown.failed", "message": str(exc)})
59
+ return
60
+ await send({"type": "lifespan.shutdown.complete"})
61
+ return
30
62
 
31
63
  @abc.abstractmethod
32
64
  async def select_backend(
33
65
  self,
34
66
  scope: Scope,
35
67
  headers: dict[str, str],
36
- ) -> tuple[asyncio.StreamReader, asyncio.StreamWriter] | tuple[None, None]:
68
+ ) -> tuple[StreamReader, StreamWriter] | tuple[None, None]:
37
69
  """Return (reader, writer) connected to the target backend."""
38
70
 
71
+ async def startup(self):
72
+ return
73
+
74
+ async def shutdown(self):
75
+ return
76
+
39
77
  async def __call__(self, scope: Scope, receive: ASGIReceive, send: ASGISend) -> None:
40
78
  """ASGI callable entry point.
41
79
 
42
80
  Delegates to smaller helper methods for readability. Subclasses only
43
81
  need to implement backend selection.
44
82
  """
83
+ scope_type = scope.get("type")
84
+ if scope_type == "lifespan":
85
+ await self.handle_lifespan(receive, send)
86
+ return
87
+
45
88
  if scope.get("type") != "http":
46
89
  await send({"type": "http.response.start", "status": 500, "headers": []})
47
90
  await send({"type": "http.response.body", "body": b"Unsupported"})
@@ -116,7 +159,7 @@ class ForwardAppBase(abc.ABC):
116
159
 
117
160
  async def write_request_line_and_headers(
118
161
  self,
119
- writer: asyncio.StreamWriter,
162
+ writer: StreamWriter,
120
163
  method: str,
121
164
  path_qs: str,
122
165
  headers: list[tuple[bytes, bytes]],
@@ -130,7 +173,7 @@ class ForwardAppBase(abc.ABC):
130
173
  await writer.drain()
131
174
 
132
175
  async def stream_request_body(
133
- self, receive: ASGIReceive, writer: asyncio.StreamWriter, use_chunked: bool
176
+ self, receive: ASGIReceive, writer: StreamWriter, use_chunked: bool
134
177
  ) -> None:
135
178
  if use_chunked:
136
179
  await self.stream_request_chunked(receive, writer)
@@ -138,9 +181,7 @@ class ForwardAppBase(abc.ABC):
138
181
 
139
182
  await self.stream_request_until_end(receive, writer)
140
183
 
141
- async def stream_request_chunked(
142
- self, receive: ASGIReceive, writer: asyncio.StreamWriter
143
- ) -> None:
184
+ async def stream_request_chunked(self, receive: ASGIReceive, writer: StreamWriter) -> None:
144
185
  """Send request body to backend using HTTP/1.1 chunked encoding."""
145
186
  while True:
146
187
  event = await receive()
@@ -160,9 +201,7 @@ class ForwardAppBase(abc.ABC):
160
201
  writer.write(b"0\r\n\r\n")
161
202
  await writer.drain()
162
203
 
163
- async def stream_request_until_end(
164
- self, receive: ASGIReceive, writer: asyncio.StreamWriter
165
- ) -> None:
204
+ async def stream_request_until_end(self, receive: ASGIReceive, writer: StreamWriter) -> None:
166
205
  """Send request body to backend when content length/transfer-encoding
167
206
  already provided (no chunking).
168
207
  """
@@ -180,7 +219,7 @@ class ForwardAppBase(abc.ABC):
180
219
  return
181
220
 
182
221
  async def read_status_and_headers(
183
- self, reader: asyncio.StreamReader, first_line: bytes
222
+ self, reader: StreamReader, first_line: bytes
184
223
  ) -> tuple[int, list[tuple[bytes, bytes]], dict[bytes, bytes]]:
185
224
  parts = first_line.decode(errors="ignore").split(" ", 2)
186
225
  status = int(parts[1]) if len(parts) >= 2 and parts[1].isdigit() else 502
@@ -217,14 +256,14 @@ class ForwardAppBase(abc.ABC):
217
256
  except Exception:
218
257
  return None
219
258
 
220
- async def drain_trailers(self, reader: asyncio.StreamReader) -> None:
259
+ async def drain_trailers(self, reader: StreamReader) -> None:
221
260
  """Consume trailer header lines until an empty line is encountered."""
222
261
  while True:
223
262
  trailer = await reader.readline()
224
263
  if trailer in (b"\r\n", b"\n", b""):
225
264
  break
226
265
 
227
- async def stream_response_chunked(self, reader: asyncio.StreamReader, send: ASGISend) -> None:
266
+ async def stream_response_chunked(self, reader: StreamReader, send: ASGISend) -> None:
228
267
  """Read chunked-encoded response from reader, decode and forward to ASGI send."""
229
268
  while True:
230
269
  size_line = await reader.readline()
@@ -253,7 +292,7 @@ class ForwardAppBase(abc.ABC):
253
292
  await send({"type": "http.response.body", "body": b"", "more_body": False})
254
293
 
255
294
  async def stream_response_with_content_length(
256
- self, reader: asyncio.StreamReader, send: ASGISend, content_length: int
295
+ self, reader: StreamReader, send: ASGISend, content_length: int
257
296
  ) -> None:
258
297
  """Read exactly content_length bytes and forward to ASGI send events."""
259
298
  remaining = content_length
@@ -272,7 +311,7 @@ class ForwardAppBase(abc.ABC):
272
311
  if not sent_final:
273
312
  await send({"type": "http.response.body", "body": b"", "more_body": False})
274
313
 
275
- async def stream_response_until_eof(self, reader: asyncio.StreamReader, send: ASGISend) -> None:
314
+ async def stream_response_until_eof(self, reader: StreamReader, send: ASGISend) -> None:
276
315
  """Read from reader until EOF and forward chunks to ASGI send events."""
277
316
  while True:
278
317
  chunk = await reader.read(self._read_chunk_size)
@@ -282,7 +321,7 @@ class ForwardAppBase(abc.ABC):
282
321
  await send({"type": "http.response.body", "body": b"", "more_body": False})
283
322
 
284
323
  async def stream_response_body(
285
- self, reader: asyncio.StreamReader, send: ASGISend, raw_headers: dict[bytes, bytes]
324
+ self, reader: StreamReader, send: ASGISend, raw_headers: dict[bytes, bytes]
286
325
  ) -> None:
287
326
  te = raw_headers.get(b"transfer-encoding", b"").lower()
288
327
  cl = raw_headers.get(b"content-length")
mrok/http/lifespan.py CHANGED
@@ -1,10 +1,39 @@
1
1
  import logging
2
+ from collections.abc import Awaitable, Callable
2
3
 
3
4
  from uvicorn.config import Config
4
5
  from uvicorn.lifespan.on import LifespanOn
5
6
 
7
+ AsyncCallback = Callable[[], Awaitable[None]]
8
+
6
9
 
7
10
  class MrokLifespan(LifespanOn):
8
11
  def __init__(self, config: Config) -> None:
9
12
  super().__init__(config)
10
13
  self.logger = logging.getLogger("mrok.proxy")
14
+
15
+
16
+ class LifespanWrapper:
17
+ def __init__(
18
+ self, app, on_startup: AsyncCallback | None = None, on_shutdown: AsyncCallback | None = None
19
+ ):
20
+ self.app = app
21
+ self.on_startup = on_startup
22
+ self.on_shutdown = on_shutdown
23
+
24
+ async def __call__(self, scope, receive, send):
25
+ if scope["type"] == "lifespan":
26
+ while True:
27
+ event = await receive()
28
+ if event["type"] == "lifespan.startup":
29
+ if self.on_startup:
30
+ await self.on_startup()
31
+ await send({"type": "lifespan.startup.complete"})
32
+
33
+ elif event["type"] == "lifespan.shutdown":
34
+ if self.on_shutdown:
35
+ await self.on_shutdown()
36
+ await send({"type": "lifespan.shutdown.complete"})
37
+ break
38
+ else:
39
+ await self.app(scope, receive, send)
@@ -0,0 +1,143 @@
1
+ import asyncio
2
+ import inspect
3
+ import logging
4
+ import time
5
+ from collections.abc import Callable, Coroutine
6
+ from typing import Any
7
+
8
+ from mrok.datastructures import FixedSizeByteBuffer, HTTPHeaders, HTTPRequest, HTTPResponse
9
+ from mrok.http.constants import MAX_REQUEST_BODY_BYTES, MAX_RESPONSE_BODY_BYTES
10
+ from mrok.http.types import ASGIApp, ASGIReceive, ASGISend, Message, Scope
11
+ from mrok.http.utils import must_capture_request, must_capture_response
12
+ from mrok.metrics import WorkerMetricsCollector
13
+
14
+ logger = logging.getLogger("mrok.proxy")
15
+
16
+ ResponseCompleteCallback = Callable[[HTTPResponse], Coroutine[Any, Any, None] | None]
17
+
18
+
19
+ class CaptureMiddleware:
20
+ def __init__(
21
+ self,
22
+ app: ASGIApp,
23
+ on_response_complete: ResponseCompleteCallback,
24
+ ):
25
+ self.app = app
26
+ self._on_response_complete = on_response_complete
27
+
28
+ async def __call__(self, scope: Scope, receive: ASGIReceive, send: ASGISend):
29
+ if scope["type"] != "http":
30
+ await self.app(scope, receive, send)
31
+ return
32
+
33
+ start_time = time.time()
34
+ method = scope["method"]
35
+ path = scope["path"]
36
+ query_string = scope.get("query_string", b"")
37
+ req_headers_raw = scope.get("headers", [])
38
+ req_headers = HTTPHeaders.from_asgi(req_headers_raw)
39
+
40
+ state = {}
41
+
42
+ req_buf = FixedSizeByteBuffer(MAX_REQUEST_BODY_BYTES)
43
+ capture_req_body = must_capture_request(method, req_headers)
44
+
45
+ request = HTTPRequest(
46
+ method=method,
47
+ url=path,
48
+ headers=req_headers,
49
+ query_string=query_string,
50
+ start_time=start_time,
51
+ )
52
+
53
+ # Response capture
54
+ resp_buf = FixedSizeByteBuffer(MAX_RESPONSE_BODY_BYTES)
55
+
56
+ async def receive_wrapper() -> Message:
57
+ msg = await receive()
58
+ if capture_req_body and msg["type"] == "http.request":
59
+ body = msg.get("body", b"")
60
+ req_buf.write(body)
61
+ return msg
62
+
63
+ async def send_wrapper(msg: Message):
64
+ if msg["type"] == "http.response.start":
65
+ state["status"] = msg["status"]
66
+ resp_headers = HTTPHeaders.from_asgi(msg.get("headers", []))
67
+ state["resp_headers_raw"] = resp_headers
68
+
69
+ state["capture_resp_body"] = must_capture_response(resp_headers)
70
+
71
+ if state["capture_resp_body"] and msg["type"] == "http.response.body":
72
+ body = msg.get("body", b"")
73
+ resp_buf.write(body)
74
+
75
+ await send(msg)
76
+
77
+ await self.app(scope, receive_wrapper, send_wrapper)
78
+
79
+ # Finalise request
80
+ request.body = req_buf.getvalue() if capture_req_body else None
81
+ request.body_truncated = req_buf.overflow if capture_req_body else None
82
+
83
+ # Finalise response
84
+ end_time = time.time()
85
+ duration = end_time - start_time
86
+
87
+ response = HTTPResponse(
88
+ request=request,
89
+ status=state["status"] or 0,
90
+ headers=state["resp_headers_raw"],
91
+ duration=duration,
92
+ body=resp_buf.getvalue() if state["capture_resp_body"] else None,
93
+ body_truncated=resp_buf.overflow if state["capture_resp_body"] else None,
94
+ )
95
+ await asyncio.create_task(self.handle_callback(response))
96
+
97
+ async def handle_callback(self, response: HTTPResponse):
98
+ try:
99
+ if inspect.iscoroutinefunction(self._on_response_complete):
100
+ await self._on_response_complete(response)
101
+ else:
102
+ await asyncio.get_running_loop().run_in_executor(
103
+ None, self._on_response_complete, response
104
+ )
105
+ except Exception:
106
+ logger.exception("Error invoking callback")
107
+
108
+
109
+ class MetricsMiddleware:
110
+ def __init__(self, app, metrics: WorkerMetricsCollector):
111
+ self.app = app
112
+ self.metrics = metrics
113
+
114
+ async def __call__(self, scope, receive, send):
115
+ if scope["type"] != "http":
116
+ return await self.app(scope, receive, send)
117
+
118
+ start_time = await self.metrics.on_request_start(scope)
119
+ status_code = 500 # default if app errors early
120
+
121
+ async def wrapped_receive():
122
+ msg = await receive()
123
+ if msg["type"] == "http.request" and msg.get("body"):
124
+ await self.metrics.on_request_body(len(msg["body"]))
125
+ return msg
126
+
127
+ async def wrapped_send(msg):
128
+ nonlocal status_code
129
+
130
+ if msg["type"] == "http.response.start":
131
+ status_code = msg["status"]
132
+ await self.metrics.on_response_start(status_code)
133
+
134
+ elif msg["type"] == "http.response.body":
135
+ body = msg.get("body", b"")
136
+ await self.metrics.on_response_chunk(len(body))
137
+
138
+ return await send(msg)
139
+
140
+ try:
141
+ await self.app(scope, wrapped_receive, wrapped_send)
142
+ finally:
143
+ await self.metrics.on_request_end(start_time, status_code)
mrok/http/types.py ADDED
@@ -0,0 +1,43 @@
1
+ from __future__ import annotations
2
+
3
+ import asyncio
4
+ from collections.abc import Awaitable, Callable, MutableMapping
5
+ from typing import Any, Protocol
6
+
7
+ from mrok.datastructures import HTTPRequest, HTTPResponse
8
+
9
+ Scope = MutableMapping[str, Any]
10
+ Message = MutableMapping[str, Any]
11
+
12
+ ASGIReceive = Callable[[], Awaitable[Message]]
13
+ ASGISend = Callable[[Message], Awaitable[None]]
14
+ ASGIApp = Callable[[Scope, ASGIReceive, ASGISend], Awaitable[None]]
15
+ RequestCompleteCallback = Callable[[HTTPRequest], Awaitable | None]
16
+ ResponseCompleteCallback = Callable[[HTTPResponse], Awaitable | None]
17
+
18
+
19
+ class StreamReaderWrapper(Protocol):
20
+ async def read(self, n: int = -1) -> bytes: ...
21
+ async def readexactly(self, n: int) -> bytes: ...
22
+ async def readline(self) -> bytes: ...
23
+ def at_eof(self) -> bool: ...
24
+
25
+ @property
26
+ def underlying(self) -> asyncio.StreamReader: ...
27
+
28
+
29
+ class StreamWriterWrapper(Protocol):
30
+ def write(self, data: bytes) -> None: ...
31
+ async def drain(self) -> None: ...
32
+ def close(self) -> None: ...
33
+ async def wait_closed(self) -> None: ...
34
+
35
+ @property
36
+ def transport(self): ...
37
+
38
+ @property
39
+ def underlying(self) -> asyncio.StreamWriter: ...
40
+
41
+
42
+ StreamReader = StreamReaderWrapper | asyncio.StreamReader
43
+ StreamWriter = StreamWriterWrapper | asyncio.StreamWriter