stac-auth-proxy 0.6.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- stac_auth_proxy/__init__.py +12 -0
- stac_auth_proxy/__main__.py +18 -0
- stac_auth_proxy/app.py +178 -0
- stac_auth_proxy/config.py +91 -0
- stac_auth_proxy/filters/__init__.py +9 -0
- stac_auth_proxy/filters/opa.py +44 -0
- stac_auth_proxy/filters/template.py +22 -0
- stac_auth_proxy/handlers/__init__.py +7 -0
- stac_auth_proxy/handlers/healthz.py +31 -0
- stac_auth_proxy/handlers/reverse_proxy.py +101 -0
- stac_auth_proxy/handlers/swagger_ui.py +41 -0
- stac_auth_proxy/middleware/AddProcessTimeHeaderMiddleware.py +18 -0
- stac_auth_proxy/middleware/ApplyCql2FilterMiddleware.py +202 -0
- stac_auth_proxy/middleware/AuthenticationExtensionMiddleware.py +101 -0
- stac_auth_proxy/middleware/BuildCql2FilterMiddleware.py +112 -0
- stac_auth_proxy/middleware/EnforceAuthMiddleware.py +180 -0
- stac_auth_proxy/middleware/ProcessLinksMiddleware.py +73 -0
- stac_auth_proxy/middleware/RemoveRootPathMiddleware.py +45 -0
- stac_auth_proxy/middleware/UpdateOpenApiMiddleware.py +76 -0
- stac_auth_proxy/middleware/__init__.py +21 -0
- stac_auth_proxy/utils/__init__.py +1 -0
- stac_auth_proxy/utils/cache.py +92 -0
- stac_auth_proxy/utils/filters.py +47 -0
- stac_auth_proxy/utils/lifespan.py +93 -0
- stac_auth_proxy/utils/middleware.py +114 -0
- stac_auth_proxy/utils/requests.py +82 -0
- stac_auth_proxy/utils/stac.py +18 -0
- stac_auth_proxy-0.6.1.dist-info/METADATA +511 -0
- stac_auth_proxy-0.6.1.dist-info/RECORD +31 -0
- stac_auth_proxy-0.6.1.dist-info/WHEEL +4 -0
- stac_auth_proxy-0.6.1.dist-info/licenses/LICENSE +21 -0
|
@@ -0,0 +1,12 @@
|
|
|
1
|
+
"""
|
|
2
|
+
STAC Auth Proxy package.
|
|
3
|
+
|
|
4
|
+
This package contains the components for the STAC authentication and proxying system.
|
|
5
|
+
It includes FastAPI routes for handling authentication, authorization, and interaction
|
|
6
|
+
with some internal STAC API.
|
|
7
|
+
"""
|
|
8
|
+
|
|
9
|
+
from .app import create_app
|
|
10
|
+
from .config import Settings
|
|
11
|
+
|
|
12
|
+
__all__ = ["create_app", "Settings"]
|
|
@@ -0,0 +1,18 @@
|
|
|
1
|
+
"""Entry point for running the module without customized code."""
|
|
2
|
+
|
|
3
|
+
import uvicorn
|
|
4
|
+
from uvicorn.config import LOGGING_CONFIG
|
|
5
|
+
|
|
6
|
+
LOGGING_CONFIG["loggers"][__package__] = {
|
|
7
|
+
"level": "DEBUG",
|
|
8
|
+
"handlers": ["default"],
|
|
9
|
+
}
|
|
10
|
+
|
|
11
|
+
uvicorn.run(
|
|
12
|
+
f"{__package__}.app:create_app",
|
|
13
|
+
host="0.0.0.0",
|
|
14
|
+
port=8000,
|
|
15
|
+
log_config=LOGGING_CONFIG,
|
|
16
|
+
reload=True,
|
|
17
|
+
factory=True,
|
|
18
|
+
)
|
stac_auth_proxy/app.py
ADDED
|
@@ -0,0 +1,178 @@
|
|
|
1
|
+
"""
|
|
2
|
+
STAC Auth Proxy API.
|
|
3
|
+
|
|
4
|
+
This module defines the FastAPI application for the STAC Auth Proxy, which handles
|
|
5
|
+
authentication, authorization, and proxying of requests to some internal STAC API.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
import logging
|
|
9
|
+
from contextlib import asynccontextmanager
|
|
10
|
+
from typing import Optional
|
|
11
|
+
|
|
12
|
+
from fastapi import FastAPI
|
|
13
|
+
from starlette_cramjam.middleware import CompressionMiddleware
|
|
14
|
+
|
|
15
|
+
from .config import Settings
|
|
16
|
+
from .handlers import HealthzHandler, ReverseProxyHandler, SwaggerUI
|
|
17
|
+
from .middleware import (
|
|
18
|
+
AddProcessTimeHeaderMiddleware,
|
|
19
|
+
ApplyCql2FilterMiddleware,
|
|
20
|
+
AuthenticationExtensionMiddleware,
|
|
21
|
+
BuildCql2FilterMiddleware,
|
|
22
|
+
EnforceAuthMiddleware,
|
|
23
|
+
OpenApiMiddleware,
|
|
24
|
+
ProcessLinksMiddleware,
|
|
25
|
+
RemoveRootPathMiddleware,
|
|
26
|
+
)
|
|
27
|
+
from .utils.lifespan import check_conformance, check_server_health
|
|
28
|
+
|
|
29
|
+
logger = logging.getLogger(__name__)
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
def create_app(settings: Optional[Settings] = None) -> FastAPI:
|
|
33
|
+
"""FastAPI Application Factory."""
|
|
34
|
+
settings = settings or Settings()
|
|
35
|
+
|
|
36
|
+
#
|
|
37
|
+
# Application
|
|
38
|
+
#
|
|
39
|
+
|
|
40
|
+
@asynccontextmanager
|
|
41
|
+
async def lifespan(app: FastAPI):
|
|
42
|
+
assert settings
|
|
43
|
+
|
|
44
|
+
# Wait for upstream servers to become available
|
|
45
|
+
if settings.wait_for_upstream:
|
|
46
|
+
logger.info("Running upstream server health checks...")
|
|
47
|
+
urls = [settings.upstream_url, settings.oidc_discovery_internal_url]
|
|
48
|
+
for url in urls:
|
|
49
|
+
await check_server_health(url=url)
|
|
50
|
+
logger.info(
|
|
51
|
+
"Upstream servers are healthy:\n%s",
|
|
52
|
+
"\n".join([f" - {url}" for url in urls]),
|
|
53
|
+
)
|
|
54
|
+
|
|
55
|
+
# Log all middleware connected to the app
|
|
56
|
+
logger.info(
|
|
57
|
+
"Connected middleware:\n%s",
|
|
58
|
+
"\n".join([f" - {m.cls.__name__}" for m in app.user_middleware]),
|
|
59
|
+
)
|
|
60
|
+
|
|
61
|
+
if settings.check_conformance:
|
|
62
|
+
await check_conformance(
|
|
63
|
+
app.user_middleware,
|
|
64
|
+
str(settings.upstream_url),
|
|
65
|
+
)
|
|
66
|
+
|
|
67
|
+
yield
|
|
68
|
+
|
|
69
|
+
app = FastAPI(
|
|
70
|
+
openapi_url=None, # Disable OpenAPI schema endpoint, we want to serve upstream's schema
|
|
71
|
+
lifespan=lifespan,
|
|
72
|
+
root_path=settings.root_path,
|
|
73
|
+
)
|
|
74
|
+
if app.root_path:
|
|
75
|
+
logger.debug("Mounted app at %s", app.root_path)
|
|
76
|
+
|
|
77
|
+
#
|
|
78
|
+
# Handlers (place catch-all proxy handler last)
|
|
79
|
+
#
|
|
80
|
+
|
|
81
|
+
if settings.swagger_ui_endpoint:
|
|
82
|
+
assert (
|
|
83
|
+
settings.openapi_spec_endpoint
|
|
84
|
+
), "openapi_spec_endpoint must be set when using swagger_ui_endpoint"
|
|
85
|
+
app.add_route(
|
|
86
|
+
settings.swagger_ui_endpoint,
|
|
87
|
+
SwaggerUI(
|
|
88
|
+
openapi_url=settings.openapi_spec_endpoint,
|
|
89
|
+
init_oauth=settings.swagger_ui_init_oauth,
|
|
90
|
+
).route,
|
|
91
|
+
include_in_schema=False,
|
|
92
|
+
)
|
|
93
|
+
if settings.healthz_prefix:
|
|
94
|
+
app.include_router(
|
|
95
|
+
HealthzHandler(upstream_url=str(settings.upstream_url)).router,
|
|
96
|
+
prefix=settings.healthz_prefix,
|
|
97
|
+
)
|
|
98
|
+
|
|
99
|
+
app.add_api_route(
|
|
100
|
+
"/{path:path}",
|
|
101
|
+
ReverseProxyHandler(
|
|
102
|
+
upstream=str(settings.upstream_url),
|
|
103
|
+
override_host=settings.override_host,
|
|
104
|
+
).proxy_request,
|
|
105
|
+
methods=["GET", "POST", "PUT", "PATCH", "DELETE"],
|
|
106
|
+
)
|
|
107
|
+
|
|
108
|
+
#
|
|
109
|
+
# Middleware (order is important, last added = first to run)
|
|
110
|
+
#
|
|
111
|
+
|
|
112
|
+
if settings.enable_authentication_extension:
|
|
113
|
+
app.add_middleware(
|
|
114
|
+
AuthenticationExtensionMiddleware,
|
|
115
|
+
default_public=settings.default_public,
|
|
116
|
+
public_endpoints=settings.public_endpoints,
|
|
117
|
+
private_endpoints=settings.private_endpoints,
|
|
118
|
+
oidc_discovery_url=str(settings.oidc_discovery_url),
|
|
119
|
+
)
|
|
120
|
+
|
|
121
|
+
if settings.openapi_spec_endpoint:
|
|
122
|
+
app.add_middleware(
|
|
123
|
+
OpenApiMiddleware,
|
|
124
|
+
openapi_spec_path=settings.openapi_spec_endpoint,
|
|
125
|
+
oidc_discovery_url=str(settings.oidc_discovery_url),
|
|
126
|
+
public_endpoints=settings.public_endpoints,
|
|
127
|
+
private_endpoints=settings.private_endpoints,
|
|
128
|
+
default_public=settings.default_public,
|
|
129
|
+
root_path=settings.root_path,
|
|
130
|
+
auth_scheme_name=settings.openapi_auth_scheme_name,
|
|
131
|
+
auth_scheme_override=settings.openapi_auth_scheme_override,
|
|
132
|
+
)
|
|
133
|
+
|
|
134
|
+
if settings.items_filter or settings.collections_filter:
|
|
135
|
+
app.add_middleware(
|
|
136
|
+
ApplyCql2FilterMiddleware,
|
|
137
|
+
)
|
|
138
|
+
app.add_middleware(
|
|
139
|
+
BuildCql2FilterMiddleware,
|
|
140
|
+
items_filter=settings.items_filter() if settings.items_filter else None,
|
|
141
|
+
collections_filter=(
|
|
142
|
+
settings.collections_filter() if settings.collections_filter else None
|
|
143
|
+
),
|
|
144
|
+
collections_filter_path=settings.collections_filter_path,
|
|
145
|
+
items_filter_path=settings.items_filter_path,
|
|
146
|
+
)
|
|
147
|
+
|
|
148
|
+
app.add_middleware(
|
|
149
|
+
AddProcessTimeHeaderMiddleware,
|
|
150
|
+
)
|
|
151
|
+
|
|
152
|
+
app.add_middleware(
|
|
153
|
+
EnforceAuthMiddleware,
|
|
154
|
+
public_endpoints=settings.public_endpoints,
|
|
155
|
+
private_endpoints=settings.private_endpoints,
|
|
156
|
+
default_public=settings.default_public,
|
|
157
|
+
oidc_discovery_url=settings.oidc_discovery_internal_url,
|
|
158
|
+
)
|
|
159
|
+
|
|
160
|
+
if settings.root_path or settings.upstream_url.path != "/":
|
|
161
|
+
app.add_middleware(
|
|
162
|
+
ProcessLinksMiddleware,
|
|
163
|
+
upstream_url=str(settings.upstream_url),
|
|
164
|
+
root_path=settings.root_path,
|
|
165
|
+
)
|
|
166
|
+
|
|
167
|
+
if settings.root_path:
|
|
168
|
+
app.add_middleware(
|
|
169
|
+
RemoveRootPathMiddleware,
|
|
170
|
+
root_path=settings.root_path,
|
|
171
|
+
)
|
|
172
|
+
|
|
173
|
+
if settings.enable_compression:
|
|
174
|
+
app.add_middleware(
|
|
175
|
+
CompressionMiddleware,
|
|
176
|
+
)
|
|
177
|
+
|
|
178
|
+
return app
|
|
@@ -0,0 +1,91 @@
|
|
|
1
|
+
"""Configuration for the STAC Auth Proxy."""
|
|
2
|
+
|
|
3
|
+
import importlib
|
|
4
|
+
from typing import Any, Literal, Optional, Sequence, TypeAlias, Union
|
|
5
|
+
|
|
6
|
+
from pydantic import BaseModel, Field, model_validator
|
|
7
|
+
from pydantic.networks import HttpUrl
|
|
8
|
+
from pydantic_settings import BaseSettings, SettingsConfigDict
|
|
9
|
+
|
|
10
|
+
METHODS = Literal["GET", "POST", "PUT", "DELETE", "PATCH"]
|
|
11
|
+
EndpointMethodsNoScope: TypeAlias = dict[str, Sequence[METHODS]]
|
|
12
|
+
EndpointMethods: TypeAlias = dict[str, Sequence[Union[METHODS, tuple[METHODS, str]]]]
|
|
13
|
+
|
|
14
|
+
_PREFIX_PATTERN = r"^/.*$"
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class ClassInput(BaseModel):
|
|
18
|
+
"""Input model for dynamically loading a class or function."""
|
|
19
|
+
|
|
20
|
+
cls: str
|
|
21
|
+
args: Sequence[str] = Field(default_factory=list)
|
|
22
|
+
kwargs: dict[str, str] = Field(default_factory=dict)
|
|
23
|
+
|
|
24
|
+
def __call__(self):
|
|
25
|
+
"""Dynamically load a class and instantiate it with args & kwargs."""
|
|
26
|
+
assert self.cls.count(":")
|
|
27
|
+
module_path, class_name = self.cls.rsplit(":", 1)
|
|
28
|
+
module = importlib.import_module(module_path)
|
|
29
|
+
cls = getattr(module, class_name)
|
|
30
|
+
return cls(*self.args, **self.kwargs)
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
class Settings(BaseSettings):
|
|
34
|
+
"""Configuration settings for the STAC Auth Proxy."""
|
|
35
|
+
|
|
36
|
+
# External URLs
|
|
37
|
+
upstream_url: HttpUrl
|
|
38
|
+
oidc_discovery_url: HttpUrl
|
|
39
|
+
oidc_discovery_internal_url: HttpUrl
|
|
40
|
+
|
|
41
|
+
root_path: str = ""
|
|
42
|
+
override_host: bool = True
|
|
43
|
+
healthz_prefix: str = Field(pattern=_PREFIX_PATTERN, default="/healthz")
|
|
44
|
+
wait_for_upstream: bool = True
|
|
45
|
+
check_conformance: bool = True
|
|
46
|
+
enable_compression: bool = True
|
|
47
|
+
|
|
48
|
+
# OpenAPI / Swagger UI
|
|
49
|
+
openapi_spec_endpoint: Optional[str] = Field(pattern=_PREFIX_PATTERN, default=None)
|
|
50
|
+
openapi_auth_scheme_name: str = "oidcAuth"
|
|
51
|
+
openapi_auth_scheme_override: Optional[dict] = None
|
|
52
|
+
swagger_ui_endpoint: Optional[str] = None
|
|
53
|
+
swagger_ui_init_oauth: dict = Field(default_factory=dict)
|
|
54
|
+
|
|
55
|
+
# Auth
|
|
56
|
+
enable_authentication_extension: bool = True
|
|
57
|
+
default_public: bool = False
|
|
58
|
+
public_endpoints: EndpointMethodsNoScope = {
|
|
59
|
+
r"^/api.html$": ["GET"],
|
|
60
|
+
r"^/api$": ["GET"],
|
|
61
|
+
r"^/docs/oauth2-redirect": ["GET"],
|
|
62
|
+
r"^/healthz": ["GET"],
|
|
63
|
+
}
|
|
64
|
+
private_endpoints: EndpointMethods = {
|
|
65
|
+
# https://github.com/stac-api-extensions/collection-transaction/blob/v1.0.0-beta.1/README.md#methods
|
|
66
|
+
r"^/collections$": ["POST"],
|
|
67
|
+
r"^/collections/([^/]+)$": ["PUT", "PATCH", "DELETE"],
|
|
68
|
+
# https://github.com/stac-api-extensions/transaction/blob/v1.0.0-rc.3/README.md#methods
|
|
69
|
+
r"^/collections/([^/]+)/items$": ["POST"],
|
|
70
|
+
r"^/collections/([^/]+)/items/([^/]+)$": ["PUT", "PATCH", "DELETE"],
|
|
71
|
+
# https://stac-utils.github.io/stac-fastapi/api/stac_fastapi/extensions/third_party/bulk_transactions/#bulktransactionextension
|
|
72
|
+
r"^/collections/([^/]+)/bulk_items$": ["POST"],
|
|
73
|
+
}
|
|
74
|
+
|
|
75
|
+
# Filters
|
|
76
|
+
items_filter: Optional[ClassInput] = None
|
|
77
|
+
items_filter_path: str = r"^(/collections/([^/]+)/items(/[^/]+)?$|/search$)"
|
|
78
|
+
collections_filter: Optional[ClassInput] = None
|
|
79
|
+
collections_filter_path: str = r"^/collections(/[^/]+)?$"
|
|
80
|
+
|
|
81
|
+
model_config = SettingsConfigDict(
|
|
82
|
+
env_nested_delimiter="_",
|
|
83
|
+
)
|
|
84
|
+
|
|
85
|
+
@model_validator(mode="before")
|
|
86
|
+
@classmethod
|
|
87
|
+
def default_oidc_discovery_internal_url(cls, data: Any) -> Any:
|
|
88
|
+
"""Set the internal OIDC discovery URL to the public URL if not set."""
|
|
89
|
+
if not data.get("oidc_discovery_internal_url"):
|
|
90
|
+
data["oidc_discovery_internal_url"] = data.get("oidc_discovery_url")
|
|
91
|
+
return data
|
|
@@ -0,0 +1,44 @@
|
|
|
1
|
+
"""Integration with Open Policy Agent (OPA) to generate CQL2 filters for requests to a STAC API."""
|
|
2
|
+
|
|
3
|
+
from dataclasses import dataclass, field
|
|
4
|
+
from typing import Any
|
|
5
|
+
|
|
6
|
+
import httpx
|
|
7
|
+
|
|
8
|
+
from ..utils.cache import MemoryCache, get_value_by_path
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
@dataclass
|
|
12
|
+
class Opa:
|
|
13
|
+
"""Call Open Policy Agent (OPA) to generate CQL2 filters from request context."""
|
|
14
|
+
|
|
15
|
+
host: str
|
|
16
|
+
decision: str
|
|
17
|
+
|
|
18
|
+
client: httpx.AsyncClient = field(init=False)
|
|
19
|
+
cache: MemoryCache = field(init=False)
|
|
20
|
+
cache_key: str = "req.headers.authorization"
|
|
21
|
+
cache_ttl: float = 5.0
|
|
22
|
+
|
|
23
|
+
def __post_init__(self):
|
|
24
|
+
"""Initialize the client."""
|
|
25
|
+
self.client = httpx.AsyncClient(base_url=self.host)
|
|
26
|
+
self.cache = MemoryCache(ttl=self.cache_ttl)
|
|
27
|
+
|
|
28
|
+
async def __call__(self, context: dict[str, Any]) -> str:
|
|
29
|
+
"""Generate a CQL2 filter for the request."""
|
|
30
|
+
token = get_value_by_path(context, self.cache_key)
|
|
31
|
+
try:
|
|
32
|
+
expr_str = self.cache[token]
|
|
33
|
+
except KeyError:
|
|
34
|
+
expr_str = await self._fetch(context)
|
|
35
|
+
self.cache[token] = expr_str
|
|
36
|
+
return expr_str
|
|
37
|
+
|
|
38
|
+
async def _fetch(self, context: dict[str, Any]) -> str:
|
|
39
|
+
"""Fetch the CQL2 filter from OPA."""
|
|
40
|
+
response = await self.client.post(
|
|
41
|
+
f"/v1/data/{self.decision}",
|
|
42
|
+
json={"input": context},
|
|
43
|
+
)
|
|
44
|
+
return response.raise_for_status().json()["result"]
|
|
@@ -0,0 +1,22 @@
|
|
|
1
|
+
"""Generate CQL2 filter expressions via Jinja2 templating."""
|
|
2
|
+
|
|
3
|
+
from dataclasses import dataclass, field
|
|
4
|
+
from typing import Any
|
|
5
|
+
|
|
6
|
+
from jinja2 import BaseLoader, Environment
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
@dataclass
|
|
10
|
+
class Template:
|
|
11
|
+
"""Generate CQL2 filter expressions via Jinja2 templating."""
|
|
12
|
+
|
|
13
|
+
template_str: str
|
|
14
|
+
env: Environment = field(init=False)
|
|
15
|
+
|
|
16
|
+
def __post_init__(self):
|
|
17
|
+
"""Initialize the Jinja2 environment."""
|
|
18
|
+
self.env = Environment(loader=BaseLoader).from_string(self.template_str)
|
|
19
|
+
|
|
20
|
+
async def __call__(self, context: dict[str, Any]) -> str:
|
|
21
|
+
"""Render a CQL2 filter expression with the request and auth token."""
|
|
22
|
+
return self.env.render(**context).strip()
|
|
@@ -0,0 +1,31 @@
|
|
|
1
|
+
"""Health check endpoints."""
|
|
2
|
+
|
|
3
|
+
from dataclasses import dataclass, field
|
|
4
|
+
|
|
5
|
+
from fastapi import APIRouter
|
|
6
|
+
from httpx import AsyncClient
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
@dataclass
|
|
10
|
+
class HealthzHandler:
|
|
11
|
+
"""Handler for health check endpoints."""
|
|
12
|
+
|
|
13
|
+
upstream_url: str
|
|
14
|
+
router: APIRouter = field(init=False)
|
|
15
|
+
|
|
16
|
+
def __post_init__(self):
|
|
17
|
+
"""Initialize the router."""
|
|
18
|
+
self.router = APIRouter()
|
|
19
|
+
self.router.add_api_route("", self.healthz, methods=["GET"])
|
|
20
|
+
self.router.add_api_route("/upstream", self.healthz_upstream, methods=["GET"])
|
|
21
|
+
|
|
22
|
+
async def healthz(self):
|
|
23
|
+
"""Return health of this API."""
|
|
24
|
+
return {"status": "ok"}
|
|
25
|
+
|
|
26
|
+
async def healthz_upstream(self):
|
|
27
|
+
"""Return health of upstream STAC API."""
|
|
28
|
+
async with AsyncClient() as client:
|
|
29
|
+
response = await client.get(self.upstream_url)
|
|
30
|
+
response.raise_for_status()
|
|
31
|
+
return {"status": "ok", "code": response.status_code}
|
|
@@ -0,0 +1,101 @@
|
|
|
1
|
+
"""Tooling to manage the reverse proxying of requests to an upstream STAC API."""
|
|
2
|
+
|
|
3
|
+
import logging
|
|
4
|
+
import time
|
|
5
|
+
from dataclasses import dataclass, field
|
|
6
|
+
|
|
7
|
+
import httpx
|
|
8
|
+
from fastapi import Request
|
|
9
|
+
from starlette.datastructures import MutableHeaders
|
|
10
|
+
from starlette.responses import Response
|
|
11
|
+
|
|
12
|
+
logger = logging.getLogger(__name__)
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
@dataclass
|
|
16
|
+
class ReverseProxyHandler:
|
|
17
|
+
"""Reverse proxy functionality."""
|
|
18
|
+
|
|
19
|
+
upstream: str
|
|
20
|
+
client: httpx.AsyncClient = None
|
|
21
|
+
timeout: httpx.Timeout = field(default_factory=lambda: httpx.Timeout(timeout=15.0))
|
|
22
|
+
|
|
23
|
+
proxy_name: str = "stac-auth-proxy"
|
|
24
|
+
override_host: bool = True
|
|
25
|
+
legacy_forwarded_headers: bool = False
|
|
26
|
+
|
|
27
|
+
def __post_init__(self):
|
|
28
|
+
"""Initialize the HTTP client."""
|
|
29
|
+
self.client = self.client or httpx.AsyncClient(
|
|
30
|
+
base_url=self.upstream,
|
|
31
|
+
timeout=self.timeout,
|
|
32
|
+
http2=True,
|
|
33
|
+
)
|
|
34
|
+
|
|
35
|
+
def _prepare_headers(self, request: Request) -> MutableHeaders:
|
|
36
|
+
"""Prepare headers for the proxied request."""
|
|
37
|
+
headers = MutableHeaders(request.headers)
|
|
38
|
+
headers.setdefault("Via", f"1.1 {self.proxy_name}")
|
|
39
|
+
|
|
40
|
+
proxy_client = request.client.host if request.client else "unknown"
|
|
41
|
+
proxy_proto = request.url.scheme
|
|
42
|
+
proxy_host = request.url.netloc
|
|
43
|
+
proxy_path = request.base_url.path
|
|
44
|
+
headers.setdefault(
|
|
45
|
+
"Forwarded",
|
|
46
|
+
f"for={proxy_client};host={proxy_host};proto={proxy_proto};path={proxy_path}",
|
|
47
|
+
)
|
|
48
|
+
if self.legacy_forwarded_headers:
|
|
49
|
+
headers.setdefault("X-Forwarded-For", proxy_client)
|
|
50
|
+
headers.setdefault("X-Forwarded-Host", proxy_host)
|
|
51
|
+
headers.setdefault("X-Forwarded-Path", proxy_path)
|
|
52
|
+
headers.setdefault("X-Forwarded-Proto", proxy_proto)
|
|
53
|
+
|
|
54
|
+
# Set host to the upstream host
|
|
55
|
+
if self.override_host:
|
|
56
|
+
headers["Host"] = self.client.base_url.netloc.decode("utf-8")
|
|
57
|
+
|
|
58
|
+
return headers
|
|
59
|
+
|
|
60
|
+
async def proxy_request(self, request: Request) -> Response:
|
|
61
|
+
"""Proxy a request to the upstream STAC API."""
|
|
62
|
+
headers = self._prepare_headers(request)
|
|
63
|
+
|
|
64
|
+
# https://github.com/fastapi/fastapi/discussions/7382#discussioncomment-5136466
|
|
65
|
+
rp_req = self.client.build_request(
|
|
66
|
+
request.method,
|
|
67
|
+
url=httpx.URL(
|
|
68
|
+
path=request.url.path,
|
|
69
|
+
query=request.url.query.encode("utf-8"),
|
|
70
|
+
),
|
|
71
|
+
headers=headers,
|
|
72
|
+
content=request.stream(),
|
|
73
|
+
)
|
|
74
|
+
|
|
75
|
+
# NOTE: HTTPX adds headers, so we need to trim them before sending request
|
|
76
|
+
for h in rp_req.headers:
|
|
77
|
+
if h not in headers:
|
|
78
|
+
del rp_req.headers[h]
|
|
79
|
+
|
|
80
|
+
logger.debug(f"Proxying request to {rp_req.url}")
|
|
81
|
+
|
|
82
|
+
start_time = time.perf_counter()
|
|
83
|
+
rp_resp = await self.client.send(rp_req, stream=True)
|
|
84
|
+
proxy_time = time.perf_counter() - start_time
|
|
85
|
+
|
|
86
|
+
logger.debug(
|
|
87
|
+
f"Received response status {rp_resp.status_code!r} from {rp_req.url} in {proxy_time:.3f}s"
|
|
88
|
+
)
|
|
89
|
+
rp_resp.headers["X-Upstream-Time"] = f"{proxy_time:.3f}"
|
|
90
|
+
|
|
91
|
+
# We read the content here to make use of HTTPX's decompression, ensuring we have
|
|
92
|
+
# non-compressed content for the middleware to work with.
|
|
93
|
+
content = await rp_resp.aread()
|
|
94
|
+
if rp_resp.headers.get("Content-Encoding"):
|
|
95
|
+
del rp_resp.headers["Content-Encoding"]
|
|
96
|
+
|
|
97
|
+
return Response(
|
|
98
|
+
content=content,
|
|
99
|
+
status_code=rp_resp.status_code,
|
|
100
|
+
headers=dict(rp_resp.headers),
|
|
101
|
+
)
|
|
@@ -0,0 +1,41 @@
|
|
|
1
|
+
"""
|
|
2
|
+
In order to allow customization fo the Swagger UI's OAuth2 configuration, we support
|
|
3
|
+
overriding the default handler. This is useful for adding custom parameters such as
|
|
4
|
+
`usePkceWithAuthorizationCodeGrant` or `clientId`.
|
|
5
|
+
|
|
6
|
+
See:
|
|
7
|
+
- https://swagger.io/docs/open-source-tools/swagger-ui/usage/oauth2/
|
|
8
|
+
"""
|
|
9
|
+
|
|
10
|
+
from dataclasses import dataclass, field
|
|
11
|
+
from typing import Optional
|
|
12
|
+
|
|
13
|
+
from fastapi.openapi.docs import get_swagger_ui_html
|
|
14
|
+
from starlette.requests import Request
|
|
15
|
+
from starlette.responses import HTMLResponse
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
@dataclass
|
|
19
|
+
class SwaggerUI:
|
|
20
|
+
"""Swagger UI handler."""
|
|
21
|
+
|
|
22
|
+
openapi_url: str
|
|
23
|
+
title: Optional[str] = "STAC API"
|
|
24
|
+
init_oauth: dict = field(default_factory=dict)
|
|
25
|
+
parameters: dict = field(default_factory=dict)
|
|
26
|
+
oauth2_redirect_url: str = "/docs/oauth2-redirect"
|
|
27
|
+
|
|
28
|
+
async def route(self, req: Request) -> HTMLResponse:
|
|
29
|
+
"""Route handler."""
|
|
30
|
+
root_path = req.scope.get("root_path", "").rstrip("/")
|
|
31
|
+
openapi_url = root_path + self.openapi_url
|
|
32
|
+
oauth2_redirect_url = self.oauth2_redirect_url
|
|
33
|
+
if oauth2_redirect_url:
|
|
34
|
+
oauth2_redirect_url = root_path + oauth2_redirect_url
|
|
35
|
+
return get_swagger_ui_html(
|
|
36
|
+
openapi_url=openapi_url,
|
|
37
|
+
title=f"{self.title} - Swagger UI",
|
|
38
|
+
oauth2_redirect_url=oauth2_redirect_url,
|
|
39
|
+
init_oauth=self.init_oauth,
|
|
40
|
+
swagger_ui_parameters=self.parameters,
|
|
41
|
+
)
|
|
@@ -0,0 +1,18 @@
|
|
|
1
|
+
"""Middleware to add a header with the process time to the response."""
|
|
2
|
+
|
|
3
|
+
import time
|
|
4
|
+
|
|
5
|
+
from fastapi import Request, Response
|
|
6
|
+
from starlette.middleware.base import BaseHTTPMiddleware
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class AddProcessTimeHeaderMiddleware(BaseHTTPMiddleware):
|
|
10
|
+
"""Middleware to add a header with the process time to the response."""
|
|
11
|
+
|
|
12
|
+
async def dispatch(self, request: Request, call_next) -> Response:
|
|
13
|
+
"""Add a header with the process time to the response."""
|
|
14
|
+
start_time = time.perf_counter()
|
|
15
|
+
response = await call_next(request)
|
|
16
|
+
process_time = time.perf_counter() - start_time
|
|
17
|
+
response.headers["X-Process-Time"] = f"{process_time:.3f}"
|
|
18
|
+
return response
|