fastapi-toolsets 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- fastapi_toolsets/__init__.py +24 -0
- fastapi_toolsets/cli/__init__.py +5 -0
- fastapi_toolsets/cli/app.py +97 -0
- fastapi_toolsets/cli/commands/__init__.py +1 -0
- fastapi_toolsets/cli/commands/fixtures.py +225 -0
- fastapi_toolsets/crud.py +378 -0
- fastapi_toolsets/db.py +175 -0
- fastapi_toolsets/exceptions/__init__.py +19 -0
- fastapi_toolsets/exceptions/exceptions.py +166 -0
- fastapi_toolsets/exceptions/handler.py +169 -0
- fastapi_toolsets/fixtures/__init__.py +17 -0
- fastapi_toolsets/fixtures/fixtures.py +321 -0
- fastapi_toolsets/fixtures/pytest_plugin.py +204 -0
- fastapi_toolsets/py.typed +0 -0
- fastapi_toolsets/schemas.py +116 -0
- fastapi_toolsets-0.1.0.dist-info/METADATA +89 -0
- fastapi_toolsets-0.1.0.dist-info/RECORD +20 -0
- fastapi_toolsets-0.1.0.dist-info/WHEEL +4 -0
- fastapi_toolsets-0.1.0.dist-info/entry_points.txt +3 -0
- fastapi_toolsets-0.1.0.dist-info/licenses/LICENSE +21 -0
|
@@ -0,0 +1,166 @@
|
|
|
1
|
+
"""Custom exceptions with standardized API error responses."""
|
|
2
|
+
|
|
3
|
+
from typing import Any, ClassVar
|
|
4
|
+
|
|
5
|
+
from ..schemas import ApiError, ErrorResponse, ResponseStatus
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class ApiException(Exception):
|
|
9
|
+
"""Base exception for API errors with structured response.
|
|
10
|
+
|
|
11
|
+
Subclass this to create custom API exceptions with consistent error format.
|
|
12
|
+
The exception handler will use api_error to generate the response.
|
|
13
|
+
|
|
14
|
+
Example:
|
|
15
|
+
class CustomError(ApiException):
|
|
16
|
+
api_error = ApiError(
|
|
17
|
+
code=400,
|
|
18
|
+
msg="Bad Request",
|
|
19
|
+
desc="The request was invalid.",
|
|
20
|
+
err_code="CUSTOM-400",
|
|
21
|
+
)
|
|
22
|
+
"""
|
|
23
|
+
|
|
24
|
+
api_error: ClassVar[ApiError]
|
|
25
|
+
|
|
26
|
+
def __init__(self, detail: str | None = None):
|
|
27
|
+
"""Initialize the exception.
|
|
28
|
+
|
|
29
|
+
Args:
|
|
30
|
+
detail: Optional override for the error message
|
|
31
|
+
"""
|
|
32
|
+
super().__init__(detail or self.api_error.msg)
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
class UnauthorizedError(ApiException):
|
|
36
|
+
"""HTTP 401 - User is not authenticated."""
|
|
37
|
+
|
|
38
|
+
api_error = ApiError(
|
|
39
|
+
code=401,
|
|
40
|
+
msg="Unauthorized",
|
|
41
|
+
desc="Authentication credentials were missing or invalid.",
|
|
42
|
+
err_code="AUTH-401",
|
|
43
|
+
)
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
class ForbiddenError(ApiException):
|
|
47
|
+
"""HTTP 403 - User lacks required permissions."""
|
|
48
|
+
|
|
49
|
+
api_error = ApiError(
|
|
50
|
+
code=403,
|
|
51
|
+
msg="Forbidden",
|
|
52
|
+
desc="You do not have permission to access this resource.",
|
|
53
|
+
err_code="AUTH-403",
|
|
54
|
+
)
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
class NotFoundError(ApiException):
|
|
58
|
+
"""HTTP 404 - Resource not found."""
|
|
59
|
+
|
|
60
|
+
api_error = ApiError(
|
|
61
|
+
code=404,
|
|
62
|
+
msg="Not Found",
|
|
63
|
+
desc="The requested resource was not found.",
|
|
64
|
+
err_code="RES-404",
|
|
65
|
+
)
|
|
66
|
+
|
|
67
|
+
|
|
68
|
+
class ConflictError(ApiException):
|
|
69
|
+
"""HTTP 409 - Resource conflict."""
|
|
70
|
+
|
|
71
|
+
api_error = ApiError(
|
|
72
|
+
code=409,
|
|
73
|
+
msg="Conflict",
|
|
74
|
+
desc="The request conflicts with the current state of the resource.",
|
|
75
|
+
err_code="RES-409",
|
|
76
|
+
)
|
|
77
|
+
|
|
78
|
+
|
|
79
|
+
class InsufficientRolesError(ForbiddenError):
|
|
80
|
+
"""User does not have the required roles."""
|
|
81
|
+
|
|
82
|
+
api_error = ApiError(
|
|
83
|
+
code=403,
|
|
84
|
+
msg="Insufficient Roles",
|
|
85
|
+
desc="You do not have the required roles to access this resource.",
|
|
86
|
+
err_code="RBAC-403",
|
|
87
|
+
)
|
|
88
|
+
|
|
89
|
+
def __init__(self, required_roles: list[str], user_roles: set[str] | None = None):
|
|
90
|
+
self.required_roles = required_roles
|
|
91
|
+
self.user_roles = user_roles
|
|
92
|
+
|
|
93
|
+
desc = f"Required roles: {', '.join(required_roles)}"
|
|
94
|
+
if user_roles is not None:
|
|
95
|
+
desc += f". User has: {', '.join(user_roles) if user_roles else 'no roles'}"
|
|
96
|
+
|
|
97
|
+
super().__init__(desc)
|
|
98
|
+
|
|
99
|
+
|
|
100
|
+
class UserNotFoundError(NotFoundError):
|
|
101
|
+
"""User was not found."""
|
|
102
|
+
|
|
103
|
+
api_error = ApiError(
|
|
104
|
+
code=404,
|
|
105
|
+
msg="User Not Found",
|
|
106
|
+
desc="The requested user was not found.",
|
|
107
|
+
err_code="USER-404",
|
|
108
|
+
)
|
|
109
|
+
|
|
110
|
+
|
|
111
|
+
class RoleNotFoundError(NotFoundError):
|
|
112
|
+
"""Role was not found."""
|
|
113
|
+
|
|
114
|
+
api_error = ApiError(
|
|
115
|
+
code=404,
|
|
116
|
+
msg="Role Not Found",
|
|
117
|
+
desc="The requested role was not found.",
|
|
118
|
+
err_code="ROLE-404",
|
|
119
|
+
)
|
|
120
|
+
|
|
121
|
+
|
|
122
|
+
def generate_error_responses(
|
|
123
|
+
*errors: type[ApiException],
|
|
124
|
+
) -> dict[int | str, dict[str, Any]]:
|
|
125
|
+
"""Generate OpenAPI response documentation for exceptions.
|
|
126
|
+
|
|
127
|
+
Use this to document possible error responses for an endpoint.
|
|
128
|
+
|
|
129
|
+
Args:
|
|
130
|
+
*errors: Exception classes that inherit from ApiException
|
|
131
|
+
|
|
132
|
+
Returns:
|
|
133
|
+
Dict suitable for FastAPI's responses parameter
|
|
134
|
+
|
|
135
|
+
Example:
|
|
136
|
+
from fastapi_toolsets.exceptions import generate_error_responses, UnauthorizedError, ForbiddenError
|
|
137
|
+
|
|
138
|
+
@app.get(
|
|
139
|
+
"/admin",
|
|
140
|
+
responses=generate_error_responses(UnauthorizedError, ForbiddenError)
|
|
141
|
+
)
|
|
142
|
+
async def admin_endpoint():
|
|
143
|
+
...
|
|
144
|
+
"""
|
|
145
|
+
responses: dict[int | str, dict[str, Any]] = {}
|
|
146
|
+
|
|
147
|
+
for error in errors:
|
|
148
|
+
api_error = error.api_error
|
|
149
|
+
|
|
150
|
+
responses[api_error.code] = {
|
|
151
|
+
"model": ErrorResponse,
|
|
152
|
+
"description": api_error.msg,
|
|
153
|
+
"content": {
|
|
154
|
+
"application/json": {
|
|
155
|
+
"example": {
|
|
156
|
+
"data": None,
|
|
157
|
+
"status": ResponseStatus.FAIL.value,
|
|
158
|
+
"message": api_error.msg,
|
|
159
|
+
"description": api_error.desc,
|
|
160
|
+
"error_code": api_error.err_code,
|
|
161
|
+
}
|
|
162
|
+
}
|
|
163
|
+
},
|
|
164
|
+
}
|
|
165
|
+
|
|
166
|
+
return responses
|
|
@@ -0,0 +1,169 @@
|
|
|
1
|
+
"""Exception handlers for FastAPI applications."""
|
|
2
|
+
|
|
3
|
+
from typing import Any
|
|
4
|
+
|
|
5
|
+
from fastapi import FastAPI, Request, Response, status
|
|
6
|
+
from fastapi.exceptions import RequestValidationError, ResponseValidationError
|
|
7
|
+
from fastapi.openapi.utils import get_openapi
|
|
8
|
+
from fastapi.responses import JSONResponse
|
|
9
|
+
|
|
10
|
+
from ..schemas import ResponseStatus
|
|
11
|
+
from .exceptions import ApiException
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
def init_exceptions_handlers(app: FastAPI) -> FastAPI:
|
|
15
|
+
_register_exception_handlers(app)
|
|
16
|
+
app.openapi = lambda: _custom_openapi(app) # type: ignore[method-assign]
|
|
17
|
+
return app
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
def _register_exception_handlers(app: FastAPI) -> None:
|
|
21
|
+
"""Register all exception handlers on a FastAPI application.
|
|
22
|
+
|
|
23
|
+
Args:
|
|
24
|
+
app: FastAPI application instance
|
|
25
|
+
|
|
26
|
+
Example:
|
|
27
|
+
from fastapi import FastAPI
|
|
28
|
+
from fastapi_toolsets.exceptions import init_exceptions_handlers
|
|
29
|
+
|
|
30
|
+
app = FastAPI()
|
|
31
|
+
init_exceptions_handlers(app)
|
|
32
|
+
"""
|
|
33
|
+
|
|
34
|
+
@app.exception_handler(ApiException)
|
|
35
|
+
async def api_exception_handler(request: Request, exc: ApiException) -> Response:
|
|
36
|
+
"""Handle custom API exceptions with structured response."""
|
|
37
|
+
api_error = exc.api_error
|
|
38
|
+
|
|
39
|
+
return JSONResponse(
|
|
40
|
+
status_code=api_error.code,
|
|
41
|
+
content={
|
|
42
|
+
"data": None,
|
|
43
|
+
"status": ResponseStatus.FAIL.value,
|
|
44
|
+
"message": api_error.msg,
|
|
45
|
+
"description": api_error.desc,
|
|
46
|
+
"error_code": api_error.err_code,
|
|
47
|
+
},
|
|
48
|
+
)
|
|
49
|
+
|
|
50
|
+
@app.exception_handler(RequestValidationError)
|
|
51
|
+
async def request_validation_handler(
|
|
52
|
+
request: Request, exc: RequestValidationError
|
|
53
|
+
) -> Response:
|
|
54
|
+
"""Handle Pydantic request validation errors (422)."""
|
|
55
|
+
return _format_validation_error(exc)
|
|
56
|
+
|
|
57
|
+
@app.exception_handler(ResponseValidationError)
|
|
58
|
+
async def response_validation_handler(
|
|
59
|
+
request: Request, exc: ResponseValidationError
|
|
60
|
+
) -> Response:
|
|
61
|
+
"""Handle Pydantic response validation errors (422)."""
|
|
62
|
+
return _format_validation_error(exc)
|
|
63
|
+
|
|
64
|
+
@app.exception_handler(Exception)
|
|
65
|
+
async def generic_exception_handler(request: Request, exc: Exception) -> Response:
|
|
66
|
+
"""Handle all unhandled exceptions with a generic 500 response."""
|
|
67
|
+
return JSONResponse(
|
|
68
|
+
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
|
69
|
+
content={
|
|
70
|
+
"data": None,
|
|
71
|
+
"status": ResponseStatus.FAIL.value,
|
|
72
|
+
"message": "Internal Server Error",
|
|
73
|
+
"description": "An unexpected error occurred. Please try again later.",
|
|
74
|
+
"error_code": "SERVER-500",
|
|
75
|
+
},
|
|
76
|
+
)
|
|
77
|
+
|
|
78
|
+
|
|
79
|
+
def _format_validation_error(
|
|
80
|
+
exc: RequestValidationError | ResponseValidationError,
|
|
81
|
+
) -> JSONResponse:
|
|
82
|
+
"""Format validation errors into a structured response."""
|
|
83
|
+
errors = exc.errors()
|
|
84
|
+
formatted_errors = []
|
|
85
|
+
|
|
86
|
+
for error in errors:
|
|
87
|
+
field_path = ".".join(
|
|
88
|
+
str(loc)
|
|
89
|
+
for loc in error["loc"]
|
|
90
|
+
if loc not in ("body", "query", "path", "header", "cookie")
|
|
91
|
+
)
|
|
92
|
+
formatted_errors.append(
|
|
93
|
+
{
|
|
94
|
+
"field": field_path or "root",
|
|
95
|
+
"message": error.get("msg", ""),
|
|
96
|
+
"type": error.get("type", ""),
|
|
97
|
+
}
|
|
98
|
+
)
|
|
99
|
+
|
|
100
|
+
return JSONResponse(
|
|
101
|
+
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
|
|
102
|
+
content={
|
|
103
|
+
"data": {"errors": formatted_errors},
|
|
104
|
+
"status": ResponseStatus.FAIL.value,
|
|
105
|
+
"message": "Validation Error",
|
|
106
|
+
"description": f"{len(formatted_errors)} validation error(s) detected",
|
|
107
|
+
"error_code": "VAL-422",
|
|
108
|
+
},
|
|
109
|
+
)
|
|
110
|
+
|
|
111
|
+
|
|
112
|
+
def _custom_openapi(app: FastAPI) -> dict[str, Any]:
|
|
113
|
+
"""Generate custom OpenAPI schema with standardized error format.
|
|
114
|
+
|
|
115
|
+
Replaces default 422 validation error responses with the custom format.
|
|
116
|
+
|
|
117
|
+
Args:
|
|
118
|
+
app: FastAPI application instance
|
|
119
|
+
|
|
120
|
+
Returns:
|
|
121
|
+
OpenAPI schema dict
|
|
122
|
+
|
|
123
|
+
Example:
|
|
124
|
+
from fastapi import FastAPI
|
|
125
|
+
from fastapi_toolsets.exceptions import init_exceptions_handlers
|
|
126
|
+
|
|
127
|
+
app = FastAPI()
|
|
128
|
+
init_exceptions_handlers(app) # Automatically sets custom OpenAPI
|
|
129
|
+
"""
|
|
130
|
+
if app.openapi_schema:
|
|
131
|
+
return app.openapi_schema
|
|
132
|
+
|
|
133
|
+
openapi_schema = get_openapi(
|
|
134
|
+
title=app.title,
|
|
135
|
+
version=app.version,
|
|
136
|
+
openapi_version=app.openapi_version,
|
|
137
|
+
description=app.description,
|
|
138
|
+
routes=app.routes,
|
|
139
|
+
)
|
|
140
|
+
|
|
141
|
+
for path_data in openapi_schema.get("paths", {}).values():
|
|
142
|
+
for operation in path_data.values():
|
|
143
|
+
if isinstance(operation, dict) and "responses" in operation:
|
|
144
|
+
if "422" in operation["responses"]:
|
|
145
|
+
operation["responses"]["422"] = {
|
|
146
|
+
"description": "Validation Error",
|
|
147
|
+
"content": {
|
|
148
|
+
"application/json": {
|
|
149
|
+
"example": {
|
|
150
|
+
"data": {
|
|
151
|
+
"errors": [
|
|
152
|
+
{
|
|
153
|
+
"field": "field_name",
|
|
154
|
+
"message": "value is not valid",
|
|
155
|
+
"type": "value_error",
|
|
156
|
+
}
|
|
157
|
+
]
|
|
158
|
+
},
|
|
159
|
+
"status": ResponseStatus.FAIL.value,
|
|
160
|
+
"message": "Validation Error",
|
|
161
|
+
"description": "1 validation error(s) detected",
|
|
162
|
+
"error_code": "VAL-422",
|
|
163
|
+
}
|
|
164
|
+
}
|
|
165
|
+
},
|
|
166
|
+
}
|
|
167
|
+
|
|
168
|
+
app.openapi_schema = openapi_schema
|
|
169
|
+
return app.openapi_schema
|
|
@@ -0,0 +1,17 @@
|
|
|
1
|
+
from .fixtures import (
|
|
2
|
+
Context,
|
|
3
|
+
FixtureRegistry,
|
|
4
|
+
LoadStrategy,
|
|
5
|
+
load_fixtures,
|
|
6
|
+
load_fixtures_by_context,
|
|
7
|
+
)
|
|
8
|
+
from .pytest_plugin import register_fixtures
|
|
9
|
+
|
|
10
|
+
__all__ = [
|
|
11
|
+
"Context",
|
|
12
|
+
"FixtureRegistry",
|
|
13
|
+
"LoadStrategy",
|
|
14
|
+
"load_fixtures",
|
|
15
|
+
"load_fixtures_by_context",
|
|
16
|
+
"register_fixtures",
|
|
17
|
+
]
|
|
@@ -0,0 +1,321 @@
|
|
|
1
|
+
"""Fixture system with dependency management and context support."""
|
|
2
|
+
|
|
3
|
+
import logging
|
|
4
|
+
from collections.abc import Callable, Sequence
|
|
5
|
+
from dataclasses import dataclass, field
|
|
6
|
+
from enum import Enum
|
|
7
|
+
from typing import Any, cast
|
|
8
|
+
|
|
9
|
+
from sqlalchemy.ext.asyncio import AsyncSession
|
|
10
|
+
from sqlalchemy.orm import DeclarativeBase
|
|
11
|
+
|
|
12
|
+
from ..db import get_transaction
|
|
13
|
+
|
|
14
|
+
logger = logging.getLogger(__name__)
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class LoadStrategy(str, Enum):
|
|
18
|
+
"""Strategy for loading fixtures into the database."""
|
|
19
|
+
|
|
20
|
+
INSERT = "insert"
|
|
21
|
+
"""Insert new records. Fails if record already exists."""
|
|
22
|
+
|
|
23
|
+
MERGE = "merge"
|
|
24
|
+
"""Insert or update based on primary key (SQLAlchemy merge)."""
|
|
25
|
+
|
|
26
|
+
SKIP_EXISTING = "skip_existing"
|
|
27
|
+
"""Insert only if record doesn't exist (based on primary key)."""
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
class Context(str, Enum):
|
|
31
|
+
"""Predefined fixture contexts."""
|
|
32
|
+
|
|
33
|
+
BASE = "base"
|
|
34
|
+
"""Base fixtures loaded in all environments."""
|
|
35
|
+
|
|
36
|
+
PRODUCTION = "production"
|
|
37
|
+
"""Production-only fixtures."""
|
|
38
|
+
|
|
39
|
+
DEVELOPMENT = "development"
|
|
40
|
+
"""Development fixtures."""
|
|
41
|
+
|
|
42
|
+
TESTING = "testing"
|
|
43
|
+
"""Test fixtures."""
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
@dataclass
|
|
47
|
+
class Fixture:
|
|
48
|
+
"""A fixture definition with metadata."""
|
|
49
|
+
|
|
50
|
+
name: str
|
|
51
|
+
func: Callable[[], Sequence[DeclarativeBase]]
|
|
52
|
+
depends_on: list[str] = field(default_factory=list)
|
|
53
|
+
contexts: list[str] = field(default_factory=lambda: [Context.BASE])
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
class FixtureRegistry:
|
|
57
|
+
"""Registry for managing fixtures with dependencies.
|
|
58
|
+
|
|
59
|
+
Example:
|
|
60
|
+
from fastapi_toolsets.fixtures import FixtureRegistry, Context
|
|
61
|
+
|
|
62
|
+
fixtures = FixtureRegistry()
|
|
63
|
+
|
|
64
|
+
@fixtures.register
|
|
65
|
+
def roles():
|
|
66
|
+
return [
|
|
67
|
+
Role(id=1, name="admin"),
|
|
68
|
+
Role(id=2, name="user"),
|
|
69
|
+
]
|
|
70
|
+
|
|
71
|
+
@fixtures.register(depends_on=["roles"])
|
|
72
|
+
def users():
|
|
73
|
+
return [
|
|
74
|
+
User(id=1, username="admin", role_id=1),
|
|
75
|
+
]
|
|
76
|
+
|
|
77
|
+
@fixtures.register(depends_on=["users"], contexts=[Context.TESTING])
|
|
78
|
+
def test_data():
|
|
79
|
+
return [
|
|
80
|
+
Post(id=1, title="Test", user_id=1),
|
|
81
|
+
]
|
|
82
|
+
"""
|
|
83
|
+
|
|
84
|
+
def __init__(self) -> None:
|
|
85
|
+
self._fixtures: dict[str, Fixture] = {}
|
|
86
|
+
|
|
87
|
+
def register(
|
|
88
|
+
self,
|
|
89
|
+
func: Callable[[], Sequence[DeclarativeBase]] | None = None,
|
|
90
|
+
*,
|
|
91
|
+
name: str | None = None,
|
|
92
|
+
depends_on: list[str] | None = None,
|
|
93
|
+
contexts: list[str | Context] | None = None,
|
|
94
|
+
) -> Callable[..., Any]:
|
|
95
|
+
"""Register a fixture function.
|
|
96
|
+
|
|
97
|
+
Can be used as a decorator with or without arguments.
|
|
98
|
+
|
|
99
|
+
Args:
|
|
100
|
+
func: Fixture function returning list of model instances
|
|
101
|
+
name: Fixture name (defaults to function name)
|
|
102
|
+
depends_on: List of fixture names this depends on
|
|
103
|
+
contexts: List of contexts this fixture belongs to
|
|
104
|
+
|
|
105
|
+
Example:
|
|
106
|
+
@fixtures.register
|
|
107
|
+
def roles():
|
|
108
|
+
return [Role(id=1, name="admin")]
|
|
109
|
+
|
|
110
|
+
@fixtures.register(depends_on=["roles"], contexts=[Context.TESTING])
|
|
111
|
+
def test_users():
|
|
112
|
+
return [User(id=1, username="test", role_id=1)]
|
|
113
|
+
"""
|
|
114
|
+
|
|
115
|
+
def decorator(
|
|
116
|
+
fn: Callable[[], Sequence[DeclarativeBase]],
|
|
117
|
+
) -> Callable[[], Sequence[DeclarativeBase]]:
|
|
118
|
+
fixture_name = name or cast(Any, fn).__name__
|
|
119
|
+
fixture_contexts = [
|
|
120
|
+
c.value if isinstance(c, Context) else c
|
|
121
|
+
for c in (contexts or [Context.BASE])
|
|
122
|
+
]
|
|
123
|
+
|
|
124
|
+
self._fixtures[fixture_name] = Fixture(
|
|
125
|
+
name=fixture_name,
|
|
126
|
+
func=fn,
|
|
127
|
+
depends_on=depends_on or [],
|
|
128
|
+
contexts=fixture_contexts,
|
|
129
|
+
)
|
|
130
|
+
return fn
|
|
131
|
+
|
|
132
|
+
if func is not None:
|
|
133
|
+
return decorator(func)
|
|
134
|
+
return decorator
|
|
135
|
+
|
|
136
|
+
def get(self, name: str) -> Fixture:
|
|
137
|
+
"""Get a fixture by name."""
|
|
138
|
+
if name not in self._fixtures:
|
|
139
|
+
raise KeyError(f"Fixture '{name}' not found")
|
|
140
|
+
return self._fixtures[name]
|
|
141
|
+
|
|
142
|
+
def get_all(self) -> list[Fixture]:
|
|
143
|
+
"""Get all registered fixtures."""
|
|
144
|
+
return list(self._fixtures.values())
|
|
145
|
+
|
|
146
|
+
def get_by_context(self, *contexts: str | Context) -> list[Fixture]:
|
|
147
|
+
"""Get fixtures for specific contexts."""
|
|
148
|
+
context_values = {c.value if isinstance(c, Context) else c for c in contexts}
|
|
149
|
+
return [f for f in self._fixtures.values() if set(f.contexts) & context_values]
|
|
150
|
+
|
|
151
|
+
def resolve_dependencies(self, *names: str) -> list[str]:
|
|
152
|
+
"""Resolve fixture dependencies in topological order.
|
|
153
|
+
|
|
154
|
+
Args:
|
|
155
|
+
*names: Fixture names to resolve
|
|
156
|
+
|
|
157
|
+
Returns:
|
|
158
|
+
List of fixture names in load order (dependencies first)
|
|
159
|
+
|
|
160
|
+
Raises:
|
|
161
|
+
KeyError: If a fixture is not found
|
|
162
|
+
ValueError: If circular dependency detected
|
|
163
|
+
"""
|
|
164
|
+
resolved: list[str] = []
|
|
165
|
+
seen: set[str] = set()
|
|
166
|
+
visiting: set[str] = set()
|
|
167
|
+
|
|
168
|
+
def visit(name: str) -> None:
|
|
169
|
+
if name in resolved:
|
|
170
|
+
return
|
|
171
|
+
if name in visiting:
|
|
172
|
+
raise ValueError(f"Circular dependency detected: {name}")
|
|
173
|
+
|
|
174
|
+
visiting.add(name)
|
|
175
|
+
fixture = self.get(name)
|
|
176
|
+
|
|
177
|
+
for dep in fixture.depends_on:
|
|
178
|
+
visit(dep)
|
|
179
|
+
|
|
180
|
+
visiting.remove(name)
|
|
181
|
+
resolved.append(name)
|
|
182
|
+
seen.add(name)
|
|
183
|
+
|
|
184
|
+
for name in names:
|
|
185
|
+
visit(name)
|
|
186
|
+
|
|
187
|
+
return resolved
|
|
188
|
+
|
|
189
|
+
def resolve_context_dependencies(self, *contexts: str | Context) -> list[str]:
|
|
190
|
+
"""Resolve all fixtures for contexts with dependencies.
|
|
191
|
+
|
|
192
|
+
Args:
|
|
193
|
+
*contexts: Contexts to load
|
|
194
|
+
|
|
195
|
+
Returns:
|
|
196
|
+
List of fixture names in load order
|
|
197
|
+
"""
|
|
198
|
+
context_fixtures = self.get_by_context(*contexts)
|
|
199
|
+
names = [f.name for f in context_fixtures]
|
|
200
|
+
|
|
201
|
+
all_deps: set[str] = set()
|
|
202
|
+
for name in names:
|
|
203
|
+
deps = self.resolve_dependencies(name)
|
|
204
|
+
all_deps.update(deps)
|
|
205
|
+
|
|
206
|
+
return self.resolve_dependencies(*all_deps)
|
|
207
|
+
|
|
208
|
+
|
|
209
|
+
async def load_fixtures(
|
|
210
|
+
session: AsyncSession,
|
|
211
|
+
registry: FixtureRegistry,
|
|
212
|
+
*names: str,
|
|
213
|
+
strategy: LoadStrategy = LoadStrategy.MERGE,
|
|
214
|
+
) -> dict[str, list[DeclarativeBase]]:
|
|
215
|
+
"""Load specific fixtures by name with dependencies.
|
|
216
|
+
|
|
217
|
+
Args:
|
|
218
|
+
session: Database session
|
|
219
|
+
registry: Fixture registry
|
|
220
|
+
*names: Fixture names to load (dependencies auto-resolved)
|
|
221
|
+
strategy: How to handle existing records
|
|
222
|
+
|
|
223
|
+
Returns:
|
|
224
|
+
Dict mapping fixture names to loaded instances
|
|
225
|
+
|
|
226
|
+
Example:
|
|
227
|
+
# Loads 'roles' first (dependency), then 'users'
|
|
228
|
+
result = await load_fixtures(session, fixtures, "users")
|
|
229
|
+
print(result["users"]) # [User(...), ...]
|
|
230
|
+
"""
|
|
231
|
+
ordered = registry.resolve_dependencies(*names)
|
|
232
|
+
return await _load_ordered(session, registry, ordered, strategy)
|
|
233
|
+
|
|
234
|
+
|
|
235
|
+
async def load_fixtures_by_context(
|
|
236
|
+
session: AsyncSession,
|
|
237
|
+
registry: FixtureRegistry,
|
|
238
|
+
*contexts: str | Context,
|
|
239
|
+
strategy: LoadStrategy = LoadStrategy.MERGE,
|
|
240
|
+
) -> dict[str, list[DeclarativeBase]]:
|
|
241
|
+
"""Load all fixtures for specific contexts.
|
|
242
|
+
|
|
243
|
+
Args:
|
|
244
|
+
session: Database session
|
|
245
|
+
registry: Fixture registry
|
|
246
|
+
*contexts: Contexts to load (e.g., Context.BASE, Context.TESTING)
|
|
247
|
+
strategy: How to handle existing records
|
|
248
|
+
|
|
249
|
+
Returns:
|
|
250
|
+
Dict mapping fixture names to loaded instances
|
|
251
|
+
|
|
252
|
+
Example:
|
|
253
|
+
# Load base + testing fixtures
|
|
254
|
+
await load_fixtures_by_context(
|
|
255
|
+
session, fixtures,
|
|
256
|
+
Context.BASE, Context.TESTING
|
|
257
|
+
)
|
|
258
|
+
"""
|
|
259
|
+
ordered = registry.resolve_context_dependencies(*contexts)
|
|
260
|
+
return await _load_ordered(session, registry, ordered, strategy)
|
|
261
|
+
|
|
262
|
+
|
|
263
|
+
async def _load_ordered(
|
|
264
|
+
session: AsyncSession,
|
|
265
|
+
registry: FixtureRegistry,
|
|
266
|
+
ordered_names: list[str],
|
|
267
|
+
strategy: LoadStrategy,
|
|
268
|
+
) -> dict[str, list[DeclarativeBase]]:
|
|
269
|
+
"""Load fixtures in order."""
|
|
270
|
+
results: dict[str, list[DeclarativeBase]] = {}
|
|
271
|
+
|
|
272
|
+
for name in ordered_names:
|
|
273
|
+
fixture = registry.get(name)
|
|
274
|
+
instances = list(fixture.func())
|
|
275
|
+
|
|
276
|
+
if not instances:
|
|
277
|
+
results[name] = []
|
|
278
|
+
continue
|
|
279
|
+
|
|
280
|
+
model_name = type(instances[0]).__name__
|
|
281
|
+
loaded: list[DeclarativeBase] = []
|
|
282
|
+
|
|
283
|
+
async with get_transaction(session):
|
|
284
|
+
for instance in instances:
|
|
285
|
+
if strategy == LoadStrategy.INSERT:
|
|
286
|
+
session.add(instance)
|
|
287
|
+
loaded.append(instance)
|
|
288
|
+
|
|
289
|
+
elif strategy == LoadStrategy.MERGE:
|
|
290
|
+
merged = await session.merge(instance)
|
|
291
|
+
loaded.append(merged)
|
|
292
|
+
|
|
293
|
+
elif strategy == LoadStrategy.SKIP_EXISTING:
|
|
294
|
+
pk = _get_primary_key(instance)
|
|
295
|
+
if pk is not None:
|
|
296
|
+
existing = await session.get(type(instance), pk)
|
|
297
|
+
if existing is None:
|
|
298
|
+
session.add(instance)
|
|
299
|
+
loaded.append(instance)
|
|
300
|
+
else:
|
|
301
|
+
session.add(instance)
|
|
302
|
+
loaded.append(instance)
|
|
303
|
+
|
|
304
|
+
results[name] = loaded
|
|
305
|
+
logger.info(f"Loaded fixture '{name}': {len(loaded)} {model_name}(s)")
|
|
306
|
+
|
|
307
|
+
return results
|
|
308
|
+
|
|
309
|
+
|
|
310
|
+
def _get_primary_key(instance: DeclarativeBase) -> Any | None:
|
|
311
|
+
"""Get the primary key value of a model instance."""
|
|
312
|
+
mapper = instance.__class__.__mapper__
|
|
313
|
+
pk_cols = mapper.primary_key
|
|
314
|
+
|
|
315
|
+
if len(pk_cols) == 1:
|
|
316
|
+
return getattr(instance, pk_cols[0].name, None)
|
|
317
|
+
|
|
318
|
+
pk_values = tuple(getattr(instance, col.name, None) for col in pk_cols)
|
|
319
|
+
if all(v is not None for v in pk_values):
|
|
320
|
+
return pk_values
|
|
321
|
+
return None
|