fastapi-toolsets 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,166 @@
1
+ """Custom exceptions with standardized API error responses."""
2
+
3
+ from typing import Any, ClassVar
4
+
5
+ from ..schemas import ApiError, ErrorResponse, ResponseStatus
6
+
7
+
8
+ class ApiException(Exception):
9
+ """Base exception for API errors with structured response.
10
+
11
+ Subclass this to create custom API exceptions with consistent error format.
12
+ The exception handler will use api_error to generate the response.
13
+
14
+ Example:
15
+ class CustomError(ApiException):
16
+ api_error = ApiError(
17
+ code=400,
18
+ msg="Bad Request",
19
+ desc="The request was invalid.",
20
+ err_code="CUSTOM-400",
21
+ )
22
+ """
23
+
24
+ api_error: ClassVar[ApiError]
25
+
26
+ def __init__(self, detail: str | None = None):
27
+ """Initialize the exception.
28
+
29
+ Args:
30
+ detail: Optional override for the error message
31
+ """
32
+ super().__init__(detail or self.api_error.msg)
33
+
34
+
35
+ class UnauthorizedError(ApiException):
36
+ """HTTP 401 - User is not authenticated."""
37
+
38
+ api_error = ApiError(
39
+ code=401,
40
+ msg="Unauthorized",
41
+ desc="Authentication credentials were missing or invalid.",
42
+ err_code="AUTH-401",
43
+ )
44
+
45
+
46
+ class ForbiddenError(ApiException):
47
+ """HTTP 403 - User lacks required permissions."""
48
+
49
+ api_error = ApiError(
50
+ code=403,
51
+ msg="Forbidden",
52
+ desc="You do not have permission to access this resource.",
53
+ err_code="AUTH-403",
54
+ )
55
+
56
+
57
+ class NotFoundError(ApiException):
58
+ """HTTP 404 - Resource not found."""
59
+
60
+ api_error = ApiError(
61
+ code=404,
62
+ msg="Not Found",
63
+ desc="The requested resource was not found.",
64
+ err_code="RES-404",
65
+ )
66
+
67
+
68
+ class ConflictError(ApiException):
69
+ """HTTP 409 - Resource conflict."""
70
+
71
+ api_error = ApiError(
72
+ code=409,
73
+ msg="Conflict",
74
+ desc="The request conflicts with the current state of the resource.",
75
+ err_code="RES-409",
76
+ )
77
+
78
+
79
+ class InsufficientRolesError(ForbiddenError):
80
+ """User does not have the required roles."""
81
+
82
+ api_error = ApiError(
83
+ code=403,
84
+ msg="Insufficient Roles",
85
+ desc="You do not have the required roles to access this resource.",
86
+ err_code="RBAC-403",
87
+ )
88
+
89
+ def __init__(self, required_roles: list[str], user_roles: set[str] | None = None):
90
+ self.required_roles = required_roles
91
+ self.user_roles = user_roles
92
+
93
+ desc = f"Required roles: {', '.join(required_roles)}"
94
+ if user_roles is not None:
95
+ desc += f". User has: {', '.join(user_roles) if user_roles else 'no roles'}"
96
+
97
+ super().__init__(desc)
98
+
99
+
100
+ class UserNotFoundError(NotFoundError):
101
+ """User was not found."""
102
+
103
+ api_error = ApiError(
104
+ code=404,
105
+ msg="User Not Found",
106
+ desc="The requested user was not found.",
107
+ err_code="USER-404",
108
+ )
109
+
110
+
111
+ class RoleNotFoundError(NotFoundError):
112
+ """Role was not found."""
113
+
114
+ api_error = ApiError(
115
+ code=404,
116
+ msg="Role Not Found",
117
+ desc="The requested role was not found.",
118
+ err_code="ROLE-404",
119
+ )
120
+
121
+
122
+ def generate_error_responses(
123
+ *errors: type[ApiException],
124
+ ) -> dict[int | str, dict[str, Any]]:
125
+ """Generate OpenAPI response documentation for exceptions.
126
+
127
+ Use this to document possible error responses for an endpoint.
128
+
129
+ Args:
130
+ *errors: Exception classes that inherit from ApiException
131
+
132
+ Returns:
133
+ Dict suitable for FastAPI's responses parameter
134
+
135
+ Example:
136
+ from fastapi_toolsets.exceptions import generate_error_responses, UnauthorizedError, ForbiddenError
137
+
138
+ @app.get(
139
+ "/admin",
140
+ responses=generate_error_responses(UnauthorizedError, ForbiddenError)
141
+ )
142
+ async def admin_endpoint():
143
+ ...
144
+ """
145
+ responses: dict[int | str, dict[str, Any]] = {}
146
+
147
+ for error in errors:
148
+ api_error = error.api_error
149
+
150
+ responses[api_error.code] = {
151
+ "model": ErrorResponse,
152
+ "description": api_error.msg,
153
+ "content": {
154
+ "application/json": {
155
+ "example": {
156
+ "data": None,
157
+ "status": ResponseStatus.FAIL.value,
158
+ "message": api_error.msg,
159
+ "description": api_error.desc,
160
+ "error_code": api_error.err_code,
161
+ }
162
+ }
163
+ },
164
+ }
165
+
166
+ return responses
@@ -0,0 +1,169 @@
1
+ """Exception handlers for FastAPI applications."""
2
+
3
+ from typing import Any
4
+
5
+ from fastapi import FastAPI, Request, Response, status
6
+ from fastapi.exceptions import RequestValidationError, ResponseValidationError
7
+ from fastapi.openapi.utils import get_openapi
8
+ from fastapi.responses import JSONResponse
9
+
10
+ from ..schemas import ResponseStatus
11
+ from .exceptions import ApiException
12
+
13
+
14
+ def init_exceptions_handlers(app: FastAPI) -> FastAPI:
15
+ _register_exception_handlers(app)
16
+ app.openapi = lambda: _custom_openapi(app) # type: ignore[method-assign]
17
+ return app
18
+
19
+
20
+ def _register_exception_handlers(app: FastAPI) -> None:
21
+ """Register all exception handlers on a FastAPI application.
22
+
23
+ Args:
24
+ app: FastAPI application instance
25
+
26
+ Example:
27
+ from fastapi import FastAPI
28
+ from fastapi_toolsets.exceptions import init_exceptions_handlers
29
+
30
+ app = FastAPI()
31
+ init_exceptions_handlers(app)
32
+ """
33
+
34
+ @app.exception_handler(ApiException)
35
+ async def api_exception_handler(request: Request, exc: ApiException) -> Response:
36
+ """Handle custom API exceptions with structured response."""
37
+ api_error = exc.api_error
38
+
39
+ return JSONResponse(
40
+ status_code=api_error.code,
41
+ content={
42
+ "data": None,
43
+ "status": ResponseStatus.FAIL.value,
44
+ "message": api_error.msg,
45
+ "description": api_error.desc,
46
+ "error_code": api_error.err_code,
47
+ },
48
+ )
49
+
50
+ @app.exception_handler(RequestValidationError)
51
+ async def request_validation_handler(
52
+ request: Request, exc: RequestValidationError
53
+ ) -> Response:
54
+ """Handle Pydantic request validation errors (422)."""
55
+ return _format_validation_error(exc)
56
+
57
+ @app.exception_handler(ResponseValidationError)
58
+ async def response_validation_handler(
59
+ request: Request, exc: ResponseValidationError
60
+ ) -> Response:
61
+ """Handle Pydantic response validation errors (422)."""
62
+ return _format_validation_error(exc)
63
+
64
+ @app.exception_handler(Exception)
65
+ async def generic_exception_handler(request: Request, exc: Exception) -> Response:
66
+ """Handle all unhandled exceptions with a generic 500 response."""
67
+ return JSONResponse(
68
+ status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
69
+ content={
70
+ "data": None,
71
+ "status": ResponseStatus.FAIL.value,
72
+ "message": "Internal Server Error",
73
+ "description": "An unexpected error occurred. Please try again later.",
74
+ "error_code": "SERVER-500",
75
+ },
76
+ )
77
+
78
+
79
+ def _format_validation_error(
80
+ exc: RequestValidationError | ResponseValidationError,
81
+ ) -> JSONResponse:
82
+ """Format validation errors into a structured response."""
83
+ errors = exc.errors()
84
+ formatted_errors = []
85
+
86
+ for error in errors:
87
+ field_path = ".".join(
88
+ str(loc)
89
+ for loc in error["loc"]
90
+ if loc not in ("body", "query", "path", "header", "cookie")
91
+ )
92
+ formatted_errors.append(
93
+ {
94
+ "field": field_path or "root",
95
+ "message": error.get("msg", ""),
96
+ "type": error.get("type", ""),
97
+ }
98
+ )
99
+
100
+ return JSONResponse(
101
+ status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
102
+ content={
103
+ "data": {"errors": formatted_errors},
104
+ "status": ResponseStatus.FAIL.value,
105
+ "message": "Validation Error",
106
+ "description": f"{len(formatted_errors)} validation error(s) detected",
107
+ "error_code": "VAL-422",
108
+ },
109
+ )
110
+
111
+
112
+ def _custom_openapi(app: FastAPI) -> dict[str, Any]:
113
+ """Generate custom OpenAPI schema with standardized error format.
114
+
115
+ Replaces default 422 validation error responses with the custom format.
116
+
117
+ Args:
118
+ app: FastAPI application instance
119
+
120
+ Returns:
121
+ OpenAPI schema dict
122
+
123
+ Example:
124
+ from fastapi import FastAPI
125
+ from fastapi_toolsets.exceptions import init_exceptions_handlers
126
+
127
+ app = FastAPI()
128
+ init_exceptions_handlers(app) # Automatically sets custom OpenAPI
129
+ """
130
+ if app.openapi_schema:
131
+ return app.openapi_schema
132
+
133
+ openapi_schema = get_openapi(
134
+ title=app.title,
135
+ version=app.version,
136
+ openapi_version=app.openapi_version,
137
+ description=app.description,
138
+ routes=app.routes,
139
+ )
140
+
141
+ for path_data in openapi_schema.get("paths", {}).values():
142
+ for operation in path_data.values():
143
+ if isinstance(operation, dict) and "responses" in operation:
144
+ if "422" in operation["responses"]:
145
+ operation["responses"]["422"] = {
146
+ "description": "Validation Error",
147
+ "content": {
148
+ "application/json": {
149
+ "example": {
150
+ "data": {
151
+ "errors": [
152
+ {
153
+ "field": "field_name",
154
+ "message": "value is not valid",
155
+ "type": "value_error",
156
+ }
157
+ ]
158
+ },
159
+ "status": ResponseStatus.FAIL.value,
160
+ "message": "Validation Error",
161
+ "description": "1 validation error(s) detected",
162
+ "error_code": "VAL-422",
163
+ }
164
+ }
165
+ },
166
+ }
167
+
168
+ app.openapi_schema = openapi_schema
169
+ return app.openapi_schema
@@ -0,0 +1,17 @@
1
+ from .fixtures import (
2
+ Context,
3
+ FixtureRegistry,
4
+ LoadStrategy,
5
+ load_fixtures,
6
+ load_fixtures_by_context,
7
+ )
8
+ from .pytest_plugin import register_fixtures
9
+
10
+ __all__ = [
11
+ "Context",
12
+ "FixtureRegistry",
13
+ "LoadStrategy",
14
+ "load_fixtures",
15
+ "load_fixtures_by_context",
16
+ "register_fixtures",
17
+ ]
@@ -0,0 +1,321 @@
1
+ """Fixture system with dependency management and context support."""
2
+
3
+ import logging
4
+ from collections.abc import Callable, Sequence
5
+ from dataclasses import dataclass, field
6
+ from enum import Enum
7
+ from typing import Any, cast
8
+
9
+ from sqlalchemy.ext.asyncio import AsyncSession
10
+ from sqlalchemy.orm import DeclarativeBase
11
+
12
+ from ..db import get_transaction
13
+
14
+ logger = logging.getLogger(__name__)
15
+
16
+
17
+ class LoadStrategy(str, Enum):
18
+ """Strategy for loading fixtures into the database."""
19
+
20
+ INSERT = "insert"
21
+ """Insert new records. Fails if record already exists."""
22
+
23
+ MERGE = "merge"
24
+ """Insert or update based on primary key (SQLAlchemy merge)."""
25
+
26
+ SKIP_EXISTING = "skip_existing"
27
+ """Insert only if record doesn't exist (based on primary key)."""
28
+
29
+
30
+ class Context(str, Enum):
31
+ """Predefined fixture contexts."""
32
+
33
+ BASE = "base"
34
+ """Base fixtures loaded in all environments."""
35
+
36
+ PRODUCTION = "production"
37
+ """Production-only fixtures."""
38
+
39
+ DEVELOPMENT = "development"
40
+ """Development fixtures."""
41
+
42
+ TESTING = "testing"
43
+ """Test fixtures."""
44
+
45
+
46
+ @dataclass
47
+ class Fixture:
48
+ """A fixture definition with metadata."""
49
+
50
+ name: str
51
+ func: Callable[[], Sequence[DeclarativeBase]]
52
+ depends_on: list[str] = field(default_factory=list)
53
+ contexts: list[str] = field(default_factory=lambda: [Context.BASE])
54
+
55
+
56
+ class FixtureRegistry:
57
+ """Registry for managing fixtures with dependencies.
58
+
59
+ Example:
60
+ from fastapi_toolsets.fixtures import FixtureRegistry, Context
61
+
62
+ fixtures = FixtureRegistry()
63
+
64
+ @fixtures.register
65
+ def roles():
66
+ return [
67
+ Role(id=1, name="admin"),
68
+ Role(id=2, name="user"),
69
+ ]
70
+
71
+ @fixtures.register(depends_on=["roles"])
72
+ def users():
73
+ return [
74
+ User(id=1, username="admin", role_id=1),
75
+ ]
76
+
77
+ @fixtures.register(depends_on=["users"], contexts=[Context.TESTING])
78
+ def test_data():
79
+ return [
80
+ Post(id=1, title="Test", user_id=1),
81
+ ]
82
+ """
83
+
84
+ def __init__(self) -> None:
85
+ self._fixtures: dict[str, Fixture] = {}
86
+
87
+ def register(
88
+ self,
89
+ func: Callable[[], Sequence[DeclarativeBase]] | None = None,
90
+ *,
91
+ name: str | None = None,
92
+ depends_on: list[str] | None = None,
93
+ contexts: list[str | Context] | None = None,
94
+ ) -> Callable[..., Any]:
95
+ """Register a fixture function.
96
+
97
+ Can be used as a decorator with or without arguments.
98
+
99
+ Args:
100
+ func: Fixture function returning list of model instances
101
+ name: Fixture name (defaults to function name)
102
+ depends_on: List of fixture names this depends on
103
+ contexts: List of contexts this fixture belongs to
104
+
105
+ Example:
106
+ @fixtures.register
107
+ def roles():
108
+ return [Role(id=1, name="admin")]
109
+
110
+ @fixtures.register(depends_on=["roles"], contexts=[Context.TESTING])
111
+ def test_users():
112
+ return [User(id=1, username="test", role_id=1)]
113
+ """
114
+
115
+ def decorator(
116
+ fn: Callable[[], Sequence[DeclarativeBase]],
117
+ ) -> Callable[[], Sequence[DeclarativeBase]]:
118
+ fixture_name = name or cast(Any, fn).__name__
119
+ fixture_contexts = [
120
+ c.value if isinstance(c, Context) else c
121
+ for c in (contexts or [Context.BASE])
122
+ ]
123
+
124
+ self._fixtures[fixture_name] = Fixture(
125
+ name=fixture_name,
126
+ func=fn,
127
+ depends_on=depends_on or [],
128
+ contexts=fixture_contexts,
129
+ )
130
+ return fn
131
+
132
+ if func is not None:
133
+ return decorator(func)
134
+ return decorator
135
+
136
+ def get(self, name: str) -> Fixture:
137
+ """Get a fixture by name."""
138
+ if name not in self._fixtures:
139
+ raise KeyError(f"Fixture '{name}' not found")
140
+ return self._fixtures[name]
141
+
142
+ def get_all(self) -> list[Fixture]:
143
+ """Get all registered fixtures."""
144
+ return list(self._fixtures.values())
145
+
146
+ def get_by_context(self, *contexts: str | Context) -> list[Fixture]:
147
+ """Get fixtures for specific contexts."""
148
+ context_values = {c.value if isinstance(c, Context) else c for c in contexts}
149
+ return [f for f in self._fixtures.values() if set(f.contexts) & context_values]
150
+
151
+ def resolve_dependencies(self, *names: str) -> list[str]:
152
+ """Resolve fixture dependencies in topological order.
153
+
154
+ Args:
155
+ *names: Fixture names to resolve
156
+
157
+ Returns:
158
+ List of fixture names in load order (dependencies first)
159
+
160
+ Raises:
161
+ KeyError: If a fixture is not found
162
+ ValueError: If circular dependency detected
163
+ """
164
+ resolved: list[str] = []
165
+ seen: set[str] = set()
166
+ visiting: set[str] = set()
167
+
168
+ def visit(name: str) -> None:
169
+ if name in resolved:
170
+ return
171
+ if name in visiting:
172
+ raise ValueError(f"Circular dependency detected: {name}")
173
+
174
+ visiting.add(name)
175
+ fixture = self.get(name)
176
+
177
+ for dep in fixture.depends_on:
178
+ visit(dep)
179
+
180
+ visiting.remove(name)
181
+ resolved.append(name)
182
+ seen.add(name)
183
+
184
+ for name in names:
185
+ visit(name)
186
+
187
+ return resolved
188
+
189
+ def resolve_context_dependencies(self, *contexts: str | Context) -> list[str]:
190
+ """Resolve all fixtures for contexts with dependencies.
191
+
192
+ Args:
193
+ *contexts: Contexts to load
194
+
195
+ Returns:
196
+ List of fixture names in load order
197
+ """
198
+ context_fixtures = self.get_by_context(*contexts)
199
+ names = [f.name for f in context_fixtures]
200
+
201
+ all_deps: set[str] = set()
202
+ for name in names:
203
+ deps = self.resolve_dependencies(name)
204
+ all_deps.update(deps)
205
+
206
+ return self.resolve_dependencies(*all_deps)
207
+
208
+
209
+ async def load_fixtures(
210
+ session: AsyncSession,
211
+ registry: FixtureRegistry,
212
+ *names: str,
213
+ strategy: LoadStrategy = LoadStrategy.MERGE,
214
+ ) -> dict[str, list[DeclarativeBase]]:
215
+ """Load specific fixtures by name with dependencies.
216
+
217
+ Args:
218
+ session: Database session
219
+ registry: Fixture registry
220
+ *names: Fixture names to load (dependencies auto-resolved)
221
+ strategy: How to handle existing records
222
+
223
+ Returns:
224
+ Dict mapping fixture names to loaded instances
225
+
226
+ Example:
227
+ # Loads 'roles' first (dependency), then 'users'
228
+ result = await load_fixtures(session, fixtures, "users")
229
+ print(result["users"]) # [User(...), ...]
230
+ """
231
+ ordered = registry.resolve_dependencies(*names)
232
+ return await _load_ordered(session, registry, ordered, strategy)
233
+
234
+
235
+ async def load_fixtures_by_context(
236
+ session: AsyncSession,
237
+ registry: FixtureRegistry,
238
+ *contexts: str | Context,
239
+ strategy: LoadStrategy = LoadStrategy.MERGE,
240
+ ) -> dict[str, list[DeclarativeBase]]:
241
+ """Load all fixtures for specific contexts.
242
+
243
+ Args:
244
+ session: Database session
245
+ registry: Fixture registry
246
+ *contexts: Contexts to load (e.g., Context.BASE, Context.TESTING)
247
+ strategy: How to handle existing records
248
+
249
+ Returns:
250
+ Dict mapping fixture names to loaded instances
251
+
252
+ Example:
253
+ # Load base + testing fixtures
254
+ await load_fixtures_by_context(
255
+ session, fixtures,
256
+ Context.BASE, Context.TESTING
257
+ )
258
+ """
259
+ ordered = registry.resolve_context_dependencies(*contexts)
260
+ return await _load_ordered(session, registry, ordered, strategy)
261
+
262
+
263
+ async def _load_ordered(
264
+ session: AsyncSession,
265
+ registry: FixtureRegistry,
266
+ ordered_names: list[str],
267
+ strategy: LoadStrategy,
268
+ ) -> dict[str, list[DeclarativeBase]]:
269
+ """Load fixtures in order."""
270
+ results: dict[str, list[DeclarativeBase]] = {}
271
+
272
+ for name in ordered_names:
273
+ fixture = registry.get(name)
274
+ instances = list(fixture.func())
275
+
276
+ if not instances:
277
+ results[name] = []
278
+ continue
279
+
280
+ model_name = type(instances[0]).__name__
281
+ loaded: list[DeclarativeBase] = []
282
+
283
+ async with get_transaction(session):
284
+ for instance in instances:
285
+ if strategy == LoadStrategy.INSERT:
286
+ session.add(instance)
287
+ loaded.append(instance)
288
+
289
+ elif strategy == LoadStrategy.MERGE:
290
+ merged = await session.merge(instance)
291
+ loaded.append(merged)
292
+
293
+ elif strategy == LoadStrategy.SKIP_EXISTING:
294
+ pk = _get_primary_key(instance)
295
+ if pk is not None:
296
+ existing = await session.get(type(instance), pk)
297
+ if existing is None:
298
+ session.add(instance)
299
+ loaded.append(instance)
300
+ else:
301
+ session.add(instance)
302
+ loaded.append(instance)
303
+
304
+ results[name] = loaded
305
+ logger.info(f"Loaded fixture '{name}': {len(loaded)} {model_name}(s)")
306
+
307
+ return results
308
+
309
+
310
+ def _get_primary_key(instance: DeclarativeBase) -> Any | None:
311
+ """Get the primary key value of a model instance."""
312
+ mapper = instance.__class__.__mapper__
313
+ pk_cols = mapper.primary_key
314
+
315
+ if len(pk_cols) == 1:
316
+ return getattr(instance, pk_cols[0].name, None)
317
+
318
+ pk_values = tuple(getattr(instance, col.name, None) for col in pk_cols)
319
+ if all(v is not None for v in pk_values):
320
+ return pk_values
321
+ return None