fastapi-factory-utilities 0.4.0__py3-none-any.whl → 0.8.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of fastapi-factory-utilities might be problematic. Click here for more details.
- fastapi_factory_utilities/core/api/__init__.py +1 -1
- fastapi_factory_utilities/core/api/v1/sys/health.py +1 -1
- fastapi_factory_utilities/core/app/__init__.py +4 -4
- fastapi_factory_utilities/core/app/application.py +22 -26
- fastapi_factory_utilities/core/app/builder.py +8 -32
- fastapi_factory_utilities/core/app/fastapi_builder.py +3 -2
- fastapi_factory_utilities/core/exceptions.py +64 -29
- fastapi_factory_utilities/core/plugins/__init__.py +2 -31
- fastapi_factory_utilities/core/plugins/abstracts.py +40 -0
- fastapi_factory_utilities/core/plugins/aiopika/__init__.py +25 -0
- fastapi_factory_utilities/core/plugins/aiopika/abstract.py +48 -0
- fastapi_factory_utilities/core/plugins/aiopika/configs.py +85 -0
- fastapi_factory_utilities/core/plugins/aiopika/depends.py +20 -0
- fastapi_factory_utilities/core/plugins/aiopika/exceptions.py +29 -0
- fastapi_factory_utilities/core/plugins/aiopika/exchange.py +69 -0
- fastapi_factory_utilities/core/plugins/aiopika/listener/__init__.py +7 -0
- fastapi_factory_utilities/core/plugins/aiopika/listener/abstract.py +72 -0
- fastapi_factory_utilities/core/plugins/aiopika/message.py +86 -0
- fastapi_factory_utilities/core/plugins/aiopika/plugins.py +84 -0
- fastapi_factory_utilities/core/plugins/aiopika/publisher/__init__.py +7 -0
- fastapi_factory_utilities/core/plugins/aiopika/publisher/abstract.py +66 -0
- fastapi_factory_utilities/core/plugins/aiopika/queue.py +88 -0
- fastapi_factory_utilities/core/plugins/odm_plugin/__init__.py +14 -157
- fastapi_factory_utilities/core/plugins/odm_plugin/builder.py +3 -3
- fastapi_factory_utilities/core/plugins/odm_plugin/configs.py +1 -1
- fastapi_factory_utilities/core/plugins/odm_plugin/documents.py +1 -1
- fastapi_factory_utilities/core/plugins/odm_plugin/helpers.py +16 -0
- fastapi_factory_utilities/core/plugins/odm_plugin/plugins.py +155 -0
- fastapi_factory_utilities/core/plugins/odm_plugin/repositories.py +1 -0
- fastapi_factory_utilities/core/plugins/opentelemetry_plugin/__init__.py +8 -121
- fastapi_factory_utilities/core/plugins/opentelemetry_plugin/instruments/__init__.py +85 -0
- fastapi_factory_utilities/core/plugins/opentelemetry_plugin/plugins.py +137 -0
- fastapi_factory_utilities/core/plugins/taskiq_plugins/__init__.py +31 -0
- fastapi_factory_utilities/core/plugins/taskiq_plugins/configs.py +12 -0
- fastapi_factory_utilities/core/plugins/taskiq_plugins/depends.py +51 -0
- fastapi_factory_utilities/core/plugins/taskiq_plugins/exceptions.py +13 -0
- fastapi_factory_utilities/core/plugins/taskiq_plugins/plugin.py +41 -0
- fastapi_factory_utilities/core/plugins/taskiq_plugins/schedulers.py +187 -0
- fastapi_factory_utilities/core/protocols.py +1 -54
- fastapi_factory_utilities/core/security/__init__.py +5 -0
- fastapi_factory_utilities/core/security/abstracts.py +42 -0
- fastapi_factory_utilities/core/security/jwt/__init__.py +45 -0
- fastapi_factory_utilities/core/security/jwt/configs.py +32 -0
- fastapi_factory_utilities/core/security/jwt/decoders.py +130 -0
- fastapi_factory_utilities/core/security/jwt/exceptions.py +23 -0
- fastapi_factory_utilities/core/security/jwt/objects.py +107 -0
- fastapi_factory_utilities/core/security/jwt/services.py +176 -0
- fastapi_factory_utilities/core/security/jwt/stores.py +43 -0
- fastapi_factory_utilities/core/security/jwt/types.py +9 -0
- fastapi_factory_utilities/core/security/jwt/verifiers.py +46 -0
- fastapi_factory_utilities/core/security/kratos.py +43 -43
- fastapi_factory_utilities/core/services/hydra/__init__.py +10 -3
- fastapi_factory_utilities/core/services/hydra/services.py +112 -34
- fastapi_factory_utilities/core/services/status/__init__.py +2 -2
- fastapi_factory_utilities/core/services/status/exceptions.py +1 -1
- fastapi_factory_utilities/core/utils/status.py +2 -1
- fastapi_factory_utilities/core/utils/yaml_reader.py +1 -1
- fastapi_factory_utilities/example/app.py +15 -5
- fastapi_factory_utilities/example/entities/books/__init__.py +1 -1
- fastapi_factory_utilities/example/models/books/__init__.py +1 -1
- {fastapi_factory_utilities-0.4.0.dist-info → fastapi_factory_utilities-0.8.3.dist-info}/METADATA +14 -8
- fastapi_factory_utilities-0.8.3.dist-info/RECORD +111 -0
- {fastapi_factory_utilities-0.4.0.dist-info → fastapi_factory_utilities-0.8.3.dist-info}/WHEEL +1 -1
- fastapi_factory_utilities/core/app/plugin_manager/__init__.py +0 -15
- fastapi_factory_utilities/core/app/plugin_manager/exceptions.py +0 -33
- fastapi_factory_utilities/core/app/plugin_manager/plugin_manager.py +0 -190
- fastapi_factory_utilities/core/plugins/example/__init__.py +0 -31
- fastapi_factory_utilities/core/plugins/httpx_plugin/__init__.py +0 -31
- fastapi_factory_utilities/core/security/jwt.py +0 -158
- fastapi_factory_utilities-0.4.0.dist-info/RECORD +0 -82
- {fastapi_factory_utilities-0.4.0.dist-info → fastapi_factory_utilities-0.8.3.dist-info}/entry_points.txt +0 -0
- {fastapi_factory_utilities-0.4.0.dist-info → fastapi_factory_utilities-0.8.3.dist-info/licenses}/LICENSE +0 -0
|
@@ -0,0 +1,13 @@
|
|
|
1
|
+
"""Provides the exceptions for the Taskiq plugin."""
|
|
2
|
+
|
|
3
|
+
from typing import Any
|
|
4
|
+
|
|
5
|
+
from fastapi_factory_utilities.core.exceptions import FastAPIFactoryUtilitiesError
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class TaskiqPluginBaseError(FastAPIFactoryUtilitiesError):
|
|
9
|
+
"""Base class for all exceptions raised by the Taskiq plugin."""
|
|
10
|
+
|
|
11
|
+
def __init__(self, message: str, **kwargs: Any) -> None:
|
|
12
|
+
"""Initialize the Taskiq plugin base exception."""
|
|
13
|
+
super().__init__(message, **kwargs)
|
|
@@ -0,0 +1,41 @@
|
|
|
1
|
+
"""Provides the Taskiq plugin."""
|
|
2
|
+
|
|
3
|
+
from collections.abc import Callable
|
|
4
|
+
|
|
5
|
+
from fastapi_factory_utilities.core.plugins.abstracts import PluginAbstract
|
|
6
|
+
|
|
7
|
+
from .configs import RedisCredentialsConfig
|
|
8
|
+
from .depends import DEPENDS_SCHEDULER_COMPONENT_KEY
|
|
9
|
+
from .schedulers import SchedulerComponent
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class TaskiqPlugin(PluginAbstract):
|
|
13
|
+
"""Taskiq plugin."""
|
|
14
|
+
|
|
15
|
+
def __init__(
|
|
16
|
+
self, redis_credentials_config: RedisCredentialsConfig, register_hook: Callable[[SchedulerComponent], None]
|
|
17
|
+
) -> None:
|
|
18
|
+
"""Initialize the Taskiq plugin."""
|
|
19
|
+
super().__init__()
|
|
20
|
+
self._redis_credentials_config: RedisCredentialsConfig = redis_credentials_config
|
|
21
|
+
self._register_hook: Callable[[SchedulerComponent], None] = register_hook
|
|
22
|
+
self._scheduler_component: SchedulerComponent = SchedulerComponent()
|
|
23
|
+
|
|
24
|
+
def on_load(self) -> None:
|
|
25
|
+
"""On load."""
|
|
26
|
+
assert self._application is not None
|
|
27
|
+
self._scheduler_component.configure(
|
|
28
|
+
redis_connection_string=self._redis_credentials_config.url, app=self._application.get_asgi_app()
|
|
29
|
+
)
|
|
30
|
+
self._add_to_state(key=DEPENDS_SCHEDULER_COMPONENT_KEY, value=self._scheduler_component)
|
|
31
|
+
self._register_hook(self._scheduler_component)
|
|
32
|
+
|
|
33
|
+
async def on_startup(self) -> None:
|
|
34
|
+
"""On startup."""
|
|
35
|
+
assert self._application is not None
|
|
36
|
+
await self._scheduler_component.startup(app=self._application.get_asgi_app())
|
|
37
|
+
|
|
38
|
+
async def on_shutdown(self) -> None:
|
|
39
|
+
"""On shutdown."""
|
|
40
|
+
assert self._application is not None
|
|
41
|
+
await self._scheduler_component.shutdown()
|
|
@@ -0,0 +1,187 @@
|
|
|
1
|
+
"""Scheduler module for fastapi_factory_utilities.
|
|
2
|
+
|
|
3
|
+
This module provides components and utilities for scheduling tasks using Taskiq, FastAPI, and Redis.
|
|
4
|
+
It enables registration, configuration, and management of scheduled tasks in FastAPI applications.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
import asyncio
|
|
8
|
+
from collections.abc import Coroutine
|
|
9
|
+
from typing import Any, Self, cast
|
|
10
|
+
|
|
11
|
+
import taskiq_fastapi
|
|
12
|
+
from fastapi import FastAPI
|
|
13
|
+
from structlog.stdlib import get_logger
|
|
14
|
+
from taskiq import (
|
|
15
|
+
AsyncBroker,
|
|
16
|
+
AsyncTaskiqDecoratedTask,
|
|
17
|
+
ScheduleSource,
|
|
18
|
+
TaskiqScheduler,
|
|
19
|
+
)
|
|
20
|
+
from taskiq.api import run_receiver_task, run_scheduler_task
|
|
21
|
+
from taskiq.scheduler.created_schedule import CreatedSchedule
|
|
22
|
+
from taskiq.scheduler.scheduled_task import ScheduledTask
|
|
23
|
+
from taskiq_redis import (
|
|
24
|
+
ListRedisScheduleSource,
|
|
25
|
+
RedisAsyncResultBackend,
|
|
26
|
+
RedisStreamBroker,
|
|
27
|
+
)
|
|
28
|
+
|
|
29
|
+
_logger = get_logger(__package__)
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
class SchedulerComponent:
|
|
33
|
+
"""Scheduler component."""
|
|
34
|
+
|
|
35
|
+
NAME_SUFFIX: str = "tiktok_integration"
|
|
36
|
+
|
|
37
|
+
def __init__(self) -> None:
|
|
38
|
+
"""Initialize the scheduler component."""
|
|
39
|
+
self._result_backend: RedisAsyncResultBackend[Any] | None = None
|
|
40
|
+
self._stream_broker: RedisStreamBroker | None = None
|
|
41
|
+
self._scheduler: TaskiqScheduler | None = None
|
|
42
|
+
self._scheduler_source: ListRedisScheduleSource | None = None
|
|
43
|
+
self._dyn_task: AsyncTaskiqDecoratedTask[Any, Any] | None = None
|
|
44
|
+
self._schedule_cron: ScheduledTask | None = None
|
|
45
|
+
self._schedulers_tasks: dict[str, AsyncTaskiqDecoratedTask[Any, Any]] = {}
|
|
46
|
+
|
|
47
|
+
def register_task(self, task: Coroutine[Any, Any, Any], task_name: str) -> None:
|
|
48
|
+
"""Register a task.
|
|
49
|
+
|
|
50
|
+
Args:
|
|
51
|
+
task: The task to register.
|
|
52
|
+
task_name: The name of the task.
|
|
53
|
+
|
|
54
|
+
Raises:
|
|
55
|
+
ValueError: If the task is already registered.
|
|
56
|
+
ValueError: If the stream broker is not initialized.
|
|
57
|
+
"""
|
|
58
|
+
if self._stream_broker is None:
|
|
59
|
+
raise ValueError("Stream broker is not initialized")
|
|
60
|
+
|
|
61
|
+
if task_name in self._schedulers_tasks:
|
|
62
|
+
raise ValueError(f"Task {task_name} already registered")
|
|
63
|
+
|
|
64
|
+
self._schedulers_tasks[task_name] = self._stream_broker.register_task(task, task_name) # type: ignore
|
|
65
|
+
|
|
66
|
+
def get_task(self, task_name: str) -> AsyncTaskiqDecoratedTask[Any, Any]:
|
|
67
|
+
"""Get a task.
|
|
68
|
+
|
|
69
|
+
Args:
|
|
70
|
+
task_name: The name of the task.
|
|
71
|
+
|
|
72
|
+
Returns:
|
|
73
|
+
AsyncTaskiqDecoratedTask: The task.
|
|
74
|
+
|
|
75
|
+
Raises:
|
|
76
|
+
ValueError: If the task is not registered.
|
|
77
|
+
"""
|
|
78
|
+
if task_name not in self._schedulers_tasks:
|
|
79
|
+
raise ValueError(f"Task {task_name} not registered")
|
|
80
|
+
return self._schedulers_tasks[task_name]
|
|
81
|
+
|
|
82
|
+
def configure(self, redis_connection_string: str, app: FastAPI) -> Self:
|
|
83
|
+
"""Configure the scheduler component."""
|
|
84
|
+
self._result_backend = RedisAsyncResultBackend(
|
|
85
|
+
redis_url=redis_connection_string,
|
|
86
|
+
prefix_str=f"velmios_taskiq_result_backend_{self.NAME_SUFFIX}",
|
|
87
|
+
result_ex_time=120,
|
|
88
|
+
)
|
|
89
|
+
self._stream_broker = RedisStreamBroker(
|
|
90
|
+
url=redis_connection_string,
|
|
91
|
+
queue_name=f"velmios_taskiq_stream_broker_{self.NAME_SUFFIX}",
|
|
92
|
+
consumer_group_name=f"velmios_taskiq_consumer_group_{self.NAME_SUFFIX}",
|
|
93
|
+
).with_result_backend(self._result_backend)
|
|
94
|
+
|
|
95
|
+
taskiq_fastapi.populate_dependency_context(self._stream_broker, app)
|
|
96
|
+
|
|
97
|
+
self._scheduler_source = ListRedisScheduleSource(
|
|
98
|
+
url=redis_connection_string,
|
|
99
|
+
prefix=f"velmios_taskiq_schedule_source_{self.NAME_SUFFIX}",
|
|
100
|
+
)
|
|
101
|
+
|
|
102
|
+
self._scheduler = TaskiqScheduler(
|
|
103
|
+
broker=self._stream_broker,
|
|
104
|
+
sources=[self._scheduler_source],
|
|
105
|
+
)
|
|
106
|
+
|
|
107
|
+
return self
|
|
108
|
+
|
|
109
|
+
async def startup(self, app: FastAPI) -> None:
|
|
110
|
+
"""Start the scheduler."""
|
|
111
|
+
if self._result_backend is None:
|
|
112
|
+
raise ValueError("Result backend is not initialized")
|
|
113
|
+
if self._stream_broker is None:
|
|
114
|
+
raise ValueError("Stream broker is not initialized")
|
|
115
|
+
if self._scheduler is None:
|
|
116
|
+
raise ValueError("Scheduler is not initialized")
|
|
117
|
+
if self._scheduler_source is None:
|
|
118
|
+
raise ValueError("Scheduler source is not initialized")
|
|
119
|
+
|
|
120
|
+
_logger.info("Starting scheduler")
|
|
121
|
+
await self._result_backend.startup()
|
|
122
|
+
await self._stream_broker.startup()
|
|
123
|
+
await self._scheduler.startup()
|
|
124
|
+
_logger.info("Scheduler started")
|
|
125
|
+
_logger.info("Scheduling task")
|
|
126
|
+
schedules: list[ScheduledTask] = await self._scheduler_source.get_schedules()
|
|
127
|
+
_logger.info("Schedules retrieved", schedules=schedules)
|
|
128
|
+
|
|
129
|
+
self._schedule_cron = next(filter(lambda x: x.task_name == "heartbeat", schedules), None)
|
|
130
|
+
|
|
131
|
+
if self._schedule_cron is None:
|
|
132
|
+
_logger.info("No schedules found, scheduling task")
|
|
133
|
+
self._dyn_task = self.get_task("heartbeat")
|
|
134
|
+
task_created: CreatedSchedule[Any] = await self._dyn_task.schedule_by_cron(
|
|
135
|
+
source=self._scheduler_source, cron="* * * * *", msg="every minute"
|
|
136
|
+
)
|
|
137
|
+
self._schedule_cron = task_created.task
|
|
138
|
+
_logger.info("Task scheduled")
|
|
139
|
+
else:
|
|
140
|
+
_logger.info("Schedules found, skipping scheduling")
|
|
141
|
+
|
|
142
|
+
_logger.info("Starting worker and scheduler tasks")
|
|
143
|
+
taskiq_fastapi.populate_dependency_context(self._stream_broker, app, app.state) # type: ignore
|
|
144
|
+
self._worker_task: asyncio.Task[None] = asyncio.create_task(run_receiver_task(self._stream_broker))
|
|
145
|
+
self._scheduler_task: asyncio.Task[None] = asyncio.create_task(run_scheduler_task(self._scheduler))
|
|
146
|
+
_logger.info("Worker and scheduler tasks started")
|
|
147
|
+
|
|
148
|
+
async def shutdown(self) -> None:
|
|
149
|
+
"""Stop the scheduler."""
|
|
150
|
+
_logger.info("Stopping worker")
|
|
151
|
+
self._worker_task.cancel()
|
|
152
|
+
self._scheduler_task.cancel()
|
|
153
|
+
try:
|
|
154
|
+
await self._worker_task
|
|
155
|
+
except (asyncio.CancelledError, RuntimeError) as e:
|
|
156
|
+
_logger.info("Worker task cancelled", error=e)
|
|
157
|
+
try:
|
|
158
|
+
await self._scheduler_task
|
|
159
|
+
except (asyncio.CancelledError, RuntimeError) as e:
|
|
160
|
+
_logger.info("Scheduler task cancelled", error=e)
|
|
161
|
+
|
|
162
|
+
while not self._worker_task.done() or not self._scheduler_task.done():
|
|
163
|
+
await asyncio.sleep(0.1)
|
|
164
|
+
|
|
165
|
+
_logger.info("Stopping scheduler")
|
|
166
|
+
if self._scheduler is not None:
|
|
167
|
+
await self._scheduler.shutdown()
|
|
168
|
+
if self._stream_broker is not None:
|
|
169
|
+
await self._stream_broker.shutdown()
|
|
170
|
+
if self._result_backend is not None:
|
|
171
|
+
await self._result_backend.shutdown()
|
|
172
|
+
_logger.info("Scheduler stopped")
|
|
173
|
+
|
|
174
|
+
@property
|
|
175
|
+
def scheduler(self) -> TaskiqScheduler:
|
|
176
|
+
"""Get the scheduler."""
|
|
177
|
+
return cast(TaskiqScheduler, self._scheduler)
|
|
178
|
+
|
|
179
|
+
@property
|
|
180
|
+
def broker(self) -> AsyncBroker:
|
|
181
|
+
"""Get the broker."""
|
|
182
|
+
return cast(AsyncBroker, self._stream_broker)
|
|
183
|
+
|
|
184
|
+
@property
|
|
185
|
+
def scheduler_source(self) -> ScheduleSource:
|
|
186
|
+
"""Get the scheduler source."""
|
|
187
|
+
return cast(ScheduleSource, self._scheduler_source)
|
|
@@ -1,17 +1,15 @@
|
|
|
1
1
|
"""Protocols for the base application."""
|
|
2
2
|
|
|
3
3
|
from abc import abstractmethod
|
|
4
|
-
from typing import TYPE_CHECKING, ClassVar, Protocol
|
|
4
|
+
from typing import TYPE_CHECKING, ClassVar, Protocol
|
|
5
5
|
|
|
6
6
|
from beanie import Document
|
|
7
7
|
from fastapi import FastAPI
|
|
8
8
|
|
|
9
|
-
from fastapi_factory_utilities.core.plugins import PluginsEnum
|
|
10
9
|
from fastapi_factory_utilities.core.services.status.services import StatusService
|
|
11
10
|
|
|
12
11
|
if TYPE_CHECKING:
|
|
13
12
|
from fastapi_factory_utilities.core.app.config import RootConfig
|
|
14
|
-
from fastapi_factory_utilities.core.plugins import PluginState
|
|
15
13
|
|
|
16
14
|
|
|
17
15
|
class ApplicationAbstractProtocol(Protocol):
|
|
@@ -21,8 +19,6 @@ class ApplicationAbstractProtocol(Protocol):
|
|
|
21
19
|
|
|
22
20
|
ODM_DOCUMENT_MODELS: ClassVar[list[type[Document]]]
|
|
23
21
|
|
|
24
|
-
DEFAULT_PLUGINS_ACTIVATED: ClassVar[list[PluginsEnum]]
|
|
25
|
-
|
|
26
22
|
@abstractmethod
|
|
27
23
|
def get_config(self) -> "RootConfig":
|
|
28
24
|
"""Get the application configuration."""
|
|
@@ -34,52 +30,3 @@ class ApplicationAbstractProtocol(Protocol):
|
|
|
34
30
|
@abstractmethod
|
|
35
31
|
def get_status_service(self) -> StatusService:
|
|
36
32
|
"""Get the status service."""
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
@runtime_checkable
|
|
40
|
-
class PluginProtocol(Protocol):
|
|
41
|
-
"""Defines the protocol for the plugins."""
|
|
42
|
-
|
|
43
|
-
@abstractmethod
|
|
44
|
-
def pre_conditions_check(self, application: ApplicationAbstractProtocol) -> bool:
|
|
45
|
-
"""Check the pre-conditions for the plugin.
|
|
46
|
-
|
|
47
|
-
Args:
|
|
48
|
-
application (BaseApplicationProtocol): The application.
|
|
49
|
-
|
|
50
|
-
Returns:
|
|
51
|
-
bool: True if the pre-conditions are met, False otherwise.
|
|
52
|
-
"""
|
|
53
|
-
|
|
54
|
-
@abstractmethod
|
|
55
|
-
def on_load(self, application: ApplicationAbstractProtocol) -> list["PluginState"] | None:
|
|
56
|
-
"""The actions to perform on load for the plugin.
|
|
57
|
-
|
|
58
|
-
Args:
|
|
59
|
-
application (BaseApplicationProtocol): The application.
|
|
60
|
-
|
|
61
|
-
Returns:
|
|
62
|
-
None
|
|
63
|
-
"""
|
|
64
|
-
|
|
65
|
-
@abstractmethod
|
|
66
|
-
async def on_startup(self, application: ApplicationAbstractProtocol) -> list["PluginState"] | None:
|
|
67
|
-
"""The actions to perform on startup for the plugin.
|
|
68
|
-
|
|
69
|
-
Args:
|
|
70
|
-
application (BaseApplicationProtocol): The application.
|
|
71
|
-
|
|
72
|
-
Returns:
|
|
73
|
-
None
|
|
74
|
-
"""
|
|
75
|
-
|
|
76
|
-
@abstractmethod
|
|
77
|
-
async def on_shutdown(self, application: ApplicationAbstractProtocol) -> None:
|
|
78
|
-
"""The actions to perform on shutdown for the plugin.
|
|
79
|
-
|
|
80
|
-
Args:
|
|
81
|
-
application (BaseApplicationProtocol): The application.
|
|
82
|
-
|
|
83
|
-
Returns:
|
|
84
|
-
None
|
|
85
|
-
"""
|
|
@@ -0,0 +1,42 @@
|
|
|
1
|
+
"""Provides the security authentication abstract classes."""
|
|
2
|
+
|
|
3
|
+
from abc import ABC, abstractmethod
|
|
4
|
+
|
|
5
|
+
from fastapi import Request
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class AuthenticationAbstract(ABC):
|
|
9
|
+
"""Authentication abstract class."""
|
|
10
|
+
|
|
11
|
+
def __init__(self, raise_exception: bool = True) -> None:
|
|
12
|
+
"""Initialize the authentication abstract class.
|
|
13
|
+
|
|
14
|
+
Args:
|
|
15
|
+
raise_exception (bool): Whether to raise an exception or return None.
|
|
16
|
+
"""
|
|
17
|
+
self._raise_exception: bool = raise_exception
|
|
18
|
+
self._errors: list[Exception] = []
|
|
19
|
+
|
|
20
|
+
def has_errors(self) -> bool:
|
|
21
|
+
"""Check if the authentication has errors.
|
|
22
|
+
|
|
23
|
+
Returns:
|
|
24
|
+
bool: True if the authentication has errors, False otherwise.
|
|
25
|
+
"""
|
|
26
|
+
return len(self._errors) > 0
|
|
27
|
+
|
|
28
|
+
def raise_exception(self, exception: Exception) -> None:
|
|
29
|
+
"""Raise the exception if the authentication has errors.
|
|
30
|
+
|
|
31
|
+
Args:
|
|
32
|
+
exception (Exception): The exception to raise.
|
|
33
|
+
"""
|
|
34
|
+
if self._raise_exception:
|
|
35
|
+
raise exception
|
|
36
|
+
else:
|
|
37
|
+
self._errors.append(exception)
|
|
38
|
+
|
|
39
|
+
@abstractmethod
|
|
40
|
+
async def authenticate(self, request: Request) -> None:
|
|
41
|
+
"""Authenticate the request."""
|
|
42
|
+
raise NotImplementedError()
|
|
@@ -0,0 +1,45 @@
|
|
|
1
|
+
"""Provides security-related functions for the API."""
|
|
2
|
+
|
|
3
|
+
from .configs import JWTBearerAuthenticationConfig
|
|
4
|
+
from .decoders import (
|
|
5
|
+
JWTBearerTokenDecoder,
|
|
6
|
+
JWTBearerTokenDecoderAbstract,
|
|
7
|
+
)
|
|
8
|
+
from .exceptions import (
|
|
9
|
+
InvalidJWTError,
|
|
10
|
+
InvalidJWTPayploadError,
|
|
11
|
+
JWTAuthenticationError,
|
|
12
|
+
MissingJWTCredentialsError,
|
|
13
|
+
NotVerifiedJWTError,
|
|
14
|
+
)
|
|
15
|
+
from .objects import JWTPayload
|
|
16
|
+
from .services import (
|
|
17
|
+
JWTAuthenticationService,
|
|
18
|
+
JWTAuthenticationServiceAbstract,
|
|
19
|
+
)
|
|
20
|
+
from .stores import JWKStoreAbstract, JWKStoreMemory
|
|
21
|
+
from .types import JWTToken, OAuth2Audience, OAuth2Issuer, OAuth2Scope, OAuth2Subject
|
|
22
|
+
from .verifiers import JWTNoneVerifier, JWTVerifierAbstract
|
|
23
|
+
|
|
24
|
+
__all__: list[str] = [
|
|
25
|
+
"InvalidJWTError",
|
|
26
|
+
"InvalidJWTPayploadError",
|
|
27
|
+
"JWKStoreAbstract",
|
|
28
|
+
"JWKStoreMemory",
|
|
29
|
+
"JWTAuthenticationError",
|
|
30
|
+
"JWTAuthenticationService",
|
|
31
|
+
"JWTAuthenticationServiceAbstract",
|
|
32
|
+
"JWTBearerAuthenticationConfig",
|
|
33
|
+
"JWTBearerTokenDecoder",
|
|
34
|
+
"JWTBearerTokenDecoderAbstract",
|
|
35
|
+
"JWTNoneVerifier",
|
|
36
|
+
"JWTPayload",
|
|
37
|
+
"JWTToken",
|
|
38
|
+
"JWTVerifierAbstract",
|
|
39
|
+
"MissingJWTCredentialsError",
|
|
40
|
+
"NotVerifiedJWTError",
|
|
41
|
+
"OAuth2Audience",
|
|
42
|
+
"OAuth2Issuer",
|
|
43
|
+
"OAuth2Scope",
|
|
44
|
+
"OAuth2Subject",
|
|
45
|
+
]
|
|
@@ -0,0 +1,32 @@
|
|
|
1
|
+
"""Provides the configurations for the JWT bearer token."""
|
|
2
|
+
|
|
3
|
+
from typing import ClassVar
|
|
4
|
+
|
|
5
|
+
from jwt.algorithms import get_default_algorithms, requires_cryptography
|
|
6
|
+
from pydantic import BaseModel, ConfigDict, Field, field_validator
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class JWTBearerAuthenticationConfig(BaseModel):
|
|
10
|
+
"""JWT bearer token authentication configuration."""
|
|
11
|
+
|
|
12
|
+
model_config: ClassVar[ConfigDict] = ConfigDict(frozen=True, extra="forbid")
|
|
13
|
+
|
|
14
|
+
authorized_algorithms: list[str] = Field(
|
|
15
|
+
default_factory=lambda: list(get_default_algorithms().keys()), description="The authorized algorithms."
|
|
16
|
+
)
|
|
17
|
+
|
|
18
|
+
audience: str = Field(description="The audience to be included in the JWT token.")
|
|
19
|
+
authorized_audiences: list[str] | None = Field(default=None, description="The authorized audiences.")
|
|
20
|
+
authorized_issuers: list[str] | None = Field(default=None, description="The authorized issuers.")
|
|
21
|
+
|
|
22
|
+
@field_validator("authorized_algorithms")
|
|
23
|
+
@classmethod
|
|
24
|
+
def validate_authorized_algorithms(cls, v: list[str]) -> list[str]:
|
|
25
|
+
"""Validate the authorized algorithms."""
|
|
26
|
+
invalid_algorithms: list[str] = []
|
|
27
|
+
for algorithm in v:
|
|
28
|
+
if algorithm not in requires_cryptography:
|
|
29
|
+
invalid_algorithms.append(algorithm)
|
|
30
|
+
if invalid_algorithms:
|
|
31
|
+
raise ValueError(f"Invalid algorithms: {invalid_algorithms}")
|
|
32
|
+
return v
|
|
@@ -0,0 +1,130 @@
|
|
|
1
|
+
"""Provides the JWT bearer token decoders.
|
|
2
|
+
|
|
3
|
+
Can be implemented to support different JWT bearer token formats or additional claims.
|
|
4
|
+
https://www.iana.org/assignments/jwt/jwt.xhtml#claims
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
from abc import ABC, abstractmethod
|
|
8
|
+
from typing import Any, Generic, TypeVar
|
|
9
|
+
|
|
10
|
+
from jwt import InvalidTokenError, decode, get_unverified_header
|
|
11
|
+
from jwt.api_jwk import PyJWK
|
|
12
|
+
from pydantic import ValidationError
|
|
13
|
+
|
|
14
|
+
from .configs import JWTBearerAuthenticationConfig
|
|
15
|
+
from .exceptions import InvalidJWTError, InvalidJWTPayploadError
|
|
16
|
+
from .objects import JWTPayload
|
|
17
|
+
from .stores import JWKStoreAbstract
|
|
18
|
+
from .types import JWTToken, OAuth2Subject
|
|
19
|
+
|
|
20
|
+
JWTBearerPayloadGeneric = TypeVar("JWTBearerPayloadGeneric", bound=JWTPayload)
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
async def decode_jwt_token_payload(
|
|
24
|
+
jwt_token: JWTToken,
|
|
25
|
+
public_key: PyJWK,
|
|
26
|
+
jwt_bearer_authentication_config: JWTBearerAuthenticationConfig,
|
|
27
|
+
subject: OAuth2Subject | None = None,
|
|
28
|
+
) -> dict[str, Any]:
|
|
29
|
+
"""Decode the JWT bearer token payload.
|
|
30
|
+
|
|
31
|
+
Args:
|
|
32
|
+
jwt_token (JWTToken): The JWT bearer token.
|
|
33
|
+
public_key (PyJWK): The public key.
|
|
34
|
+
jwt_bearer_authentication_config (JWTBearerAuthenticationConfig): The JWT bearer authentication configuration.
|
|
35
|
+
subject (OAuth2Subject | None): The subject.
|
|
36
|
+
|
|
37
|
+
Returns:
|
|
38
|
+
dict[str, Any]: The decoded JWT bearer token payload.
|
|
39
|
+
|
|
40
|
+
Raises:
|
|
41
|
+
JWTBearerTokenDecoderError: If the JWT bearer token is invalid.
|
|
42
|
+
"""
|
|
43
|
+
# Additional kwargs for the decode function
|
|
44
|
+
kwargs: dict[str, Any] = {}
|
|
45
|
+
if jwt_bearer_authentication_config.authorized_issuers:
|
|
46
|
+
kwargs["issuer"] = jwt_bearer_authentication_config.authorized_issuers
|
|
47
|
+
if jwt_bearer_authentication_config.authorized_audiences:
|
|
48
|
+
kwargs["audience"] = jwt_bearer_authentication_config.authorized_audiences
|
|
49
|
+
if subject:
|
|
50
|
+
kwargs["subject"] = subject
|
|
51
|
+
# Decode the JWT bearer token payload
|
|
52
|
+
try:
|
|
53
|
+
return decode(
|
|
54
|
+
jwt=jwt_token,
|
|
55
|
+
key=public_key,
|
|
56
|
+
algorithms=jwt_bearer_authentication_config.authorized_algorithms,
|
|
57
|
+
options={"verify_signature": True},
|
|
58
|
+
**kwargs,
|
|
59
|
+
)
|
|
60
|
+
except InvalidTokenError as e:
|
|
61
|
+
raise InvalidJWTError("Failed to decode the JWT bearer token payload") from e
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
class JWTBearerTokenDecoderAbstract(ABC, Generic[JWTBearerPayloadGeneric]):
|
|
65
|
+
"""JWT bearer token decoder."""
|
|
66
|
+
|
|
67
|
+
def get_kid_from_jwt_unsafe_header(self, jwt_token: JWTToken) -> str:
|
|
68
|
+
"""Get the kid from the JWT header.
|
|
69
|
+
|
|
70
|
+
Args:
|
|
71
|
+
jwt_token (JWTToken): The JWT bearer token.
|
|
72
|
+
|
|
73
|
+
Returns:
|
|
74
|
+
str: The kid.
|
|
75
|
+
"""
|
|
76
|
+
try:
|
|
77
|
+
jwt_unsafe_headers: dict[str, Any] = get_unverified_header(jwt_token)
|
|
78
|
+
return jwt_unsafe_headers["kid"]
|
|
79
|
+
except (KeyError, InvalidTokenError) as e:
|
|
80
|
+
raise InvalidJWTError("Failed to get the kid from the JWT header") from e
|
|
81
|
+
|
|
82
|
+
@abstractmethod
|
|
83
|
+
async def decode_payload(self, jwt_token: JWTToken) -> JWTBearerPayloadGeneric:
|
|
84
|
+
"""Decode the JWT bearer token payload.
|
|
85
|
+
|
|
86
|
+
Args:
|
|
87
|
+
jwt_token (JWTToken): The JWT bearer token.
|
|
88
|
+
|
|
89
|
+
Returns:
|
|
90
|
+
JWTBearerPayloadGeneric: The decoded JWT bearer token payload.
|
|
91
|
+
|
|
92
|
+
Raises:
|
|
93
|
+
InvalidJWTError: If the JWT bearer token is invalid.
|
|
94
|
+
InvalidJWTPayploadError: If the JWT bearer token payload is invalid.
|
|
95
|
+
"""
|
|
96
|
+
raise NotImplementedError()
|
|
97
|
+
|
|
98
|
+
|
|
99
|
+
class JWTBearerTokenDecoder(JWTBearerTokenDecoderAbstract[JWTPayload]):
|
|
100
|
+
"""JWT bearer token classic decoder."""
|
|
101
|
+
|
|
102
|
+
def __init__(
|
|
103
|
+
self, jwt_bearer_authentication_config: JWTBearerAuthenticationConfig, jwks_store: JWKStoreAbstract
|
|
104
|
+
) -> None:
|
|
105
|
+
"""Initialize the JWT bearer token classic decoder.
|
|
106
|
+
|
|
107
|
+
Args:
|
|
108
|
+
jwt_bearer_authentication_config (JWTBearerAuthenticationConfig): The JWT bearer authentication
|
|
109
|
+
configuration.
|
|
110
|
+
jwks_store (JWKStoreAbstract): The JWKS store.
|
|
111
|
+
"""
|
|
112
|
+
self._jwt_bearer_authentication_config: JWTBearerAuthenticationConfig = jwt_bearer_authentication_config
|
|
113
|
+
self._jwks_store: JWKStoreAbstract = jwks_store
|
|
114
|
+
|
|
115
|
+
async def decode_payload(self, jwt_token: JWTToken) -> JWTPayload:
|
|
116
|
+
"""Decode the JWT bearer token."""
|
|
117
|
+
# Get the kid from the JWT header
|
|
118
|
+
kid: str = self.get_kid_from_jwt_unsafe_header(jwt_token=jwt_token)
|
|
119
|
+
# Get the JWK from the JWKS store
|
|
120
|
+
jwk: PyJWK = await self._jwks_store.get_jwk(kid=kid)
|
|
121
|
+
# Decode the JWT bearer token payload
|
|
122
|
+
jwt_decoded: dict[str, Any] = await decode_jwt_token_payload(
|
|
123
|
+
jwt_token=jwt_token,
|
|
124
|
+
public_key=jwk,
|
|
125
|
+
jwt_bearer_authentication_config=self._jwt_bearer_authentication_config,
|
|
126
|
+
)
|
|
127
|
+
try:
|
|
128
|
+
return JWTPayload.model_validate(jwt_decoded)
|
|
129
|
+
except ValidationError as e:
|
|
130
|
+
raise InvalidJWTPayploadError("Failed to validate the JWT bearer token payload") from e
|
|
@@ -0,0 +1,23 @@
|
|
|
1
|
+
"""Provides the exceptions for the JWT authentication."""
|
|
2
|
+
|
|
3
|
+
from fastapi_factory_utilities.core.exceptions import FastAPIFactoryUtilitiesError
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
class JWTAuthenticationError(FastAPIFactoryUtilitiesError):
|
|
7
|
+
"""JWT authentication error."""
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class MissingJWTCredentialsError(JWTAuthenticationError):
|
|
11
|
+
"""Missing JWT authentication credentials error."""
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class InvalidJWTError(JWTAuthenticationError):
|
|
15
|
+
"""Invalid JWT authentication credentials error."""
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class InvalidJWTPayploadError(JWTAuthenticationError):
|
|
19
|
+
"""Invalid JWT payload error."""
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
class NotVerifiedJWTError(JWTAuthenticationError):
|
|
23
|
+
"""Not verified JWT error."""
|