mrok 0.4.5__py3-none-any.whl → 0.5.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (39) hide show
  1. mrok/agent/devtools/inspector/app.py +2 -2
  2. mrok/agent/sidecar/app.py +64 -16
  3. mrok/agent/sidecar/main.py +5 -5
  4. mrok/agent/ziticorn.py +2 -2
  5. mrok/cli/commands/__init__.py +2 -2
  6. mrok/cli/commands/admin/register/extensions.py +2 -4
  7. mrok/cli/commands/admin/register/instances.py +11 -5
  8. mrok/cli/commands/{proxy → frontend}/__init__.py +1 -1
  9. mrok/cli/commands/{proxy → frontend}/run.py +4 -4
  10. mrok/constants.py +4 -0
  11. mrok/controller/openapi/examples.py +13 -0
  12. mrok/frontend/__init__.py +3 -0
  13. mrok/frontend/app.py +75 -0
  14. mrok/{proxy → frontend}/main.py +4 -10
  15. mrok/proxy/__init__.py +0 -3
  16. mrok/proxy/app.py +160 -57
  17. mrok/proxy/backend.py +43 -0
  18. mrok/{http → proxy}/config.py +3 -3
  19. mrok/{datastructures.py → proxy/datastructures.py} +43 -10
  20. mrok/proxy/exceptions.py +22 -0
  21. mrok/proxy/lifespan.py +10 -0
  22. mrok/{master.py → proxy/master.py} +35 -38
  23. mrok/{metrics.py → proxy/metrics.py} +37 -49
  24. mrok/{http → proxy}/middlewares.py +47 -26
  25. mrok/proxy/streams.py +45 -0
  26. mrok/proxy/types.py +15 -0
  27. mrok/{http → proxy}/utils.py +1 -1
  28. {mrok-0.4.5.dist-info → mrok-0.5.0.dist-info}/METADATA +2 -5
  29. {mrok-0.4.5.dist-info → mrok-0.5.0.dist-info}/RECORD +35 -31
  30. mrok/http/__init__.py +0 -0
  31. mrok/http/forwarder.py +0 -338
  32. mrok/http/lifespan.py +0 -39
  33. mrok/http/types.py +0 -43
  34. /mrok/{http → proxy}/constants.py +0 -0
  35. /mrok/{http → proxy}/protocol.py +0 -0
  36. /mrok/{http → proxy}/server.py +0 -0
  37. {mrok-0.4.5.dist-info → mrok-0.5.0.dist-info}/WHEEL +0 -0
  38. {mrok-0.4.5.dist-info → mrok-0.5.0.dist-info}/entry_points.txt +0 -0
  39. {mrok-0.4.5.dist-info → mrok-0.5.0.dist-info}/licenses/LICENSE.txt +0 -0
@@ -27,7 +27,7 @@ from textual.widgets.data_table import ColumnKey
27
27
  from textual.worker import get_current_worker
28
28
 
29
29
  from mrok import __version__
30
- from mrok.datastructures import Event, HTTPHeaders, HTTPResponse, Meta, WorkerMetrics
30
+ from mrok.proxy.datastructures import Event, HTTPHeaders, HTTPResponse, WorkerMetrics, ZitiMrokMeta
31
31
 
32
32
 
33
33
  def build_tree(node, data):
@@ -185,7 +185,7 @@ class InfoPanel(Static):
185
185
  # mem=int(mean([m.process.mem for m in self.workers_metrics.values()])),
186
186
  # )
187
187
 
188
- def update_meta(self, meta: Meta) -> None:
188
+ def update_meta(self, meta: ZitiMrokMeta) -> None:
189
189
  table = self.query_one(DataTable)
190
190
  if len(table.rows) == 0:
191
191
  table.add_row("URL", f"https://{meta.extension}.{meta.domain}")
mrok/agent/sidecar/app.py CHANGED
@@ -1,29 +1,77 @@
1
- import asyncio
2
1
  import logging
3
2
  from pathlib import Path
3
+ from typing import Literal
4
4
 
5
- from mrok.http.forwarder import ForwardAppBase
6
- from mrok.http.types import Scope, StreamReader, StreamWriter
5
+ from httpcore import AsyncConnectionPool
6
+
7
+ from mrok.proxy.app import ProxyAppBase
8
+ from mrok.proxy.types import Scope
7
9
 
8
10
  logger = logging.getLogger("mrok.agent")
9
11
 
10
12
 
11
- class ForwardApp(ForwardAppBase):
13
+ TargetType = Literal["tcp", "unix"]
14
+
15
+
16
+ class SidecarProxyApp(ProxyAppBase):
12
17
  def __init__(
13
18
  self,
14
- target_address: str | Path | tuple[str, int],
15
- read_chunk_size: int = 65536,
16
- ) -> None:
19
+ target: str | Path | tuple[str, int],
20
+ *,
21
+ max_connections=1000,
22
+ max_keepalive_connections=10,
23
+ keepalive_expiry=120,
24
+ retries=0,
25
+ ):
26
+ self._target = target
27
+ self._target_type, self._target_address = self._parse_target()
17
28
  super().__init__(
18
- read_chunk_size=read_chunk_size,
29
+ max_connections=max_connections,
30
+ max_keepalive_connections=max_keepalive_connections,
31
+ keepalive_expiry=keepalive_expiry,
32
+ retries=retries,
19
33
  )
20
- self._target_address = target_address
21
34
 
22
- async def select_backend(
35
+ def setup_connection_pool(
23
36
  self,
24
- scope: Scope,
25
- headers: dict[str, str],
26
- ) -> tuple[StreamReader, StreamWriter] | tuple[None, None]:
27
- if isinstance(self._target_address, tuple):
28
- return await asyncio.open_connection(*self._target_address)
29
- return await asyncio.open_unix_connection(str(self._target_address))
37
+ max_connections: int | None = 1000,
38
+ max_keepalive_connections: int | None = 10,
39
+ keepalive_expiry: float | None = 120.0,
40
+ retries: int = 0,
41
+ ) -> AsyncConnectionPool:
42
+ if self._target_type == "unix":
43
+ return AsyncConnectionPool(
44
+ max_connections=max_connections,
45
+ max_keepalive_connections=max_keepalive_connections,
46
+ keepalive_expiry=keepalive_expiry,
47
+ retries=retries,
48
+ uds=self._target_address,
49
+ )
50
+ return AsyncConnectionPool(
51
+ max_connections=max_connections,
52
+ max_keepalive_connections=max_keepalive_connections,
53
+ keepalive_expiry=keepalive_expiry,
54
+ retries=retries,
55
+ )
56
+
57
+ def get_upstream_base_url(self, scope: Scope) -> str:
58
+ if self._target_type == "unix":
59
+ return "http://localhost"
60
+ return f"http://{self._target_address}"
61
+
62
+ def _parse_target(self) -> tuple[TargetType, str]:
63
+ if isinstance(self._target, Path) or (
64
+ isinstance(self._target, str) and ":" not in self._target
65
+ ):
66
+ return "unix", str(self._target)
67
+
68
+ if isinstance(self._target, str) and ":" in self._target:
69
+ host, port = str(self._target).split(":", 1)
70
+ host = host or "127.0.0.1"
71
+ elif isinstance(self._target, tuple) and len(self._target) == 2:
72
+ host = self._target[0]
73
+ port = str(self._target[1])
74
+ else:
75
+ raise Exception(f"Invalid target address: {self._target}")
76
+
77
+ return "tcp", f"{host}:{port}"
@@ -1,8 +1,8 @@
1
1
  import logging
2
2
  from pathlib import Path
3
3
 
4
- from mrok.agent.sidecar.app import ForwardApp
5
- from mrok.master import MasterBase
4
+ from mrok.agent.sidecar.app import SidecarProxyApp
5
+ from mrok.proxy.master import MasterBase
6
6
 
7
7
  logger = logging.getLogger("mrok.proxy")
8
8
 
@@ -11,7 +11,7 @@ class SidecarAgent(MasterBase):
11
11
  def __init__(
12
12
  self,
13
13
  identity_file: str,
14
- target_addr: str | Path | tuple[str, int],
14
+ target: str | Path | tuple[str, int],
15
15
  workers: int = 4,
16
16
  publishers_port: int = 50000,
17
17
  subscribers_port: int = 50001,
@@ -23,10 +23,10 @@ class SidecarAgent(MasterBase):
23
23
  publishers_port,
24
24
  subscribers_port,
25
25
  )
26
- self.target_address = target_addr
26
+ self._target = target
27
27
 
28
28
  def get_asgi_app(self):
29
- return ForwardApp(self.target_address)
29
+ return SidecarProxyApp(self._target)
30
30
 
31
31
 
32
32
  def run(
mrok/agent/ziticorn.py CHANGED
@@ -1,5 +1,5 @@
1
- from mrok.http.types import ASGIApp
2
- from mrok.master import MasterBase
1
+ from mrok.proxy.master import MasterBase
2
+ from mrok.proxy.types import ASGIApp
3
3
 
4
4
 
5
5
  class ZiticornAgent(MasterBase):
@@ -1,8 +1,8 @@
1
- from mrok.cli.commands import admin, agent, controller, proxy
1
+ from mrok.cli.commands import admin, agent, controller, frontend
2
2
 
3
3
  __all__ = [
4
4
  "admin",
5
5
  "agent",
6
6
  "controller",
7
- "proxy",
7
+ "frontend",
8
8
  ]
@@ -1,5 +1,4 @@
1
1
  import asyncio
2
- import re
3
2
  from typing import Annotated
4
3
 
5
4
  import typer
@@ -7,11 +6,10 @@ from rich import print
7
6
 
8
7
  from mrok.cli.commands.admin.utils import parse_tags
9
8
  from mrok.conf import Settings
9
+ from mrok.constants import RE_EXTENSION_ID
10
10
  from mrok.ziti.api import ZitiManagementAPI
11
11
  from mrok.ziti.services import register_service
12
12
 
13
- RE_EXTENSION_ID = re.compile(r"(?i)EXT-\d{4}-\d{4}")
14
-
15
13
 
16
14
  async def do_register(settings: Settings, extension_id: str, tags: list[str] | None):
17
15
  async with ZitiManagementAPI(settings) as api:
@@ -20,7 +18,7 @@ async def do_register(settings: Settings, extension_id: str, tags: list[str] | N
20
18
 
21
19
  def validate_extension_id(extension_id: str) -> str:
22
20
  if not RE_EXTENSION_ID.fullmatch(extension_id):
23
- raise typer.BadParameter("ext_id must match EXT-xxxx-yyyy (case-insensitive)")
21
+ raise typer.BadParameter("it must match EXT-xxxx-yyyy (case-insensitive)")
24
22
  return extension_id
25
23
 
26
24
 
@@ -1,6 +1,5 @@
1
1
  import asyncio
2
2
  import json
3
- import re
4
3
  from pathlib import Path
5
4
  from typing import Annotated
6
5
 
@@ -8,11 +7,10 @@ import typer
8
7
 
9
8
  from mrok.cli.commands.admin.utils import parse_tags
10
9
  from mrok.conf import Settings
10
+ from mrok.constants import RE_EXTENSION_ID, RE_INSTANCE_ID
11
11
  from mrok.ziti.api import ZitiClientAPI, ZitiManagementAPI
12
12
  from mrok.ziti.identities import register_identity
13
13
 
14
- RE_EXTENSION_ID = re.compile(r"(?i)EXT-\d{4}-\d{4}")
15
-
16
14
 
17
15
  async def do_register(
18
16
  settings: Settings, extension_id: str, instance_id: str, tags: list[str] | None
@@ -25,10 +23,16 @@ async def do_register(
25
23
 
26
24
  def validate_extension_id(extension_id: str):
27
25
  if not RE_EXTENSION_ID.fullmatch(extension_id):
28
- raise typer.BadParameter("ext_id must match EXT-xxxx-yyyy (case-insensitive)")
26
+ raise typer.BadParameter("it must match EXT-xxxx-yyyy (case-insensitive)")
29
27
  return extension_id
30
28
 
31
29
 
30
+ def validate_instance_id(instance_id: str):
31
+ if not RE_INSTANCE_ID.fullmatch(instance_id):
32
+ raise typer.BadParameter("it must match INS-xxxx-yyyy-zzzz (case-insensitive)")
33
+ return instance_id
34
+
35
+
32
36
  def register(app: typer.Typer) -> None:
33
37
  @app.command("instance")
34
38
  def register_instance(
@@ -36,7 +40,9 @@ def register(app: typer.Typer) -> None:
36
40
  extension_id: str = typer.Argument(
37
41
  ..., callback=validate_extension_id, help="Extension ID in format EXT-xxxx-yyyy"
38
42
  ),
39
- instance_id: str = typer.Argument(..., help="Instance ID"),
43
+ instance_id: str = typer.Argument(
44
+ ..., callback=validate_instance_id, help="Instance ID in format INS-xxxx-yyyy-zzzz"
45
+ ),
40
46
  output: Path = typer.Argument(
41
47
  ...,
42
48
  file_okay=True,
@@ -1,6 +1,6 @@
1
1
  import typer
2
2
 
3
- from mrok.cli.commands.proxy import run
3
+ from mrok.cli.commands.frontend import run
4
4
 
5
5
  app = typer.Typer(help="mrok proxy commands.")
6
6
  run.register(app)
@@ -3,7 +3,7 @@ from typing import Annotated
3
3
 
4
4
  import typer
5
5
 
6
- from mrok import proxy
6
+ from mrok import frontend
7
7
  from mrok.cli.utils import number_of_workers
8
8
 
9
9
  default_workers = number_of_workers()
@@ -11,7 +11,7 @@ default_workers = number_of_workers()
11
11
 
12
12
  def register(app: typer.Typer) -> None:
13
13
  @app.command("run")
14
- def run_proxy(
14
+ def run_frontend(
15
15
  ctx: typer.Context,
16
16
  identity_file: Path = typer.Argument(
17
17
  ...,
@@ -45,5 +45,5 @@ def register(app: typer.Typer) -> None:
45
45
  ),
46
46
  ] = default_workers,
47
47
  ):
48
- """Run the mrok proxy with Gunicorn and Uvicorn workers."""
49
- proxy.run(identity_file, host, port, workers)
48
+ """Run the mrok frontend with Gunicorn and Uvicorn workers."""
49
+ frontend.run(identity_file, host, port, workers)
mrok/constants.py ADDED
@@ -0,0 +1,4 @@
1
+ import re
2
+
3
+ RE_EXTENSION_ID = re.compile(r"(?i)EXT-\d{4}-\d{4}")
4
+ RE_INSTANCE_ID = re.compile(r"(?i)INS-\d{4}-\d{4}-\d{4}")
@@ -13,6 +13,7 @@ INSTANCE_RESPONSE = {
13
13
  "name": "ins-1234-5678-0001.ext-1234-5678",
14
14
  "extension": {"id": "EXT-1234-5678"},
15
15
  "instance": {"id": "INS-1234-5678-0001"},
16
+ "status": "offline",
16
17
  "tags": {
17
18
  "account": "ACC-5555-3333",
18
19
  MROK_VERSION_TAG_NAME: "1.0",
@@ -25,6 +26,7 @@ INSTANCE_CREATE_RESPONSE = {
25
26
  "name": "ins-1234-5678-0001.ext-1234-5678",
26
27
  "extension": {"id": "EXT-1234-5678"},
27
28
  "instance": {"id": "INS-1234-5678-0001"},
29
+ "status": "online",
28
30
  "identity": {
29
31
  "ztAPI": "https://ziti.exts.platform.softwareone.com/edge/client/v1",
30
32
  "ztAPIs": None,
@@ -35,6 +37,17 @@ INSTANCE_CREATE_RESPONSE = {
35
37
  "ca": "pem:-----BEGIN CERTIFICATE-----\n...\n-----END CERTIFICATE-----\n",
36
38
  },
37
39
  "enableHa": None,
40
+ "mrok": {
41
+ "identity": "ins-0000-0000-0000.ext-0000-0000",
42
+ "extension": "ext-0000-0000",
43
+ "instance": "ins-0000-0000-0000",
44
+ "domain": "ext.s1.today",
45
+ "tags": {
46
+ "mrok-service": "ext-0000-0000",
47
+ "mrok-identity-type": "instance",
48
+ "mrok": "0.4.0",
49
+ },
50
+ },
38
51
  },
39
52
  "tags": {
40
53
  "account": "ACC-5555-3333",
@@ -0,0 +1,3 @@
1
+ from mrok.frontend.main import run
2
+
3
+ __all__ = ["run"]
mrok/frontend/app.py ADDED
@@ -0,0 +1,75 @@
1
+ import re
2
+
3
+ from httpcore import AsyncConnectionPool
4
+
5
+ from mrok.conf import get_settings
6
+ from mrok.proxy.app import ProxyAppBase
7
+ from mrok.proxy.backend import AIOZitiNetworkBackend
8
+ from mrok.proxy.exceptions import InvalidTargetError
9
+ from mrok.proxy.types import Scope
10
+
11
+ RE_SUBDOMAIN = re.compile(r"(?i)^(?:EXT-\d{4}-\d{4}|INS-\d{4}-\d{4}-\d{4})$")
12
+
13
+
14
+ class FrontendProxyApp(ProxyAppBase):
15
+ def __init__(
16
+ self,
17
+ identity_file: str,
18
+ *,
19
+ max_connections: int = 1000,
20
+ max_keepalive_connections: int = 10,
21
+ keepalive_expiry: float = 120.0,
22
+ retries=0,
23
+ ):
24
+ self._identity_file = identity_file
25
+ self._proxy_domain = self._get_proxy_domain()
26
+ super().__init__(
27
+ max_connections=max_connections,
28
+ max_keepalive_connections=max_keepalive_connections,
29
+ keepalive_expiry=keepalive_expiry,
30
+ retries=retries,
31
+ )
32
+
33
+ def setup_connection_pool(
34
+ self,
35
+ max_connections: int | None = 1000,
36
+ max_keepalive_connections: int | None = 100,
37
+ keepalive_expiry: float | None = 120.0,
38
+ retries: int = 0,
39
+ ) -> AsyncConnectionPool:
40
+ return AsyncConnectionPool(
41
+ max_connections=max_connections,
42
+ max_keepalive_connections=max_keepalive_connections,
43
+ keepalive_expiry=keepalive_expiry,
44
+ retries=retries,
45
+ network_backend=AIOZitiNetworkBackend(self._identity_file),
46
+ )
47
+
48
+ def get_upstream_base_url(self, scope: Scope) -> str:
49
+ target = self._get_target_name(
50
+ {k.decode("latin1"): v.decode("latin1") for k, v in scope.get("headers", {})}
51
+ )
52
+ return f"http://{target.lower()}"
53
+
54
+ def _get_proxy_domain(self):
55
+ settings = get_settings()
56
+ return (
57
+ settings.proxy.domain
58
+ if settings.proxy.domain[0] == "."
59
+ else f".{settings.proxy.domain}"
60
+ )
61
+
62
+ def _get_target_from_header(self, headers: dict[str, str], name: str) -> str | None:
63
+ header_value = headers.get(name, "")
64
+ if self._proxy_domain in header_value:
65
+ if ":" in header_value:
66
+ header_value, _ = header_value.split(":", 1)
67
+ return header_value[: -len(self._proxy_domain)]
68
+
69
+ def _get_target_name(self, headers: dict[str, str]) -> str:
70
+ target = self._get_target_from_header(headers, "x-forwarded-host")
71
+ if not target:
72
+ target = self._get_target_from_header(headers, "host")
73
+ if not target or not RE_SUBDOMAIN.fullmatch(target):
74
+ raise InvalidTargetError()
75
+ return target
@@ -6,9 +6,8 @@ from gunicorn.app.base import BaseApplication
6
6
  from uvicorn_worker import UvicornWorker
7
7
 
8
8
  from mrok.conf import get_settings
9
- from mrok.http.lifespan import LifespanWrapper
9
+ from mrok.frontend.app import FrontendProxyApp
10
10
  from mrok.logging import get_logging_config
11
- from mrok.proxy.app import ProxyApp
12
11
 
13
12
 
14
13
  class MrokUvicornWorker(UvicornWorker):
@@ -40,19 +39,14 @@ def run(
40
39
  port: int,
41
40
  workers: int,
42
41
  ):
43
- proxy_app = ProxyApp(identity_file)
42
+ app = FrontendProxyApp(str(identity_file))
44
43
 
45
- asgi_app = LifespanWrapper(
46
- proxy_app,
47
- proxy_app.startup,
48
- proxy_app.shutdown,
49
- )
50
44
  options = {
51
45
  "bind": f"{host}:{port}",
52
46
  "workers": workers,
53
- "worker_class": "mrok.proxy.main.MrokUvicornWorker",
47
+ "worker_class": "mrok.frontend.main.MrokUvicornWorker",
54
48
  "logconfig_dict": get_logging_config(get_settings()),
55
49
  "reload": False,
56
50
  }
57
51
 
58
- StandaloneApplication(asgi_app, options).run()
52
+ StandaloneApplication(app, options).run()
mrok/proxy/__init__.py CHANGED
@@ -1,3 +0,0 @@
1
- from mrok.proxy.main import run
2
-
3
- __all__ = ["run"]
mrok/proxy/app.py CHANGED
@@ -1,72 +1,175 @@
1
- import asyncio
1
+ import abc
2
2
  import logging
3
- from pathlib import Path
4
3
 
5
- import openziti
6
- from openziti.context import ZitiContext
4
+ from httpcore import AsyncConnectionPool, Request
7
5
 
8
- from mrok.conf import get_settings
9
- from mrok.http.forwarder import ForwardAppBase
10
- from mrok.http.types import Scope, StreamReader, StreamWriter
11
- from mrok.logging import setup_logging
6
+ from mrok.proxy.exceptions import ProxyError
7
+ from mrok.proxy.streams import ASGIRequestBodyStream
8
+ from mrok.proxy.types import ASGIReceive, ASGISend, Scope
12
9
 
13
10
  logger = logging.getLogger("mrok.proxy")
14
11
 
15
12
 
16
- class ProxyError(Exception):
17
- pass
13
+ HOP_BY_HOP_HEADERS = [
14
+ b"connection",
15
+ b"keep-alive",
16
+ b"proxy-authenticate",
17
+ b"proxy-authorization",
18
+ b"te",
19
+ b"trailers",
20
+ b"transfer-encoding",
21
+ b"upgrade",
22
+ ]
18
23
 
19
24
 
20
- class ProxyApp(ForwardAppBase):
25
+ class ProxyAppBase(abc.ABC):
21
26
  def __init__(
22
27
  self,
23
- identity_file: str | Path,
24
28
  *,
25
- read_chunk_size: int = 65536,
29
+ max_connections: int | None = 1000,
30
+ max_keepalive_connections: int | None = 10,
31
+ keepalive_expiry: float | None = 120.0,
32
+ retries: int = 0,
26
33
  ) -> None:
27
- super().__init__(read_chunk_size=read_chunk_size)
28
- self._identity_file = identity_file
29
- settings = get_settings()
30
- self._proxy_wildcard_domain = (
31
- settings.proxy.domain
32
- if settings.proxy.domain[0] == "."
33
- else f".{settings.proxy.domain}"
34
+ self._pool = self.setup_connection_pool(
35
+ max_connections=max_connections,
36
+ max_keepalive_connections=max_keepalive_connections,
37
+ keepalive_expiry=keepalive_expiry,
38
+ retries=retries,
34
39
  )
35
- self._ziti_ctx: ZitiContext | None = None
36
-
37
- def get_target_from_header(self, headers: dict[str, str], name: str) -> str | None:
38
- header_value = headers.get(name, "")
39
- if self._proxy_wildcard_domain in header_value:
40
- if ":" in header_value:
41
- header_value, _ = header_value.split(":", 1)
42
- return header_value[: -len(self._proxy_wildcard_domain)]
43
-
44
- def get_target_name(self, headers: dict[str, str]) -> str:
45
- target = self.get_target_from_header(headers, "x-forwarded-host")
46
- if not target:
47
- target = self.get_target_from_header(headers, "host")
48
- if not target:
49
- raise ProxyError("Neither Host nor X-Forwarded-Host contain a valid target name")
50
- return target
51
-
52
- def _get_ziti_ctx(self) -> ZitiContext:
53
- if self._ziti_ctx is None:
54
- ctx, err = openziti.load(str(self._identity_file), timeout=10_000)
55
- if err != 0:
56
- raise Exception(f"Cannot create a Ziti context from the identity file: {err}")
57
- self._ziti_ctx = ctx
58
- return self._ziti_ctx
59
-
60
- async def startup(self):
61
- setup_logging(get_settings())
62
- self._get_ziti_ctx()
63
-
64
- async def select_backend(
40
+
41
+ @abc.abstractmethod
42
+ def setup_connection_pool(
65
43
  self,
66
- scope: Scope,
67
- headers: dict[str, str],
68
- ) -> tuple[StreamReader, StreamWriter] | tuple[None, None]:
69
- target_name = self.get_target_name(headers)
70
- sock = self._get_ziti_ctx().connect(target_name)
71
- reader, writer = await asyncio.open_connection(sock=sock)
72
- return reader, writer
44
+ max_connections: int | None = 1000,
45
+ max_keepalive_connections: int | None = 10,
46
+ keepalive_expiry: float | None = 120.0,
47
+ retries: int = 0,
48
+ ) -> AsyncConnectionPool:
49
+ raise NotImplementedError()
50
+
51
+ @abc.abstractmethod
52
+ def get_upstream_base_url(self, scope: Scope) -> str:
53
+ raise NotImplementedError()
54
+
55
+ async def __call__(self, scope: Scope, receive: ASGIReceive, send: ASGISend) -> None:
56
+ if scope.get("type") == "lifespan":
57
+ return
58
+
59
+ if scope.get("type") != "http":
60
+ await self._send_error(send, 500, "Unsupported")
61
+ return
62
+
63
+ try:
64
+ base_url = self.get_upstream_base_url(scope)
65
+ if base_url.endswith("/"): # pragma: no cover
66
+ base_url = base_url[:-1]
67
+ full_path = self._format_path(scope)
68
+ url = f"{base_url}{full_path}"
69
+ method = scope.get("method", "GET").encode()
70
+ headers = self._prepare_headers(scope)
71
+
72
+ body_stream = ASGIRequestBodyStream(receive)
73
+
74
+ request = Request(
75
+ method=method,
76
+ url=url,
77
+ headers=headers,
78
+ content=body_stream,
79
+ )
80
+ response = await self._pool.handle_async_request(request)
81
+ response_headers = []
82
+ for k, v in response.headers:
83
+ if k.lower() not in HOP_BY_HOP_HEADERS:
84
+ response_headers.append((k, v))
85
+
86
+ await send(
87
+ {
88
+ "type": "http.response.start",
89
+ "status": response.status,
90
+ "headers": response_headers,
91
+ }
92
+ )
93
+
94
+ async for chunk in response.stream: # type: ignore[union-attr]
95
+ await send(
96
+ {
97
+ "type": "http.response.body",
98
+ "body": chunk,
99
+ "more_body": True,
100
+ }
101
+ )
102
+
103
+ await send({"type": "http.response.body", "body": b"", "more_body": False})
104
+ await response.aclose()
105
+
106
+ except ProxyError as pe:
107
+ await self._send_error(send, pe.http_status, pe.message)
108
+
109
+ except Exception:
110
+ logger.exception("Unexpected error in forwarder")
111
+ await self._send_error(send, 502, "Bad Gateway")
112
+
113
+ async def _send_error(self, send: ASGISend, http_status: int, body: str):
114
+ try:
115
+ await send({"type": "http.response.start", "status": http_status, "headers": []})
116
+ await send({"type": "http.response.body", "body": body.encode()})
117
+ except Exception as e: # pragma: no cover
118
+ logger.error(f"Cannot send error response: {e}")
119
+
120
+ def _prepare_headers(self, scope: Scope) -> list[tuple[bytes, bytes]]:
121
+ headers: list[tuple[bytes, bytes]] = []
122
+ scope_headers = scope.get("headers", [])
123
+
124
+ for k, v in scope_headers:
125
+ if k.lower() not in HOP_BY_HOP_HEADERS:
126
+ headers.append((k, v))
127
+
128
+ self._merge_x_forwarded(headers, scope)
129
+
130
+ return headers
131
+
132
+ def _find_header(self, headers: list[tuple[bytes, bytes]], name: bytes) -> int | None:
133
+ """Return index of header `name` in `headers`, or None if missing."""
134
+ lname = name.lower()
135
+ for i, (k, _) in enumerate(headers):
136
+ if k.lower() == lname:
137
+ return i
138
+ return None
139
+
140
+ def _merge_x_forwarded(self, headers: list[tuple[bytes, bytes]], scope: Scope) -> None:
141
+ client = scope.get("client")
142
+ if client:
143
+ client_ip = client[0].encode()
144
+ idx = self._find_header(headers, b"x-forwarded-for")
145
+ if idx is None:
146
+ headers.append((b"x-forwarded-for", client_ip))
147
+ else:
148
+ k, v = headers[idx]
149
+ headers[idx] = (k, v + b", " + client_ip)
150
+
151
+ server = scope.get("server")
152
+ if server:
153
+ if self._find_header(headers, b"x-forwarded-host") is None:
154
+ headers.append((b"x-forwarded-host", server[0].encode()))
155
+ if server[1] and self._find_header(headers, b"x-forwarded-port") is None:
156
+ headers.append((b"x-forwarded-port", str(server[1]).encode()))
157
+
158
+ # Always set the protocol to https for upstream
159
+ idx_proto = self._find_header(headers, b"x-forwarded-proto")
160
+ if idx_proto is None:
161
+ headers.append((b"x-forwarded-proto", b"https"))
162
+ else:
163
+ k, _ = headers[idx_proto]
164
+ headers[idx_proto] = (k, b"https")
165
+
166
+ def _format_path(self, scope: Scope) -> str:
167
+ raw_path = scope.get("raw_path")
168
+ if raw_path:
169
+ return raw_path.decode()
170
+ q = scope.get("query_string", b"")
171
+ path = scope.get("path", "/")
172
+ path_qs = path
173
+ if q:
174
+ path_qs += "?" + q.decode()
175
+ return path_qs