stac-auth-proxy 0.6.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (31) hide show
  1. stac_auth_proxy/__init__.py +12 -0
  2. stac_auth_proxy/__main__.py +18 -0
  3. stac_auth_proxy/app.py +178 -0
  4. stac_auth_proxy/config.py +91 -0
  5. stac_auth_proxy/filters/__init__.py +9 -0
  6. stac_auth_proxy/filters/opa.py +44 -0
  7. stac_auth_proxy/filters/template.py +22 -0
  8. stac_auth_proxy/handlers/__init__.py +7 -0
  9. stac_auth_proxy/handlers/healthz.py +31 -0
  10. stac_auth_proxy/handlers/reverse_proxy.py +101 -0
  11. stac_auth_proxy/handlers/swagger_ui.py +41 -0
  12. stac_auth_proxy/middleware/AddProcessTimeHeaderMiddleware.py +18 -0
  13. stac_auth_proxy/middleware/ApplyCql2FilterMiddleware.py +202 -0
  14. stac_auth_proxy/middleware/AuthenticationExtensionMiddleware.py +101 -0
  15. stac_auth_proxy/middleware/BuildCql2FilterMiddleware.py +112 -0
  16. stac_auth_proxy/middleware/EnforceAuthMiddleware.py +180 -0
  17. stac_auth_proxy/middleware/ProcessLinksMiddleware.py +73 -0
  18. stac_auth_proxy/middleware/RemoveRootPathMiddleware.py +45 -0
  19. stac_auth_proxy/middleware/UpdateOpenApiMiddleware.py +76 -0
  20. stac_auth_proxy/middleware/__init__.py +21 -0
  21. stac_auth_proxy/utils/__init__.py +1 -0
  22. stac_auth_proxy/utils/cache.py +92 -0
  23. stac_auth_proxy/utils/filters.py +47 -0
  24. stac_auth_proxy/utils/lifespan.py +93 -0
  25. stac_auth_proxy/utils/middleware.py +114 -0
  26. stac_auth_proxy/utils/requests.py +82 -0
  27. stac_auth_proxy/utils/stac.py +18 -0
  28. stac_auth_proxy-0.6.1.dist-info/METADATA +511 -0
  29. stac_auth_proxy-0.6.1.dist-info/RECORD +31 -0
  30. stac_auth_proxy-0.6.1.dist-info/WHEEL +4 -0
  31. stac_auth_proxy-0.6.1.dist-info/licenses/LICENSE +21 -0
@@ -0,0 +1,12 @@
1
+ """
2
+ STAC Auth Proxy package.
3
+
4
+ This package contains the components for the STAC authentication and proxying system.
5
+ It includes FastAPI routes for handling authentication, authorization, and interaction
6
+ with some internal STAC API.
7
+ """
8
+
9
+ from .app import create_app
10
+ from .config import Settings
11
+
12
+ __all__ = ["create_app", "Settings"]
@@ -0,0 +1,18 @@
1
+ """Entry point for running the module without customized code."""
2
+
3
+ import uvicorn
4
+ from uvicorn.config import LOGGING_CONFIG
5
+
6
+ LOGGING_CONFIG["loggers"][__package__] = {
7
+ "level": "DEBUG",
8
+ "handlers": ["default"],
9
+ }
10
+
11
+ uvicorn.run(
12
+ f"{__package__}.app:create_app",
13
+ host="0.0.0.0",
14
+ port=8000,
15
+ log_config=LOGGING_CONFIG,
16
+ reload=True,
17
+ factory=True,
18
+ )
stac_auth_proxy/app.py ADDED
@@ -0,0 +1,178 @@
1
+ """
2
+ STAC Auth Proxy API.
3
+
4
+ This module defines the FastAPI application for the STAC Auth Proxy, which handles
5
+ authentication, authorization, and proxying of requests to some internal STAC API.
6
+ """
7
+
8
+ import logging
9
+ from contextlib import asynccontextmanager
10
+ from typing import Optional
11
+
12
+ from fastapi import FastAPI
13
+ from starlette_cramjam.middleware import CompressionMiddleware
14
+
15
+ from .config import Settings
16
+ from .handlers import HealthzHandler, ReverseProxyHandler, SwaggerUI
17
+ from .middleware import (
18
+ AddProcessTimeHeaderMiddleware,
19
+ ApplyCql2FilterMiddleware,
20
+ AuthenticationExtensionMiddleware,
21
+ BuildCql2FilterMiddleware,
22
+ EnforceAuthMiddleware,
23
+ OpenApiMiddleware,
24
+ ProcessLinksMiddleware,
25
+ RemoveRootPathMiddleware,
26
+ )
27
+ from .utils.lifespan import check_conformance, check_server_health
28
+
29
+ logger = logging.getLogger(__name__)
30
+
31
+
32
+ def create_app(settings: Optional[Settings] = None) -> FastAPI:
33
+ """FastAPI Application Factory."""
34
+ settings = settings or Settings()
35
+
36
+ #
37
+ # Application
38
+ #
39
+
40
+ @asynccontextmanager
41
+ async def lifespan(app: FastAPI):
42
+ assert settings
43
+
44
+ # Wait for upstream servers to become available
45
+ if settings.wait_for_upstream:
46
+ logger.info("Running upstream server health checks...")
47
+ urls = [settings.upstream_url, settings.oidc_discovery_internal_url]
48
+ for url in urls:
49
+ await check_server_health(url=url)
50
+ logger.info(
51
+ "Upstream servers are healthy:\n%s",
52
+ "\n".join([f" - {url}" for url in urls]),
53
+ )
54
+
55
+ # Log all middleware connected to the app
56
+ logger.info(
57
+ "Connected middleware:\n%s",
58
+ "\n".join([f" - {m.cls.__name__}" for m in app.user_middleware]),
59
+ )
60
+
61
+ if settings.check_conformance:
62
+ await check_conformance(
63
+ app.user_middleware,
64
+ str(settings.upstream_url),
65
+ )
66
+
67
+ yield
68
+
69
+ app = FastAPI(
70
+ openapi_url=None, # Disable OpenAPI schema endpoint, we want to serve upstream's schema
71
+ lifespan=lifespan,
72
+ root_path=settings.root_path,
73
+ )
74
+ if app.root_path:
75
+ logger.debug("Mounted app at %s", app.root_path)
76
+
77
+ #
78
+ # Handlers (place catch-all proxy handler last)
79
+ #
80
+
81
+ if settings.swagger_ui_endpoint:
82
+ assert (
83
+ settings.openapi_spec_endpoint
84
+ ), "openapi_spec_endpoint must be set when using swagger_ui_endpoint"
85
+ app.add_route(
86
+ settings.swagger_ui_endpoint,
87
+ SwaggerUI(
88
+ openapi_url=settings.openapi_spec_endpoint,
89
+ init_oauth=settings.swagger_ui_init_oauth,
90
+ ).route,
91
+ include_in_schema=False,
92
+ )
93
+ if settings.healthz_prefix:
94
+ app.include_router(
95
+ HealthzHandler(upstream_url=str(settings.upstream_url)).router,
96
+ prefix=settings.healthz_prefix,
97
+ )
98
+
99
+ app.add_api_route(
100
+ "/{path:path}",
101
+ ReverseProxyHandler(
102
+ upstream=str(settings.upstream_url),
103
+ override_host=settings.override_host,
104
+ ).proxy_request,
105
+ methods=["GET", "POST", "PUT", "PATCH", "DELETE"],
106
+ )
107
+
108
+ #
109
+ # Middleware (order is important, last added = first to run)
110
+ #
111
+
112
+ if settings.enable_authentication_extension:
113
+ app.add_middleware(
114
+ AuthenticationExtensionMiddleware,
115
+ default_public=settings.default_public,
116
+ public_endpoints=settings.public_endpoints,
117
+ private_endpoints=settings.private_endpoints,
118
+ oidc_discovery_url=str(settings.oidc_discovery_url),
119
+ )
120
+
121
+ if settings.openapi_spec_endpoint:
122
+ app.add_middleware(
123
+ OpenApiMiddleware,
124
+ openapi_spec_path=settings.openapi_spec_endpoint,
125
+ oidc_discovery_url=str(settings.oidc_discovery_url),
126
+ public_endpoints=settings.public_endpoints,
127
+ private_endpoints=settings.private_endpoints,
128
+ default_public=settings.default_public,
129
+ root_path=settings.root_path,
130
+ auth_scheme_name=settings.openapi_auth_scheme_name,
131
+ auth_scheme_override=settings.openapi_auth_scheme_override,
132
+ )
133
+
134
+ if settings.items_filter or settings.collections_filter:
135
+ app.add_middleware(
136
+ ApplyCql2FilterMiddleware,
137
+ )
138
+ app.add_middleware(
139
+ BuildCql2FilterMiddleware,
140
+ items_filter=settings.items_filter() if settings.items_filter else None,
141
+ collections_filter=(
142
+ settings.collections_filter() if settings.collections_filter else None
143
+ ),
144
+ collections_filter_path=settings.collections_filter_path,
145
+ items_filter_path=settings.items_filter_path,
146
+ )
147
+
148
+ app.add_middleware(
149
+ AddProcessTimeHeaderMiddleware,
150
+ )
151
+
152
+ app.add_middleware(
153
+ EnforceAuthMiddleware,
154
+ public_endpoints=settings.public_endpoints,
155
+ private_endpoints=settings.private_endpoints,
156
+ default_public=settings.default_public,
157
+ oidc_discovery_url=settings.oidc_discovery_internal_url,
158
+ )
159
+
160
+ if settings.root_path or settings.upstream_url.path != "/":
161
+ app.add_middleware(
162
+ ProcessLinksMiddleware,
163
+ upstream_url=str(settings.upstream_url),
164
+ root_path=settings.root_path,
165
+ )
166
+
167
+ if settings.root_path:
168
+ app.add_middleware(
169
+ RemoveRootPathMiddleware,
170
+ root_path=settings.root_path,
171
+ )
172
+
173
+ if settings.enable_compression:
174
+ app.add_middleware(
175
+ CompressionMiddleware,
176
+ )
177
+
178
+ return app
@@ -0,0 +1,91 @@
1
+ """Configuration for the STAC Auth Proxy."""
2
+
3
+ import importlib
4
+ from typing import Any, Literal, Optional, Sequence, TypeAlias, Union
5
+
6
+ from pydantic import BaseModel, Field, model_validator
7
+ from pydantic.networks import HttpUrl
8
+ from pydantic_settings import BaseSettings, SettingsConfigDict
9
+
10
+ METHODS = Literal["GET", "POST", "PUT", "DELETE", "PATCH"]
11
+ EndpointMethodsNoScope: TypeAlias = dict[str, Sequence[METHODS]]
12
+ EndpointMethods: TypeAlias = dict[str, Sequence[Union[METHODS, tuple[METHODS, str]]]]
13
+
14
+ _PREFIX_PATTERN = r"^/.*$"
15
+
16
+
17
+ class ClassInput(BaseModel):
18
+ """Input model for dynamically loading a class or function."""
19
+
20
+ cls: str
21
+ args: Sequence[str] = Field(default_factory=list)
22
+ kwargs: dict[str, str] = Field(default_factory=dict)
23
+
24
+ def __call__(self):
25
+ """Dynamically load a class and instantiate it with args & kwargs."""
26
+ assert self.cls.count(":")
27
+ module_path, class_name = self.cls.rsplit(":", 1)
28
+ module = importlib.import_module(module_path)
29
+ cls = getattr(module, class_name)
30
+ return cls(*self.args, **self.kwargs)
31
+
32
+
33
+ class Settings(BaseSettings):
34
+ """Configuration settings for the STAC Auth Proxy."""
35
+
36
+ # External URLs
37
+ upstream_url: HttpUrl
38
+ oidc_discovery_url: HttpUrl
39
+ oidc_discovery_internal_url: HttpUrl
40
+
41
+ root_path: str = ""
42
+ override_host: bool = True
43
+ healthz_prefix: str = Field(pattern=_PREFIX_PATTERN, default="/healthz")
44
+ wait_for_upstream: bool = True
45
+ check_conformance: bool = True
46
+ enable_compression: bool = True
47
+
48
+ # OpenAPI / Swagger UI
49
+ openapi_spec_endpoint: Optional[str] = Field(pattern=_PREFIX_PATTERN, default=None)
50
+ openapi_auth_scheme_name: str = "oidcAuth"
51
+ openapi_auth_scheme_override: Optional[dict] = None
52
+ swagger_ui_endpoint: Optional[str] = None
53
+ swagger_ui_init_oauth: dict = Field(default_factory=dict)
54
+
55
+ # Auth
56
+ enable_authentication_extension: bool = True
57
+ default_public: bool = False
58
+ public_endpoints: EndpointMethodsNoScope = {
59
+ r"^/api.html$": ["GET"],
60
+ r"^/api$": ["GET"],
61
+ r"^/docs/oauth2-redirect": ["GET"],
62
+ r"^/healthz": ["GET"],
63
+ }
64
+ private_endpoints: EndpointMethods = {
65
+ # https://github.com/stac-api-extensions/collection-transaction/blob/v1.0.0-beta.1/README.md#methods
66
+ r"^/collections$": ["POST"],
67
+ r"^/collections/([^/]+)$": ["PUT", "PATCH", "DELETE"],
68
+ # https://github.com/stac-api-extensions/transaction/blob/v1.0.0-rc.3/README.md#methods
69
+ r"^/collections/([^/]+)/items$": ["POST"],
70
+ r"^/collections/([^/]+)/items/([^/]+)$": ["PUT", "PATCH", "DELETE"],
71
+ # https://stac-utils.github.io/stac-fastapi/api/stac_fastapi/extensions/third_party/bulk_transactions/#bulktransactionextension
72
+ r"^/collections/([^/]+)/bulk_items$": ["POST"],
73
+ }
74
+
75
+ # Filters
76
+ items_filter: Optional[ClassInput] = None
77
+ items_filter_path: str = r"^(/collections/([^/]+)/items(/[^/]+)?$|/search$)"
78
+ collections_filter: Optional[ClassInput] = None
79
+ collections_filter_path: str = r"^/collections(/[^/]+)?$"
80
+
81
+ model_config = SettingsConfigDict(
82
+ env_nested_delimiter="_",
83
+ )
84
+
85
+ @model_validator(mode="before")
86
+ @classmethod
87
+ def default_oidc_discovery_internal_url(cls, data: Any) -> Any:
88
+ """Set the internal OIDC discovery URL to the public URL if not set."""
89
+ if not data.get("oidc_discovery_internal_url"):
90
+ data["oidc_discovery_internal_url"] = data.get("oidc_discovery_url")
91
+ return data
@@ -0,0 +1,9 @@
1
+ """CQL2 filter generators."""
2
+
3
+ from .opa import Opa
4
+ from .template import Template
5
+
6
+ __all__ = [
7
+ "Opa",
8
+ "Template",
9
+ ]
@@ -0,0 +1,44 @@
1
+ """Integration with Open Policy Agent (OPA) to generate CQL2 filters for requests to a STAC API."""
2
+
3
+ from dataclasses import dataclass, field
4
+ from typing import Any
5
+
6
+ import httpx
7
+
8
+ from ..utils.cache import MemoryCache, get_value_by_path
9
+
10
+
11
+ @dataclass
12
+ class Opa:
13
+ """Call Open Policy Agent (OPA) to generate CQL2 filters from request context."""
14
+
15
+ host: str
16
+ decision: str
17
+
18
+ client: httpx.AsyncClient = field(init=False)
19
+ cache: MemoryCache = field(init=False)
20
+ cache_key: str = "req.headers.authorization"
21
+ cache_ttl: float = 5.0
22
+
23
+ def __post_init__(self):
24
+ """Initialize the client."""
25
+ self.client = httpx.AsyncClient(base_url=self.host)
26
+ self.cache = MemoryCache(ttl=self.cache_ttl)
27
+
28
+ async def __call__(self, context: dict[str, Any]) -> str:
29
+ """Generate a CQL2 filter for the request."""
30
+ token = get_value_by_path(context, self.cache_key)
31
+ try:
32
+ expr_str = self.cache[token]
33
+ except KeyError:
34
+ expr_str = await self._fetch(context)
35
+ self.cache[token] = expr_str
36
+ return expr_str
37
+
38
+ async def _fetch(self, context: dict[str, Any]) -> str:
39
+ """Fetch the CQL2 filter from OPA."""
40
+ response = await self.client.post(
41
+ f"/v1/data/{self.decision}",
42
+ json={"input": context},
43
+ )
44
+ return response.raise_for_status().json()["result"]
@@ -0,0 +1,22 @@
1
+ """Generate CQL2 filter expressions via Jinja2 templating."""
2
+
3
+ from dataclasses import dataclass, field
4
+ from typing import Any
5
+
6
+ from jinja2 import BaseLoader, Environment
7
+
8
+
9
+ @dataclass
10
+ class Template:
11
+ """Generate CQL2 filter expressions via Jinja2 templating."""
12
+
13
+ template_str: str
14
+ env: Environment = field(init=False)
15
+
16
+ def __post_init__(self):
17
+ """Initialize the Jinja2 environment."""
18
+ self.env = Environment(loader=BaseLoader).from_string(self.template_str)
19
+
20
+ async def __call__(self, context: dict[str, Any]) -> str:
21
+ """Render a CQL2 filter expression with the request and auth token."""
22
+ return self.env.render(**context).strip()
@@ -0,0 +1,7 @@
1
+ """Handlers to process requests."""
2
+
3
+ from .healthz import HealthzHandler
4
+ from .reverse_proxy import ReverseProxyHandler
5
+ from .swagger_ui import SwaggerUI
6
+
7
+ __all__ = ["ReverseProxyHandler", "HealthzHandler", "SwaggerUI"]
@@ -0,0 +1,31 @@
1
+ """Health check endpoints."""
2
+
3
+ from dataclasses import dataclass, field
4
+
5
+ from fastapi import APIRouter
6
+ from httpx import AsyncClient
7
+
8
+
9
+ @dataclass
10
+ class HealthzHandler:
11
+ """Handler for health check endpoints."""
12
+
13
+ upstream_url: str
14
+ router: APIRouter = field(init=False)
15
+
16
+ def __post_init__(self):
17
+ """Initialize the router."""
18
+ self.router = APIRouter()
19
+ self.router.add_api_route("", self.healthz, methods=["GET"])
20
+ self.router.add_api_route("/upstream", self.healthz_upstream, methods=["GET"])
21
+
22
+ async def healthz(self):
23
+ """Return health of this API."""
24
+ return {"status": "ok"}
25
+
26
+ async def healthz_upstream(self):
27
+ """Return health of upstream STAC API."""
28
+ async with AsyncClient() as client:
29
+ response = await client.get(self.upstream_url)
30
+ response.raise_for_status()
31
+ return {"status": "ok", "code": response.status_code}
@@ -0,0 +1,101 @@
1
+ """Tooling to manage the reverse proxying of requests to an upstream STAC API."""
2
+
3
+ import logging
4
+ import time
5
+ from dataclasses import dataclass, field
6
+
7
+ import httpx
8
+ from fastapi import Request
9
+ from starlette.datastructures import MutableHeaders
10
+ from starlette.responses import Response
11
+
12
+ logger = logging.getLogger(__name__)
13
+
14
+
15
+ @dataclass
16
+ class ReverseProxyHandler:
17
+ """Reverse proxy functionality."""
18
+
19
+ upstream: str
20
+ client: httpx.AsyncClient = None
21
+ timeout: httpx.Timeout = field(default_factory=lambda: httpx.Timeout(timeout=15.0))
22
+
23
+ proxy_name: str = "stac-auth-proxy"
24
+ override_host: bool = True
25
+ legacy_forwarded_headers: bool = False
26
+
27
+ def __post_init__(self):
28
+ """Initialize the HTTP client."""
29
+ self.client = self.client or httpx.AsyncClient(
30
+ base_url=self.upstream,
31
+ timeout=self.timeout,
32
+ http2=True,
33
+ )
34
+
35
+ def _prepare_headers(self, request: Request) -> MutableHeaders:
36
+ """Prepare headers for the proxied request."""
37
+ headers = MutableHeaders(request.headers)
38
+ headers.setdefault("Via", f"1.1 {self.proxy_name}")
39
+
40
+ proxy_client = request.client.host if request.client else "unknown"
41
+ proxy_proto = request.url.scheme
42
+ proxy_host = request.url.netloc
43
+ proxy_path = request.base_url.path
44
+ headers.setdefault(
45
+ "Forwarded",
46
+ f"for={proxy_client};host={proxy_host};proto={proxy_proto};path={proxy_path}",
47
+ )
48
+ if self.legacy_forwarded_headers:
49
+ headers.setdefault("X-Forwarded-For", proxy_client)
50
+ headers.setdefault("X-Forwarded-Host", proxy_host)
51
+ headers.setdefault("X-Forwarded-Path", proxy_path)
52
+ headers.setdefault("X-Forwarded-Proto", proxy_proto)
53
+
54
+ # Set host to the upstream host
55
+ if self.override_host:
56
+ headers["Host"] = self.client.base_url.netloc.decode("utf-8")
57
+
58
+ return headers
59
+
60
+ async def proxy_request(self, request: Request) -> Response:
61
+ """Proxy a request to the upstream STAC API."""
62
+ headers = self._prepare_headers(request)
63
+
64
+ # https://github.com/fastapi/fastapi/discussions/7382#discussioncomment-5136466
65
+ rp_req = self.client.build_request(
66
+ request.method,
67
+ url=httpx.URL(
68
+ path=request.url.path,
69
+ query=request.url.query.encode("utf-8"),
70
+ ),
71
+ headers=headers,
72
+ content=request.stream(),
73
+ )
74
+
75
+ # NOTE: HTTPX adds headers, so we need to trim them before sending request
76
+ for h in rp_req.headers:
77
+ if h not in headers:
78
+ del rp_req.headers[h]
79
+
80
+ logger.debug(f"Proxying request to {rp_req.url}")
81
+
82
+ start_time = time.perf_counter()
83
+ rp_resp = await self.client.send(rp_req, stream=True)
84
+ proxy_time = time.perf_counter() - start_time
85
+
86
+ logger.debug(
87
+ f"Received response status {rp_resp.status_code!r} from {rp_req.url} in {proxy_time:.3f}s"
88
+ )
89
+ rp_resp.headers["X-Upstream-Time"] = f"{proxy_time:.3f}"
90
+
91
+ # We read the content here to make use of HTTPX's decompression, ensuring we have
92
+ # non-compressed content for the middleware to work with.
93
+ content = await rp_resp.aread()
94
+ if rp_resp.headers.get("Content-Encoding"):
95
+ del rp_resp.headers["Content-Encoding"]
96
+
97
+ return Response(
98
+ content=content,
99
+ status_code=rp_resp.status_code,
100
+ headers=dict(rp_resp.headers),
101
+ )
@@ -0,0 +1,41 @@
1
+ """
2
+ In order to allow customization fo the Swagger UI's OAuth2 configuration, we support
3
+ overriding the default handler. This is useful for adding custom parameters such as
4
+ `usePkceWithAuthorizationCodeGrant` or `clientId`.
5
+
6
+ See:
7
+ - https://swagger.io/docs/open-source-tools/swagger-ui/usage/oauth2/
8
+ """
9
+
10
+ from dataclasses import dataclass, field
11
+ from typing import Optional
12
+
13
+ from fastapi.openapi.docs import get_swagger_ui_html
14
+ from starlette.requests import Request
15
+ from starlette.responses import HTMLResponse
16
+
17
+
18
+ @dataclass
19
+ class SwaggerUI:
20
+ """Swagger UI handler."""
21
+
22
+ openapi_url: str
23
+ title: Optional[str] = "STAC API"
24
+ init_oauth: dict = field(default_factory=dict)
25
+ parameters: dict = field(default_factory=dict)
26
+ oauth2_redirect_url: str = "/docs/oauth2-redirect"
27
+
28
+ async def route(self, req: Request) -> HTMLResponse:
29
+ """Route handler."""
30
+ root_path = req.scope.get("root_path", "").rstrip("/")
31
+ openapi_url = root_path + self.openapi_url
32
+ oauth2_redirect_url = self.oauth2_redirect_url
33
+ if oauth2_redirect_url:
34
+ oauth2_redirect_url = root_path + oauth2_redirect_url
35
+ return get_swagger_ui_html(
36
+ openapi_url=openapi_url,
37
+ title=f"{self.title} - Swagger UI",
38
+ oauth2_redirect_url=oauth2_redirect_url,
39
+ init_oauth=self.init_oauth,
40
+ swagger_ui_parameters=self.parameters,
41
+ )
@@ -0,0 +1,18 @@
1
+ """Middleware to add a header with the process time to the response."""
2
+
3
+ import time
4
+
5
+ from fastapi import Request, Response
6
+ from starlette.middleware.base import BaseHTTPMiddleware
7
+
8
+
9
+ class AddProcessTimeHeaderMiddleware(BaseHTTPMiddleware):
10
+ """Middleware to add a header with the process time to the response."""
11
+
12
+ async def dispatch(self, request: Request, call_next) -> Response:
13
+ """Add a header with the process time to the response."""
14
+ start_time = time.perf_counter()
15
+ response = await call_next(request)
16
+ process_time = time.perf_counter() - start_time
17
+ response.headers["X-Process-Time"] = f"{process_time:.3f}"
18
+ return response