fastapi-toolsets 3.1.0__tar.gz → 4.0.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (46) hide show
  1. {fastapi_toolsets-3.1.0 → fastapi_toolsets-4.0.0}/PKG-INFO +7 -2
  2. {fastapi_toolsets-3.1.0 → fastapi_toolsets-4.0.0}/README.md +2 -0
  3. {fastapi_toolsets-3.1.0 → fastapi_toolsets-4.0.0}/pyproject.toml +11 -2
  4. {fastapi_toolsets-3.1.0 → fastapi_toolsets-4.0.0}/src/fastapi_toolsets/__init__.py +1 -1
  5. {fastapi_toolsets-3.1.0 → fastapi_toolsets-4.0.0}/src/fastapi_toolsets/crud/factory.py +1 -1
  6. {fastapi_toolsets-3.1.0 → fastapi_toolsets-4.0.0}/src/fastapi_toolsets/db.py +26 -21
  7. {fastapi_toolsets-3.1.0 → fastapi_toolsets-4.0.0}/src/fastapi_toolsets/fixtures/utils.py +68 -29
  8. {fastapi_toolsets-3.1.0 → fastapi_toolsets-4.0.0}/src/fastapi_toolsets/schemas.py +1 -1
  9. fastapi_toolsets-4.0.0/src/fastapi_toolsets/security/__init__.py +26 -0
  10. fastapi_toolsets-4.0.0/src/fastapi_toolsets/security/abc.py +55 -0
  11. fastapi_toolsets-4.0.0/src/fastapi_toolsets/security/oauth.py +197 -0
  12. fastapi_toolsets-4.0.0/src/fastapi_toolsets/security/sources/__init__.py +8 -0
  13. fastapi_toolsets-4.0.0/src/fastapi_toolsets/security/sources/bearer.py +120 -0
  14. fastapi_toolsets-4.0.0/src/fastapi_toolsets/security/sources/cookie.py +148 -0
  15. fastapi_toolsets-4.0.0/src/fastapi_toolsets/security/sources/header.py +67 -0
  16. fastapi_toolsets-4.0.0/src/fastapi_toolsets/security/sources/multi.py +71 -0
  17. {fastapi_toolsets-3.1.0 → fastapi_toolsets-4.0.0}/LICENSE +0 -0
  18. {fastapi_toolsets-3.1.0 → fastapi_toolsets-4.0.0}/src/fastapi_toolsets/_imports.py +0 -0
  19. {fastapi_toolsets-3.1.0 → fastapi_toolsets-4.0.0}/src/fastapi_toolsets/cli/__init__.py +0 -0
  20. {fastapi_toolsets-3.1.0 → fastapi_toolsets-4.0.0}/src/fastapi_toolsets/cli/app.py +0 -0
  21. {fastapi_toolsets-3.1.0 → fastapi_toolsets-4.0.0}/src/fastapi_toolsets/cli/commands/__init__.py +0 -0
  22. {fastapi_toolsets-3.1.0 → fastapi_toolsets-4.0.0}/src/fastapi_toolsets/cli/commands/fixtures.py +0 -0
  23. {fastapi_toolsets-3.1.0 → fastapi_toolsets-4.0.0}/src/fastapi_toolsets/cli/config.py +0 -0
  24. {fastapi_toolsets-3.1.0 → fastapi_toolsets-4.0.0}/src/fastapi_toolsets/cli/pyproject.py +0 -0
  25. {fastapi_toolsets-3.1.0 → fastapi_toolsets-4.0.0}/src/fastapi_toolsets/cli/utils.py +0 -0
  26. {fastapi_toolsets-3.1.0 → fastapi_toolsets-4.0.0}/src/fastapi_toolsets/crud/__init__.py +0 -0
  27. {fastapi_toolsets-3.1.0 → fastapi_toolsets-4.0.0}/src/fastapi_toolsets/crud/search.py +0 -0
  28. {fastapi_toolsets-3.1.0 → fastapi_toolsets-4.0.0}/src/fastapi_toolsets/dependencies.py +0 -0
  29. {fastapi_toolsets-3.1.0 → fastapi_toolsets-4.0.0}/src/fastapi_toolsets/exceptions/__init__.py +0 -0
  30. {fastapi_toolsets-3.1.0 → fastapi_toolsets-4.0.0}/src/fastapi_toolsets/exceptions/exceptions.py +0 -0
  31. {fastapi_toolsets-3.1.0 → fastapi_toolsets-4.0.0}/src/fastapi_toolsets/exceptions/handler.py +0 -0
  32. {fastapi_toolsets-3.1.0 → fastapi_toolsets-4.0.0}/src/fastapi_toolsets/fixtures/__init__.py +0 -0
  33. {fastapi_toolsets-3.1.0 → fastapi_toolsets-4.0.0}/src/fastapi_toolsets/fixtures/enum.py +0 -0
  34. {fastapi_toolsets-3.1.0 → fastapi_toolsets-4.0.0}/src/fastapi_toolsets/fixtures/registry.py +0 -0
  35. {fastapi_toolsets-3.1.0 → fastapi_toolsets-4.0.0}/src/fastapi_toolsets/logger.py +0 -0
  36. {fastapi_toolsets-3.1.0 → fastapi_toolsets-4.0.0}/src/fastapi_toolsets/metrics/__init__.py +0 -0
  37. {fastapi_toolsets-3.1.0 → fastapi_toolsets-4.0.0}/src/fastapi_toolsets/metrics/handler.py +0 -0
  38. {fastapi_toolsets-3.1.0 → fastapi_toolsets-4.0.0}/src/fastapi_toolsets/metrics/registry.py +0 -0
  39. {fastapi_toolsets-3.1.0 → fastapi_toolsets-4.0.0}/src/fastapi_toolsets/models/__init__.py +0 -0
  40. {fastapi_toolsets-3.1.0 → fastapi_toolsets-4.0.0}/src/fastapi_toolsets/models/columns.py +0 -0
  41. {fastapi_toolsets-3.1.0 → fastapi_toolsets-4.0.0}/src/fastapi_toolsets/models/watched.py +0 -0
  42. {fastapi_toolsets-3.1.0 → fastapi_toolsets-4.0.0}/src/fastapi_toolsets/py.typed +0 -0
  43. {fastapi_toolsets-3.1.0 → fastapi_toolsets-4.0.0}/src/fastapi_toolsets/pytest/__init__.py +0 -0
  44. {fastapi_toolsets-3.1.0 → fastapi_toolsets-4.0.0}/src/fastapi_toolsets/pytest/plugin.py +0 -0
  45. {fastapi_toolsets-3.1.0 → fastapi_toolsets-4.0.0}/src/fastapi_toolsets/pytest/utils.py +0 -0
  46. {fastapi_toolsets-3.1.0 → fastapi_toolsets-4.0.0}/src/fastapi_toolsets/types.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: fastapi-toolsets
3
- Version: 3.1.0
3
+ Version: 4.0.0
4
4
  Summary: Production-ready utilities for FastAPI applications
5
5
  Keywords: fastapi,sqlalchemy,postgresql
6
6
  Author: d3vyce
@@ -29,12 +29,14 @@ Requires-Dist: asyncpg>=0.29.0
29
29
  Requires-Dist: fastapi>=0.100.0
30
30
  Requires-Dist: pydantic>=2.0
31
31
  Requires-Dist: sqlalchemy[asyncio]>=2.0
32
- Requires-Dist: fastapi-toolsets[cli,metrics,pytest] ; extra == 'all'
32
+ Requires-Dist: fastapi-toolsets[cli,metrics,pytest,security] ; extra == 'all'
33
33
  Requires-Dist: typer>=0.9.0 ; extra == 'cli'
34
34
  Requires-Dist: prometheus-client>=0.20.0 ; extra == 'metrics'
35
35
  Requires-Dist: httpx>=0.25.0 ; extra == 'pytest'
36
36
  Requires-Dist: pytest-xdist>=3.0.0 ; extra == 'pytest'
37
37
  Requires-Dist: pytest>=8.0.0 ; extra == 'pytest'
38
+ Requires-Dist: async-lru>=1.0 ; extra == 'security'
39
+ Requires-Dist: httpx>=0.25.0 ; extra == 'security'
38
40
  Requires-Python: >=3.11
39
41
  Project-URL: Homepage, https://github.com/d3vyce/fastapi-toolsets
40
42
  Project-URL: Documentation, https://fastapi-toolsets.d3vyce.fr/
@@ -44,6 +46,7 @@ Provides-Extra: all
44
46
  Provides-Extra: cli
45
47
  Provides-Extra: metrics
46
48
  Provides-Extra: pytest
49
+ Provides-Extra: security
47
50
  Description-Content-Type: text/markdown
48
51
 
49
52
  # FastAPI Toolsets
@@ -79,6 +82,7 @@ Install only the extras you need:
79
82
  ```bash
80
83
  uv add "fastapi-toolsets[cli]"
81
84
  uv add "fastapi-toolsets[metrics]"
85
+ uv add "fastapi-toolsets[security]"
82
86
  uv add "fastapi-toolsets[pytest]"
83
87
  ```
84
88
 
@@ -104,6 +108,7 @@ uv add "fastapi-toolsets[all]"
104
108
 
105
109
  ### Optional
106
110
 
111
+ - **Security**: Composable authentication sources (`BearerTokenAuth`, `CookieAuth`, `APIKeyHeaderAuth`, `MultiAuth`) with HMAC-signed cookies and OAuth 2.0 / OIDC helpers
107
112
  - **CLI**: Django-like command-line interface with fixture management and custom commands support
108
113
  - **Metrics**: Prometheus metrics endpoint with provider/collector registry
109
114
  - **Pytest Helpers**: Async test client, database session management, `pytest-xdist` support, and table cleanup utilities
@@ -31,6 +31,7 @@ Install only the extras you need:
31
31
  ```bash
32
32
  uv add "fastapi-toolsets[cli]"
33
33
  uv add "fastapi-toolsets[metrics]"
34
+ uv add "fastapi-toolsets[security]"
34
35
  uv add "fastapi-toolsets[pytest]"
35
36
  ```
36
37
 
@@ -56,6 +57,7 @@ uv add "fastapi-toolsets[all]"
56
57
 
57
58
  ### Optional
58
59
 
60
+ - **Security**: Composable authentication sources (`BearerTokenAuth`, `CookieAuth`, `APIKeyHeaderAuth`, `MultiAuth`) with HMAC-signed cookies and OAuth 2.0 / OIDC helpers
59
61
  - **CLI**: Django-like command-line interface with fixture management and custom commands support
60
62
  - **Metrics**: Prometheus metrics endpoint with provider/collector registry
61
63
  - **Pytest Helpers**: Async test client, database session management, `pytest-xdist` support, and table cleanup utilities
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "fastapi-toolsets"
3
- version = "3.1.0"
3
+ version = "4.0.0"
4
4
  description = "Production-ready utilities for FastAPI applications"
5
5
  readme = "README.md"
6
6
  license = "MIT"
@@ -50,13 +50,17 @@ cli = [
50
50
  metrics = [
51
51
  "prometheus_client>=0.20.0",
52
52
  ]
53
+ security = [
54
+ "async-lru>=1.0",
55
+ "httpx>=0.25.0",
56
+ ]
53
57
  pytest = [
54
58
  "httpx>=0.25.0",
55
59
  "pytest-xdist>=3.0.0",
56
60
  "pytest>=8.0.0",
57
61
  ]
58
62
  all = [
59
- "fastapi-toolsets[cli,metrics,pytest]",
63
+ "fastapi-toolsets[cli,metrics,pytest,security]",
60
64
  ]
61
65
 
62
66
  [project.scripts]
@@ -66,12 +70,14 @@ manager = "fastapi_toolsets.cli.app:cli"
66
70
  dev = [
67
71
  {include-group = "tests"},
68
72
  {include-group = "docs"},
73
+ {include-group = "docs-src"},
69
74
  "fastapi-toolsets[all]",
70
75
  "prek>=0.3.8",
71
76
  "ruff>=0.1.0",
72
77
  "ty>=0.0.1a0",
73
78
  ]
74
79
  tests = [
80
+ "async-lru>=1.0",
75
81
  "coverage>=7.0.0",
76
82
  "httpx>=0.25.0",
77
83
  "pytest-anyio>=0.0.0",
@@ -84,6 +90,9 @@ docs = [
84
90
  "mkdocstrings-python>=2.0.2",
85
91
  "zensical>=0.0.30",
86
92
  ]
93
+ docs-src = [
94
+ "bcrypt>=4.0.0",
95
+ ]
87
96
 
88
97
  [build-system]
89
98
  requires = ["uv_build>=0.10,<0.12.0"]
@@ -21,4 +21,4 @@ Example usage:
21
21
  return Response(data={"user": user.username}, message="Success")
22
22
  """
23
23
 
24
- __version__ = "3.1.0"
24
+ __version__ = "4.0.0"
@@ -1334,7 +1334,7 @@ class AsyncCrud(Generic[ModelType]):
1334
1334
  count_q = count_q.where(and_(*filters))
1335
1335
 
1336
1336
  count_result = await session.execute(count_q)
1337
- total_count: int | None = count_result.scalar_one()
1337
+ total_count: int = count_result.scalar_one()
1338
1338
  has_more = page * items_per_page < total_count
1339
1339
  else:
1340
1340
  # Fetch one extra row to detect if a next page exists without COUNT
@@ -151,52 +151,57 @@ class LockMode(str, Enum):
151
151
  ACCESS_EXCLUSIVE = "ACCESS EXCLUSIVE"
152
152
 
153
153
 
154
- @asynccontextmanager
155
- async def lock_tables(
156
- session: AsyncSession,
154
+ def lock_tables(
155
+ session_maker: async_sessionmaker[_SessionT],
157
156
  tables: list[type[DeclarativeBase]],
158
157
  *,
159
158
  mode: LockMode = LockMode.SHARE_UPDATE_EXCLUSIVE,
160
159
  timeout: str = "5s",
161
- ) -> AsyncGenerator[AsyncSession, None]:
160
+ ) -> AbstractAsyncContextManager[_SessionT]:
162
161
  """Lock PostgreSQL tables for the duration of a transaction.
163
162
 
164
- Acquires table-level locks that are held until the transaction ends.
165
- Useful for preventing concurrent modifications during critical operations.
166
-
167
163
  Args:
168
- session: AsyncSession instance
169
- tables: List of SQLAlchemy model classes to lock
170
- mode: Lock mode (default: SHARE UPDATE EXCLUSIVE)
171
- timeout: Lock timeout (default: "5s")
164
+ session_maker: Async session factory used to create the dedicated
165
+ session.
166
+ tables: List of SQLAlchemy model classes to lock.
167
+ mode: Lock mode (default: SHARE UPDATE EXCLUSIVE).
168
+ timeout: Lock timeout (default: "5s").
172
169
 
173
170
  Yields:
174
- The session with locked tables
171
+ The dedicated session, open within the locked transaction.
175
172
 
176
173
  Raises:
177
- SQLAlchemyError: If lock cannot be acquired within timeout
174
+ SQLAlchemyError: If the lock cannot be acquired within *timeout*.
178
175
 
179
176
  Example:
180
177
  ```python
181
178
  from fastapi_toolsets.db import lock_tables, LockMode
182
179
 
183
- async with lock_tables(session, [User, Account]):
184
- # Tables are locked with SHARE UPDATE EXCLUSIVE mode
180
+ async with lock_tables(session_maker, [User, Account]) as session:
181
+ # Tables are locked; changes are committed when the context exits.
185
182
  user = await UserCrud.get(session, [User.id == 1])
186
183
  user.balance += 100
187
184
 
188
185
  # With custom lock mode
189
- async with lock_tables(session, [Order], mode=LockMode.EXCLUSIVE):
190
- # Exclusive lock - no other transactions can access
186
+ async with lock_tables(session_maker, [Order], mode=LockMode.EXCLUSIVE) as session:
191
187
  await process_order(session, order_id)
192
188
  ```
193
189
  """
194
190
  table_names = ",".join(table.__tablename__ for table in tables)
195
191
 
196
- async with get_transaction(session):
197
- await session.execute(text(f"SET LOCAL lock_timeout='{timeout}'"))
198
- await session.execute(text(f"LOCK {table_names} IN {mode.value} MODE"))
199
- yield session
192
+ @asynccontextmanager
193
+ async def _lock() -> AsyncGenerator[_SessionT, None]:
194
+ async with session_maker() as session:
195
+ try:
196
+ await session.execute(text(f"SET LOCAL lock_timeout='{timeout}'"))
197
+ await session.execute(text(f"LOCK {table_names} IN {mode.value} MODE"))
198
+ yield session
199
+ await session.commit()
200
+ except BaseException:
201
+ await session.rollback()
202
+ raise
203
+
204
+ return _lock()
200
205
 
201
206
 
202
207
  async def create_database(
@@ -40,6 +40,32 @@ def _instance_to_dict(instance: DeclarativeBase) -> dict[str, Any]:
40
40
  return result
41
41
 
42
42
 
43
+ def _get_table_chain(model_cls: type[DeclarativeBase]) -> list[type[DeclarativeBase]]:
44
+ """Return [root, ..., model_cls] for joined-table inheritance, or [model_cls]."""
45
+ chain: list[type[DeclarativeBase]] = []
46
+ current = sa_inspect(model_cls)
47
+ while current is not None:
48
+ chain.append(current.class_)
49
+ current = current.inherits
50
+ chain.reverse()
51
+ seen: set[int] = set()
52
+ result: list[type[DeclarativeBase]] = []
53
+ for cls in chain:
54
+ tid = id(cls.__table__)
55
+ if tid not in seen: # pragma: no branch
56
+ seen.add(tid)
57
+ result.append(cls)
58
+ return result
59
+
60
+
61
+ def _instance_to_dict_for_cls(
62
+ instance: DeclarativeBase, cls: type[DeclarativeBase]
63
+ ) -> dict[str, Any]:
64
+ """Like _instance_to_dict but limited to columns belonging to cls's own table."""
65
+ own_cols = {col.key for col in cls.__table__.columns}
66
+ return {k: v for k, v in _instance_to_dict(instance).items() if k in own_cols}
67
+
68
+
43
69
  def _group_by_type(
44
70
  instances: list[DeclarativeBase],
45
71
  ) -> list[tuple[type[DeclarativeBase], list[DeclarativeBase]]]:
@@ -73,9 +99,11 @@ async def _batch_insert(
73
99
  instances: list[DeclarativeBase],
74
100
  ) -> None:
75
101
  """INSERT all instances — raises on conflict (no duplicate handling)."""
76
- dicts = [_instance_to_dict(i) for i in instances]
77
- for group_dicts, _ in _group_by_column_set(dicts, instances):
78
- await session.execute(pg_insert(model_cls).values(group_dicts))
102
+ for cls in _get_table_chain(model_cls):
103
+ dicts = [_instance_to_dict_for_cls(i, cls) for i in instances]
104
+ for group_dicts, _ in _group_by_column_set(dicts, instances):
105
+ if group_dicts and group_dicts[0]: # pragma: no branch
106
+ await session.execute(pg_insert(cls).values(group_dicts))
79
107
 
80
108
 
81
109
  async def _batch_merge(
@@ -84,31 +112,30 @@ async def _batch_merge(
84
112
  instances: list[DeclarativeBase],
85
113
  ) -> None:
86
114
  """UPSERT: insert new rows, update existing ones with the provided values."""
87
- mapper = model_cls.__mapper__
88
- pk_names = [col.name for col in mapper.primary_key]
89
- pk_names_set = set(pk_names)
90
- non_pk_cols = [
91
- prop.key
92
- for prop in mapper.column_attrs
93
- if not any(col.name in pk_names_set for col in prop.columns)
94
- ]
95
-
96
- dicts = [_instance_to_dict(i) for i in instances]
97
- for group_dicts, _ in _group_by_column_set(dicts, instances):
98
- stmt = pg_insert(model_cls).values(group_dicts)
99
-
100
- inserted_keys = set(group_dicts[0])
101
- update_cols = [col for col in non_pk_cols if col in inserted_keys]
102
-
103
- if update_cols:
104
- stmt = stmt.on_conflict_do_update(
105
- index_elements=pk_names,
106
- set_={col: stmt.excluded[col] for col in update_cols},
107
- )
108
- else:
109
- stmt = stmt.on_conflict_do_nothing(index_elements=pk_names)
115
+ for cls in _get_table_chain(model_cls):
116
+ pk_names = [col.name for col in cls.__table__.primary_key]
117
+ pk_names_set = set(pk_names)
118
+ own_col_keys = {col.key for col in cls.__table__.columns}
119
+ non_pk_cols = [k for k in own_col_keys if k not in pk_names_set]
120
+
121
+ dicts = [_instance_to_dict_for_cls(i, cls) for i in instances]
122
+ for group_dicts, _ in _group_by_column_set(dicts, instances):
123
+ if not group_dicts or not group_dicts[0]: # pragma: no cover
124
+ continue
125
+ stmt = pg_insert(cls).values(group_dicts)
110
126
 
111
- await session.execute(stmt)
127
+ inserted_keys = set(group_dicts[0])
128
+ update_cols = [col for col in non_pk_cols if col in inserted_keys]
129
+
130
+ if update_cols:
131
+ stmt = stmt.on_conflict_do_update(
132
+ index_elements=pk_names,
133
+ set_={col: stmt.excluded[col] for col in update_cols},
134
+ )
135
+ else:
136
+ stmt = stmt.on_conflict_do_nothing(index_elements=pk_names)
137
+
138
+ await session.execute(stmt)
112
139
 
113
140
 
114
141
  async def _batch_skip_existing(
@@ -117,6 +144,16 @@ async def _batch_skip_existing(
117
144
  instances: list[DeclarativeBase],
118
145
  ) -> list[DeclarativeBase]:
119
146
  """INSERT only rows that do not already exist; return the inserted ones."""
147
+ if len(_get_table_chain(model_cls)) > 1:
148
+ loaded: list[DeclarativeBase] = []
149
+ for inst in instances:
150
+ pk = _get_primary_key(inst)
151
+ if pk is None or not await session.get(model_cls, pk):
152
+ session.add(inst)
153
+ loaded.append(inst)
154
+ await session.flush()
155
+ return loaded
156
+
120
157
  mapper = model_cls.__mapper__
121
158
  pk_names = [col.name for col in mapper.primary_key]
122
159
 
@@ -129,7 +166,7 @@ async def _batch_skip_existing(
129
166
  else:
130
167
  with_pk_pairs.append((inst, pk))
131
168
 
132
- loaded: list[DeclarativeBase] = list(no_pk)
169
+ loaded = list(no_pk)
133
170
  if no_pk:
134
171
  no_pk_dicts = [_instance_to_dict(i) for i in no_pk]
135
172
  for group_dicts, _ in _group_by_column_set(no_pk_dicts, no_pk):
@@ -179,7 +216,7 @@ async def _load_ordered(
179
216
  if contexts is not None and not variants:
180
217
  variants = registry.get_variants(name)
181
218
 
182
- if not variants:
219
+ if not variants: # pragma: no cover
183
220
  results[name] = []
184
221
  continue
185
222
 
@@ -204,6 +241,8 @@ async def _load_ordered(
204
241
  case LoadStrategy.SKIP_EXISTING:
205
242
  inserted = await _batch_skip_existing(session, model_cls, group)
206
243
  loaded.extend(inserted)
244
+ case _: # pragma: no cover
245
+ pass
207
246
 
208
247
  results[name] = loaded
209
248
  logger.info(f"Loaded fixture '{name}': {len(loaded)} {model_name}(s)")
@@ -179,7 +179,7 @@ class PaginatedResponse(BaseResponse, Generic[DataT]):
179
179
  ]
180
180
  cls._discriminated_union_cache[item] = cached
181
181
  return cached # ty:ignore[invalid-return-type]
182
- return super().__class_getitem__(item)
182
+ return super().__class_getitem__(item) # ty:ignore[invalid-return-type]
183
183
 
184
184
 
185
185
  class OffsetPaginatedResponse(PaginatedResponse[DataT]):
@@ -0,0 +1,26 @@
1
+ """Authentication helpers for FastAPI using Security()."""
2
+
3
+ from .abc import AuthSource
4
+ from .oauth import (
5
+ oauth_build_authorization_redirect,
6
+ oauth_decode_state,
7
+ oauth_encode_state,
8
+ oauth_fetch_userinfo,
9
+ oauth_generate_state_token,
10
+ oauth_resolve_provider_urls,
11
+ )
12
+ from .sources import APIKeyHeaderAuth, BearerTokenAuth, CookieAuth, MultiAuth
13
+
14
+ __all__ = [
15
+ "APIKeyHeaderAuth",
16
+ "AuthSource",
17
+ "BearerTokenAuth",
18
+ "CookieAuth",
19
+ "MultiAuth",
20
+ "oauth_build_authorization_redirect",
21
+ "oauth_decode_state",
22
+ "oauth_encode_state",
23
+ "oauth_fetch_userinfo",
24
+ "oauth_generate_state_token",
25
+ "oauth_resolve_provider_urls",
26
+ ]
@@ -0,0 +1,55 @@
1
+ """Abstract base class for authentication sources."""
2
+
3
+ import functools
4
+ import inspect
5
+ from abc import ABC, abstractmethod
6
+ from typing import Any, Callable
7
+
8
+ from fastapi import Request
9
+ from fastapi.security import SecurityScopes
10
+
11
+ from fastapi_toolsets.exceptions import UnauthorizedError
12
+
13
+
14
+ def _ensure_async(fn: Callable[..., Any]) -> Callable[..., Any]:
15
+ """Wrap *fn* so it can always be awaited, caching the coroutine check at init time."""
16
+ if inspect.iscoroutinefunction(fn):
17
+ return fn
18
+
19
+ @functools.wraps(fn)
20
+ async def wrapper(*args: Any, **kwargs: Any) -> Any:
21
+ return fn(*args, **kwargs)
22
+
23
+ return wrapper
24
+
25
+
26
+ class AuthSource(ABC):
27
+ """Abstract base class for authentication sources."""
28
+
29
+ def __init__(self) -> None:
30
+ """Set up the default FastAPI dependency signature."""
31
+ source = self
32
+
33
+ async def _call(
34
+ request: Request,
35
+ security_scopes: SecurityScopes, # noqa: ARG001
36
+ ) -> Any:
37
+ credential = await source.extract(request)
38
+ if credential is None:
39
+ raise UnauthorizedError()
40
+ return await source.authenticate(credential)
41
+
42
+ self._call_fn: Callable[..., Any] = _call
43
+ self.__signature__ = inspect.signature(_call)
44
+
45
+ @abstractmethod
46
+ async def extract(self, request: Request) -> str | None:
47
+ """Extract the raw credential from the request without validating."""
48
+
49
+ @abstractmethod
50
+ async def authenticate(self, credential: str) -> Any:
51
+ """Validate a credential and return the authenticated identity."""
52
+
53
+ async def __call__(self, **kwargs: Any) -> Any:
54
+ """FastAPI dependency dispatch."""
55
+ return await self._call_fn(**kwargs)
@@ -0,0 +1,197 @@
1
+ """OAuth 2.0 / OIDC helper utilities."""
2
+
3
+ import base64
4
+ import binascii
5
+ import hmac
6
+ import json
7
+ import secrets
8
+ from typing import Any
9
+ from urllib.parse import urlencode
10
+
11
+ import httpx
12
+ from async_lru import alru_cache
13
+ from fastapi.responses import RedirectResponse
14
+
15
+
16
+ @alru_cache(maxsize=32)
17
+ async def oauth_resolve_provider_urls(
18
+ discovery_url: str,
19
+ ) -> tuple[str, str, str | None]:
20
+ """Fetch the OIDC discovery document and return endpoint URLs.
21
+
22
+ Args:
23
+ discovery_url: URL of the provider's ``/.well-known/openid-configuration``.
24
+
25
+ Returns:
26
+ A ``(authorization_url, token_url, userinfo_url)`` tuple.
27
+ *userinfo_url* is ``None`` when the provider does not advertise one.
28
+ """
29
+ async with httpx.AsyncClient() as client:
30
+ resp = await client.get(discovery_url)
31
+ resp.raise_for_status()
32
+ cfg = resp.json()
33
+ return (
34
+ cfg["authorization_endpoint"],
35
+ cfg["token_endpoint"],
36
+ cfg.get("userinfo_endpoint"),
37
+ )
38
+
39
+
40
+ async def oauth_fetch_userinfo(
41
+ *,
42
+ token_url: str,
43
+ userinfo_url: str,
44
+ code: str,
45
+ client_id: str,
46
+ client_secret: str,
47
+ redirect_uri: str,
48
+ required_scopes: str | None = None,
49
+ ) -> dict[str, Any]:
50
+ """Exchange an authorization code for tokens and return the userinfo payload.
51
+
52
+ Args:
53
+ token_url: Provider's token endpoint.
54
+ userinfo_url: Provider's userinfo endpoint.
55
+ code: Authorization code received from the provider's callback.
56
+ client_id: OAuth application client ID.
57
+ client_secret: OAuth application client secret.
58
+ redirect_uri: Redirect URI that was used in the authorization request.
59
+ required_scopes: Space-separated scopes that must be present in the token
60
+ response ``scope`` field (RFC 6749 §3.3). Raises ``ValueError`` if
61
+ the provider granted fewer scopes than requested.
62
+
63
+ Returns:
64
+ The JSON payload returned by the userinfo endpoint as a plain ``dict``.
65
+
66
+ Raises:
67
+ ValueError: If the provider granted a different token type than ``bearer``
68
+ or did not grant all ``required_scopes``.
69
+ """
70
+ async with httpx.AsyncClient() as client:
71
+ token_resp = await client.post(
72
+ token_url,
73
+ data={
74
+ "grant_type": "authorization_code",
75
+ "code": code,
76
+ "client_id": client_id,
77
+ "client_secret": client_secret,
78
+ "redirect_uri": redirect_uri,
79
+ },
80
+ headers={"Accept": "application/json"},
81
+ )
82
+ token_resp.raise_for_status()
83
+ token_data = token_resp.json()
84
+
85
+ if token_data.get("token_type", "bearer").lower() != "bearer":
86
+ raise ValueError(
87
+ f"unsupported token_type: {token_data.get('token_type')!r}"
88
+ )
89
+
90
+ if required_scopes is not None:
91
+ granted = set(token_data.get("scope", "").split())
92
+ missing = set(required_scopes.split()) - granted
93
+ if missing:
94
+ raise ValueError(f"provider did not grant required scopes: {missing}")
95
+
96
+ access_token = token_data["access_token"]
97
+
98
+ userinfo_resp = await client.get(
99
+ userinfo_url,
100
+ headers={"Authorization": f"Bearer {access_token}"},
101
+ )
102
+ userinfo_resp.raise_for_status()
103
+ return userinfo_resp.json()
104
+
105
+
106
+ def oauth_generate_state_token() -> str:
107
+ """Generate a cryptographically random CSRF token for the OAuth ``state`` parameter."""
108
+ return secrets.token_urlsafe(32)
109
+
110
+
111
+ def oauth_build_authorization_redirect(
112
+ authorization_url: str,
113
+ *,
114
+ client_id: str,
115
+ scopes: str,
116
+ redirect_uri: str,
117
+ destination: str,
118
+ state_token: str,
119
+ ) -> RedirectResponse:
120
+ """Return an OAuth 2.0 authorization ``RedirectResponse``.
121
+
122
+ Args:
123
+ authorization_url: Provider's authorization endpoint.
124
+ client_id: OAuth application client ID.
125
+ scopes: Space-separated list of requested scopes.
126
+ redirect_uri: URI the provider should redirect back to after authorization.
127
+ destination: URL the user should be sent to after the full OAuth flow
128
+ completes (embedded in ``state``).
129
+ state_token: CSRF token generated by :func:`oauth_generate_state_token`.
130
+ Must be stored server-side (session or signed cookie) and verified via
131
+ :func:`oauth_decode_state` on the callback endpoint (RFC 6749 §10.12).
132
+
133
+ Returns:
134
+ A :class:`~fastapi.responses.RedirectResponse` to the provider's
135
+ authorization page.
136
+ """
137
+ params = urlencode(
138
+ {
139
+ "client_id": client_id,
140
+ "response_type": "code",
141
+ "scope": scopes,
142
+ "redirect_uri": redirect_uri,
143
+ "state": oauth_encode_state(destination, state_token),
144
+ }
145
+ )
146
+ return RedirectResponse(f"{authorization_url}?{params}")
147
+
148
+
149
+ def oauth_encode_state(url: str, state_token: str) -> str:
150
+ """Encode a destination URL and CSRF token into an OAuth ``state`` parameter.
151
+
152
+ Args:
153
+ url: Post-login destination URL.
154
+ state_token: CSRF token from :func:`oauth_generate_state_token`.
155
+ """
156
+ payload = json.dumps({"n": state_token, "d": url}, separators=(",", ":"))
157
+ return base64.urlsafe_b64encode(payload.encode()).decode()
158
+
159
+
160
+ def oauth_decode_state(
161
+ state: str | None, *, expected_state_token: str, fallback: str
162
+ ) -> str:
163
+ """Decode and CSRF-verify an OAuth ``state`` parameter.
164
+
165
+ Uses a constant-time comparison for the CSRF token to prevent timing attacks.
166
+
167
+ Args:
168
+ state: Raw ``state`` query parameter from the provider's callback.
169
+ expected_state_token: The token stored before the authorization redirect.
170
+ If it does not match the decoded value, ``fallback`` is returned.
171
+ fallback: URL to return when ``state`` is absent, malformed, or fails
172
+ CSRF verification.
173
+
174
+ Returns:
175
+ The destination URL embedded in ``state``, or ``fallback``.
176
+
177
+ Important:
178
+ **Single-use**: delete the stored token from the session immediately
179
+ after calling this function — whether it matched or not — so that a
180
+ captured callback URL cannot be replayed.
181
+
182
+ **Open-redirect**: validate the returned URL against a known-good
183
+ origin or relative-path allowlist before issuing the final redirect.
184
+ Do not forward arbitrary URLs to ``RedirectResponse``.
185
+ """
186
+ if not state or state == "null": # "null" guards against JS JSON.stringify(null)
187
+ return fallback
188
+ try:
189
+ padded = state + "=" * (-len(state) % 4)
190
+ payload = json.loads(base64.urlsafe_b64decode(padded).decode("utf-8"))
191
+ if not isinstance(payload, dict) or not hmac.compare_digest(
192
+ payload.get("n", "").encode(), expected_state_token.encode()
193
+ ):
194
+ return fallback
195
+ return str(payload["d"])
196
+ except (UnicodeDecodeError, ValueError, binascii.Error, KeyError):
197
+ return fallback
@@ -0,0 +1,8 @@
1
+ """Built-in authentication source implementations."""
2
+
3
+ from .header import APIKeyHeaderAuth
4
+ from .bearer import BearerTokenAuth
5
+ from .cookie import CookieAuth
6
+ from .multi import MultiAuth
7
+
8
+ __all__ = ["APIKeyHeaderAuth", "BearerTokenAuth", "CookieAuth", "MultiAuth"]
@@ -0,0 +1,120 @@
1
+ """Bearer token authentication source."""
2
+
3
+ import inspect
4
+ import secrets
5
+ from typing import Annotated, Any, Callable
6
+
7
+ from fastapi import Depends, Request
8
+ from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer, SecurityScopes
9
+
10
+ from fastapi_toolsets.exceptions import UnauthorizedError
11
+
12
+ from ..abc import AuthSource, _ensure_async
13
+
14
+
15
+ class BearerTokenAuth(AuthSource):
16
+ """Bearer token authentication source.
17
+
18
+ Wraps :class:`fastapi.security.HTTPBearer` for OpenAPI documentation.
19
+ The validator is called as ``await validator(credential, **kwargs)``
20
+ where ``kwargs`` are the extra keyword arguments provided at instantiation.
21
+
22
+ Args:
23
+ validator: Sync or async callable that receives the credential and any
24
+ extra keyword arguments, and returns the authenticated identity
25
+ (e.g. a ``User`` model). Should raise
26
+ :class:`~fastapi_toolsets.exceptions.UnauthorizedError` on failure.
27
+ prefix: Optional token prefix (e.g. ``"user_"``). If set, only tokens
28
+ whose value starts with this prefix are matched. The prefix is
29
+ **kept** in the value passed to the validator — store and compare
30
+ tokens with their prefix included. Use :meth:`generate_token` to
31
+ create correctly-prefixed tokens. This enables multiple
32
+ ``BearerTokenAuth`` instances in the same app (e.g. ``"user_"``
33
+ for user tokens, ``"org_"`` for org tokens).
34
+ **kwargs: Extra keyword arguments forwarded to the validator on every
35
+ call (e.g. ``role=Role.ADMIN``).
36
+ """
37
+
38
+ def __init__(
39
+ self,
40
+ validator: Callable[..., Any],
41
+ *,
42
+ prefix: str | None = None,
43
+ **kwargs: Any,
44
+ ) -> None:
45
+ self._validator = _ensure_async(validator)
46
+ self._prefix = prefix
47
+ self._kwargs = kwargs
48
+ self._scheme = HTTPBearer(auto_error=False)
49
+
50
+ async def _call(
51
+ security_scopes: SecurityScopes, # noqa: ARG001
52
+ credentials: Annotated[
53
+ HTTPAuthorizationCredentials | None, Depends(self._scheme)
54
+ ] = None,
55
+ ) -> Any:
56
+ if credentials is None:
57
+ raise UnauthorizedError()
58
+ return await self._validate(credentials.credentials)
59
+
60
+ self._call_fn = _call
61
+ self.__signature__ = inspect.signature(_call)
62
+
63
+ async def _validate(self, token: str) -> Any:
64
+ """Check prefix and call the validator."""
65
+ if self._prefix is not None and not token.startswith(self._prefix):
66
+ raise UnauthorizedError()
67
+ return await self._validator(token, **self._kwargs)
68
+
69
+ async def extract(self, request: Request) -> str | None:
70
+ """Extract the raw credential from the request without validating.
71
+
72
+ Returns ``None`` if no ``Authorization: Bearer`` header is present,
73
+ the token is empty, or the token does not match the configured prefix.
74
+ The prefix is included in the returned value.
75
+ """
76
+ auth = request.headers.get("Authorization", "")
77
+ if not auth.startswith("Bearer "):
78
+ return None
79
+ token = auth[7:]
80
+ if not token:
81
+ return None
82
+ if self._prefix is not None and not token.startswith(self._prefix):
83
+ return None
84
+ return token
85
+
86
+ async def authenticate(self, credential: str) -> Any:
87
+ """Validate a credential and return the identity.
88
+
89
+ Calls ``await validator(credential, **kwargs)`` where ``kwargs`` are
90
+ the extra keyword arguments provided at instantiation.
91
+ """
92
+ return await self._validate(credential)
93
+
94
+ def require(self, **kwargs: Any) -> "BearerTokenAuth":
95
+ """Return a new instance with additional (or overriding) validator kwargs."""
96
+ return BearerTokenAuth(
97
+ self._validator,
98
+ prefix=self._prefix,
99
+ **{**self._kwargs, **kwargs},
100
+ )
101
+
102
+ def generate_token(self, nbytes: int = 32) -> str:
103
+ """Generate a secure random token for this auth source.
104
+
105
+ Returns a URL-safe random token. If a prefix is configured it is
106
+ prepended — the returned value is what you store in your database
107
+ and return to the client as-is.
108
+
109
+ Args:
110
+ nbytes: Number of random bytes before base64 encoding. The
111
+ resulting string is ``ceil(nbytes * 4 / 3)`` characters
112
+ (43 chars for the default 32 bytes). Defaults to 32.
113
+
114
+ Returns:
115
+ A ready-to-use token string (e.g. ``"user_Xk3..."``).
116
+ """
117
+ token = secrets.token_urlsafe(nbytes)
118
+ if self._prefix is not None:
119
+ return f"{self._prefix}{token}"
120
+ return token
@@ -0,0 +1,148 @@
1
+ """Cookie-based authentication source."""
2
+
3
+ import base64
4
+ import hashlib
5
+ import hmac
6
+ import inspect
7
+ import json
8
+ import time
9
+ from typing import Annotated, Any, Callable
10
+
11
+ from fastapi import Depends, Request, Response
12
+ from fastapi.security import APIKeyCookie, SecurityScopes
13
+
14
+ from fastapi_toolsets.exceptions import UnauthorizedError
15
+
16
+ from ..abc import AuthSource, _ensure_async
17
+
18
+
19
+ class CookieAuth(AuthSource):
20
+ """Cookie-based authentication source.
21
+
22
+ Wraps :class:`fastapi.security.APIKeyCookie` for OpenAPI documentation.
23
+ Optionally signs the cookie with HMAC-SHA256 to provide stateless, tamper-
24
+ proof sessions without any database entry.
25
+
26
+ Args:
27
+ name: Cookie name.
28
+ validator: Sync or async callable that receives the cookie value
29
+ (plain, after signature verification when ``secret_key`` is set)
30
+ and any extra keyword arguments, and returns the authenticated
31
+ identity.
32
+ secret_key: When provided, the cookie is HMAC-SHA256 signed.
33
+ :meth:`set_cookie` embeds an expiry and signs the payload;
34
+ :meth:`extract` verifies the signature and expiry before handing
35
+ the plain value to the validator. When ``None`` (default), the raw
36
+ cookie value is passed to the validator as-is.
37
+ ttl: Cookie lifetime in seconds (default 24 h). Only used when
38
+ ``secret_key`` is set.
39
+ secure: Set the ``Secure`` flag on the cookie so it is only transmitted
40
+ over HTTPS (default ``True``). Set to ``False`` only in local
41
+ development environments where HTTPS is unavailable.
42
+ **kwargs: Extra keyword arguments forwarded to the validator on every
43
+ call (e.g. ``role=Role.ADMIN``).
44
+ """
45
+
46
+ def __init__(
47
+ self,
48
+ name: str,
49
+ validator: Callable[..., Any],
50
+ *,
51
+ secret_key: str | None = None,
52
+ ttl: int = 86400,
53
+ secure: bool = True,
54
+ **kwargs: Any,
55
+ ) -> None:
56
+ self._name = name
57
+ self._validator = _ensure_async(validator)
58
+ self._secret_key = secret_key
59
+ self._ttl = ttl
60
+ self._secure = secure
61
+ self._kwargs = kwargs
62
+ self._scheme = APIKeyCookie(name=name, auto_error=False)
63
+
64
+ async def _call(
65
+ security_scopes: SecurityScopes, # noqa: ARG001
66
+ value: Annotated[str | None, Depends(self._scheme)] = None,
67
+ ) -> Any:
68
+ if value is None:
69
+ raise UnauthorizedError()
70
+ plain = self._verify(value)
71
+ return await self._validator(plain, **self._kwargs)
72
+
73
+ self._call_fn = _call
74
+ self.__signature__ = inspect.signature(_call)
75
+
76
+ def _hmac(self, data: str) -> str:
77
+ if self._secret_key is None:
78
+ raise RuntimeError("_hmac called without secret_key configured")
79
+ return hmac.new(
80
+ self._secret_key.encode(), data.encode(), hashlib.sha256
81
+ ).hexdigest()
82
+
83
+ def _sign(self, value: str) -> str:
84
+ data = base64.urlsafe_b64encode(
85
+ json.dumps({"v": value, "exp": int(time.time()) + self._ttl}).encode()
86
+ ).decode()
87
+ return f"{data}.{self._hmac(data)}"
88
+
89
+ def _verify(self, cookie_value: str) -> str:
90
+ """Return the plain value, verifying HMAC + expiry when signed."""
91
+ if not self._secret_key:
92
+ return cookie_value
93
+
94
+ try:
95
+ data, sig = cookie_value.rsplit(".", 1)
96
+ except ValueError:
97
+ raise UnauthorizedError()
98
+
99
+ if not hmac.compare_digest(self._hmac(data), sig):
100
+ raise UnauthorizedError()
101
+
102
+ try:
103
+ payload = json.loads(base64.urlsafe_b64decode(data))
104
+ value: str = payload["v"]
105
+ exp: int = payload["exp"]
106
+ except Exception:
107
+ raise UnauthorizedError()
108
+
109
+ if exp < int(time.time()):
110
+ raise UnauthorizedError()
111
+
112
+ return value
113
+
114
+ async def extract(self, request: Request) -> str | None:
115
+ return request.cookies.get(self._name)
116
+
117
+ async def authenticate(self, credential: str) -> Any:
118
+ plain = self._verify(credential)
119
+ return await self._validator(plain, **self._kwargs)
120
+
121
+ def require(self, **kwargs: Any) -> "CookieAuth":
122
+ """Return a new instance with additional (or overriding) validator kwargs."""
123
+ return CookieAuth(
124
+ self._name,
125
+ self._validator,
126
+ secret_key=self._secret_key,
127
+ ttl=self._ttl,
128
+ secure=self._secure,
129
+ **{**self._kwargs, **kwargs},
130
+ )
131
+
132
+ def set_cookie(self, response: Response, value: str) -> None:
133
+ """Attach the cookie to *response*, signing it when ``secret_key`` is set."""
134
+ cookie_value = self._sign(value) if self._secret_key else value
135
+ response.set_cookie(
136
+ self._name,
137
+ cookie_value,
138
+ httponly=True,
139
+ samesite="lax",
140
+ secure=self._secure,
141
+ max_age=self._ttl,
142
+ )
143
+
144
+ def delete_cookie(self, response: Response) -> None:
145
+ """Clear the session cookie (logout)."""
146
+ response.delete_cookie(
147
+ self._name, httponly=True, samesite="lax", secure=self._secure
148
+ )
@@ -0,0 +1,67 @@
1
+ """API key header authentication source."""
2
+
3
+ import inspect
4
+ from typing import Annotated, Any, Callable
5
+
6
+ from fastapi import Depends, Request
7
+ from fastapi.security import APIKeyHeader, SecurityScopes
8
+
9
+ from fastapi_toolsets.exceptions import UnauthorizedError
10
+
11
+ from ..abc import AuthSource, _ensure_async
12
+
13
+
14
+ class APIKeyHeaderAuth(AuthSource):
15
+ """API key header authentication source.
16
+
17
+ Wraps :class:`fastapi.security.APIKeyHeader` for OpenAPI documentation.
18
+ The validator is called as ``await validator(api_key, **kwargs)``
19
+ where ``kwargs`` are the extra keyword arguments provided at instantiation.
20
+
21
+ Args:
22
+ name: HTTP header name that carries the API key (e.g. ``"X-API-Key"``).
23
+ validator: Sync or async callable that receives the API key and any
24
+ extra keyword arguments, and returns the authenticated identity.
25
+ Should raise :class:`~fastapi_toolsets.exceptions.UnauthorizedError`
26
+ on failure.
27
+ **kwargs: Extra keyword arguments forwarded to the validator on every
28
+ call (e.g. ``role=Role.ADMIN``).
29
+ """
30
+
31
+ def __init__(
32
+ self,
33
+ name: str,
34
+ validator: Callable[..., Any],
35
+ **kwargs: Any,
36
+ ) -> None:
37
+ self._name = name
38
+ self._validator = _ensure_async(validator)
39
+ self._kwargs = kwargs
40
+ self._scheme = APIKeyHeader(name=name, auto_error=False)
41
+
42
+ async def _call(
43
+ security_scopes: SecurityScopes, # noqa: ARG001
44
+ api_key: Annotated[str | None, Depends(self._scheme)] = None,
45
+ ) -> Any:
46
+ if api_key is None:
47
+ raise UnauthorizedError()
48
+ return await self._validator(api_key, **self._kwargs)
49
+
50
+ self._call_fn = _call
51
+ self.__signature__ = inspect.signature(_call)
52
+
53
+ async def extract(self, request: Request) -> str | None:
54
+ """Extract the API key from the configured header."""
55
+ return request.headers.get(self._name) or None
56
+
57
+ async def authenticate(self, credential: str) -> Any:
58
+ """Validate a credential and return the identity."""
59
+ return await self._validator(credential, **self._kwargs)
60
+
61
+ def require(self, **kwargs: Any) -> "APIKeyHeaderAuth":
62
+ """Return a new instance with additional (or overriding) validator kwargs."""
63
+ return APIKeyHeaderAuth(
64
+ self._name,
65
+ self._validator,
66
+ **{**self._kwargs, **kwargs},
67
+ )
@@ -0,0 +1,71 @@
1
+ """MultiAuth: combine multiple authentication sources into a single callable."""
2
+
3
+ import inspect
4
+ from typing import Any, cast
5
+
6
+ from fastapi import Request
7
+ from fastapi.security import SecurityScopes
8
+
9
+ from fastapi_toolsets.exceptions import UnauthorizedError
10
+
11
+ from ..abc import AuthSource
12
+
13
+
14
+ class MultiAuth:
15
+ """Combine multiple authentication sources into a single callable.
16
+
17
+ Args:
18
+ *sources: Auth source instances to try in order.
19
+ """
20
+
21
+ def __init__(self, *sources: AuthSource) -> None:
22
+ self._sources = sources
23
+
24
+ async def _call(
25
+ request: Request,
26
+ security_scopes: SecurityScopes, # noqa: ARG001
27
+ **kwargs: Any, # noqa: ARG001 — absorbs scheme values injected by FastAPI
28
+ ) -> Any:
29
+ for source in self._sources:
30
+ credential = await source.extract(request)
31
+ if credential is not None:
32
+ return await source.authenticate(credential)
33
+ raise UnauthorizedError()
34
+
35
+ self._call_fn = _call
36
+
37
+ # Build a merged signature that includes the security-scheme Depends()
38
+ # parameters from every source so FastAPI registers them in OpenAPI docs.
39
+ seen: set[str] = {"request", "security_scopes"}
40
+ merged: list[inspect.Parameter] = [
41
+ inspect.Parameter(
42
+ "request",
43
+ inspect.Parameter.POSITIONAL_OR_KEYWORD,
44
+ annotation=Request,
45
+ ),
46
+ inspect.Parameter(
47
+ "security_scopes",
48
+ inspect.Parameter.POSITIONAL_OR_KEYWORD,
49
+ annotation=SecurityScopes,
50
+ ),
51
+ ]
52
+ for i, source in enumerate(sources):
53
+ for name, param in inspect.signature(source).parameters.items():
54
+ if name in seen:
55
+ continue
56
+ merged.append(param.replace(name=f"_s{i}_{name}"))
57
+ seen.add(name)
58
+ self.__signature__ = inspect.Signature(merged, return_annotation=Any)
59
+
60
+ async def __call__(self, **kwargs: Any) -> Any:
61
+ return await self._call_fn(**kwargs)
62
+
63
+ def require(self, **kwargs: Any) -> "MultiAuth":
64
+ """Return a new :class:`MultiAuth` with kwargs forwarded to each source."""
65
+ new_sources = tuple(
66
+ cast(Any, source).require(**kwargs)
67
+ if hasattr(source, "require")
68
+ else source
69
+ for source in self._sources
70
+ )
71
+ return MultiAuth(*new_sources)