cadwyn 5.4.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
cadwyn/applications.py ADDED
@@ -0,0 +1,484 @@
1
+ import dataclasses
2
+ import warnings
3
+ from collections.abc import Awaitable, Callable, Coroutine, Sequence
4
+ from datetime import date
5
+ from logging import getLogger
6
+ from pathlib import Path
7
+ from typing import TYPE_CHECKING, Annotated, Any, Optional, Union, cast
8
+ from urllib.parse import quote
9
+
10
+ import fastapi
11
+ from fastapi import APIRouter, FastAPI, HTTPException, routing
12
+ from fastapi.datastructures import Default
13
+ from fastapi.openapi.docs import (
14
+ get_redoc_html,
15
+ get_swagger_ui_html,
16
+ get_swagger_ui_oauth2_redirect_html,
17
+ )
18
+ from fastapi.openapi.utils import get_openapi
19
+ from fastapi.params import Depends
20
+ from fastapi.responses import HTMLResponse
21
+ from fastapi.templating import Jinja2Templates
22
+ from fastapi.utils import generate_unique_id
23
+ from starlette.middleware import Middleware
24
+ from starlette.requests import Request
25
+ from starlette.responses import JSONResponse, Response
26
+ from starlette.routing import BaseRoute, Route
27
+ from starlette.types import Lifespan
28
+ from typing_extensions import Self, assert_never, deprecated
29
+
30
+ from cadwyn._utils import DATACLASS_SLOTS, same_definition_as_in
31
+ from cadwyn.changelogs import CadwynChangelogResource, _generate_changelog
32
+ from cadwyn.exceptions import CadwynStructureError
33
+ from cadwyn.middleware import (
34
+ APIVersionFormat,
35
+ APIVersionLocation,
36
+ HeaderVersionManager,
37
+ URLVersionManager,
38
+ VersionPickingMiddleware,
39
+ _generate_api_version_dependency,
40
+ )
41
+ from cadwyn.route_generation import generate_versioned_routers
42
+ from cadwyn.routing import _RootCadwynAPIRouter
43
+ from cadwyn.structure import VersionBundle
44
+
45
+ if TYPE_CHECKING:
46
+ from cadwyn.structure.common import VersionType
47
+
48
+ CURR_DIR = Path(__file__).resolve()
49
+ logger = getLogger(__name__)
50
+
51
+
52
+ @dataclasses.dataclass(**DATACLASS_SLOTS)
53
+ class FakeDependencyOverridesProvider:
54
+ dependency_overrides: dict[Callable[..., Any], Callable[..., Any]]
55
+
56
+
57
+ class Cadwyn(FastAPI):
58
+ _templates = Jinja2Templates(directory=CURR_DIR.parent / "static")
59
+
60
+ def __init__(
61
+ self,
62
+ *,
63
+ versions: VersionBundle,
64
+ api_version_header_name: Annotated[
65
+ Union[str, None],
66
+ deprecated(
67
+ "api_version_header_name is deprecated and will be removed in the future. "
68
+ "Use api_version_parameter_name instead."
69
+ ),
70
+ ] = None,
71
+ api_version_location: APIVersionLocation = "custom_header",
72
+ api_version_format: APIVersionFormat = "date",
73
+ api_version_parameter_name: str = "x-api-version",
74
+ api_version_default_value: Union[str, None, Callable[[Request], Awaitable[str]]] = None,
75
+ api_version_title: Optional[str] = None,
76
+ api_version_description: Optional[str] = None,
77
+ versioning_middleware_class: type[VersionPickingMiddleware] = VersionPickingMiddleware,
78
+ changelog_url: Union[str, None] = "/changelog",
79
+ include_changelog_url_in_schema: bool = True,
80
+ debug: bool = False,
81
+ title: str = "FastAPI",
82
+ summary: Union[str, None] = None,
83
+ description: str = "",
84
+ version: str = "0.1.0",
85
+ openapi_url: Union[str, None] = "/openapi.json",
86
+ openapi_tags: Union[list[dict[str, Any]], None] = None,
87
+ servers: Union[list[dict[str, Union[str, Any]]], None] = None,
88
+ dependencies: Union[Sequence[Depends], None] = None,
89
+ default_response_class: type[Response] = JSONResponse,
90
+ redirect_slashes: bool = True,
91
+ routes: Union[list[BaseRoute], None] = None,
92
+ docs_url: Union[str, None] = "/docs",
93
+ redoc_url: Union[str, None] = "/redoc",
94
+ swagger_ui_oauth2_redirect_url: Union[str, None] = "/docs/oauth2-redirect",
95
+ swagger_ui_init_oauth: Union[dict[str, Any], None] = None,
96
+ middleware: Union[Sequence[Middleware], None] = None,
97
+ exception_handlers: (
98
+ Union[
99
+ dict[
100
+ Union[int, type[Exception]],
101
+ Callable[[Request, Any], Coroutine[Any, Any, Response]],
102
+ ],
103
+ None,
104
+ ]
105
+ ) = None,
106
+ on_startup: Union[Sequence[Callable[[], Any]], None] = None,
107
+ on_shutdown: Union[Sequence[Callable[[], Any]], None] = None,
108
+ lifespan: Union[Lifespan[Self], None] = None,
109
+ terms_of_service: Union[str, None] = None,
110
+ contact: Union[dict[str, Union[str, Any]], None] = None,
111
+ license_info: Union[dict[str, Union[str, Any]], None] = None,
112
+ openapi_prefix: str = "",
113
+ root_path: str = "",
114
+ root_path_in_servers: bool = True,
115
+ responses: Union[dict[Union[int, str], dict[str, Any]], None] = None,
116
+ callbacks: Union[list[BaseRoute], None] = None,
117
+ webhooks: Union[APIRouter, None] = None,
118
+ deprecated: Union[bool, None] = None,
119
+ include_in_schema: bool = True,
120
+ swagger_ui_parameters: Union[dict[str, Any], None] = None,
121
+ generate_unique_id_function: Callable[[routing.APIRoute], str] = Default( # noqa: B008
122
+ generate_unique_id
123
+ ),
124
+ separate_input_output_schemas: bool = True,
125
+ **extra: Any,
126
+ ) -> None:
127
+ self.versions = versions
128
+ self._dependency_overrides_provider = FakeDependencyOverridesProvider({})
129
+ self._cadwyn_initialized = False
130
+
131
+ if api_version_header_name is not None:
132
+ warnings.warn(
133
+ "api_version_header_name is deprecated and will be removed in the future. "
134
+ "Use api_version_parameter_name instead.",
135
+ DeprecationWarning,
136
+ stacklevel=2,
137
+ )
138
+ api_version_parameter_name = api_version_header_name
139
+ if api_version_default_value is not None and api_version_location == "path":
140
+ raise CadwynStructureError(
141
+ "You tried to pass an api_version_default_value while putting the API version in Path. "
142
+ "This is not currently supported by Cadwyn. "
143
+ "Please, open an issue on our github if you'd like to have it."
144
+ )
145
+
146
+ super().__init__(
147
+ debug=debug,
148
+ title=title,
149
+ summary=summary,
150
+ description=description,
151
+ version=version,
152
+ openapi_tags=openapi_tags,
153
+ servers=servers,
154
+ dependencies=dependencies,
155
+ default_response_class=default_response_class,
156
+ redirect_slashes=redirect_slashes,
157
+ openapi_url=None,
158
+ docs_url=None,
159
+ redoc_url=None,
160
+ swagger_ui_oauth2_redirect_url=swagger_ui_oauth2_redirect_url,
161
+ swagger_ui_init_oauth=swagger_ui_init_oauth,
162
+ middleware=middleware,
163
+ exception_handlers=exception_handlers,
164
+ on_startup=on_startup,
165
+ on_shutdown=on_shutdown,
166
+ lifespan=lifespan,
167
+ terms_of_service=terms_of_service,
168
+ contact=contact,
169
+ license_info=license_info,
170
+ openapi_prefix=openapi_prefix,
171
+ root_path=root_path,
172
+ root_path_in_servers=root_path_in_servers,
173
+ responses=responses,
174
+ callbacks=callbacks,
175
+ webhooks=webhooks,
176
+ deprecated=deprecated,
177
+ include_in_schema=include_in_schema,
178
+ swagger_ui_parameters=swagger_ui_parameters,
179
+ generate_unique_id_function=generate_unique_id_function,
180
+ separate_input_output_schemas=separate_input_output_schemas,
181
+ **extra,
182
+ )
183
+
184
+ self._versioned_webhook_routers: dict[VersionType, APIRouter] = {}
185
+ self._latest_version_router = APIRouter(dependency_overrides_provider=self._dependency_overrides_provider)
186
+
187
+ self.changelog_url = changelog_url
188
+ self.include_changelog_url_in_schema = include_changelog_url_in_schema
189
+
190
+ self.docs_url = docs_url
191
+ self.redoc_url = redoc_url
192
+ self.openapi_url = openapi_url
193
+ self.redoc_url = redoc_url
194
+
195
+ self._kwargs_to_router: dict[str, Any] = {
196
+ "routes": routes,
197
+ "redirect_slashes": redirect_slashes,
198
+ "dependency_overrides_provider": self,
199
+ "on_startup": on_startup,
200
+ "on_shutdown": on_shutdown,
201
+ "lifespan": lifespan,
202
+ "default_response_class": default_response_class,
203
+ "dependencies": dependencies,
204
+ "callbacks": callbacks,
205
+ "deprecated": deprecated,
206
+ "include_in_schema": include_in_schema,
207
+ "responses": responses,
208
+ "generate_unique_id_function": generate_unique_id_function,
209
+ }
210
+ self.api_version_format = api_version_format
211
+ self.api_version_parameter_name = api_version_parameter_name
212
+ self.api_version_pythonic_parameter_name = api_version_parameter_name.replace("-", "_")
213
+ self.api_version_title = api_version_title
214
+ self.api_version_description = api_version_description
215
+ if api_version_location == "custom_header":
216
+ self._api_version_manager = HeaderVersionManager(api_version_parameter_name=api_version_parameter_name)
217
+ self._api_version_fastapi_depends_class = fastapi.Header
218
+ elif api_version_location == "path":
219
+ self._api_version_manager = URLVersionManager(possible_version_values=self.versions._version_values_set)
220
+ self._api_version_fastapi_depends_class = fastapi.Path
221
+ else:
222
+ assert_never(api_version_location)
223
+ # TODO: Add a test validating the error message when there are no versions
224
+ default_version_example = next(iter(self.versions._version_values_set))
225
+ if api_version_format == "date":
226
+ self.api_version_validation_data_type = date
227
+ elif api_version_format == "string":
228
+ self.api_version_validation_data_type = str
229
+ else:
230
+ assert_never(default_version_example)
231
+ self.router: _RootCadwynAPIRouter = _RootCadwynAPIRouter( # pyright: ignore[reportIncompatibleVariableOverride]
232
+ **self._kwargs_to_router,
233
+ api_version_parameter_name=api_version_parameter_name,
234
+ api_version_var=self.versions.api_version_var,
235
+ api_version_format=api_version_format,
236
+ )
237
+ unversioned_router = APIRouter(**self._kwargs_to_router)
238
+ self._add_utility_endpoints(unversioned_router)
239
+ self._add_default_versioned_routers()
240
+ self.include_router(unversioned_router)
241
+ self.add_middleware(
242
+ versioning_middleware_class,
243
+ api_version_parameter_name=api_version_parameter_name,
244
+ api_version_manager=self._api_version_manager,
245
+ api_version_var=self.versions.api_version_var,
246
+ api_version_default_value=api_version_default_value,
247
+ )
248
+ if self.api_version_format == "date" and (
249
+ sorted(self.versions.versions, key=lambda v: v.value, reverse=True) != list(self.versions.versions)
250
+ ):
251
+ raise CadwynStructureError(
252
+ "Versions are not sorted correctly. Please sort them in descending order.",
253
+ )
254
+
255
+ @same_definition_as_in(FastAPI.__call__)
256
+ async def __call__(self, scope: Any, receive: Any, send: Any) -> None:
257
+ if not self._cadwyn_initialized:
258
+ self._cadwyn_initialize()
259
+ self.__call__ = super().__call__
260
+ await self.__call__(scope, receive, send)
261
+
262
+ def _cadwyn_initialize(self) -> None:
263
+ generated_routers = generate_versioned_routers(
264
+ self._latest_version_router,
265
+ webhooks=self.webhooks,
266
+ versions=self.versions,
267
+ )
268
+ for version, router in generated_routers.endpoints.items():
269
+ self._add_versioned_routers(router, version=version)
270
+
271
+ for version, router in generated_routers.webhooks.items():
272
+ self._versioned_webhook_routers[version] = router
273
+ self._cadwyn_initialized = True
274
+
275
+ def _add_default_versioned_routers(self) -> None:
276
+ for version in self.versions:
277
+ self.router.versioned_routers[version.value] = APIRouter(**self._kwargs_to_router)
278
+
279
+ @property
280
+ def dependency_overrides(self) -> dict[Callable[..., Any], Callable[..., Any]]:
281
+ # TODO: Remove this approach as it is no longer necessary
282
+ # This is only necessary because we cannot send self to versioned router generator
283
+ # because it takes a deepcopy of the router and self.versions.head_schemas_package was a module
284
+ # which couldn't be copied.
285
+ return self._dependency_overrides_provider.dependency_overrides
286
+
287
+ @dependency_overrides.setter
288
+ def dependency_overrides( # pyright: ignore[reportIncompatibleVariableOverride]
289
+ self,
290
+ value: dict[Callable[..., Any], Callable[..., Any]],
291
+ ) -> None:
292
+ self._dependency_overrides_provider.dependency_overrides = value
293
+
294
+ def generate_changelog(self) -> CadwynChangelogResource:
295
+ return _generate_changelog(self.versions, self.router)
296
+
297
+ def _add_utility_endpoints(self, unversioned_router: APIRouter):
298
+ if self.changelog_url is not None:
299
+ unversioned_router.add_api_route(
300
+ path=self.changelog_url,
301
+ endpoint=self.generate_changelog,
302
+ response_model=CadwynChangelogResource,
303
+ methods=["GET"],
304
+ include_in_schema=self.include_changelog_url_in_schema,
305
+ )
306
+
307
+ if self.openapi_url is not None:
308
+ unversioned_router.add_route(
309
+ path=self.openapi_url,
310
+ endpoint=self.openapi_jsons,
311
+ include_in_schema=False,
312
+ )
313
+ if self.docs_url is not None:
314
+ unversioned_router.add_route(
315
+ path=self.docs_url,
316
+ endpoint=self.swagger_dashboard,
317
+ include_in_schema=False,
318
+ )
319
+ if self.swagger_ui_oauth2_redirect_url:
320
+
321
+ async def swagger_ui_redirect(req: Request) -> HTMLResponse:
322
+ return (
323
+ get_swagger_ui_oauth2_redirect_html() # pragma: no cover # unimportant right now but # TODO
324
+ )
325
+
326
+ self.add_route(
327
+ self.swagger_ui_oauth2_redirect_url,
328
+ swagger_ui_redirect,
329
+ include_in_schema=False,
330
+ )
331
+ if self.redoc_url is not None:
332
+ unversioned_router.add_route(
333
+ path=self.redoc_url,
334
+ endpoint=self.redoc_dashboard,
335
+ include_in_schema=False,
336
+ )
337
+
338
+ def generate_and_include_versioned_routers(self, *routers: APIRouter) -> None:
339
+ for router in routers:
340
+ self._latest_version_router.include_router(router)
341
+
342
+ async def openapi_jsons(self, req: Request) -> JSONResponse:
343
+ version = req.query_params.get("version") or req.headers.get(self.router.api_version_parameter_name)
344
+
345
+ if version in self.router.versioned_routers:
346
+ routes = self.router.versioned_routers[version].routes
347
+ formatted_version = version
348
+ elif version == "unversioned" and self._there_are_public_unversioned_routes():
349
+ routes = self.router.unversioned_routes
350
+ formatted_version = "unversioned"
351
+ else:
352
+ raise HTTPException(
353
+ status_code=404,
354
+ detail=f"OpenApi file of with version `{version}` not found",
355
+ )
356
+
357
+ # Add root path to servers when mounted as sub-app or proxy is used
358
+ urls = (server_data.get("url") for server_data in self.servers)
359
+ server_urls = {url for url in urls if url}
360
+ root_path = self._extract_root_path(req)
361
+ if root_path and root_path not in server_urls and self.root_path_in_servers:
362
+ self.servers.insert(0, {"url": root_path})
363
+
364
+ webhook_routes = None
365
+ if version in self._versioned_webhook_routers:
366
+ webhook_routes = self._versioned_webhook_routers[version].routes
367
+
368
+ return JSONResponse(
369
+ get_openapi(
370
+ title=self.title,
371
+ version=formatted_version,
372
+ openapi_version=self.openapi_version,
373
+ description=self.description,
374
+ summary=self.summary,
375
+ terms_of_service=self.terms_of_service,
376
+ contact=self.contact,
377
+ license_info=self.license_info,
378
+ routes=routes,
379
+ webhooks=webhook_routes,
380
+ tags=self.openapi_tags,
381
+ servers=self.servers,
382
+ )
383
+ )
384
+
385
+ def _there_are_public_unversioned_routes(self):
386
+ return any(isinstance(route, Route) and route.include_in_schema for route in self.router.unversioned_routes)
387
+
388
+ async def swagger_dashboard(self, req: Request) -> Response:
389
+ version = req.query_params.get("version")
390
+
391
+ if version:
392
+ root_path = self._extract_root_path(req)
393
+ openapi_url = root_path + f"{self.openapi_url}?version={quote(version, safe='')}"
394
+ oauth2_redirect_url = self.swagger_ui_oauth2_redirect_url
395
+ if oauth2_redirect_url:
396
+ oauth2_redirect_url = root_path + oauth2_redirect_url
397
+ return get_swagger_ui_html(
398
+ openapi_url=openapi_url,
399
+ title=f"{self.title} - Swagger UI",
400
+ oauth2_redirect_url=oauth2_redirect_url,
401
+ init_oauth=self.swagger_ui_init_oauth,
402
+ swagger_ui_parameters=self.swagger_ui_parameters,
403
+ )
404
+ return self._render_docs_dashboard(req, cast("str", self.docs_url))
405
+
406
+ async def redoc_dashboard(self, req: Request) -> Response:
407
+ version = req.query_params.get("version")
408
+
409
+ if version:
410
+ root_path = self._extract_root_path(req)
411
+ openapi_url = root_path + f"{self.openapi_url}?version={quote(version, safe='')}"
412
+ return get_redoc_html(openapi_url=openapi_url, title=f"{self.title} - ReDoc")
413
+
414
+ return self._render_docs_dashboard(req, docs_url=cast("str", self.redoc_url))
415
+
416
+ def _extract_root_path(self, req: Request):
417
+ return req.scope.get("root_path", "").rstrip("/")
418
+
419
+ def _render_docs_dashboard(self, req: Request, docs_url: str):
420
+ base_host = str(req.base_url).rstrip("/")
421
+ root_path = self._extract_root_path(req)
422
+ base_url = base_host + root_path
423
+ table = {version: f"{base_url}{docs_url}?version={version}" for version in self.router.versions}
424
+ if self._there_are_public_unversioned_routes():
425
+ table |= {"unversioned": f"{base_url}{docs_url}?version=unversioned"}
426
+ return self._templates.TemplateResponse(
427
+ "docs.html",
428
+ {"request": req, "table": table},
429
+ )
430
+
431
+ @deprecated("Use generate_and_include_versioned_routers and VersionBundle versions instead")
432
+ def add_header_versioned_routers(
433
+ self,
434
+ first_router: APIRouter,
435
+ *other_routers: APIRouter,
436
+ header_value: str,
437
+ ) -> list[BaseRoute]:
438
+ """Add all routes from routers to be routed using header_value and return the added routes"""
439
+ try:
440
+ date.fromisoformat(header_value)
441
+ except ValueError as e:
442
+ raise ValueError("header_value should be in ISO 8601 format") from e
443
+
444
+ return self._add_versioned_routers(first_router, *other_routers, version=header_value)
445
+
446
+ def _add_versioned_routers(
447
+ self, first_router: APIRouter, *other_routers: APIRouter, version: str
448
+ ) -> list[BaseRoute]:
449
+ added_routes: list[BaseRoute] = []
450
+ if version not in self.router.versioned_routers: # pragma: no branch
451
+ self.router.versioned_routers[version] = APIRouter(**self._kwargs_to_router)
452
+
453
+ versioned_router = self.router.versioned_routers[version]
454
+ if self.openapi_url is not None: # pragma: no branch
455
+ versioned_router.add_route(
456
+ path=self.openapi_url,
457
+ endpoint=self.openapi_jsons,
458
+ include_in_schema=False,
459
+ )
460
+ added_routes.append(versioned_router.routes[-1])
461
+
462
+ added_route_count = 0
463
+ for router in (first_router, *other_routers):
464
+ self.router.versioned_routers[version].include_router(
465
+ router,
466
+ dependencies=[
467
+ Depends(
468
+ _generate_api_version_dependency(
469
+ api_version_pythonic_parameter_name=self.api_version_pythonic_parameter_name,
470
+ default_value=version,
471
+ fastapi_depends_class=self._api_version_fastapi_depends_class,
472
+ validation_data_type=self.api_version_validation_data_type,
473
+ title=self.api_version_title,
474
+ description=self.api_version_description,
475
+ )
476
+ )
477
+ ],
478
+ )
479
+ added_route_count += len(router.routes)
480
+
481
+ added_routes.extend(versioned_router.routes[-added_route_count:])
482
+ self.router.routes.extend(added_routes)
483
+
484
+ return added_routes