fastapi-restly 0.5.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- fastapi_restly/__init__.py +106 -0
- fastapi_restly/_exception_handlers.py +194 -0
- fastapi_restly/_pytest_fixtures.py +256 -0
- fastapi_restly/db/__init__.py +31 -0
- fastapi_restly/db/_globals.py +76 -0
- fastapi_restly/db/_proxy.py +42 -0
- fastapi_restly/db/_session.py +275 -0
- fastapi_restly/exceptions.py +12 -0
- fastapi_restly/models/__init__.py +18 -0
- fastapi_restly/models/_base.py +84 -0
- fastapi_restly/objects.py +144 -0
- fastapi_restly/py.typed +0 -0
- fastapi_restly/pytest_fixtures.py +24 -0
- fastapi_restly/query/__init__.py +13 -0
- fastapi_restly/query/_impl.py +594 -0
- fastapi_restly/query/_shared.py +22 -0
- fastapi_restly/schemas/__init__.py +29 -0
- fastapi_restly/schemas/_base.py +518 -0
- fastapi_restly/schemas/_generator.py +382 -0
- fastapi_restly/testing/__init__.py +20 -0
- fastapi_restly/testing/_client.py +98 -0
- fastapi_restly/testing/_fixtures.py +20 -0
- fastapi_restly/views/__init__.py +40 -0
- fastapi_restly/views/_async.py +216 -0
- fastapi_restly/views/_base.py +1294 -0
- fastapi_restly/views/_openapi.py +206 -0
- fastapi_restly/views/_react_admin.py +393 -0
- fastapi_restly/views/_sync.py +213 -0
- fastapi_restly-0.5.0.dist-info/METADATA +407 -0
- fastapi_restly-0.5.0.dist-info/RECORD +34 -0
- fastapi_restly-0.5.0.dist-info/WHEEL +5 -0
- fastapi_restly-0.5.0.dist-info/entry_points.txt +2 -0
- fastapi_restly-0.5.0.dist-info/licenses/LICENSE +21 -0
- fastapi_restly-0.5.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,106 @@
|
|
|
1
|
+
from importlib.metadata import PackageNotFoundError as _PackageNotFoundError
|
|
2
|
+
from importlib.metadata import version as _version
|
|
3
|
+
|
|
4
|
+
# Database layer
|
|
5
|
+
from .db import (
|
|
6
|
+
AsyncSessionDep,
|
|
7
|
+
SessionDep,
|
|
8
|
+
configure,
|
|
9
|
+
get_async_engine,
|
|
10
|
+
get_engine,
|
|
11
|
+
open_async_session,
|
|
12
|
+
open_session,
|
|
13
|
+
)
|
|
14
|
+
from .exceptions import RestlyConfigurationError, RestlyError
|
|
15
|
+
|
|
16
|
+
# Model base classes
|
|
17
|
+
from .models import DataclassBase, IDBase, IDMixin, TimestampsMixin
|
|
18
|
+
|
|
19
|
+
# List endpoint query parameters
|
|
20
|
+
from .query import apply_list_params, create_list_params_schema
|
|
21
|
+
|
|
22
|
+
# Schema utilities
|
|
23
|
+
from .schemas import (
|
|
24
|
+
BaseSchema,
|
|
25
|
+
IDRef,
|
|
26
|
+
IDSchema,
|
|
27
|
+
ReadOnly,
|
|
28
|
+
TimestampsSchemaMixin,
|
|
29
|
+
WriteOnly,
|
|
30
|
+
)
|
|
31
|
+
|
|
32
|
+
# Views
|
|
33
|
+
from .views import (
|
|
34
|
+
AsyncReactAdminView,
|
|
35
|
+
AsyncRestView,
|
|
36
|
+
ListingResult,
|
|
37
|
+
ReactAdminView,
|
|
38
|
+
RestView,
|
|
39
|
+
View,
|
|
40
|
+
ViewRoute,
|
|
41
|
+
delete,
|
|
42
|
+
get,
|
|
43
|
+
include_view,
|
|
44
|
+
patch,
|
|
45
|
+
post,
|
|
46
|
+
put,
|
|
47
|
+
route,
|
|
48
|
+
)
|
|
49
|
+
|
|
50
|
+
try:
|
|
51
|
+
__version__ = _version("fastapi-restly")
|
|
52
|
+
except _PackageNotFoundError: # pragma: no cover - only possible from an unpackaged tree
|
|
53
|
+
__version__ = "0+unknown"
|
|
54
|
+
|
|
55
|
+
# Public API surface for fastapi-restly.
|
|
56
|
+
#
|
|
57
|
+
# This top-level namespace is the primary public API. Submodule ``__all__``
|
|
58
|
+
# lists may expose additional supported advanced symbols for users working in
|
|
59
|
+
# that subsystem, such as ``from fastapi_restly.views import BaseRestView``.
|
|
60
|
+
__all__ = [
|
|
61
|
+
"__version__",
|
|
62
|
+
# Database — session context managers
|
|
63
|
+
"open_async_session",
|
|
64
|
+
"open_session",
|
|
65
|
+
# Database — FastAPI dependencies
|
|
66
|
+
"AsyncSessionDep",
|
|
67
|
+
"SessionDep",
|
|
68
|
+
# Database — engine access
|
|
69
|
+
"get_async_engine",
|
|
70
|
+
"get_engine",
|
|
71
|
+
# Database — setup & utilities
|
|
72
|
+
"configure",
|
|
73
|
+
# Exceptions
|
|
74
|
+
"RestlyError",
|
|
75
|
+
"RestlyConfigurationError",
|
|
76
|
+
# Models
|
|
77
|
+
"DataclassBase",
|
|
78
|
+
"IDBase",
|
|
79
|
+
"IDMixin",
|
|
80
|
+
"TimestampsMixin",
|
|
81
|
+
# List endpoint query parameters
|
|
82
|
+
"apply_list_params",
|
|
83
|
+
"create_list_params_schema",
|
|
84
|
+
# Schemas
|
|
85
|
+
"BaseSchema",
|
|
86
|
+
"IDRef",
|
|
87
|
+
"IDSchema",
|
|
88
|
+
"ReadOnly",
|
|
89
|
+
"WriteOnly",
|
|
90
|
+
"TimestampsSchemaMixin",
|
|
91
|
+
# Views
|
|
92
|
+
"RestView",
|
|
93
|
+
"AsyncRestView",
|
|
94
|
+
"ListingResult",
|
|
95
|
+
"AsyncReactAdminView",
|
|
96
|
+
"ReactAdminView",
|
|
97
|
+
"View",
|
|
98
|
+
"ViewRoute",
|
|
99
|
+
"delete",
|
|
100
|
+
"get",
|
|
101
|
+
"include_view",
|
|
102
|
+
"patch",
|
|
103
|
+
"post",
|
|
104
|
+
"put",
|
|
105
|
+
"route",
|
|
106
|
+
]
|
|
@@ -0,0 +1,194 @@
|
|
|
1
|
+
"""Default FastAPI exception handlers installed by fastapi-restly.
|
|
2
|
+
|
|
3
|
+
Currently this module provides a translation layer from SQLAlchemy
|
|
4
|
+
:class:`~sqlalchemy.exc.IntegrityError` (unique-constraint, foreign-key,
|
|
5
|
+
not-null, and check-constraint violations) into a clean HTTP 409 Conflict
|
|
6
|
+
response. Without this handler, an ``IntegrityError`` bubbles up to FastAPI
|
|
7
|
+
and turns into a 500 Internal Server Error, which is misleading for clients
|
|
8
|
+
(the server is fine; the request conflicts with the current state of the
|
|
9
|
+
resource).
|
|
10
|
+
|
|
11
|
+
The handler is installed automatically by :func:`fastapi_restly.configure`
|
|
12
|
+
and as a fallback by :func:`fastapi_restly.include_view`. Users can opt out
|
|
13
|
+
by calling ``fr.configure(install_default_exception_handlers=False)`` or by
|
|
14
|
+
registering their own handler for ``IntegrityError`` *before* the framework
|
|
15
|
+
gets a chance to install one.
|
|
16
|
+
|
|
17
|
+
The detail-message extraction is best-effort: it understands the most common
|
|
18
|
+
PostgreSQL SQLSTATE codes (via psycopg's ``orig.pgcode``) and the SQLite
|
|
19
|
+
error-message conventions. For unrecognised dialects/messages we fall back
|
|
20
|
+
to a generic conflict message that includes a truncated version of the
|
|
21
|
+
underlying error text.
|
|
22
|
+
"""
|
|
23
|
+
|
|
24
|
+
from __future__ import annotations
|
|
25
|
+
|
|
26
|
+
from typing import Any
|
|
27
|
+
|
|
28
|
+
from fastapi import FastAPI, Request
|
|
29
|
+
from fastapi.responses import JSONResponse
|
|
30
|
+
from sqlalchemy.exc import IntegrityError
|
|
31
|
+
|
|
32
|
+
# Maximum length of original-error text we are willing to echo back. Keeps
|
|
33
|
+
# response bodies sane and avoids accidentally leaking long SQL strings.
|
|
34
|
+
_MAX_ORIG_TEXT_LENGTH = 500
|
|
35
|
+
|
|
36
|
+
# Marker stored on ``app.state`` so we know we've already installed our
|
|
37
|
+
# handlers on this FastAPI instance. Public so it is easy to inspect from
|
|
38
|
+
# tests or user code.
|
|
39
|
+
_HANDLERS_INSTALLED_FLAG = "_fr_default_exception_handlers_installed"
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
# ---------------------------------------------------------------------------
|
|
43
|
+
# Detail extraction
|
|
44
|
+
# ---------------------------------------------------------------------------
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
# PostgreSQL SQLSTATE codes — see
|
|
48
|
+
# https://www.postgresql.org/docs/current/errcodes-appendix.html (class 23
|
|
49
|
+
# "Integrity Constraint Violation").
|
|
50
|
+
_PG_SQLSTATE_DETAILS: dict[str, str] = {
|
|
51
|
+
"23505": "Unique constraint violated",
|
|
52
|
+
"23503": "Foreign key constraint violated",
|
|
53
|
+
"23502": "Not-null constraint violated",
|
|
54
|
+
"23514": "Check constraint violated",
|
|
55
|
+
"23000": "Integrity constraint violated",
|
|
56
|
+
"23001": "Restrict violation",
|
|
57
|
+
"23P01": "Exclusion constraint violated",
|
|
58
|
+
}
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
def _extract_postgres_detail(orig: Any) -> str | None:
|
|
62
|
+
"""Return a user-facing detail message for a Postgres-driver error.
|
|
63
|
+
|
|
64
|
+
Looks at ``orig.pgcode`` (set by psycopg / psycopg2 / asyncpg-via-psycopg)
|
|
65
|
+
and, when available, ``orig.diag.constraint_name`` /
|
|
66
|
+
``orig.diag.column_name`` to enrich the message.
|
|
67
|
+
"""
|
|
68
|
+
pgcode = getattr(orig, "pgcode", None)
|
|
69
|
+
if not pgcode:
|
|
70
|
+
return None
|
|
71
|
+
|
|
72
|
+
base = _PG_SQLSTATE_DETAILS.get(pgcode)
|
|
73
|
+
if base is None:
|
|
74
|
+
return None
|
|
75
|
+
|
|
76
|
+
# ``diag`` is a psycopg-specific attribute holding fielded error info.
|
|
77
|
+
diag = getattr(orig, "diag", None)
|
|
78
|
+
constraint_name = getattr(diag, "constraint_name", None) if diag else None
|
|
79
|
+
column_name = getattr(diag, "column_name", None) if diag else None
|
|
80
|
+
|
|
81
|
+
if pgcode == "23505" and constraint_name:
|
|
82
|
+
return f"{base}: {constraint_name}"
|
|
83
|
+
if pgcode == "23503" and constraint_name:
|
|
84
|
+
return f"{base}: {constraint_name}"
|
|
85
|
+
if pgcode == "23502" and column_name:
|
|
86
|
+
return f"{base} on column {column_name!r}"
|
|
87
|
+
if pgcode == "23514" and constraint_name:
|
|
88
|
+
return f"{base}: {constraint_name}"
|
|
89
|
+
return base
|
|
90
|
+
|
|
91
|
+
|
|
92
|
+
# Mapping from SQLite error-message prefixes to a clean detail message.
|
|
93
|
+
# SQLite's IntegrityError.args[0] (and ``str(orig)``) follow predictable
|
|
94
|
+
# patterns, e.g. ``"UNIQUE constraint failed: user.username"``.
|
|
95
|
+
_SQLITE_PREFIX_DETAILS: tuple[tuple[str, str], ...] = (
|
|
96
|
+
("UNIQUE constraint failed:", "Unique constraint violated"),
|
|
97
|
+
("FOREIGN KEY constraint failed", "Foreign key constraint violated"),
|
|
98
|
+
("NOT NULL constraint failed:", "Not-null constraint violated"),
|
|
99
|
+
("CHECK constraint failed:", "Check constraint violated"),
|
|
100
|
+
("PRIMARY KEY must be unique", "Unique constraint violated (primary key)"),
|
|
101
|
+
)
|
|
102
|
+
|
|
103
|
+
|
|
104
|
+
def _extract_sqlite_detail(orig: Any) -> str | None:
|
|
105
|
+
"""Return a user-facing detail message for a SQLite-driver error."""
|
|
106
|
+
text = str(orig).strip()
|
|
107
|
+
if not text:
|
|
108
|
+
return None
|
|
109
|
+
|
|
110
|
+
for prefix, base in _SQLITE_PREFIX_DETAILS:
|
|
111
|
+
if not text.startswith(prefix):
|
|
112
|
+
continue
|
|
113
|
+
# Try to surface the column / constraint info that SQLite tacks on
|
|
114
|
+
# after the colon. ``UNIQUE constraint failed: user.username`` →
|
|
115
|
+
# ``"Unique constraint violated on user.username"``.
|
|
116
|
+
remainder = text[len(prefix) :].strip().lstrip(":").strip()
|
|
117
|
+
if remainder:
|
|
118
|
+
return f"{base} on {remainder}"
|
|
119
|
+
return base
|
|
120
|
+
return None
|
|
121
|
+
|
|
122
|
+
|
|
123
|
+
def _build_integrity_detail(exc: IntegrityError) -> str:
|
|
124
|
+
"""Build a clean HTTP 409 detail message from a SQLAlchemy IntegrityError.
|
|
125
|
+
|
|
126
|
+
Best-effort across dialects:
|
|
127
|
+
|
|
128
|
+
* PostgreSQL — switches on ``exc.orig.pgcode`` (SQLSTATE class 23).
|
|
129
|
+
* SQLite — pattern-matches ``str(exc.orig)`` against known prefixes.
|
|
130
|
+
* Anything else — returns a generic fallback that includes a truncated
|
|
131
|
+
copy of the original error text so the body is still useful for
|
|
132
|
+
debugging without being huge.
|
|
133
|
+
"""
|
|
134
|
+
orig = getattr(exc, "orig", None)
|
|
135
|
+
if orig is not None:
|
|
136
|
+
pg_detail = _extract_postgres_detail(orig)
|
|
137
|
+
if pg_detail is not None:
|
|
138
|
+
return pg_detail
|
|
139
|
+
|
|
140
|
+
sqlite_detail = _extract_sqlite_detail(orig)
|
|
141
|
+
if sqlite_detail is not None:
|
|
142
|
+
return sqlite_detail
|
|
143
|
+
|
|
144
|
+
# Generic fallback. Prefer the original driver error text (it's usually
|
|
145
|
+
# the most informative); truncate so we don't dump a giant SQL statement.
|
|
146
|
+
raw = str(orig) if orig is not None else str(exc)
|
|
147
|
+
raw = raw.strip()
|
|
148
|
+
if len(raw) > _MAX_ORIG_TEXT_LENGTH:
|
|
149
|
+
raw = raw[:_MAX_ORIG_TEXT_LENGTH] + "...(truncated)"
|
|
150
|
+
|
|
151
|
+
base = "Conflict with current state of the resource"
|
|
152
|
+
if raw:
|
|
153
|
+
return f"{base}: {raw}"
|
|
154
|
+
return base
|
|
155
|
+
|
|
156
|
+
|
|
157
|
+
# ---------------------------------------------------------------------------
|
|
158
|
+
# The handler & registration helper
|
|
159
|
+
# ---------------------------------------------------------------------------
|
|
160
|
+
|
|
161
|
+
|
|
162
|
+
def integrity_error_handler(request: Request, exc: Exception) -> JSONResponse:
|
|
163
|
+
"""Translate a SQLAlchemy IntegrityError into HTTP 409 Conflict.
|
|
164
|
+
|
|
165
|
+
Signature uses ``Exception`` rather than ``IntegrityError`` to satisfy
|
|
166
|
+
Starlette's exception-handler typing; we narrow at runtime.
|
|
167
|
+
"""
|
|
168
|
+
assert isinstance(exc, IntegrityError) # noqa: S101 - registered for IntegrityError only
|
|
169
|
+
detail = _build_integrity_detail(exc)
|
|
170
|
+
return JSONResponse(status_code=409, content={"detail": detail})
|
|
171
|
+
|
|
172
|
+
|
|
173
|
+
def register_default_exception_handlers(app: FastAPI) -> None:
|
|
174
|
+
"""Idempotently install fastapi-restly default exception handlers on ``app``.
|
|
175
|
+
|
|
176
|
+
* Skips if a handler for :class:`IntegrityError` is already registered on
|
|
177
|
+
``app`` — we always defer to the user.
|
|
178
|
+
* Skips if we have already installed handlers on this ``app`` instance
|
|
179
|
+
(so calling from both :func:`fastapi_restly.configure` and
|
|
180
|
+
:func:`fastapi_restly.include_view` is safe).
|
|
181
|
+
"""
|
|
182
|
+
if getattr(app.state, _HANDLERS_INSTALLED_FLAG, False):
|
|
183
|
+
return
|
|
184
|
+
|
|
185
|
+
# Respect a user-registered handler if one is already in place.
|
|
186
|
+
if IntegrityError in app.exception_handlers:
|
|
187
|
+
setattr(app.state, _HANDLERS_INSTALLED_FLAG, True)
|
|
188
|
+
return
|
|
189
|
+
|
|
190
|
+
app.add_exception_handler(IntegrityError, integrity_error_handler)
|
|
191
|
+
setattr(app.state, _HANDLERS_INSTALLED_FLAG, True)
|
|
192
|
+
|
|
193
|
+
|
|
194
|
+
__all__ = ["integrity_error_handler", "register_default_exception_handlers"]
|
|
@@ -0,0 +1,256 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import traceback
|
|
4
|
+
from contextlib import asynccontextmanager
|
|
5
|
+
from pathlib import Path
|
|
6
|
+
from typing import TYPE_CHECKING, AsyncIterator, Iterator
|
|
7
|
+
from unittest.mock import MagicMock, patch
|
|
8
|
+
|
|
9
|
+
import alembic
|
|
10
|
+
import alembic.command
|
|
11
|
+
import alembic.config
|
|
12
|
+
import pytest
|
|
13
|
+
from fastapi import FastAPI
|
|
14
|
+
from sqlalchemy.ext.asyncio import AsyncConnection
|
|
15
|
+
from sqlalchemy.ext.asyncio import AsyncSession as SA_AsyncSession
|
|
16
|
+
from sqlalchemy.orm import Session as SA_Session
|
|
17
|
+
|
|
18
|
+
from .db import activate_savepoint_only_mode
|
|
19
|
+
from .db._globals import _fr_globals, _get_restly_context
|
|
20
|
+
|
|
21
|
+
if TYPE_CHECKING:
|
|
22
|
+
from .testing._client import RestlyTestClient
|
|
23
|
+
|
|
24
|
+
try:
|
|
25
|
+
import pytest_asyncio
|
|
26
|
+
except ModuleNotFoundError as exc:
|
|
27
|
+
if exc.name != "pytest_asyncio":
|
|
28
|
+
raise
|
|
29
|
+
pytest_asyncio = None
|
|
30
|
+
|
|
31
|
+
_TESTING_EXTRA_MESSAGE = (
|
|
32
|
+
"fastapi_restly.pytest_fixtures requires optional testing dependencies. "
|
|
33
|
+
'Install them with: pip install "fastapi-restly[testing]"'
|
|
34
|
+
)
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
@pytest.fixture(scope="session")
|
|
38
|
+
def restly_project_root() -> Path:
|
|
39
|
+
"""Return the project root directory."""
|
|
40
|
+
# Try to find the project root by looking for pyproject.toml
|
|
41
|
+
current = Path.cwd()
|
|
42
|
+
while current != current.parent:
|
|
43
|
+
if (current / "pyproject.toml").exists():
|
|
44
|
+
return current
|
|
45
|
+
current = current.parent
|
|
46
|
+
raise Exception("Could not find a pyproject.toml to establish project root")
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
def _run_alembic_upgrade(project_root: Path) -> None:
|
|
50
|
+
# Only run alembic migrations if the alembic directory exists
|
|
51
|
+
alembic_dir = project_root / "alembic"
|
|
52
|
+
if not alembic_dir.exists():
|
|
53
|
+
return # Skip if no alembic directory
|
|
54
|
+
|
|
55
|
+
# restly_project_root owns discovery; this helper only builds Alembic config.
|
|
56
|
+
alembic_cfg = alembic.config.Config(project_root / "alembic.ini")
|
|
57
|
+
alembic_cfg.set_main_option("script_location", str(alembic_dir))
|
|
58
|
+
try:
|
|
59
|
+
alembic.command.upgrade(alembic_cfg, "head")
|
|
60
|
+
except Exception as exc:
|
|
61
|
+
tb = traceback.format_exc()
|
|
62
|
+
pytest.exit(
|
|
63
|
+
f"Alembic migrations failed: {exc}\n\nTraceback:\n{tb}", returncode=1
|
|
64
|
+
)
|
|
65
|
+
|
|
66
|
+
|
|
67
|
+
def _activate_savepoint_only_mode_sessions() -> None:
|
|
68
|
+
# Only run if database connections are set up
|
|
69
|
+
if not _fr_globals.async_make_session and not _fr_globals.make_session:
|
|
70
|
+
return # Skip if no database connections
|
|
71
|
+
|
|
72
|
+
if _fr_globals.async_make_session:
|
|
73
|
+
activate_savepoint_only_mode(_fr_globals.async_make_session)
|
|
74
|
+
if _fr_globals.make_session:
|
|
75
|
+
activate_savepoint_only_mode(_fr_globals.make_session)
|
|
76
|
+
|
|
77
|
+
|
|
78
|
+
@pytest.fixture
|
|
79
|
+
def _shared_connection():
|
|
80
|
+
# Sync tests need a sync sessionmaker, but async-only projects should still
|
|
81
|
+
# be able to use the restly_async_session fixture without one.
|
|
82
|
+
if not _fr_globals.make_session:
|
|
83
|
+
yield None
|
|
84
|
+
return
|
|
85
|
+
|
|
86
|
+
engine = _fr_globals.make_session.kw["bind"]
|
|
87
|
+
with engine.connect() as conn:
|
|
88
|
+
yield conn
|
|
89
|
+
|
|
90
|
+
|
|
91
|
+
if pytest_asyncio is None:
|
|
92
|
+
|
|
93
|
+
@pytest.fixture
|
|
94
|
+
def restly_async_session(_shared_connection) -> None: # pyright: ignore[reportRedeclaration]
|
|
95
|
+
# The else-branch defines the real async fixture; this stub only
|
|
96
|
+
# runs when the optional ``pytest_asyncio`` extra isn't installed.
|
|
97
|
+
# Pyright cannot model mutually exclusive module-level branches.
|
|
98
|
+
raise ModuleNotFoundError(_TESTING_EXTRA_MESSAGE, name="pytest_asyncio")
|
|
99
|
+
|
|
100
|
+
else:
|
|
101
|
+
|
|
102
|
+
@pytest_asyncio.fixture
|
|
103
|
+
async def restly_async_session(_shared_connection) -> AsyncIterator[SA_AsyncSession]:
|
|
104
|
+
"""
|
|
105
|
+
Pytest fixture providing a database session with savepoint-based isolation.
|
|
106
|
+
|
|
107
|
+
Each test runs inside a savepoint. At the end of the test, the savepoint is
|
|
108
|
+
rolled back, leaving the database clean for the next test.
|
|
109
|
+
|
|
110
|
+
NOTE: Calling session.rollback() inside a test rolls back to the last savepoint
|
|
111
|
+
(created by each patched commit()), NOT to the start of the test. This differs
|
|
112
|
+
from production behavior. To undo all changes in a test, use session.rollback()
|
|
113
|
+
after each commit(), but be aware that data added before the last commit() is
|
|
114
|
+
still visible.
|
|
115
|
+
"""
|
|
116
|
+
# Only run if database connections are set up
|
|
117
|
+
if not _fr_globals.async_make_session:
|
|
118
|
+
pytest.skip("Database connection not set up")
|
|
119
|
+
|
|
120
|
+
async_engine = _fr_globals.async_make_session.kw["bind"]
|
|
121
|
+
|
|
122
|
+
@asynccontextmanager
|
|
123
|
+
async def get_bound_async_connection():
|
|
124
|
+
if _shared_connection is None:
|
|
125
|
+
async with async_engine.connect() as async_conn:
|
|
126
|
+
yield async_conn
|
|
127
|
+
return
|
|
128
|
+
|
|
129
|
+
async_conn = AsyncConnection(
|
|
130
|
+
async_engine, sync_connection=_shared_connection
|
|
131
|
+
)
|
|
132
|
+
async with async_conn:
|
|
133
|
+
yield async_conn
|
|
134
|
+
|
|
135
|
+
async with get_bound_async_connection() as async_conn:
|
|
136
|
+
async with _fr_globals.async_make_session(bind=async_conn) as sess:
|
|
137
|
+
class AsyncSessionContext:
|
|
138
|
+
def __init__(self, *, flush_on_success: bool) -> None:
|
|
139
|
+
self.flush_on_success = flush_on_success
|
|
140
|
+
|
|
141
|
+
async def __aenter__(self):
|
|
142
|
+
await sess.begin_nested()
|
|
143
|
+
return sess
|
|
144
|
+
|
|
145
|
+
async def __aexit__(self, exc_type, exc_value, tb):
|
|
146
|
+
if self.flush_on_success and exc_type is None:
|
|
147
|
+
await sess.flush()
|
|
148
|
+
return False # re-raise any exception
|
|
149
|
+
|
|
150
|
+
mock_sessionmaker = MagicMock()
|
|
151
|
+
mock_sessionmaker.side_effect = lambda *args, **kwargs: (
|
|
152
|
+
AsyncSessionContext(flush_on_success=False)
|
|
153
|
+
)
|
|
154
|
+
# session.begin() is used as a context manager (async with
|
|
155
|
+
# session.begin():). Return the same isolated session and flush
|
|
156
|
+
# pending changes after successful explicit transaction blocks.
|
|
157
|
+
mock_sessionmaker.begin.side_effect = lambda *args, **kwargs: (
|
|
158
|
+
AsyncSessionContext(flush_on_success=True)
|
|
159
|
+
)
|
|
160
|
+
|
|
161
|
+
async def passthrough_exit(self, exc_type, exc_value, traceback):
|
|
162
|
+
await sess.flush()
|
|
163
|
+
return False # re-raise any exception
|
|
164
|
+
|
|
165
|
+
async def patched_commit(self):
|
|
166
|
+
await sess.flush()
|
|
167
|
+
await sess.begin_nested()
|
|
168
|
+
|
|
169
|
+
globals_obj = _get_restly_context()
|
|
170
|
+
original_async_make_session = globals_obj.async_make_session
|
|
171
|
+
globals_obj.async_make_session = mock_sessionmaker
|
|
172
|
+
try:
|
|
173
|
+
with (
|
|
174
|
+
patch.object(SA_AsyncSession, "__aexit__", passthrough_exit),
|
|
175
|
+
patch.object(SA_AsyncSession, "commit", patched_commit),
|
|
176
|
+
):
|
|
177
|
+
yield sess
|
|
178
|
+
finally:
|
|
179
|
+
globals_obj.async_make_session = original_async_make_session
|
|
180
|
+
|
|
181
|
+
|
|
182
|
+
@pytest.fixture
|
|
183
|
+
def restly_session(_shared_connection) -> Iterator[SA_Session]:
|
|
184
|
+
"""
|
|
185
|
+
Pytest fixture providing a database session with savepoint-based isolation.
|
|
186
|
+
|
|
187
|
+
Each test runs inside a savepoint. At the end of the test, the savepoint is
|
|
188
|
+
rolled back, leaving the database clean for the next test.
|
|
189
|
+
|
|
190
|
+
NOTE: Calling session.rollback() inside a test rolls back to the last savepoint
|
|
191
|
+
(created by each patched commit()), NOT to the start of the test. This differs
|
|
192
|
+
from production behavior. To undo all changes in a test, use session.rollback()
|
|
193
|
+
after each commit(), but be aware that data added before the last commit() is
|
|
194
|
+
still visible.
|
|
195
|
+
"""
|
|
196
|
+
# Only run if database connections are set up
|
|
197
|
+
if not _fr_globals.make_session:
|
|
198
|
+
pytest.skip("Database connection not set up")
|
|
199
|
+
|
|
200
|
+
with _fr_globals.make_session(bind=_shared_connection) as sess:
|
|
201
|
+
|
|
202
|
+
def begin_nested():
|
|
203
|
+
sess.begin_nested()
|
|
204
|
+
return sess
|
|
205
|
+
|
|
206
|
+
mock_sessionmaker = MagicMock()
|
|
207
|
+
mock_sessionmaker.side_effect = begin_nested
|
|
208
|
+
# session.begin() is used as a context manager (with session.begin():)
|
|
209
|
+
# We need it to also return our savepoint session so explicit transaction
|
|
210
|
+
# blocks work correctly with our isolation mechanism
|
|
211
|
+
mock_sessionmaker.begin.return_value.__enter__.side_effect = begin_nested
|
|
212
|
+
|
|
213
|
+
def exit_nested(exc_type, exc_value, tb):
|
|
214
|
+
if exc_type is None:
|
|
215
|
+
sess.flush()
|
|
216
|
+
return False # re-raise any exception
|
|
217
|
+
|
|
218
|
+
def passthrough_exit(self, exc_type, exc_value, traceback):
|
|
219
|
+
sess.flush()
|
|
220
|
+
return False # re-raise any exception
|
|
221
|
+
|
|
222
|
+
def patched_commit(self):
|
|
223
|
+
sess.flush()
|
|
224
|
+
sess.begin_nested()
|
|
225
|
+
|
|
226
|
+
globals_obj = _get_restly_context()
|
|
227
|
+
original_make_session = globals_obj.make_session
|
|
228
|
+
globals_obj.make_session = mock_sessionmaker
|
|
229
|
+
try:
|
|
230
|
+
with (
|
|
231
|
+
patch.object(SA_Session, "__exit__", passthrough_exit),
|
|
232
|
+
patch.object(SA_Session, "commit", patched_commit),
|
|
233
|
+
):
|
|
234
|
+
mock_sessionmaker.begin.return_value.__exit__.side_effect = exit_nested
|
|
235
|
+
yield sess
|
|
236
|
+
finally:
|
|
237
|
+
globals_obj.make_session = original_make_session
|
|
238
|
+
|
|
239
|
+
|
|
240
|
+
@pytest.fixture
|
|
241
|
+
def restly_app() -> FastAPI:
|
|
242
|
+
"""Create a FastAPI app instance for testing."""
|
|
243
|
+
return FastAPI()
|
|
244
|
+
|
|
245
|
+
|
|
246
|
+
@pytest.fixture
|
|
247
|
+
def restly_client(restly_app) -> RestlyTestClient:
|
|
248
|
+
"""Create a RestlyTestClient instance for testing."""
|
|
249
|
+
try:
|
|
250
|
+
from .testing._client import RestlyTestClient
|
|
251
|
+
except ModuleNotFoundError as exc:
|
|
252
|
+
if exc.name == "httpx":
|
|
253
|
+
raise ModuleNotFoundError(_TESTING_EXTRA_MESSAGE, name="httpx") from exc
|
|
254
|
+
raise
|
|
255
|
+
|
|
256
|
+
return RestlyTestClient(restly_app)
|
|
@@ -0,0 +1,31 @@
|
|
|
1
|
+
from ._proxy import open_async_session, open_session
|
|
2
|
+
from ._session import (
|
|
3
|
+
AsyncSessionDep,
|
|
4
|
+
SessionDep,
|
|
5
|
+
activate_savepoint_only_mode,
|
|
6
|
+
configure,
|
|
7
|
+
deactivate_savepoint_only_mode,
|
|
8
|
+
get_async_engine,
|
|
9
|
+
get_engine,
|
|
10
|
+
)
|
|
11
|
+
|
|
12
|
+
# Public API for ``fastapi_restly.db``.
|
|
13
|
+
#
|
|
14
|
+
# Session generator internals live in private modules; use ``configure`` to
|
|
15
|
+
# configure the process-wide Restly runtime state.
|
|
16
|
+
__all__ = [
|
|
17
|
+
# Session context managers
|
|
18
|
+
"open_async_session",
|
|
19
|
+
"open_session",
|
|
20
|
+
# FastAPI dependencies
|
|
21
|
+
"AsyncSessionDep",
|
|
22
|
+
"SessionDep",
|
|
23
|
+
# Engine access
|
|
24
|
+
"get_async_engine",
|
|
25
|
+
"get_engine",
|
|
26
|
+
# Setup
|
|
27
|
+
"configure",
|
|
28
|
+
# Savepoint mode
|
|
29
|
+
"activate_savepoint_only_mode",
|
|
30
|
+
"deactivate_savepoint_only_mode",
|
|
31
|
+
]
|
|
@@ -0,0 +1,76 @@
|
|
|
1
|
+
from collections.abc import AsyncIterator, Callable, Iterator
|
|
2
|
+
from contextvars import ContextVar, Token
|
|
3
|
+
from typing import Any
|
|
4
|
+
|
|
5
|
+
from sqlalchemy.ext.asyncio import AsyncSession as SA_AsyncSession
|
|
6
|
+
from sqlalchemy.ext.asyncio import async_sessionmaker
|
|
7
|
+
from sqlalchemy.orm import Session as SA_Session
|
|
8
|
+
from sqlalchemy.orm import sessionmaker
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class RestlyContext:
|
|
12
|
+
"""Private container for Restly runtime state used by tests and internals."""
|
|
13
|
+
|
|
14
|
+
__slots__ = (
|
|
15
|
+
"async_database_url",
|
|
16
|
+
"async_make_session",
|
|
17
|
+
"commit_session_on_response",
|
|
18
|
+
"database_url",
|
|
19
|
+
"make_session",
|
|
20
|
+
"session_generator",
|
|
21
|
+
"sync_session_generator",
|
|
22
|
+
)
|
|
23
|
+
|
|
24
|
+
async_database_url: str | None
|
|
25
|
+
async_make_session: async_sessionmaker[Any] | None
|
|
26
|
+
commit_session_on_response: bool
|
|
27
|
+
database_url: str | None
|
|
28
|
+
make_session: sessionmaker[Any] | None
|
|
29
|
+
session_generator: Callable[[], AsyncIterator[SA_AsyncSession]] | None
|
|
30
|
+
sync_session_generator: Callable[[], Iterator[SA_Session]] | None
|
|
31
|
+
|
|
32
|
+
def __init__(self) -> None:
|
|
33
|
+
self.async_database_url = None
|
|
34
|
+
self.async_make_session = None
|
|
35
|
+
self.commit_session_on_response = True
|
|
36
|
+
self.database_url = None
|
|
37
|
+
self.make_session = None
|
|
38
|
+
self.session_generator = None
|
|
39
|
+
self.sync_session_generator = None
|
|
40
|
+
|
|
41
|
+
def __enter__(self) -> "RestlyContext":
|
|
42
|
+
token = _restly_context_ctx.set(self)
|
|
43
|
+
_restly_context_token_stack.set(_restly_context_token_stack.get() + (token,))
|
|
44
|
+
return self
|
|
45
|
+
|
|
46
|
+
def __exit__(self, *exc_info: object) -> None:
|
|
47
|
+
token_stack = _restly_context_token_stack.get()
|
|
48
|
+
if not token_stack:
|
|
49
|
+
raise RuntimeError("RestlyContext was exited without being entered.")
|
|
50
|
+
token = token_stack[-1]
|
|
51
|
+
_restly_context_token_stack.set(token_stack[:-1])
|
|
52
|
+
_restly_context_ctx.reset(token)
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
_default_context = RestlyContext()
|
|
56
|
+
_restly_context_ctx: ContextVar[RestlyContext | None] = ContextVar(
|
|
57
|
+
"fastapi_restly_context", default=None
|
|
58
|
+
)
|
|
59
|
+
_restly_context_token_stack: ContextVar[tuple[Token[RestlyContext | None], ...]] = (
|
|
60
|
+
ContextVar("fastapi_restly_context_token_stack", default=())
|
|
61
|
+
)
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
def _get_restly_context() -> RestlyContext:
|
|
65
|
+
return _restly_context_ctx.get() or _default_context
|
|
66
|
+
|
|
67
|
+
|
|
68
|
+
class _FRGlobalsProxy:
|
|
69
|
+
def __getattr__(self, name: str):
|
|
70
|
+
return getattr(_get_restly_context(), name)
|
|
71
|
+
|
|
72
|
+
def __setattr__(self, name: str, value):
|
|
73
|
+
setattr(_get_restly_context(), name, value)
|
|
74
|
+
|
|
75
|
+
|
|
76
|
+
_fr_globals = _FRGlobalsProxy()
|