fastapi-toolsets 3.1.1__tar.gz → 4.1.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (46) hide show
  1. {fastapi_toolsets-3.1.1 → fastapi_toolsets-4.1.0}/PKG-INFO +7 -2
  2. {fastapi_toolsets-3.1.1 → fastapi_toolsets-4.1.0}/README.md +2 -0
  3. {fastapi_toolsets-3.1.1 → fastapi_toolsets-4.1.0}/pyproject.toml +11 -2
  4. {fastapi_toolsets-3.1.1 → fastapi_toolsets-4.1.0}/src/fastapi_toolsets/__init__.py +1 -1
  5. {fastapi_toolsets-3.1.1 → fastapi_toolsets-4.1.0}/src/fastapi_toolsets/crud/factory.py +38 -15
  6. {fastapi_toolsets-3.1.1 → fastapi_toolsets-4.1.0}/src/fastapi_toolsets/db.py +98 -20
  7. {fastapi_toolsets-3.1.1 → fastapi_toolsets-4.1.0}/src/fastapi_toolsets/schemas.py +1 -1
  8. fastapi_toolsets-4.1.0/src/fastapi_toolsets/security/__init__.py +26 -0
  9. fastapi_toolsets-4.1.0/src/fastapi_toolsets/security/abc.py +55 -0
  10. fastapi_toolsets-4.1.0/src/fastapi_toolsets/security/oauth.py +197 -0
  11. fastapi_toolsets-4.1.0/src/fastapi_toolsets/security/sources/__init__.py +8 -0
  12. fastapi_toolsets-4.1.0/src/fastapi_toolsets/security/sources/bearer.py +120 -0
  13. fastapi_toolsets-4.1.0/src/fastapi_toolsets/security/sources/cookie.py +148 -0
  14. fastapi_toolsets-4.1.0/src/fastapi_toolsets/security/sources/header.py +67 -0
  15. fastapi_toolsets-4.1.0/src/fastapi_toolsets/security/sources/multi.py +71 -0
  16. {fastapi_toolsets-3.1.1 → fastapi_toolsets-4.1.0}/LICENSE +0 -0
  17. {fastapi_toolsets-3.1.1 → fastapi_toolsets-4.1.0}/src/fastapi_toolsets/_imports.py +0 -0
  18. {fastapi_toolsets-3.1.1 → fastapi_toolsets-4.1.0}/src/fastapi_toolsets/cli/__init__.py +0 -0
  19. {fastapi_toolsets-3.1.1 → fastapi_toolsets-4.1.0}/src/fastapi_toolsets/cli/app.py +0 -0
  20. {fastapi_toolsets-3.1.1 → fastapi_toolsets-4.1.0}/src/fastapi_toolsets/cli/commands/__init__.py +0 -0
  21. {fastapi_toolsets-3.1.1 → fastapi_toolsets-4.1.0}/src/fastapi_toolsets/cli/commands/fixtures.py +0 -0
  22. {fastapi_toolsets-3.1.1 → fastapi_toolsets-4.1.0}/src/fastapi_toolsets/cli/config.py +0 -0
  23. {fastapi_toolsets-3.1.1 → fastapi_toolsets-4.1.0}/src/fastapi_toolsets/cli/pyproject.py +0 -0
  24. {fastapi_toolsets-3.1.1 → fastapi_toolsets-4.1.0}/src/fastapi_toolsets/cli/utils.py +0 -0
  25. {fastapi_toolsets-3.1.1 → fastapi_toolsets-4.1.0}/src/fastapi_toolsets/crud/__init__.py +0 -0
  26. {fastapi_toolsets-3.1.1 → fastapi_toolsets-4.1.0}/src/fastapi_toolsets/crud/search.py +0 -0
  27. {fastapi_toolsets-3.1.1 → fastapi_toolsets-4.1.0}/src/fastapi_toolsets/dependencies.py +0 -0
  28. {fastapi_toolsets-3.1.1 → fastapi_toolsets-4.1.0}/src/fastapi_toolsets/exceptions/__init__.py +0 -0
  29. {fastapi_toolsets-3.1.1 → fastapi_toolsets-4.1.0}/src/fastapi_toolsets/exceptions/exceptions.py +0 -0
  30. {fastapi_toolsets-3.1.1 → fastapi_toolsets-4.1.0}/src/fastapi_toolsets/exceptions/handler.py +0 -0
  31. {fastapi_toolsets-3.1.1 → fastapi_toolsets-4.1.0}/src/fastapi_toolsets/fixtures/__init__.py +0 -0
  32. {fastapi_toolsets-3.1.1 → fastapi_toolsets-4.1.0}/src/fastapi_toolsets/fixtures/enum.py +0 -0
  33. {fastapi_toolsets-3.1.1 → fastapi_toolsets-4.1.0}/src/fastapi_toolsets/fixtures/registry.py +0 -0
  34. {fastapi_toolsets-3.1.1 → fastapi_toolsets-4.1.0}/src/fastapi_toolsets/fixtures/utils.py +0 -0
  35. {fastapi_toolsets-3.1.1 → fastapi_toolsets-4.1.0}/src/fastapi_toolsets/logger.py +0 -0
  36. {fastapi_toolsets-3.1.1 → fastapi_toolsets-4.1.0}/src/fastapi_toolsets/metrics/__init__.py +0 -0
  37. {fastapi_toolsets-3.1.1 → fastapi_toolsets-4.1.0}/src/fastapi_toolsets/metrics/handler.py +0 -0
  38. {fastapi_toolsets-3.1.1 → fastapi_toolsets-4.1.0}/src/fastapi_toolsets/metrics/registry.py +0 -0
  39. {fastapi_toolsets-3.1.1 → fastapi_toolsets-4.1.0}/src/fastapi_toolsets/models/__init__.py +0 -0
  40. {fastapi_toolsets-3.1.1 → fastapi_toolsets-4.1.0}/src/fastapi_toolsets/models/columns.py +0 -0
  41. {fastapi_toolsets-3.1.1 → fastapi_toolsets-4.1.0}/src/fastapi_toolsets/models/watched.py +0 -0
  42. {fastapi_toolsets-3.1.1 → fastapi_toolsets-4.1.0}/src/fastapi_toolsets/py.typed +0 -0
  43. {fastapi_toolsets-3.1.1 → fastapi_toolsets-4.1.0}/src/fastapi_toolsets/pytest/__init__.py +0 -0
  44. {fastapi_toolsets-3.1.1 → fastapi_toolsets-4.1.0}/src/fastapi_toolsets/pytest/plugin.py +0 -0
  45. {fastapi_toolsets-3.1.1 → fastapi_toolsets-4.1.0}/src/fastapi_toolsets/pytest/utils.py +0 -0
  46. {fastapi_toolsets-3.1.1 → fastapi_toolsets-4.1.0}/src/fastapi_toolsets/types.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: fastapi-toolsets
3
- Version: 3.1.1
3
+ Version: 4.1.0
4
4
  Summary: Production-ready utilities for FastAPI applications
5
5
  Keywords: fastapi,sqlalchemy,postgresql
6
6
  Author: d3vyce
@@ -29,12 +29,14 @@ Requires-Dist: asyncpg>=0.29.0
29
29
  Requires-Dist: fastapi>=0.100.0
30
30
  Requires-Dist: pydantic>=2.0
31
31
  Requires-Dist: sqlalchemy[asyncio]>=2.0
32
- Requires-Dist: fastapi-toolsets[cli,metrics,pytest] ; extra == 'all'
32
+ Requires-Dist: fastapi-toolsets[cli,metrics,pytest,security] ; extra == 'all'
33
33
  Requires-Dist: typer>=0.9.0 ; extra == 'cli'
34
34
  Requires-Dist: prometheus-client>=0.20.0 ; extra == 'metrics'
35
35
  Requires-Dist: httpx>=0.25.0 ; extra == 'pytest'
36
36
  Requires-Dist: pytest-xdist>=3.0.0 ; extra == 'pytest'
37
37
  Requires-Dist: pytest>=8.0.0 ; extra == 'pytest'
38
+ Requires-Dist: async-lru>=1.0 ; extra == 'security'
39
+ Requires-Dist: httpx>=0.25.0 ; extra == 'security'
38
40
  Requires-Python: >=3.11
39
41
  Project-URL: Homepage, https://github.com/d3vyce/fastapi-toolsets
40
42
  Project-URL: Documentation, https://fastapi-toolsets.d3vyce.fr/
@@ -44,6 +46,7 @@ Provides-Extra: all
44
46
  Provides-Extra: cli
45
47
  Provides-Extra: metrics
46
48
  Provides-Extra: pytest
49
+ Provides-Extra: security
47
50
  Description-Content-Type: text/markdown
48
51
 
49
52
  # FastAPI Toolsets
@@ -79,6 +82,7 @@ Install only the extras you need:
79
82
  ```bash
80
83
  uv add "fastapi-toolsets[cli]"
81
84
  uv add "fastapi-toolsets[metrics]"
85
+ uv add "fastapi-toolsets[security]"
82
86
  uv add "fastapi-toolsets[pytest]"
83
87
  ```
84
88
 
@@ -104,6 +108,7 @@ uv add "fastapi-toolsets[all]"
104
108
 
105
109
  ### Optional
106
110
 
111
+ - **Security**: Composable authentication sources (`BearerTokenAuth`, `CookieAuth`, `APIKeyHeaderAuth`, `MultiAuth`) with HMAC-signed cookies and OAuth 2.0 / OIDC helpers
107
112
  - **CLI**: Django-like command-line interface with fixture management and custom commands support
108
113
  - **Metrics**: Prometheus metrics endpoint with provider/collector registry
109
114
  - **Pytest Helpers**: Async test client, database session management, `pytest-xdist` support, and table cleanup utilities
@@ -31,6 +31,7 @@ Install only the extras you need:
31
31
  ```bash
32
32
  uv add "fastapi-toolsets[cli]"
33
33
  uv add "fastapi-toolsets[metrics]"
34
+ uv add "fastapi-toolsets[security]"
34
35
  uv add "fastapi-toolsets[pytest]"
35
36
  ```
36
37
 
@@ -56,6 +57,7 @@ uv add "fastapi-toolsets[all]"
56
57
 
57
58
  ### Optional
58
59
 
60
+ - **Security**: Composable authentication sources (`BearerTokenAuth`, `CookieAuth`, `APIKeyHeaderAuth`, `MultiAuth`) with HMAC-signed cookies and OAuth 2.0 / OIDC helpers
59
61
  - **CLI**: Django-like command-line interface with fixture management and custom commands support
60
62
  - **Metrics**: Prometheus metrics endpoint with provider/collector registry
61
63
  - **Pytest Helpers**: Async test client, database session management, `pytest-xdist` support, and table cleanup utilities
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "fastapi-toolsets"
3
- version = "3.1.1"
3
+ version = "4.1.0"
4
4
  description = "Production-ready utilities for FastAPI applications"
5
5
  readme = "README.md"
6
6
  license = "MIT"
@@ -50,13 +50,17 @@ cli = [
50
50
  metrics = [
51
51
  "prometheus_client>=0.20.0",
52
52
  ]
53
+ security = [
54
+ "async-lru>=1.0",
55
+ "httpx>=0.25.0",
56
+ ]
53
57
  pytest = [
54
58
  "httpx>=0.25.0",
55
59
  "pytest-xdist>=3.0.0",
56
60
  "pytest>=8.0.0",
57
61
  ]
58
62
  all = [
59
- "fastapi-toolsets[cli,metrics,pytest]",
63
+ "fastapi-toolsets[cli,metrics,pytest,security]",
60
64
  ]
61
65
 
62
66
  [project.scripts]
@@ -66,12 +70,14 @@ manager = "fastapi_toolsets.cli.app:cli"
66
70
  dev = [
67
71
  {include-group = "tests"},
68
72
  {include-group = "docs"},
73
+ {include-group = "docs-src"},
69
74
  "fastapi-toolsets[all]",
70
75
  "prek>=0.3.8",
71
76
  "ruff>=0.1.0",
72
77
  "ty>=0.0.1a0",
73
78
  ]
74
79
  tests = [
80
+ "async-lru>=1.0",
75
81
  "coverage>=7.0.0",
76
82
  "httpx>=0.25.0",
77
83
  "pytest-anyio>=0.0.0",
@@ -84,6 +90,9 @@ docs = [
84
90
  "mkdocstrings-python>=2.0.2",
85
91
  "zensical>=0.0.30",
86
92
  ]
93
+ docs-src = [
94
+ "bcrypt>=4.0.0",
95
+ ]
87
96
 
88
97
  [build-system]
89
98
  requires = ["uv_build>=0.10,<0.12.0"]
@@ -21,4 +21,4 @@ Example usage:
21
21
  return Response(data={"user": user.username}, message="Success")
22
22
  """
23
23
 
24
- __version__ = "3.1.1"
24
+ __version__ = "4.1.0"
@@ -10,7 +10,7 @@ from collections.abc import Awaitable, Callable, Sequence
10
10
  from datetime import date, datetime
11
11
  from decimal import Decimal
12
12
  from enum import Enum
13
- from typing import Any, ClassVar, Generic, Literal, Self, cast, overload
13
+ from typing import Any, ClassVar, Generic, Literal, Self, TypeAlias, cast, overload
14
14
 
15
15
  from fastapi import Query
16
16
  from pydantic import BaseModel
@@ -52,6 +52,19 @@ from .search import (
52
52
  )
53
53
 
54
54
 
55
+ _ForUpdateMode: TypeAlias = bool | Literal["nowait", "skip_locked"]
56
+
57
+
58
+ def _apply_for_update(q: Any, mode: _ForUpdateMode) -> Any:
59
+ if not mode:
60
+ return q
61
+ if mode == "nowait":
62
+ return q.with_for_update(nowait=True)
63
+ if mode == "skip_locked":
64
+ return q.with_for_update(skip_locked=True)
65
+ return q.with_for_update()
66
+
67
+
55
68
  class _CursorDirection(str, Enum):
56
69
  NEXT = "next"
57
70
  PREV = "prev"
@@ -733,7 +746,7 @@ class AsyncCrud(Generic[ModelType]):
733
746
  *,
734
747
  joins: JoinType | None = None,
735
748
  outer_join: bool = False,
736
- with_for_update: bool = False,
749
+ with_for_update: _ForUpdateMode = False,
737
750
  load_options: Sequence[ExecutableOption] | None = None,
738
751
  schema: type[SchemaType],
739
752
  ) -> Response[SchemaType]: ...
@@ -747,7 +760,7 @@ class AsyncCrud(Generic[ModelType]):
747
760
  *,
748
761
  joins: JoinType | None = None,
749
762
  outer_join: bool = False,
750
- with_for_update: bool = False,
763
+ with_for_update: _ForUpdateMode = False,
751
764
  load_options: Sequence[ExecutableOption] | None = None,
752
765
  schema: None = ...,
753
766
  ) -> ModelType: ...
@@ -760,7 +773,7 @@ class AsyncCrud(Generic[ModelType]):
760
773
  *,
761
774
  joins: JoinType | None = None,
762
775
  outer_join: bool = False,
763
- with_for_update: bool = False,
776
+ with_for_update: _ForUpdateMode = False,
764
777
  load_options: Sequence[ExecutableOption] | None = None,
765
778
  schema: type[BaseModel] | None = None,
766
779
  ) -> ModelType | Response[Any]:
@@ -805,7 +818,7 @@ class AsyncCrud(Generic[ModelType]):
805
818
  *,
806
819
  joins: JoinType | None = None,
807
820
  outer_join: bool = False,
808
- with_for_update: bool = False,
821
+ with_for_update: _ForUpdateMode = False,
809
822
  load_options: Sequence[ExecutableOption] | None = None,
810
823
  schema: type[SchemaType],
811
824
  ) -> Response[SchemaType] | None: ...
@@ -819,7 +832,7 @@ class AsyncCrud(Generic[ModelType]):
819
832
  *,
820
833
  joins: JoinType | None = None,
821
834
  outer_join: bool = False,
822
- with_for_update: bool = False,
835
+ with_for_update: _ForUpdateMode = False,
823
836
  load_options: Sequence[ExecutableOption] | None = None,
824
837
  schema: None = ...,
825
838
  ) -> ModelType | None: ...
@@ -832,7 +845,7 @@ class AsyncCrud(Generic[ModelType]):
832
845
  *,
833
846
  joins: JoinType | None = None,
834
847
  outer_join: bool = False,
835
- with_for_update: bool = False,
848
+ with_for_update: _ForUpdateMode = False,
836
849
  load_options: Sequence[ExecutableOption] | None = None,
837
850
  schema: type[BaseModel] | None = None,
838
851
  ) -> ModelType | Response[Any] | None:
@@ -864,8 +877,7 @@ class AsyncCrud(Generic[ModelType]):
864
877
  q = q.where(and_(*filters))
865
878
  if resolved := cls._resolve_load_options(load_options):
866
879
  q = q.options(*resolved)
867
- if with_for_update:
868
- q = q.with_for_update()
880
+ q = _apply_for_update(q, with_for_update)
869
881
  result = await session.execute(q)
870
882
  item = result.unique().scalar_one_or_none()
871
883
  if item is None:
@@ -884,7 +896,7 @@ class AsyncCrud(Generic[ModelType]):
884
896
  *,
885
897
  joins: JoinType | None = None,
886
898
  outer_join: bool = False,
887
- with_for_update: bool = False,
899
+ with_for_update: _ForUpdateMode = False,
888
900
  load_options: Sequence[ExecutableOption] | None = None,
889
901
  schema: type[SchemaType],
890
902
  ) -> Response[SchemaType] | None: ...
@@ -898,7 +910,7 @@ class AsyncCrud(Generic[ModelType]):
898
910
  *,
899
911
  joins: JoinType | None = None,
900
912
  outer_join: bool = False,
901
- with_for_update: bool = False,
913
+ with_for_update: _ForUpdateMode = False,
902
914
  load_options: Sequence[ExecutableOption] | None = None,
903
915
  schema: None = ...,
904
916
  ) -> ModelType | None: ...
@@ -911,7 +923,7 @@ class AsyncCrud(Generic[ModelType]):
911
923
  *,
912
924
  joins: JoinType | None = None,
913
925
  outer_join: bool = False,
914
- with_for_update: bool = False,
926
+ with_for_update: _ForUpdateMode = False,
915
927
  load_options: Sequence[ExecutableOption] | None = None,
916
928
  schema: type[BaseModel] | None = None,
917
929
  ) -> ModelType | Response[Any] | None:
@@ -937,8 +949,7 @@ class AsyncCrud(Generic[ModelType]):
937
949
  q = q.where(and_(*filters))
938
950
  if resolved := cls._resolve_load_options(load_options):
939
951
  q = q.options(*resolved)
940
- if with_for_update:
941
- q = q.with_for_update()
952
+ q = _apply_for_update(q, with_for_update)
942
953
  result = await session.execute(q)
943
954
  item = result.unique().scalars().first()
944
955
  if item is None:
@@ -956,6 +967,7 @@ class AsyncCrud(Generic[ModelType]):
956
967
  filters: list[Any] | None = None,
957
968
  joins: JoinType | None = None,
958
969
  outer_join: bool = False,
970
+ with_for_update: _ForUpdateMode = False,
959
971
  load_options: Sequence[ExecutableOption] | None = None,
960
972
  order_by: OrderByClause | None = None,
961
973
  limit: int | None = None,
@@ -968,6 +980,9 @@ class AsyncCrud(Generic[ModelType]):
968
980
  filters: List of SQLAlchemy filter conditions
969
981
  joins: List of (model, condition) tuples for joining related tables
970
982
  outer_join: Use LEFT OUTER JOIN instead of INNER JOIN
983
+ with_for_update: Lock rows for update. ``True`` for plain ``FOR UPDATE``,
984
+ ``"nowait"`` for ``FOR UPDATE NOWAIT``, ``"skip_locked"`` for
985
+ ``FOR UPDATE SKIP LOCKED``.
971
986
  load_options: SQLAlchemy loader options
972
987
  order_by: Column or list of columns to order by
973
988
  limit: Max number of rows to return
@@ -982,6 +997,7 @@ class AsyncCrud(Generic[ModelType]):
982
997
  q = q.where(and_(*filters))
983
998
  if resolved := cls._resolve_load_options(load_options):
984
999
  q = q.options(*resolved)
1000
+ q = _apply_for_update(q, with_for_update)
985
1001
  if order_by is not None:
986
1002
  q = q.order_by(order_by)
987
1003
  if offset is not None:
@@ -1001,6 +1017,7 @@ class AsyncCrud(Generic[ModelType]):
1001
1017
  *,
1002
1018
  exclude_unset: bool = True,
1003
1019
  exclude_none: bool = False,
1020
+ with_for_update: _ForUpdateMode = False,
1004
1021
  schema: type[SchemaType],
1005
1022
  ) -> Response[SchemaType]: ...
1006
1023
 
@@ -1014,6 +1031,7 @@ class AsyncCrud(Generic[ModelType]):
1014
1031
  *,
1015
1032
  exclude_unset: bool = True,
1016
1033
  exclude_none: bool = False,
1034
+ with_for_update: _ForUpdateMode = False,
1017
1035
  schema: None = ...,
1018
1036
  ) -> ModelType: ...
1019
1037
 
@@ -1026,6 +1044,7 @@ class AsyncCrud(Generic[ModelType]):
1026
1044
  *,
1027
1045
  exclude_unset: bool = True,
1028
1046
  exclude_none: bool = False,
1047
+ with_for_update: _ForUpdateMode = False,
1029
1048
  schema: type[BaseModel] | None = None,
1030
1049
  ) -> ModelType | Response[Any]:
1031
1050
  """Update a record in the database.
@@ -1036,6 +1055,9 @@ class AsyncCrud(Generic[ModelType]):
1036
1055
  filters: List of SQLAlchemy filter conditions
1037
1056
  exclude_unset: Exclude fields not explicitly set in the schema
1038
1057
  exclude_none: Exclude fields with None value
1058
+ with_for_update: Lock the row before updating. ``True`` for plain
1059
+ ``FOR UPDATE``, ``"nowait"`` for ``FOR UPDATE NOWAIT``,
1060
+ ``"skip_locked"`` for ``FOR UPDATE SKIP LOCKED``.
1039
1061
  schema: Pydantic schema to serialize the result into. When provided,
1040
1062
  the result is automatically wrapped in a ``Response[schema]``.
1041
1063
 
@@ -1059,6 +1081,7 @@ class AsyncCrud(Generic[ModelType]):
1059
1081
  db_model = await cls.get(
1060
1082
  session=session,
1061
1083
  filters=filters,
1084
+ with_for_update=with_for_update,
1062
1085
  load_options=m2m_load_options or None,
1063
1086
  )
1064
1087
  values = obj.model_dump(
@@ -1334,7 +1357,7 @@ class AsyncCrud(Generic[ModelType]):
1334
1357
  count_q = count_q.where(and_(*filters))
1335
1358
 
1336
1359
  count_result = await session.execute(count_q)
1337
- total_count: int | None = count_result.scalar_one()
1360
+ total_count: int = count_result.scalar_one()
1338
1361
  has_more = page * items_per_page < total_count
1339
1362
  else:
1340
1363
  # Fetch one extra row to detect if a next page exists without COUNT
@@ -16,6 +16,7 @@ from .exceptions import NotFoundError
16
16
 
17
17
  __all__ = [
18
18
  "LockMode",
19
+ "advisory_lock",
19
20
  "cleanup_tables",
20
21
  "create_database",
21
22
  "create_db_context",
@@ -151,52 +152,129 @@ class LockMode(str, Enum):
151
152
  ACCESS_EXCLUSIVE = "ACCESS EXCLUSIVE"
152
153
 
153
154
 
154
- @asynccontextmanager
155
- async def lock_tables(
156
- session: AsyncSession,
155
+ def lock_tables(
156
+ session_maker: async_sessionmaker[_SessionT],
157
157
  tables: list[type[DeclarativeBase]],
158
158
  *,
159
159
  mode: LockMode = LockMode.SHARE_UPDATE_EXCLUSIVE,
160
160
  timeout: str = "5s",
161
- ) -> AsyncGenerator[AsyncSession, None]:
161
+ ) -> AbstractAsyncContextManager[_SessionT]:
162
162
  """Lock PostgreSQL tables for the duration of a transaction.
163
163
 
164
- Acquires table-level locks that are held until the transaction ends.
165
- Useful for preventing concurrent modifications during critical operations.
166
-
167
164
  Args:
168
- session: AsyncSession instance
169
- tables: List of SQLAlchemy model classes to lock
170
- mode: Lock mode (default: SHARE UPDATE EXCLUSIVE)
171
- timeout: Lock timeout (default: "5s")
165
+ session_maker: Async session factory used to create the dedicated
166
+ session.
167
+ tables: List of SQLAlchemy model classes to lock.
168
+ mode: Lock mode (default: SHARE UPDATE EXCLUSIVE).
169
+ timeout: Lock timeout (default: "5s").
172
170
 
173
171
  Yields:
174
- The session with locked tables
172
+ The dedicated session, open within the locked transaction.
175
173
 
176
174
  Raises:
177
- SQLAlchemyError: If lock cannot be acquired within timeout
175
+ SQLAlchemyError: If the lock cannot be acquired within *timeout*.
178
176
 
179
177
  Example:
180
178
  ```python
181
179
  from fastapi_toolsets.db import lock_tables, LockMode
182
180
 
183
- async with lock_tables(session, [User, Account]):
184
- # Tables are locked with SHARE UPDATE EXCLUSIVE mode
181
+ async with lock_tables(session_maker, [User, Account]) as session:
182
+ # Tables are locked; changes are committed when the context exits.
185
183
  user = await UserCrud.get(session, [User.id == 1])
186
184
  user.balance += 100
187
185
 
188
186
  # With custom lock mode
189
- async with lock_tables(session, [Order], mode=LockMode.EXCLUSIVE):
190
- # Exclusive lock - no other transactions can access
187
+ async with lock_tables(session_maker, [Order], mode=LockMode.EXCLUSIVE) as session:
191
188
  await process_order(session, order_id)
192
189
  ```
193
190
  """
194
191
  table_names = ",".join(table.__tablename__ for table in tables)
195
192
 
196
- async with get_transaction(session):
193
+ @asynccontextmanager
194
+ async def _lock() -> AsyncGenerator[_SessionT, None]:
195
+ async with session_maker() as session:
196
+ try:
197
+ await session.execute(text(f"SET LOCAL lock_timeout='{timeout}'"))
198
+ await session.execute(text(f"LOCK {table_names} IN {mode.value} MODE"))
199
+ yield session
200
+ await session.commit()
201
+ except BaseException:
202
+ await session.rollback()
203
+ raise
204
+
205
+ return _lock()
206
+
207
+
208
+ @asynccontextmanager
209
+ async def advisory_lock(
210
+ session: AsyncSession,
211
+ key: int | tuple[int, int],
212
+ *,
213
+ shared: bool = False,
214
+ nowait: bool = False,
215
+ timeout: str | None = None,
216
+ ) -> AsyncGenerator[bool, None]:
217
+ """Acquire a PostgreSQL session-level advisory lock.
218
+
219
+ Args:
220
+ session: AsyncSession instance.
221
+ key: Lock key — a single ``int`` (bigint) or a ``(int, int)`` pair for namespacing.
222
+ shared: Acquire a shared lock (multiple holders allowed). Default is exclusive.
223
+ nowait: Return ``False`` immediately if the lock is unavailable instead of waiting.
224
+ timeout: Maximum wait time (e.g. ``"5s"``, ``"500ms"``). Raises ``DBAPIError``
225
+ if exceeded. Ignored when *nowait* is ``True``.
226
+
227
+ Yields:
228
+ ``True`` if the lock was acquired, ``False`` if *nowait* is ``True`` and the lock
229
+ is already held.
230
+
231
+ Raises:
232
+ sqlalchemy.exc.DBAPIError: If *timeout* is set and the lock cannot be acquired
233
+ in time.
234
+
235
+ Example:
236
+ ```python
237
+ from fastapi_toolsets.db import advisory_lock
238
+
239
+ async with advisory_lock(session, 42):
240
+ ...
241
+
242
+ async with advisory_lock(session, 42, nowait=True) as acquired:
243
+ if not acquired:
244
+ raise HTTPException(409, "Resource is locked")
245
+
246
+ async with advisory_lock(session, 42, timeout="5s"):
247
+ ...
248
+
249
+ async with advisory_lock(session, (1, user_id), shared=True):
250
+ ...
251
+ ```
252
+ """
253
+ suffix = "_shared" if shared else ""
254
+ acquire_fn = f"{'pg_try_advisory_lock' if nowait else 'pg_advisory_lock'}{suffix}"
255
+ release_fn = f"pg_advisory_unlock{suffix}"
256
+
257
+ if isinstance(key, tuple):
258
+ k1, k2 = key
259
+ args = "CAST(:k1 AS integer), CAST(:k2 AS integer)"
260
+ params: dict[str, int] = {"k1": k1, "k2": k2}
261
+ else:
262
+ args = ":k"
263
+ params = {"k": key}
264
+
265
+ acquire_sql = text(f"SELECT {acquire_fn}({args})")
266
+ release_sql = text(f"SELECT {release_fn}({args})")
267
+
268
+ if timeout is not None and not nowait:
197
269
  await session.execute(text(f"SET LOCAL lock_timeout='{timeout}'"))
198
- await session.execute(text(f"LOCK {table_names} IN {mode.value} MODE"))
199
- yield session
270
+
271
+ result = await session.execute(acquire_sql, params)
272
+ acquired = result.scalar() if nowait else True
273
+ try:
274
+ yield acquired
275
+ finally:
276
+ if acquired:
277
+ await session.execute(release_sql, params)
200
278
 
201
279
 
202
280
  async def create_database(
@@ -179,7 +179,7 @@ class PaginatedResponse(BaseResponse, Generic[DataT]):
179
179
  ]
180
180
  cls._discriminated_union_cache[item] = cached
181
181
  return cached # ty:ignore[invalid-return-type]
182
- return super().__class_getitem__(item)
182
+ return super().__class_getitem__(item) # ty:ignore[invalid-return-type]
183
183
 
184
184
 
185
185
  class OffsetPaginatedResponse(PaginatedResponse[DataT]):
@@ -0,0 +1,26 @@
1
+ """Authentication helpers for FastAPI using Security()."""
2
+
3
+ from .abc import AuthSource
4
+ from .oauth import (
5
+ oauth_build_authorization_redirect,
6
+ oauth_decode_state,
7
+ oauth_encode_state,
8
+ oauth_fetch_userinfo,
9
+ oauth_generate_state_token,
10
+ oauth_resolve_provider_urls,
11
+ )
12
+ from .sources import APIKeyHeaderAuth, BearerTokenAuth, CookieAuth, MultiAuth
13
+
14
+ __all__ = [
15
+ "APIKeyHeaderAuth",
16
+ "AuthSource",
17
+ "BearerTokenAuth",
18
+ "CookieAuth",
19
+ "MultiAuth",
20
+ "oauth_build_authorization_redirect",
21
+ "oauth_decode_state",
22
+ "oauth_encode_state",
23
+ "oauth_fetch_userinfo",
24
+ "oauth_generate_state_token",
25
+ "oauth_resolve_provider_urls",
26
+ ]
@@ -0,0 +1,55 @@
1
+ """Abstract base class for authentication sources."""
2
+
3
+ import functools
4
+ import inspect
5
+ from abc import ABC, abstractmethod
6
+ from typing import Any, Callable
7
+
8
+ from fastapi import Request
9
+ from fastapi.security import SecurityScopes
10
+
11
+ from fastapi_toolsets.exceptions import UnauthorizedError
12
+
13
+
14
+ def _ensure_async(fn: Callable[..., Any]) -> Callable[..., Any]:
15
+ """Wrap *fn* so it can always be awaited, caching the coroutine check at init time."""
16
+ if inspect.iscoroutinefunction(fn):
17
+ return fn
18
+
19
+ @functools.wraps(fn)
20
+ async def wrapper(*args: Any, **kwargs: Any) -> Any:
21
+ return fn(*args, **kwargs)
22
+
23
+ return wrapper
24
+
25
+
26
+ class AuthSource(ABC):
27
+ """Abstract base class for authentication sources."""
28
+
29
+ def __init__(self) -> None:
30
+ """Set up the default FastAPI dependency signature."""
31
+ source = self
32
+
33
+ async def _call(
34
+ request: Request,
35
+ security_scopes: SecurityScopes, # noqa: ARG001
36
+ ) -> Any:
37
+ credential = await source.extract(request)
38
+ if credential is None:
39
+ raise UnauthorizedError()
40
+ return await source.authenticate(credential)
41
+
42
+ self._call_fn: Callable[..., Any] = _call
43
+ self.__signature__ = inspect.signature(_call)
44
+
45
+ @abstractmethod
46
+ async def extract(self, request: Request) -> str | None:
47
+ """Extract the raw credential from the request without validating."""
48
+
49
+ @abstractmethod
50
+ async def authenticate(self, credential: str) -> Any:
51
+ """Validate a credential and return the authenticated identity."""
52
+
53
+ async def __call__(self, **kwargs: Any) -> Any:
54
+ """FastAPI dependency dispatch."""
55
+ return await self._call_fn(**kwargs)
@@ -0,0 +1,197 @@
1
+ """OAuth 2.0 / OIDC helper utilities."""
2
+
3
+ import base64
4
+ import binascii
5
+ import hmac
6
+ import json
7
+ import secrets
8
+ from typing import Any
9
+ from urllib.parse import urlencode
10
+
11
+ import httpx
12
+ from async_lru import alru_cache
13
+ from fastapi.responses import RedirectResponse
14
+
15
+
16
+ @alru_cache(maxsize=32)
17
+ async def oauth_resolve_provider_urls(
18
+ discovery_url: str,
19
+ ) -> tuple[str, str, str | None]:
20
+ """Fetch the OIDC discovery document and return endpoint URLs.
21
+
22
+ Args:
23
+ discovery_url: URL of the provider's ``/.well-known/openid-configuration``.
24
+
25
+ Returns:
26
+ A ``(authorization_url, token_url, userinfo_url)`` tuple.
27
+ *userinfo_url* is ``None`` when the provider does not advertise one.
28
+ """
29
+ async with httpx.AsyncClient() as client:
30
+ resp = await client.get(discovery_url)
31
+ resp.raise_for_status()
32
+ cfg = resp.json()
33
+ return (
34
+ cfg["authorization_endpoint"],
35
+ cfg["token_endpoint"],
36
+ cfg.get("userinfo_endpoint"),
37
+ )
38
+
39
+
40
+ async def oauth_fetch_userinfo(
41
+ *,
42
+ token_url: str,
43
+ userinfo_url: str,
44
+ code: str,
45
+ client_id: str,
46
+ client_secret: str,
47
+ redirect_uri: str,
48
+ required_scopes: str | None = None,
49
+ ) -> dict[str, Any]:
50
+ """Exchange an authorization code for tokens and return the userinfo payload.
51
+
52
+ Args:
53
+ token_url: Provider's token endpoint.
54
+ userinfo_url: Provider's userinfo endpoint.
55
+ code: Authorization code received from the provider's callback.
56
+ client_id: OAuth application client ID.
57
+ client_secret: OAuth application client secret.
58
+ redirect_uri: Redirect URI that was used in the authorization request.
59
+ required_scopes: Space-separated scopes that must be present in the token
60
+ response ``scope`` field (RFC 6749 §3.3). Raises ``ValueError`` if
61
+ the provider granted fewer scopes than requested.
62
+
63
+ Returns:
64
+ The JSON payload returned by the userinfo endpoint as a plain ``dict``.
65
+
66
+ Raises:
67
+ ValueError: If the provider granted a different token type than ``bearer``
68
+ or did not grant all ``required_scopes``.
69
+ """
70
+ async with httpx.AsyncClient() as client:
71
+ token_resp = await client.post(
72
+ token_url,
73
+ data={
74
+ "grant_type": "authorization_code",
75
+ "code": code,
76
+ "client_id": client_id,
77
+ "client_secret": client_secret,
78
+ "redirect_uri": redirect_uri,
79
+ },
80
+ headers={"Accept": "application/json"},
81
+ )
82
+ token_resp.raise_for_status()
83
+ token_data = token_resp.json()
84
+
85
+ if token_data.get("token_type", "bearer").lower() != "bearer":
86
+ raise ValueError(
87
+ f"unsupported token_type: {token_data.get('token_type')!r}"
88
+ )
89
+
90
+ if required_scopes is not None:
91
+ granted = set(token_data.get("scope", "").split())
92
+ missing = set(required_scopes.split()) - granted
93
+ if missing:
94
+ raise ValueError(f"provider did not grant required scopes: {missing}")
95
+
96
+ access_token = token_data["access_token"]
97
+
98
+ userinfo_resp = await client.get(
99
+ userinfo_url,
100
+ headers={"Authorization": f"Bearer {access_token}"},
101
+ )
102
+ userinfo_resp.raise_for_status()
103
+ return userinfo_resp.json()
104
+
105
+
106
+ def oauth_generate_state_token() -> str:
107
+ """Generate a cryptographically random CSRF token for the OAuth ``state`` parameter."""
108
+ return secrets.token_urlsafe(32)
109
+
110
+
111
+ def oauth_build_authorization_redirect(
112
+ authorization_url: str,
113
+ *,
114
+ client_id: str,
115
+ scopes: str,
116
+ redirect_uri: str,
117
+ destination: str,
118
+ state_token: str,
119
+ ) -> RedirectResponse:
120
+ """Return an OAuth 2.0 authorization ``RedirectResponse``.
121
+
122
+ Args:
123
+ authorization_url: Provider's authorization endpoint.
124
+ client_id: OAuth application client ID.
125
+ scopes: Space-separated list of requested scopes.
126
+ redirect_uri: URI the provider should redirect back to after authorization.
127
+ destination: URL the user should be sent to after the full OAuth flow
128
+ completes (embedded in ``state``).
129
+ state_token: CSRF token generated by :func:`oauth_generate_state_token`.
130
+ Must be stored server-side (session or signed cookie) and verified via
131
+ :func:`oauth_decode_state` on the callback endpoint (RFC 6749 §10.12).
132
+
133
+ Returns:
134
+ A :class:`~fastapi.responses.RedirectResponse` to the provider's
135
+ authorization page.
136
+ """
137
+ params = urlencode(
138
+ {
139
+ "client_id": client_id,
140
+ "response_type": "code",
141
+ "scope": scopes,
142
+ "redirect_uri": redirect_uri,
143
+ "state": oauth_encode_state(destination, state_token),
144
+ }
145
+ )
146
+ return RedirectResponse(f"{authorization_url}?{params}")
147
+
148
+
149
+ def oauth_encode_state(url: str, state_token: str) -> str:
150
+ """Encode a destination URL and CSRF token into an OAuth ``state`` parameter.
151
+
152
+ Args:
153
+ url: Post-login destination URL.
154
+ state_token: CSRF token from :func:`oauth_generate_state_token`.
155
+ """
156
+ payload = json.dumps({"n": state_token, "d": url}, separators=(",", ":"))
157
+ return base64.urlsafe_b64encode(payload.encode()).decode()
158
+
159
+
160
+ def oauth_decode_state(
161
+ state: str | None, *, expected_state_token: str, fallback: str
162
+ ) -> str:
163
+ """Decode and CSRF-verify an OAuth ``state`` parameter.
164
+
165
+ Uses a constant-time comparison for the CSRF token to prevent timing attacks.
166
+
167
+ Args:
168
+ state: Raw ``state`` query parameter from the provider's callback.
169
+ expected_state_token: The token stored before the authorization redirect.
170
+ If it does not match the decoded value, ``fallback`` is returned.
171
+ fallback: URL to return when ``state`` is absent, malformed, or fails
172
+ CSRF verification.
173
+
174
+ Returns:
175
+ The destination URL embedded in ``state``, or ``fallback``.
176
+
177
+ Important:
178
+ **Single-use**: delete the stored token from the session immediately
179
+ after calling this function — whether it matched or not — so that a
180
+ captured callback URL cannot be replayed.
181
+
182
+ **Open-redirect**: validate the returned URL against a known-good
183
+ origin or relative-path allowlist before issuing the final redirect.
184
+ Do not forward arbitrary URLs to ``RedirectResponse``.
185
+ """
186
+ if not state or state == "null": # "null" guards against JS JSON.stringify(null)
187
+ return fallback
188
+ try:
189
+ padded = state + "=" * (-len(state) % 4)
190
+ payload = json.loads(base64.urlsafe_b64decode(padded).decode("utf-8"))
191
+ if not isinstance(payload, dict) or not hmac.compare_digest(
192
+ payload.get("n", "").encode(), expected_state_token.encode()
193
+ ):
194
+ return fallback
195
+ return str(payload["d"])
196
+ except (UnicodeDecodeError, ValueError, binascii.Error, KeyError):
197
+ return fallback
@@ -0,0 +1,8 @@
1
+ """Built-in authentication source implementations."""
2
+
3
+ from .header import APIKeyHeaderAuth
4
+ from .bearer import BearerTokenAuth
5
+ from .cookie import CookieAuth
6
+ from .multi import MultiAuth
7
+
8
+ __all__ = ["APIKeyHeaderAuth", "BearerTokenAuth", "CookieAuth", "MultiAuth"]
@@ -0,0 +1,120 @@
1
+ """Bearer token authentication source."""
2
+
3
+ import inspect
4
+ import secrets
5
+ from typing import Annotated, Any, Callable
6
+
7
+ from fastapi import Depends, Request
8
+ from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer, SecurityScopes
9
+
10
+ from fastapi_toolsets.exceptions import UnauthorizedError
11
+
12
+ from ..abc import AuthSource, _ensure_async
13
+
14
+
15
+ class BearerTokenAuth(AuthSource):
16
+ """Bearer token authentication source.
17
+
18
+ Wraps :class:`fastapi.security.HTTPBearer` for OpenAPI documentation.
19
+ The validator is called as ``await validator(credential, **kwargs)``
20
+ where ``kwargs`` are the extra keyword arguments provided at instantiation.
21
+
22
+ Args:
23
+ validator: Sync or async callable that receives the credential and any
24
+ extra keyword arguments, and returns the authenticated identity
25
+ (e.g. a ``User`` model). Should raise
26
+ :class:`~fastapi_toolsets.exceptions.UnauthorizedError` on failure.
27
+ prefix: Optional token prefix (e.g. ``"user_"``). If set, only tokens
28
+ whose value starts with this prefix are matched. The prefix is
29
+ **kept** in the value passed to the validator — store and compare
30
+ tokens with their prefix included. Use :meth:`generate_token` to
31
+ create correctly-prefixed tokens. This enables multiple
32
+ ``BearerTokenAuth`` instances in the same app (e.g. ``"user_"``
33
+ for user tokens, ``"org_"`` for org tokens).
34
+ **kwargs: Extra keyword arguments forwarded to the validator on every
35
+ call (e.g. ``role=Role.ADMIN``).
36
+ """
37
+
38
+ def __init__(
39
+ self,
40
+ validator: Callable[..., Any],
41
+ *,
42
+ prefix: str | None = None,
43
+ **kwargs: Any,
44
+ ) -> None:
45
+ self._validator = _ensure_async(validator)
46
+ self._prefix = prefix
47
+ self._kwargs = kwargs
48
+ self._scheme = HTTPBearer(auto_error=False)
49
+
50
+ async def _call(
51
+ security_scopes: SecurityScopes, # noqa: ARG001
52
+ credentials: Annotated[
53
+ HTTPAuthorizationCredentials | None, Depends(self._scheme)
54
+ ] = None,
55
+ ) -> Any:
56
+ if credentials is None:
57
+ raise UnauthorizedError()
58
+ return await self._validate(credentials.credentials)
59
+
60
+ self._call_fn = _call
61
+ self.__signature__ = inspect.signature(_call)
62
+
63
+ async def _validate(self, token: str) -> Any:
64
+ """Check prefix and call the validator."""
65
+ if self._prefix is not None and not token.startswith(self._prefix):
66
+ raise UnauthorizedError()
67
+ return await self._validator(token, **self._kwargs)
68
+
69
+ async def extract(self, request: Request) -> str | None:
70
+ """Extract the raw credential from the request without validating.
71
+
72
+ Returns ``None`` if no ``Authorization: Bearer`` header is present,
73
+ the token is empty, or the token does not match the configured prefix.
74
+ The prefix is included in the returned value.
75
+ """
76
+ auth = request.headers.get("Authorization", "")
77
+ if not auth.startswith("Bearer "):
78
+ return None
79
+ token = auth[7:]
80
+ if not token:
81
+ return None
82
+ if self._prefix is not None and not token.startswith(self._prefix):
83
+ return None
84
+ return token
85
+
86
+ async def authenticate(self, credential: str) -> Any:
87
+ """Validate a credential and return the identity.
88
+
89
+ Calls ``await validator(credential, **kwargs)`` where ``kwargs`` are
90
+ the extra keyword arguments provided at instantiation.
91
+ """
92
+ return await self._validate(credential)
93
+
94
+ def require(self, **kwargs: Any) -> "BearerTokenAuth":
95
+ """Return a new instance with additional (or overriding) validator kwargs."""
96
+ return BearerTokenAuth(
97
+ self._validator,
98
+ prefix=self._prefix,
99
+ **{**self._kwargs, **kwargs},
100
+ )
101
+
102
+ def generate_token(self, nbytes: int = 32) -> str:
103
+ """Generate a secure random token for this auth source.
104
+
105
+ Returns a URL-safe random token. If a prefix is configured it is
106
+ prepended — the returned value is what you store in your database
107
+ and return to the client as-is.
108
+
109
+ Args:
110
+ nbytes: Number of random bytes before base64 encoding. The
111
+ resulting string is ``ceil(nbytes * 4 / 3)`` characters
112
+ (43 chars for the default 32 bytes). Defaults to 32.
113
+
114
+ Returns:
115
+ A ready-to-use token string (e.g. ``"user_Xk3..."``).
116
+ """
117
+ token = secrets.token_urlsafe(nbytes)
118
+ if self._prefix is not None:
119
+ return f"{self._prefix}{token}"
120
+ return token
@@ -0,0 +1,148 @@
1
+ """Cookie-based authentication source."""
2
+
3
+ import base64
4
+ import hashlib
5
+ import hmac
6
+ import inspect
7
+ import json
8
+ import time
9
+ from typing import Annotated, Any, Callable
10
+
11
+ from fastapi import Depends, Request, Response
12
+ from fastapi.security import APIKeyCookie, SecurityScopes
13
+
14
+ from fastapi_toolsets.exceptions import UnauthorizedError
15
+
16
+ from ..abc import AuthSource, _ensure_async
17
+
18
+
19
+ class CookieAuth(AuthSource):
20
+ """Cookie-based authentication source.
21
+
22
+ Wraps :class:`fastapi.security.APIKeyCookie` for OpenAPI documentation.
23
+ Optionally signs the cookie with HMAC-SHA256 to provide stateless, tamper-
24
+ proof sessions without any database entry.
25
+
26
+ Args:
27
+ name: Cookie name.
28
+ validator: Sync or async callable that receives the cookie value
29
+ (plain, after signature verification when ``secret_key`` is set)
30
+ and any extra keyword arguments, and returns the authenticated
31
+ identity.
32
+ secret_key: When provided, the cookie is HMAC-SHA256 signed.
33
+ :meth:`set_cookie` embeds an expiry and signs the payload;
34
+ :meth:`extract` verifies the signature and expiry before handing
35
+ the plain value to the validator. When ``None`` (default), the raw
36
+ cookie value is passed to the validator as-is.
37
+ ttl: Cookie lifetime in seconds (default 24 h). Only used when
38
+ ``secret_key`` is set.
39
+ secure: Set the ``Secure`` flag on the cookie so it is only transmitted
40
+ over HTTPS (default ``True``). Set to ``False`` only in local
41
+ development environments where HTTPS is unavailable.
42
+ **kwargs: Extra keyword arguments forwarded to the validator on every
43
+ call (e.g. ``role=Role.ADMIN``).
44
+ """
45
+
46
+ def __init__(
47
+ self,
48
+ name: str,
49
+ validator: Callable[..., Any],
50
+ *,
51
+ secret_key: str | None = None,
52
+ ttl: int = 86400,
53
+ secure: bool = True,
54
+ **kwargs: Any,
55
+ ) -> None:
56
+ self._name = name
57
+ self._validator = _ensure_async(validator)
58
+ self._secret_key = secret_key
59
+ self._ttl = ttl
60
+ self._secure = secure
61
+ self._kwargs = kwargs
62
+ self._scheme = APIKeyCookie(name=name, auto_error=False)
63
+
64
+ async def _call(
65
+ security_scopes: SecurityScopes, # noqa: ARG001
66
+ value: Annotated[str | None, Depends(self._scheme)] = None,
67
+ ) -> Any:
68
+ if value is None:
69
+ raise UnauthorizedError()
70
+ plain = self._verify(value)
71
+ return await self._validator(plain, **self._kwargs)
72
+
73
+ self._call_fn = _call
74
+ self.__signature__ = inspect.signature(_call)
75
+
76
+ def _hmac(self, data: str) -> str:
77
+ if self._secret_key is None:
78
+ raise RuntimeError("_hmac called without secret_key configured")
79
+ return hmac.new(
80
+ self._secret_key.encode(), data.encode(), hashlib.sha256
81
+ ).hexdigest()
82
+
83
+ def _sign(self, value: str) -> str:
84
+ data = base64.urlsafe_b64encode(
85
+ json.dumps({"v": value, "exp": int(time.time()) + self._ttl}).encode()
86
+ ).decode()
87
+ return f"{data}.{self._hmac(data)}"
88
+
89
+ def _verify(self, cookie_value: str) -> str:
90
+ """Return the plain value, verifying HMAC + expiry when signed."""
91
+ if not self._secret_key:
92
+ return cookie_value
93
+
94
+ try:
95
+ data, sig = cookie_value.rsplit(".", 1)
96
+ except ValueError:
97
+ raise UnauthorizedError()
98
+
99
+ if not hmac.compare_digest(self._hmac(data), sig):
100
+ raise UnauthorizedError()
101
+
102
+ try:
103
+ payload = json.loads(base64.urlsafe_b64decode(data))
104
+ value: str = payload["v"]
105
+ exp: int = payload["exp"]
106
+ except Exception:
107
+ raise UnauthorizedError()
108
+
109
+ if exp < int(time.time()):
110
+ raise UnauthorizedError()
111
+
112
+ return value
113
+
114
+ async def extract(self, request: Request) -> str | None:
115
+ return request.cookies.get(self._name)
116
+
117
+ async def authenticate(self, credential: str) -> Any:
118
+ plain = self._verify(credential)
119
+ return await self._validator(plain, **self._kwargs)
120
+
121
+ def require(self, **kwargs: Any) -> "CookieAuth":
122
+ """Return a new instance with additional (or overriding) validator kwargs."""
123
+ return CookieAuth(
124
+ self._name,
125
+ self._validator,
126
+ secret_key=self._secret_key,
127
+ ttl=self._ttl,
128
+ secure=self._secure,
129
+ **{**self._kwargs, **kwargs},
130
+ )
131
+
132
+ def set_cookie(self, response: Response, value: str) -> None:
133
+ """Attach the cookie to *response*, signing it when ``secret_key`` is set."""
134
+ cookie_value = self._sign(value) if self._secret_key else value
135
+ response.set_cookie(
136
+ self._name,
137
+ cookie_value,
138
+ httponly=True,
139
+ samesite="lax",
140
+ secure=self._secure,
141
+ max_age=self._ttl,
142
+ )
143
+
144
+ def delete_cookie(self, response: Response) -> None:
145
+ """Clear the session cookie (logout)."""
146
+ response.delete_cookie(
147
+ self._name, httponly=True, samesite="lax", secure=self._secure
148
+ )
@@ -0,0 +1,67 @@
1
+ """API key header authentication source."""
2
+
3
+ import inspect
4
+ from typing import Annotated, Any, Callable
5
+
6
+ from fastapi import Depends, Request
7
+ from fastapi.security import APIKeyHeader, SecurityScopes
8
+
9
+ from fastapi_toolsets.exceptions import UnauthorizedError
10
+
11
+ from ..abc import AuthSource, _ensure_async
12
+
13
+
14
+ class APIKeyHeaderAuth(AuthSource):
15
+ """API key header authentication source.
16
+
17
+ Wraps :class:`fastapi.security.APIKeyHeader` for OpenAPI documentation.
18
+ The validator is called as ``await validator(api_key, **kwargs)``
19
+ where ``kwargs`` are the extra keyword arguments provided at instantiation.
20
+
21
+ Args:
22
+ name: HTTP header name that carries the API key (e.g. ``"X-API-Key"``).
23
+ validator: Sync or async callable that receives the API key and any
24
+ extra keyword arguments, and returns the authenticated identity.
25
+ Should raise :class:`~fastapi_toolsets.exceptions.UnauthorizedError`
26
+ on failure.
27
+ **kwargs: Extra keyword arguments forwarded to the validator on every
28
+ call (e.g. ``role=Role.ADMIN``).
29
+ """
30
+
31
+ def __init__(
32
+ self,
33
+ name: str,
34
+ validator: Callable[..., Any],
35
+ **kwargs: Any,
36
+ ) -> None:
37
+ self._name = name
38
+ self._validator = _ensure_async(validator)
39
+ self._kwargs = kwargs
40
+ self._scheme = APIKeyHeader(name=name, auto_error=False)
41
+
42
+ async def _call(
43
+ security_scopes: SecurityScopes, # noqa: ARG001
44
+ api_key: Annotated[str | None, Depends(self._scheme)] = None,
45
+ ) -> Any:
46
+ if api_key is None:
47
+ raise UnauthorizedError()
48
+ return await self._validator(api_key, **self._kwargs)
49
+
50
+ self._call_fn = _call
51
+ self.__signature__ = inspect.signature(_call)
52
+
53
+ async def extract(self, request: Request) -> str | None:
54
+ """Extract the API key from the configured header."""
55
+ return request.headers.get(self._name) or None
56
+
57
+ async def authenticate(self, credential: str) -> Any:
58
+ """Validate a credential and return the identity."""
59
+ return await self._validator(credential, **self._kwargs)
60
+
61
+ def require(self, **kwargs: Any) -> "APIKeyHeaderAuth":
62
+ """Return a new instance with additional (or overriding) validator kwargs."""
63
+ return APIKeyHeaderAuth(
64
+ self._name,
65
+ self._validator,
66
+ **{**self._kwargs, **kwargs},
67
+ )
@@ -0,0 +1,71 @@
1
+ """MultiAuth: combine multiple authentication sources into a single callable."""
2
+
3
+ import inspect
4
+ from typing import Any, cast
5
+
6
+ from fastapi import Request
7
+ from fastapi.security import SecurityScopes
8
+
9
+ from fastapi_toolsets.exceptions import UnauthorizedError
10
+
11
+ from ..abc import AuthSource
12
+
13
+
14
+ class MultiAuth:
15
+ """Combine multiple authentication sources into a single callable.
16
+
17
+ Args:
18
+ *sources: Auth source instances to try in order.
19
+ """
20
+
21
+ def __init__(self, *sources: AuthSource) -> None:
22
+ self._sources = sources
23
+
24
+ async def _call(
25
+ request: Request,
26
+ security_scopes: SecurityScopes, # noqa: ARG001
27
+ **kwargs: Any, # noqa: ARG001 — absorbs scheme values injected by FastAPI
28
+ ) -> Any:
29
+ for source in self._sources:
30
+ credential = await source.extract(request)
31
+ if credential is not None:
32
+ return await source.authenticate(credential)
33
+ raise UnauthorizedError()
34
+
35
+ self._call_fn = _call
36
+
37
+ # Build a merged signature that includes the security-scheme Depends()
38
+ # parameters from every source so FastAPI registers them in OpenAPI docs.
39
+ seen: set[str] = {"request", "security_scopes"}
40
+ merged: list[inspect.Parameter] = [
41
+ inspect.Parameter(
42
+ "request",
43
+ inspect.Parameter.POSITIONAL_OR_KEYWORD,
44
+ annotation=Request,
45
+ ),
46
+ inspect.Parameter(
47
+ "security_scopes",
48
+ inspect.Parameter.POSITIONAL_OR_KEYWORD,
49
+ annotation=SecurityScopes,
50
+ ),
51
+ ]
52
+ for i, source in enumerate(sources):
53
+ for name, param in inspect.signature(source).parameters.items():
54
+ if name in seen:
55
+ continue
56
+ merged.append(param.replace(name=f"_s{i}_{name}"))
57
+ seen.add(name)
58
+ self.__signature__ = inspect.Signature(merged, return_annotation=Any)
59
+
60
+ async def __call__(self, **kwargs: Any) -> Any:
61
+ return await self._call_fn(**kwargs)
62
+
63
+ def require(self, **kwargs: Any) -> "MultiAuth":
64
+ """Return a new :class:`MultiAuth` with kwargs forwarded to each source."""
65
+ new_sources = tuple(
66
+ cast(Any, source).require(**kwargs)
67
+ if hasattr(source, "require")
68
+ else source
69
+ for source in self._sources
70
+ )
71
+ return MultiAuth(*new_sources)