mrok 0.4.6__py3-none-any.whl → 0.5.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (38) hide show
  1. mrok/agent/devtools/inspector/app.py +2 -2
  2. mrok/agent/sidecar/app.py +61 -35
  3. mrok/agent/sidecar/main.py +5 -5
  4. mrok/agent/ziticorn.py +2 -2
  5. mrok/cli/commands/__init__.py +2 -2
  6. mrok/cli/commands/{proxy → frontend}/__init__.py +1 -1
  7. mrok/cli/commands/{proxy → frontend}/run.py +4 -4
  8. mrok/constants.py +0 -2
  9. mrok/controller/openapi/examples.py +13 -0
  10. mrok/frontend/__init__.py +3 -0
  11. mrok/frontend/app.py +75 -0
  12. mrok/{proxy → frontend}/main.py +4 -10
  13. mrok/proxy/__init__.py +0 -3
  14. mrok/proxy/app.py +157 -83
  15. mrok/proxy/backend.py +43 -0
  16. mrok/{http → proxy}/config.py +3 -3
  17. mrok/{datastructures.py → proxy/datastructures.py} +43 -10
  18. mrok/proxy/exceptions.py +22 -0
  19. mrok/proxy/lifespan.py +10 -0
  20. mrok/{master.py → proxy/master.py} +35 -38
  21. mrok/{metrics.py → proxy/metrics.py} +37 -49
  22. mrok/{http → proxy}/middlewares.py +47 -26
  23. mrok/proxy/streams.py +45 -0
  24. mrok/proxy/types.py +15 -0
  25. mrok/{http → proxy}/utils.py +1 -1
  26. {mrok-0.4.6.dist-info → mrok-0.5.0.dist-info}/METADATA +2 -5
  27. {mrok-0.4.6.dist-info → mrok-0.5.0.dist-info}/RECORD +33 -31
  28. mrok/http/__init__.py +0 -0
  29. mrok/http/forwarder.py +0 -354
  30. mrok/http/lifespan.py +0 -39
  31. mrok/http/pool.py +0 -239
  32. mrok/http/types.py +0 -18
  33. /mrok/{http → proxy}/constants.py +0 -0
  34. /mrok/{http → proxy}/protocol.py +0 -0
  35. /mrok/{http → proxy}/server.py +0 -0
  36. {mrok-0.4.6.dist-info → mrok-0.5.0.dist-info}/WHEEL +0 -0
  37. {mrok-0.4.6.dist-info → mrok-0.5.0.dist-info}/entry_points.txt +0 -0
  38. {mrok-0.4.6.dist-info → mrok-0.5.0.dist-info}/licenses/LICENSE.txt +0 -0
@@ -27,7 +27,7 @@ from textual.widgets.data_table import ColumnKey
27
27
  from textual.worker import get_current_worker
28
28
 
29
29
  from mrok import __version__
30
- from mrok.datastructures import Event, HTTPHeaders, HTTPResponse, Meta, WorkerMetrics
30
+ from mrok.proxy.datastructures import Event, HTTPHeaders, HTTPResponse, WorkerMetrics, ZitiMrokMeta
31
31
 
32
32
 
33
33
  def build_tree(node, data):
@@ -185,7 +185,7 @@ class InfoPanel(Static):
185
185
  # mem=int(mean([m.process.mem for m in self.workers_metrics.values()])),
186
186
  # )
187
187
 
188
- def update_meta(self, meta: Meta) -> None:
188
+ def update_meta(self, meta: ZitiMrokMeta) -> None:
189
189
  table = self.query_one(DataTable)
190
190
  if len(table.rows) == 0:
191
191
  table.add_row("URL", f"https://{meta.extension}.{meta.domain}")
mrok/agent/sidecar/app.py CHANGED
@@ -1,51 +1,77 @@
1
- import asyncio
2
1
  import logging
3
- from collections.abc import AsyncGenerator
4
- from contextlib import asynccontextmanager
5
2
  from pathlib import Path
3
+ from typing import Literal
6
4
 
7
- from mrok.http.forwarder import ForwardAppBase
8
- from mrok.http.pool import ConnectionPool
9
- from mrok.http.types import Scope, StreamPair
5
+ from httpcore import AsyncConnectionPool
6
+
7
+ from mrok.proxy.app import ProxyAppBase
8
+ from mrok.proxy.types import Scope
10
9
 
11
10
  logger = logging.getLogger("mrok.agent")
12
11
 
13
12
 
14
- class ForwardApp(ForwardAppBase):
13
+ TargetType = Literal["tcp", "unix"]
14
+
15
+
16
+ class SidecarProxyApp(ProxyAppBase):
15
17
  def __init__(
16
18
  self,
17
- target_address: str | Path | tuple[str, int],
18
- read_chunk_size: int = 65536,
19
- ) -> None:
19
+ target: str | Path | tuple[str, int],
20
+ *,
21
+ max_connections=1000,
22
+ max_keepalive_connections=10,
23
+ keepalive_expiry=120,
24
+ retries=0,
25
+ ):
26
+ self._target = target
27
+ self._target_type, self._target_address = self._parse_target()
20
28
  super().__init__(
21
- read_chunk_size=read_chunk_size,
29
+ max_connections=max_connections,
30
+ max_keepalive_connections=max_keepalive_connections,
31
+ keepalive_expiry=keepalive_expiry,
32
+ retries=retries,
22
33
  )
23
- self._target_address = target_address
24
- self._pool = ConnectionPool(
25
- pool_name=str(self._target_address),
26
- factory=self.connect,
27
- initial_connections=5,
28
- max_size=100,
29
- idle_timeout=20.0,
30
- reaper_interval=5.0,
34
+
35
+ def setup_connection_pool(
36
+ self,
37
+ max_connections: int | None = 1000,
38
+ max_keepalive_connections: int | None = 10,
39
+ keepalive_expiry: float | None = 120.0,
40
+ retries: int = 0,
41
+ ) -> AsyncConnectionPool:
42
+ if self._target_type == "unix":
43
+ return AsyncConnectionPool(
44
+ max_connections=max_connections,
45
+ max_keepalive_connections=max_keepalive_connections,
46
+ keepalive_expiry=keepalive_expiry,
47
+ retries=retries,
48
+ uds=self._target_address,
49
+ )
50
+ return AsyncConnectionPool(
51
+ max_connections=max_connections,
52
+ max_keepalive_connections=max_keepalive_connections,
53
+ keepalive_expiry=keepalive_expiry,
54
+ retries=retries,
31
55
  )
32
56
 
33
- async def connect(self) -> StreamPair:
34
- if isinstance(self._target_address, tuple):
35
- return await asyncio.open_connection(*self._target_address)
36
- return await asyncio.open_unix_connection(str(self._target_address))
57
+ def get_upstream_base_url(self, scope: Scope) -> str:
58
+ if self._target_type == "unix":
59
+ return "http://localhost"
60
+ return f"http://{self._target_address}"
37
61
 
38
- async def startup(self):
39
- await self._pool.start()
62
+ def _parse_target(self) -> tuple[TargetType, str]:
63
+ if isinstance(self._target, Path) or (
64
+ isinstance(self._target, str) and ":" not in self._target
65
+ ):
66
+ return "unix", str(self._target)
40
67
 
41
- async def shutdown(self):
42
- await self._pool.stop()
68
+ if isinstance(self._target, str) and ":" in self._target:
69
+ host, port = str(self._target).split(":", 1)
70
+ host = host or "127.0.0.1"
71
+ elif isinstance(self._target, tuple) and len(self._target) == 2:
72
+ host = self._target[0]
73
+ port = str(self._target[1])
74
+ else:
75
+ raise Exception(f"Invalid target address: {self._target}")
43
76
 
44
- @asynccontextmanager
45
- async def select_backend(
46
- self,
47
- scope: Scope,
48
- headers: dict[str, str],
49
- ) -> AsyncGenerator[StreamPair, None]:
50
- async with self._pool.acquire() as (reader, writer):
51
- yield reader, writer
77
+ return "tcp", f"{host}:{port}"
@@ -1,8 +1,8 @@
1
1
  import logging
2
2
  from pathlib import Path
3
3
 
4
- from mrok.agent.sidecar.app import ForwardApp
5
- from mrok.master import MasterBase
4
+ from mrok.agent.sidecar.app import SidecarProxyApp
5
+ from mrok.proxy.master import MasterBase
6
6
 
7
7
  logger = logging.getLogger("mrok.proxy")
8
8
 
@@ -11,7 +11,7 @@ class SidecarAgent(MasterBase):
11
11
  def __init__(
12
12
  self,
13
13
  identity_file: str,
14
- target_addr: str | Path | tuple[str, int],
14
+ target: str | Path | tuple[str, int],
15
15
  workers: int = 4,
16
16
  publishers_port: int = 50000,
17
17
  subscribers_port: int = 50001,
@@ -23,10 +23,10 @@ class SidecarAgent(MasterBase):
23
23
  publishers_port,
24
24
  subscribers_port,
25
25
  )
26
- self.target_address = target_addr
26
+ self._target = target
27
27
 
28
28
  def get_asgi_app(self):
29
- return ForwardApp(self.target_address)
29
+ return SidecarProxyApp(self._target)
30
30
 
31
31
 
32
32
  def run(
mrok/agent/ziticorn.py CHANGED
@@ -1,5 +1,5 @@
1
- from mrok.http.types import ASGIApp
2
- from mrok.master import MasterBase
1
+ from mrok.proxy.master import MasterBase
2
+ from mrok.proxy.types import ASGIApp
3
3
 
4
4
 
5
5
  class ZiticornAgent(MasterBase):
@@ -1,8 +1,8 @@
1
- from mrok.cli.commands import admin, agent, controller, proxy
1
+ from mrok.cli.commands import admin, agent, controller, frontend
2
2
 
3
3
  __all__ = [
4
4
  "admin",
5
5
  "agent",
6
6
  "controller",
7
- "proxy",
7
+ "frontend",
8
8
  ]
@@ -1,6 +1,6 @@
1
1
  import typer
2
2
 
3
- from mrok.cli.commands.proxy import run
3
+ from mrok.cli.commands.frontend import run
4
4
 
5
5
  app = typer.Typer(help="mrok proxy commands.")
6
6
  run.register(app)
@@ -3,7 +3,7 @@ from typing import Annotated
3
3
 
4
4
  import typer
5
5
 
6
- from mrok import proxy
6
+ from mrok import frontend
7
7
  from mrok.cli.utils import number_of_workers
8
8
 
9
9
  default_workers = number_of_workers()
@@ -11,7 +11,7 @@ default_workers = number_of_workers()
11
11
 
12
12
  def register(app: typer.Typer) -> None:
13
13
  @app.command("run")
14
- def run_proxy(
14
+ def run_frontend(
15
15
  ctx: typer.Context,
16
16
  identity_file: Path = typer.Argument(
17
17
  ...,
@@ -45,5 +45,5 @@ def register(app: typer.Typer) -> None:
45
45
  ),
46
46
  ] = default_workers,
47
47
  ):
48
- """Run the mrok proxy with Gunicorn and Uvicorn workers."""
49
- proxy.run(identity_file, host, port, workers)
48
+ """Run the mrok frontend with Gunicorn and Uvicorn workers."""
49
+ frontend.run(identity_file, host, port, workers)
mrok/constants.py CHANGED
@@ -2,5 +2,3 @@ import re
2
2
 
3
3
  RE_EXTENSION_ID = re.compile(r"(?i)EXT-\d{4}-\d{4}")
4
4
  RE_INSTANCE_ID = re.compile(r"(?i)INS-\d{4}-\d{4}-\d{4}")
5
-
6
- RE_SUBDOMAIN = re.compile(r"(?i)^(?:EXT-\d{4}-\d{4}|INS-\d{4}-\d{4}-\d{4})$")
@@ -13,6 +13,7 @@ INSTANCE_RESPONSE = {
13
13
  "name": "ins-1234-5678-0001.ext-1234-5678",
14
14
  "extension": {"id": "EXT-1234-5678"},
15
15
  "instance": {"id": "INS-1234-5678-0001"},
16
+ "status": "offline",
16
17
  "tags": {
17
18
  "account": "ACC-5555-3333",
18
19
  MROK_VERSION_TAG_NAME: "1.0",
@@ -25,6 +26,7 @@ INSTANCE_CREATE_RESPONSE = {
25
26
  "name": "ins-1234-5678-0001.ext-1234-5678",
26
27
  "extension": {"id": "EXT-1234-5678"},
27
28
  "instance": {"id": "INS-1234-5678-0001"},
29
+ "status": "online",
28
30
  "identity": {
29
31
  "ztAPI": "https://ziti.exts.platform.softwareone.com/edge/client/v1",
30
32
  "ztAPIs": None,
@@ -35,6 +37,17 @@ INSTANCE_CREATE_RESPONSE = {
35
37
  "ca": "pem:-----BEGIN CERTIFICATE-----\n...\n-----END CERTIFICATE-----\n",
36
38
  },
37
39
  "enableHa": None,
40
+ "mrok": {
41
+ "identity": "ins-0000-0000-0000.ext-0000-0000",
42
+ "extension": "ext-0000-0000",
43
+ "instance": "ins-0000-0000-0000",
44
+ "domain": "ext.s1.today",
45
+ "tags": {
46
+ "mrok-service": "ext-0000-0000",
47
+ "mrok-identity-type": "instance",
48
+ "mrok": "0.4.0",
49
+ },
50
+ },
38
51
  },
39
52
  "tags": {
40
53
  "account": "ACC-5555-3333",
@@ -0,0 +1,3 @@
1
+ from mrok.frontend.main import run
2
+
3
+ __all__ = ["run"]
mrok/frontend/app.py ADDED
@@ -0,0 +1,75 @@
1
+ import re
2
+
3
+ from httpcore import AsyncConnectionPool
4
+
5
+ from mrok.conf import get_settings
6
+ from mrok.proxy.app import ProxyAppBase
7
+ from mrok.proxy.backend import AIOZitiNetworkBackend
8
+ from mrok.proxy.exceptions import InvalidTargetError
9
+ from mrok.proxy.types import Scope
10
+
11
+ RE_SUBDOMAIN = re.compile(r"(?i)^(?:EXT-\d{4}-\d{4}|INS-\d{4}-\d{4}-\d{4})$")
12
+
13
+
14
+ class FrontendProxyApp(ProxyAppBase):
15
+ def __init__(
16
+ self,
17
+ identity_file: str,
18
+ *,
19
+ max_connections: int = 1000,
20
+ max_keepalive_connections: int = 10,
21
+ keepalive_expiry: float = 120.0,
22
+ retries=0,
23
+ ):
24
+ self._identity_file = identity_file
25
+ self._proxy_domain = self._get_proxy_domain()
26
+ super().__init__(
27
+ max_connections=max_connections,
28
+ max_keepalive_connections=max_keepalive_connections,
29
+ keepalive_expiry=keepalive_expiry,
30
+ retries=retries,
31
+ )
32
+
33
+ def setup_connection_pool(
34
+ self,
35
+ max_connections: int | None = 1000,
36
+ max_keepalive_connections: int | None = 100,
37
+ keepalive_expiry: float | None = 120.0,
38
+ retries: int = 0,
39
+ ) -> AsyncConnectionPool:
40
+ return AsyncConnectionPool(
41
+ max_connections=max_connections,
42
+ max_keepalive_connections=max_keepalive_connections,
43
+ keepalive_expiry=keepalive_expiry,
44
+ retries=retries,
45
+ network_backend=AIOZitiNetworkBackend(self._identity_file),
46
+ )
47
+
48
+ def get_upstream_base_url(self, scope: Scope) -> str:
49
+ target = self._get_target_name(
50
+ {k.decode("latin1"): v.decode("latin1") for k, v in scope.get("headers", {})}
51
+ )
52
+ return f"http://{target.lower()}"
53
+
54
+ def _get_proxy_domain(self):
55
+ settings = get_settings()
56
+ return (
57
+ settings.proxy.domain
58
+ if settings.proxy.domain[0] == "."
59
+ else f".{settings.proxy.domain}"
60
+ )
61
+
62
+ def _get_target_from_header(self, headers: dict[str, str], name: str) -> str | None:
63
+ header_value = headers.get(name, "")
64
+ if self._proxy_domain in header_value:
65
+ if ":" in header_value:
66
+ header_value, _ = header_value.split(":", 1)
67
+ return header_value[: -len(self._proxy_domain)]
68
+
69
+ def _get_target_name(self, headers: dict[str, str]) -> str:
70
+ target = self._get_target_from_header(headers, "x-forwarded-host")
71
+ if not target:
72
+ target = self._get_target_from_header(headers, "host")
73
+ if not target or not RE_SUBDOMAIN.fullmatch(target):
74
+ raise InvalidTargetError()
75
+ return target
@@ -6,9 +6,8 @@ from gunicorn.app.base import BaseApplication
6
6
  from uvicorn_worker import UvicornWorker
7
7
 
8
8
  from mrok.conf import get_settings
9
- from mrok.http.lifespan import LifespanWrapper
9
+ from mrok.frontend.app import FrontendProxyApp
10
10
  from mrok.logging import get_logging_config
11
- from mrok.proxy.app import ProxyApp
12
11
 
13
12
 
14
13
  class MrokUvicornWorker(UvicornWorker):
@@ -40,19 +39,14 @@ def run(
40
39
  port: int,
41
40
  workers: int,
42
41
  ):
43
- proxy_app = ProxyApp(identity_file)
42
+ app = FrontendProxyApp(str(identity_file))
44
43
 
45
- asgi_app = LifespanWrapper(
46
- proxy_app,
47
- proxy_app.startup,
48
- proxy_app.shutdown,
49
- )
50
44
  options = {
51
45
  "bind": f"{host}:{port}",
52
46
  "workers": workers,
53
- "worker_class": "mrok.proxy.main.MrokUvicornWorker",
47
+ "worker_class": "mrok.frontend.main.MrokUvicornWorker",
54
48
  "logconfig_dict": get_logging_config(get_settings()),
55
49
  "reload": False,
56
50
  }
57
51
 
58
- StandaloneApplication(asgi_app, options).run()
52
+ StandaloneApplication(app, options).run()
mrok/proxy/__init__.py CHANGED
@@ -1,3 +0,0 @@
1
- from mrok.proxy.main import run
2
-
3
- __all__ = ["run"]
mrok/proxy/app.py CHANGED
@@ -1,101 +1,175 @@
1
- import asyncio
1
+ import abc
2
2
  import logging
3
- from collections.abc import AsyncGenerator
4
- from contextlib import asynccontextmanager
5
- from pathlib import Path
6
3
 
7
- import openziti
8
- from openziti.context import ZitiContext
4
+ from httpcore import AsyncConnectionPool, Request
9
5
 
10
- from mrok.conf import get_settings
11
- from mrok.constants import RE_SUBDOMAIN
12
- from mrok.http.forwarder import BackendUnavailableError, ForwardAppBase, InvalidBackendError
13
- from mrok.http.pool import ConnectionPool, PoolManager
14
- from mrok.http.types import Scope, StreamPair
15
- from mrok.logging import setup_logging
6
+ from mrok.proxy.exceptions import ProxyError
7
+ from mrok.proxy.streams import ASGIRequestBodyStream
8
+ from mrok.proxy.types import ASGIReceive, ASGISend, Scope
16
9
 
17
10
  logger = logging.getLogger("mrok.proxy")
18
11
 
19
12
 
20
- class ProxyError(Exception):
21
- pass
13
+ HOP_BY_HOP_HEADERS = [
14
+ b"connection",
15
+ b"keep-alive",
16
+ b"proxy-authenticate",
17
+ b"proxy-authorization",
18
+ b"te",
19
+ b"trailers",
20
+ b"transfer-encoding",
21
+ b"upgrade",
22
+ ]
22
23
 
23
24
 
24
- class ProxyApp(ForwardAppBase):
25
+ class ProxyAppBase(abc.ABC):
25
26
  def __init__(
26
27
  self,
27
- identity_file: str | Path,
28
28
  *,
29
- read_chunk_size: int = 65536,
29
+ max_connections: int | None = 1000,
30
+ max_keepalive_connections: int | None = 10,
31
+ keepalive_expiry: float | None = 120.0,
32
+ retries: int = 0,
30
33
  ) -> None:
31
- super().__init__(read_chunk_size=read_chunk_size)
32
- self._identity_file = identity_file
33
- settings = get_settings()
34
- self._proxy_wildcard_domain = (
35
- settings.proxy.domain
36
- if settings.proxy.domain[0] == "."
37
- else f".{settings.proxy.domain}"
38
- )
39
- self._ziti_ctx: ZitiContext | None = None
40
- self._pool_manager = PoolManager(self.build_connection_pool)
41
-
42
- def get_target_from_header(self, headers: dict[str, str], name: str) -> str | None:
43
- header_value = headers.get(name, "")
44
- if self._proxy_wildcard_domain in header_value:
45
- if ":" in header_value:
46
- header_value, _ = header_value.split(":", 1)
47
- return header_value[: -len(self._proxy_wildcard_domain)]
48
-
49
- def get_target_name(self, headers: dict[str, str]) -> str:
50
- target = self.get_target_from_header(headers, "x-forwarded-host")
51
- if not target:
52
- target = self.get_target_from_header(headers, "host")
53
- if not target:
54
- raise ProxyError("Neither Host nor X-Forwarded-Host contain a valid target name")
55
- return target
56
-
57
- def _get_ziti_ctx(self) -> ZitiContext:
58
- if self._ziti_ctx is None:
59
- ctx, err = openziti.load(str(self._identity_file), timeout=10_000)
60
- if err != 0:
61
- raise Exception(f"Cannot create a Ziti context from the identity file: {err}")
62
- self._ziti_ctx = ctx
63
- return self._ziti_ctx
64
-
65
- async def startup(self):
66
- setup_logging(get_settings())
67
- self._get_ziti_ctx()
68
-
69
- async def shutdown(self):
70
- await self._pool_manager.shutdown()
71
-
72
- async def build_connection_pool(self, key: str) -> ConnectionPool:
73
- async def connect():
74
- sock = self._get_ziti_ctx().connect(key)
75
- reader, writer = await asyncio.open_connection(sock=sock)
76
- return reader, writer
77
-
78
- return ConnectionPool(
79
- pool_name=key,
80
- factory=connect,
81
- initial_connections=5,
82
- max_size=100,
83
- idle_timeout=20.0,
84
- reaper_interval=5.0,
34
+ self._pool = self.setup_connection_pool(
35
+ max_connections=max_connections,
36
+ max_keepalive_connections=max_keepalive_connections,
37
+ keepalive_expiry=keepalive_expiry,
38
+ retries=retries,
85
39
  )
86
40
 
87
- @asynccontextmanager
88
- async def select_backend(
41
+ @abc.abstractmethod
42
+ def setup_connection_pool(
89
43
  self,
90
- scope: Scope,
91
- headers: dict[str, str],
92
- ) -> AsyncGenerator[StreamPair]:
93
- target_name = self.get_target_name(headers)
94
- if not target_name or not RE_SUBDOMAIN.fullmatch(target_name):
95
- raise InvalidBackendError()
96
- pool = await self._pool_manager.get_pool(target_name)
44
+ max_connections: int | None = 1000,
45
+ max_keepalive_connections: int | None = 10,
46
+ keepalive_expiry: float | None = 120.0,
47
+ retries: int = 0,
48
+ ) -> AsyncConnectionPool:
49
+ raise NotImplementedError()
50
+
51
+ @abc.abstractmethod
52
+ def get_upstream_base_url(self, scope: Scope) -> str:
53
+ raise NotImplementedError()
54
+
55
+ async def __call__(self, scope: Scope, receive: ASGIReceive, send: ASGISend) -> None:
56
+ if scope.get("type") == "lifespan":
57
+ return
58
+
59
+ if scope.get("type") != "http":
60
+ await self._send_error(send, 500, "Unsupported")
61
+ return
62
+
97
63
  try:
98
- async with pool.acquire() as (reader, writer):
99
- yield reader, writer
64
+ base_url = self.get_upstream_base_url(scope)
65
+ if base_url.endswith("/"): # pragma: no cover
66
+ base_url = base_url[:-1]
67
+ full_path = self._format_path(scope)
68
+ url = f"{base_url}{full_path}"
69
+ method = scope.get("method", "GET").encode()
70
+ headers = self._prepare_headers(scope)
71
+
72
+ body_stream = ASGIRequestBodyStream(receive)
73
+
74
+ request = Request(
75
+ method=method,
76
+ url=url,
77
+ headers=headers,
78
+ content=body_stream,
79
+ )
80
+ response = await self._pool.handle_async_request(request)
81
+ response_headers = []
82
+ for k, v in response.headers:
83
+ if k.lower() not in HOP_BY_HOP_HEADERS:
84
+ response_headers.append((k, v))
85
+
86
+ await send(
87
+ {
88
+ "type": "http.response.start",
89
+ "status": response.status,
90
+ "headers": response_headers,
91
+ }
92
+ )
93
+
94
+ async for chunk in response.stream: # type: ignore[union-attr]
95
+ await send(
96
+ {
97
+ "type": "http.response.body",
98
+ "body": chunk,
99
+ "more_body": True,
100
+ }
101
+ )
102
+
103
+ await send({"type": "http.response.body", "body": b"", "more_body": False})
104
+ await response.aclose()
105
+
106
+ except ProxyError as pe:
107
+ await self._send_error(send, pe.http_status, pe.message)
108
+
100
109
  except Exception:
101
- raise BackendUnavailableError()
110
+ logger.exception("Unexpected error in forwarder")
111
+ await self._send_error(send, 502, "Bad Gateway")
112
+
113
+ async def _send_error(self, send: ASGISend, http_status: int, body: str):
114
+ try:
115
+ await send({"type": "http.response.start", "status": http_status, "headers": []})
116
+ await send({"type": "http.response.body", "body": body.encode()})
117
+ except Exception as e: # pragma: no cover
118
+ logger.error(f"Cannot send error response: {e}")
119
+
120
+ def _prepare_headers(self, scope: Scope) -> list[tuple[bytes, bytes]]:
121
+ headers: list[tuple[bytes, bytes]] = []
122
+ scope_headers = scope.get("headers", [])
123
+
124
+ for k, v in scope_headers:
125
+ if k.lower() not in HOP_BY_HOP_HEADERS:
126
+ headers.append((k, v))
127
+
128
+ self._merge_x_forwarded(headers, scope)
129
+
130
+ return headers
131
+
132
+ def _find_header(self, headers: list[tuple[bytes, bytes]], name: bytes) -> int | None:
133
+ """Return index of header `name` in `headers`, or None if missing."""
134
+ lname = name.lower()
135
+ for i, (k, _) in enumerate(headers):
136
+ if k.lower() == lname:
137
+ return i
138
+ return None
139
+
140
+ def _merge_x_forwarded(self, headers: list[tuple[bytes, bytes]], scope: Scope) -> None:
141
+ client = scope.get("client")
142
+ if client:
143
+ client_ip = client[0].encode()
144
+ idx = self._find_header(headers, b"x-forwarded-for")
145
+ if idx is None:
146
+ headers.append((b"x-forwarded-for", client_ip))
147
+ else:
148
+ k, v = headers[idx]
149
+ headers[idx] = (k, v + b", " + client_ip)
150
+
151
+ server = scope.get("server")
152
+ if server:
153
+ if self._find_header(headers, b"x-forwarded-host") is None:
154
+ headers.append((b"x-forwarded-host", server[0].encode()))
155
+ if server[1] and self._find_header(headers, b"x-forwarded-port") is None:
156
+ headers.append((b"x-forwarded-port", str(server[1]).encode()))
157
+
158
+ # Always set the protocol to https for upstream
159
+ idx_proto = self._find_header(headers, b"x-forwarded-proto")
160
+ if idx_proto is None:
161
+ headers.append((b"x-forwarded-proto", b"https"))
162
+ else:
163
+ k, _ = headers[idx_proto]
164
+ headers[idx_proto] = (k, b"https")
165
+
166
+ def _format_path(self, scope: Scope) -> str:
167
+ raw_path = scope.get("raw_path")
168
+ if raw_path:
169
+ return raw_path.decode()
170
+ q = scope.get("query_string", b"")
171
+ path = scope.get("path", "/")
172
+ path_qs = path
173
+ if q:
174
+ path_qs += "?" + q.decode()
175
+ return path_qs