fastapi-sqla 2.10.0__py3-none-any.whl → 3.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of fastapi-sqla might be problematic. Click here for more details.
- fastapi_sqla/__init__.py +16 -3
- fastapi_sqla/_pytest_plugin.py +46 -31
- fastapi_sqla/async_pagination.py +16 -12
- fastapi_sqla/async_sqla.py +99 -50
- fastapi_sqla/aws_rds_iam_support.py +5 -3
- fastapi_sqla/base.py +37 -13
- fastapi_sqla/pagination.py +12 -11
- fastapi_sqla/sqla.py +66 -44
- {fastapi_sqla-2.10.0.dist-info → fastapi_sqla-3.0.0.dist-info}/METADATA +165 -28
- fastapi_sqla-3.0.0.dist-info/RECORD +16 -0
- fastapi_sqla-2.10.0.dist-info/RECORD +0 -16
- {fastapi_sqla-2.10.0.dist-info → fastapi_sqla-3.0.0.dist-info}/LICENSE +0 -0
- {fastapi_sqla-2.10.0.dist-info → fastapi_sqla-3.0.0.dist-info}/WHEEL +0 -0
- {fastapi_sqla-2.10.0.dist-info → fastapi_sqla-3.0.0.dist-info}/entry_points.txt +0 -0
fastapi_sqla/__init__.py
CHANGED
|
@@ -1,7 +1,7 @@
|
|
|
1
1
|
from fastapi_sqla.base import setup
|
|
2
2
|
from fastapi_sqla.models import Base, Collection, Item, Page
|
|
3
3
|
from fastapi_sqla.pagination import Paginate, PaginateSignature, Pagination
|
|
4
|
-
from fastapi_sqla.sqla import Session, open_session
|
|
4
|
+
from fastapi_sqla.sqla import Session, SessionDependency, SqlaSession, open_session
|
|
5
5
|
|
|
6
6
|
__all__ = [
|
|
7
7
|
"Base",
|
|
@@ -12,20 +12,33 @@ __all__ = [
|
|
|
12
12
|
"PaginateSignature",
|
|
13
13
|
"Pagination",
|
|
14
14
|
"Session",
|
|
15
|
+
"SessionDependency",
|
|
16
|
+
"SqlaSession",
|
|
15
17
|
"open_session",
|
|
16
18
|
"setup",
|
|
17
19
|
]
|
|
18
20
|
|
|
19
21
|
|
|
20
22
|
try:
|
|
21
|
-
from fastapi_sqla.async_pagination import
|
|
22
|
-
|
|
23
|
+
from fastapi_sqla.async_pagination import (
|
|
24
|
+
AsyncPaginate,
|
|
25
|
+
AsyncPaginateSignature,
|
|
26
|
+
AsyncPagination,
|
|
27
|
+
)
|
|
28
|
+
from fastapi_sqla.async_sqla import (
|
|
29
|
+
AsyncSession,
|
|
30
|
+
AsyncSessionDependency,
|
|
31
|
+
SqlaAsyncSession,
|
|
32
|
+
)
|
|
23
33
|
from fastapi_sqla.async_sqla import open_session as open_async_session
|
|
24
34
|
|
|
25
35
|
__all__ += [
|
|
26
36
|
"AsyncPaginate",
|
|
37
|
+
"AsyncPaginateSignature",
|
|
27
38
|
"AsyncPagination",
|
|
28
39
|
"AsyncSession",
|
|
40
|
+
"AsyncSessionDependency",
|
|
41
|
+
"SqlaAsyncSession",
|
|
29
42
|
"open_async_session",
|
|
30
43
|
]
|
|
31
44
|
has_asyncio_support = True
|
fastapi_sqla/_pytest_plugin.py
CHANGED
|
@@ -6,6 +6,7 @@ from alembic import command
|
|
|
6
6
|
from alembic.config import Config
|
|
7
7
|
from pytest import fixture
|
|
8
8
|
from sqlalchemy import create_engine, text
|
|
9
|
+
from sqlalchemy.orm.session import sessionmaker
|
|
9
10
|
|
|
10
11
|
try:
|
|
11
12
|
import asyncpg # noqa
|
|
@@ -52,8 +53,12 @@ def db_url(db_host, db_user):
|
|
|
52
53
|
|
|
53
54
|
|
|
54
55
|
@fixture(scope="session")
|
|
55
|
-
def
|
|
56
|
-
|
|
56
|
+
def engine(db_url):
|
|
57
|
+
return create_engine(db_url)
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
@fixture(scope="session")
|
|
61
|
+
def sqla_connection(engine):
|
|
57
62
|
with engine.connect() as connection:
|
|
58
63
|
yield connection
|
|
59
64
|
|
|
@@ -92,7 +97,7 @@ def sqla_modules():
|
|
|
92
97
|
|
|
93
98
|
|
|
94
99
|
@fixture
|
|
95
|
-
def sqla_reflection(sqla_modules, sqla_connection
|
|
100
|
+
def sqla_reflection(sqla_modules, sqla_connection):
|
|
96
101
|
import fastapi_sqla
|
|
97
102
|
|
|
98
103
|
fastapi_sqla.Base.metadata.bind = sqla_connection
|
|
@@ -100,17 +105,15 @@ def sqla_reflection(sqla_modules, sqla_connection, db_url):
|
|
|
100
105
|
|
|
101
106
|
|
|
102
107
|
@fixture
|
|
103
|
-
def patch_engine_from_config(request,
|
|
108
|
+
def patch_engine_from_config(request, sqla_connection, sqla_transaction):
|
|
104
109
|
"""So that all DB operations are never written to db for real."""
|
|
105
|
-
from fastapi_sqla.sqla import _Session
|
|
106
110
|
|
|
107
|
-
if "dont_patch_engines" in request.keywords:
|
|
111
|
+
if "dont_patch_engines" in request.keywords: # pragma: no cover
|
|
108
112
|
yield
|
|
109
113
|
|
|
110
114
|
else:
|
|
111
115
|
with patch("fastapi_sqla.sqla.engine_from_config") as engine_from_config:
|
|
112
116
|
engine_from_config.return_value = sqla_connection
|
|
113
|
-
_Session.configure(bind=sqla_connection)
|
|
114
117
|
yield engine_from_config
|
|
115
118
|
|
|
116
119
|
|
|
@@ -121,18 +124,27 @@ def sqla_transaction(sqla_connection):
|
|
|
121
124
|
transaction.rollback()
|
|
122
125
|
|
|
123
126
|
|
|
127
|
+
@fixture
|
|
128
|
+
def session_factory():
|
|
129
|
+
return sessionmaker()
|
|
130
|
+
|
|
131
|
+
|
|
124
132
|
@fixture
|
|
125
133
|
def session(
|
|
126
|
-
|
|
134
|
+
session_factory,
|
|
135
|
+
sqla_connection,
|
|
136
|
+
sqla_transaction,
|
|
137
|
+
sqla_reflection,
|
|
138
|
+
patch_engine_from_config,
|
|
127
139
|
):
|
|
128
140
|
"""Sqla session to use when creating db fixtures.
|
|
129
141
|
|
|
130
142
|
While it does not write any record in DB, the application will still be able to
|
|
131
143
|
access any record committed with that session.
|
|
132
144
|
"""
|
|
133
|
-
|
|
134
|
-
|
|
135
|
-
|
|
145
|
+
session = session_factory(bind=sqla_connection)
|
|
146
|
+
yield session
|
|
147
|
+
session.close()
|
|
136
148
|
|
|
137
149
|
|
|
138
150
|
def format_async_async_sqlalchemy_url(url):
|
|
@@ -149,32 +161,33 @@ def async_sqlalchemy_url(db_url):
|
|
|
149
161
|
return format_async_async_sqlalchemy_url(db_url)
|
|
150
162
|
|
|
151
163
|
|
|
152
|
-
if asyncio_support:
|
|
164
|
+
if asyncio_support: # noqa: C901
|
|
153
165
|
|
|
154
166
|
@fixture
|
|
155
|
-
|
|
167
|
+
def async_engine(async_sqlalchemy_url):
|
|
156
168
|
return create_async_engine(async_sqlalchemy_url)
|
|
157
169
|
|
|
158
170
|
@fixture
|
|
159
171
|
async def async_sqla_connection(async_engine, event_loop):
|
|
160
|
-
async with async_engine.
|
|
172
|
+
async with async_engine.connect() as connection:
|
|
161
173
|
yield connection
|
|
162
|
-
await connection.rollback()
|
|
163
174
|
|
|
164
175
|
@fixture
|
|
165
|
-
async def
|
|
176
|
+
async def async_sqla_transaction(async_sqla_connection):
|
|
177
|
+
async with async_sqla_connection.begin() as transaction:
|
|
178
|
+
yield transaction
|
|
179
|
+
await transaction.rollback()
|
|
180
|
+
|
|
181
|
+
@fixture
|
|
182
|
+
async def patch_new_engine(request, async_sqla_connection, async_sqla_transaction):
|
|
166
183
|
"""So that all async DB operations are never written to db for real."""
|
|
167
|
-
from fastapi_sqla.async_sqla import _AsyncSession
|
|
168
184
|
|
|
169
|
-
if "dont_patch_engines" in request.keywords:
|
|
185
|
+
if "dont_patch_engines" in request.keywords: # pragma: no cover
|
|
170
186
|
yield
|
|
171
187
|
|
|
172
188
|
else:
|
|
173
189
|
with patch("fastapi_sqla.async_sqla.new_engine") as new_engine:
|
|
174
190
|
new_engine.return_value = async_sqla_connection
|
|
175
|
-
_AsyncSession.configure(
|
|
176
|
-
bind=async_sqla_connection, expire_on_commit=False
|
|
177
|
-
)
|
|
178
191
|
yield new_engine
|
|
179
192
|
|
|
180
193
|
@fixture
|
|
@@ -183,18 +196,20 @@ if asyncio_support:
|
|
|
183
196
|
|
|
184
197
|
await async_sqla_connection.run_sync(lambda conn: Base.prepare(conn.engine))
|
|
185
198
|
|
|
199
|
+
@fixture
|
|
200
|
+
def async_session_factory():
|
|
201
|
+
from fastapi_sqla.async_sqla import SqlaAsyncSession
|
|
202
|
+
|
|
203
|
+
return sessionmaker(class_=SqlaAsyncSession)
|
|
204
|
+
|
|
186
205
|
@fixture
|
|
187
206
|
async def async_session(
|
|
188
|
-
|
|
207
|
+
async_session_factory,
|
|
208
|
+
async_sqla_connection,
|
|
209
|
+
async_sqla_transaction,
|
|
210
|
+
async_sqla_reflection,
|
|
211
|
+
patch_new_engine,
|
|
189
212
|
):
|
|
190
|
-
|
|
191
|
-
|
|
192
|
-
session = _AsyncSession(bind=async_sqla_connection)
|
|
213
|
+
session = async_session_factory(bind=async_sqla_connection)
|
|
193
214
|
yield session
|
|
194
215
|
await session.close()
|
|
195
|
-
|
|
196
|
-
else:
|
|
197
|
-
|
|
198
|
-
@fixture
|
|
199
|
-
async def patch_new_engine():
|
|
200
|
-
pass
|
fastapi_sqla/async_pagination.py
CHANGED
|
@@ -1,28 +1,31 @@
|
|
|
1
1
|
import math
|
|
2
2
|
from collections.abc import Awaitable, Callable
|
|
3
|
-
from typing import Iterator, Optional, Union, cast
|
|
3
|
+
from typing import Annotated, Iterator, Optional, Union, cast
|
|
4
4
|
|
|
5
5
|
from fastapi import Depends, Query
|
|
6
6
|
from sqlalchemy.sql import Select, func, select
|
|
7
7
|
|
|
8
|
-
from fastapi_sqla.async_sqla import
|
|
8
|
+
from fastapi_sqla.async_sqla import AsyncSessionDependency, SqlaAsyncSession
|
|
9
9
|
from fastapi_sqla.models import Page
|
|
10
|
+
from fastapi_sqla.sqla import _DEFAULT_SESSION_KEY
|
|
10
11
|
|
|
11
12
|
QueryCountDependency = Callable[..., Awaitable[int]]
|
|
12
|
-
|
|
13
|
-
DefaultDependency = Callable[[
|
|
14
|
-
WithQueryCountDependency = Callable[
|
|
13
|
+
AsyncPaginateSignature = Callable[[Select, Optional[bool]], Awaitable[Page]]
|
|
14
|
+
DefaultDependency = Callable[[SqlaAsyncSession, int, int], AsyncPaginateSignature]
|
|
15
|
+
WithQueryCountDependency = Callable[
|
|
16
|
+
[SqlaAsyncSession, int, int, int], AsyncPaginateSignature
|
|
17
|
+
]
|
|
15
18
|
PaginateDependency = Union[DefaultDependency, WithQueryCountDependency]
|
|
16
19
|
|
|
17
20
|
|
|
18
|
-
async def default_query_count(session:
|
|
21
|
+
async def default_query_count(session: SqlaAsyncSession, query: Select) -> int:
|
|
19
22
|
result = await session.execute(select(func.count()).select_from(query.subquery()))
|
|
20
23
|
return cast(int, result.scalar())
|
|
21
24
|
|
|
22
25
|
|
|
23
26
|
async def paginate_query(
|
|
24
27
|
query: Select,
|
|
25
|
-
session:
|
|
28
|
+
session: SqlaAsyncSession,
|
|
26
29
|
total_items: int,
|
|
27
30
|
offset: int,
|
|
28
31
|
limit: int,
|
|
@@ -48,15 +51,16 @@ async def paginate_query(
|
|
|
48
51
|
|
|
49
52
|
|
|
50
53
|
def AsyncPagination(
|
|
54
|
+
session_key: str = _DEFAULT_SESSION_KEY,
|
|
51
55
|
min_page_size: int = 10,
|
|
52
56
|
max_page_size: int = 100,
|
|
53
57
|
query_count: Union[QueryCountDependency, None] = None,
|
|
54
58
|
) -> PaginateDependency:
|
|
55
59
|
def default_dependency(
|
|
56
|
-
session:
|
|
60
|
+
session: SqlaAsyncSession = Depends(AsyncSessionDependency(key=session_key)),
|
|
57
61
|
offset: int = Query(0, ge=0),
|
|
58
62
|
limit: int = Query(min_page_size, ge=1, le=max_page_size),
|
|
59
|
-
) ->
|
|
63
|
+
) -> AsyncPaginateSignature:
|
|
60
64
|
async def paginate(query: Select, scalars=True) -> Page:
|
|
61
65
|
total_items = await default_query_count(session, query)
|
|
62
66
|
return await paginate_query(
|
|
@@ -66,11 +70,11 @@ def AsyncPagination(
|
|
|
66
70
|
return paginate
|
|
67
71
|
|
|
68
72
|
def with_query_count_dependency(
|
|
69
|
-
session:
|
|
73
|
+
session: SqlaAsyncSession = Depends(AsyncSessionDependency(key=session_key)),
|
|
70
74
|
offset: int = Query(0, ge=0),
|
|
71
75
|
limit: int = Query(min_page_size, ge=1, le=max_page_size),
|
|
72
76
|
total_items: int = Depends(query_count),
|
|
73
|
-
):
|
|
77
|
+
) -> AsyncPaginateSignature:
|
|
74
78
|
async def paginate(query: Select, scalars=True) -> Page:
|
|
75
79
|
return await paginate_query(
|
|
76
80
|
query, session, total_items, offset, limit, scalars=scalars
|
|
@@ -84,4 +88,4 @@ def AsyncPagination(
|
|
|
84
88
|
return default_dependency
|
|
85
89
|
|
|
86
90
|
|
|
87
|
-
AsyncPaginate
|
|
91
|
+
AsyncPaginate = Annotated[AsyncPaginateSignature, Depends(AsyncPagination())]
|
fastapi_sqla/async_sqla.py
CHANGED
|
@@ -1,10 +1,10 @@
|
|
|
1
|
-
import os
|
|
2
1
|
from collections.abc import AsyncGenerator
|
|
3
2
|
from contextlib import asynccontextmanager
|
|
4
|
-
from typing import
|
|
3
|
+
from typing import Annotated
|
|
5
4
|
|
|
6
5
|
import structlog
|
|
7
|
-
from fastapi import Request
|
|
6
|
+
from fastapi import Depends, Request
|
|
7
|
+
from fastapi.responses import PlainTextResponse
|
|
8
8
|
from sqlalchemy import text
|
|
9
9
|
from sqlalchemy.ext.asyncio import AsyncEngine
|
|
10
10
|
from sqlalchemy.ext.asyncio import AsyncSession as SqlaAsyncSession
|
|
@@ -12,93 +12,86 @@ from sqlalchemy.orm.session import sessionmaker
|
|
|
12
12
|
|
|
13
13
|
from fastapi_sqla import aws_aurora_support, aws_rds_iam_support
|
|
14
14
|
from fastapi_sqla.models import Base
|
|
15
|
-
from fastapi_sqla.sqla import new_engine
|
|
15
|
+
from fastapi_sqla.sqla import _DEFAULT_SESSION_KEY, new_engine
|
|
16
16
|
|
|
17
17
|
logger = structlog.get_logger(__name__)
|
|
18
|
-
_ASYNC_SESSION_KEY = "fastapi_sqla_async_session"
|
|
19
|
-
_AsyncSession = sessionmaker(class_=SqlaAsyncSession)
|
|
20
18
|
|
|
19
|
+
_ASYNC_REQUEST_SESSION_KEY = "fastapi_sqla_async_session"
|
|
20
|
+
_async_session_factories: dict[str, sessionmaker] = {}
|
|
21
21
|
|
|
22
|
-
def new_async_engine():
|
|
23
|
-
envvar_prefix = None
|
|
24
|
-
if "async_sqlalchemy_url" in os.environ:
|
|
25
|
-
envvar_prefix = "async_sqlalchemy_"
|
|
26
22
|
|
|
27
|
-
|
|
23
|
+
def new_async_engine(key: str = _DEFAULT_SESSION_KEY):
|
|
24
|
+
engine = new_engine(key)
|
|
28
25
|
return AsyncEngine(engine)
|
|
29
26
|
|
|
30
27
|
|
|
31
|
-
async def startup():
|
|
32
|
-
engine = new_async_engine()
|
|
28
|
+
async def startup(key: str = _DEFAULT_SESSION_KEY):
|
|
29
|
+
engine = new_async_engine(key)
|
|
33
30
|
aws_rds_iam_support.setup(engine.sync_engine)
|
|
34
31
|
aws_aurora_support.setup(engine.sync_engine)
|
|
35
32
|
|
|
36
|
-
# Fail early
|
|
33
|
+
# Fail early
|
|
37
34
|
try:
|
|
38
35
|
async with engine.connect() as connection:
|
|
39
36
|
await connection.execute(text("select 'ok'"))
|
|
40
37
|
except Exception:
|
|
41
38
|
logger.critical(
|
|
42
|
-
"Failed querying db
|
|
43
|
-
"correctly configured?"
|
|
39
|
+
f"Failed querying db for key '{key}': "
|
|
40
|
+
"are the the environment variables correctly configured for this key?"
|
|
44
41
|
)
|
|
45
42
|
raise
|
|
46
43
|
|
|
47
44
|
async with engine.connect() as connection:
|
|
48
45
|
await connection.run_sync(lambda conn: Base.prepare(conn.engine))
|
|
49
46
|
|
|
50
|
-
|
|
51
|
-
|
|
47
|
+
_async_session_factories[key] = sessionmaker(
|
|
48
|
+
class_=SqlaAsyncSession, bind=engine, expire_on_commit=False
|
|
49
|
+
)
|
|
52
50
|
|
|
53
|
-
|
|
54
|
-
class AsyncSession(SqlaAsyncSession):
|
|
55
|
-
def __new__(cls, request: Request):
|
|
56
|
-
"""Yield the sqlalchmey async session for that request.
|
|
57
|
-
|
|
58
|
-
It is meant to be used as a FastAPI dependency::
|
|
59
|
-
|
|
60
|
-
from fastapi import APIRouter, Depends
|
|
61
|
-
from fastapi_sqla import AsyncSession
|
|
62
|
-
|
|
63
|
-
router = APIRouter()
|
|
64
|
-
|
|
65
|
-
@router.get("/users")
|
|
66
|
-
async def get_users(session: AsyncSession = Depends()):
|
|
67
|
-
pass
|
|
68
|
-
"""
|
|
69
|
-
try:
|
|
70
|
-
return request.scope[_ASYNC_SESSION_KEY]
|
|
71
|
-
except KeyError: # pragma: no cover
|
|
72
|
-
raise Exception(
|
|
73
|
-
"No async session found in request, please ensure you've setup "
|
|
74
|
-
"fastapi_sqla."
|
|
75
|
-
)
|
|
51
|
+
logger.info("engine startup", engine_key=key, async_engine=engine)
|
|
76
52
|
|
|
77
53
|
|
|
78
54
|
@asynccontextmanager
|
|
79
|
-
async def open_session(
|
|
55
|
+
async def open_session(
|
|
56
|
+
key: str = _DEFAULT_SESSION_KEY,
|
|
57
|
+
) -> AsyncGenerator[SqlaAsyncSession, None]:
|
|
80
58
|
"""Context manager to open an async session and properly close it when exiting.
|
|
81
59
|
|
|
82
60
|
If no exception is raised before exiting context, session is committed when exiting
|
|
83
61
|
context. If an exception is raised, session is rollbacked.
|
|
84
62
|
"""
|
|
85
|
-
|
|
63
|
+
try:
|
|
64
|
+
session: SqlaAsyncSession = _async_session_factories[key]()
|
|
65
|
+
except KeyError as exc:
|
|
66
|
+
raise KeyError(
|
|
67
|
+
f"No async session with key '{key}' found, "
|
|
68
|
+
"please ensure you've configured the environment variables for this key."
|
|
69
|
+
) from exc
|
|
70
|
+
|
|
86
71
|
logger.bind(db_async_session=session)
|
|
87
72
|
|
|
88
73
|
try:
|
|
89
74
|
yield session
|
|
90
|
-
await session.commit()
|
|
91
|
-
|
|
92
75
|
except Exception:
|
|
93
|
-
logger.
|
|
76
|
+
logger.warning("context failed, rolling back", exc_info=True)
|
|
94
77
|
await session.rollback()
|
|
95
78
|
raise
|
|
96
79
|
|
|
80
|
+
else:
|
|
81
|
+
try:
|
|
82
|
+
await session.commit()
|
|
83
|
+
except Exception:
|
|
84
|
+
logger.exception("commit failed, rolling back")
|
|
85
|
+
await session.rollback()
|
|
86
|
+
raise
|
|
87
|
+
|
|
97
88
|
finally:
|
|
98
89
|
await session.close()
|
|
99
90
|
|
|
100
91
|
|
|
101
|
-
async def add_session_to_request(
|
|
92
|
+
async def add_session_to_request(
|
|
93
|
+
request: Request, call_next, key: str = _DEFAULT_SESSION_KEY
|
|
94
|
+
):
|
|
102
95
|
"""Middleware which injects a new sqla async session into every request.
|
|
103
96
|
|
|
104
97
|
Handles creation of session, as well as commit, rollback, and closing of session.
|
|
@@ -113,15 +106,71 @@ async def add_session_to_request(request: Request, call_next):
|
|
|
113
106
|
fastapi_sqla.setup(app) # includes middleware
|
|
114
107
|
|
|
115
108
|
@app.get("/users")
|
|
116
|
-
async def get_users(session: fastapi_sqla.AsyncSession
|
|
109
|
+
async def get_users(session: fastapi_sqla.AsyncSession):
|
|
117
110
|
return await session.execute(...) # use your session here
|
|
118
111
|
"""
|
|
119
|
-
async with open_session() as session:
|
|
120
|
-
request.
|
|
112
|
+
async with open_session(key) as session:
|
|
113
|
+
setattr(request.state, f"{_ASYNC_REQUEST_SESSION_KEY}_{key}", session)
|
|
121
114
|
response = await call_next(request)
|
|
115
|
+
|
|
116
|
+
is_dirty = bool(session.dirty or session.deleted or session.new)
|
|
117
|
+
|
|
118
|
+
# try to commit after response, so that we can return a proper 500 response
|
|
119
|
+
# and not raise a true internal server error
|
|
120
|
+
if response.status_code < 400:
|
|
121
|
+
try:
|
|
122
|
+
await session.commit()
|
|
123
|
+
except Exception:
|
|
124
|
+
logger.exception("commit failed, returning http error")
|
|
125
|
+
response = PlainTextResponse(
|
|
126
|
+
content="Internal Server Error", status_code=500
|
|
127
|
+
)
|
|
128
|
+
|
|
122
129
|
if response.status_code >= 400:
|
|
123
130
|
# If ever a route handler returns an http exception, we do not want the
|
|
124
131
|
# session opened by current context manager to commit anything in db.
|
|
132
|
+
if is_dirty:
|
|
133
|
+
# optimistically only log if there were uncommitted changes
|
|
134
|
+
logger.warning(
|
|
135
|
+
"http error, rolling back possibly uncommitted changes",
|
|
136
|
+
status_code=response.status_code,
|
|
137
|
+
)
|
|
138
|
+
# since this is no-op if session is not dirty, we can always call it
|
|
125
139
|
await session.rollback()
|
|
126
140
|
|
|
127
141
|
return response
|
|
142
|
+
|
|
143
|
+
|
|
144
|
+
class AsyncSessionDependency:
|
|
145
|
+
def __init__(self, key: str = _DEFAULT_SESSION_KEY) -> None:
|
|
146
|
+
self.key = key
|
|
147
|
+
|
|
148
|
+
def __call__(self, request: Request) -> SqlaAsyncSession:
|
|
149
|
+
"""Yield the sqlalchemy async session for that request.
|
|
150
|
+
|
|
151
|
+
It is meant to be used as a FastAPI dependency::
|
|
152
|
+
|
|
153
|
+
from fastapi import APIRouter, Depends
|
|
154
|
+
from fastapi_sqla import SqlaAsyncSession, AsyncSessionDependency
|
|
155
|
+
|
|
156
|
+
router = APIRouter()
|
|
157
|
+
|
|
158
|
+
@router.get("/users")
|
|
159
|
+
async def get_users(
|
|
160
|
+
session: SqlaAsyncSession = Depends(AsyncSessionDependency())
|
|
161
|
+
):
|
|
162
|
+
pass
|
|
163
|
+
"""
|
|
164
|
+
try:
|
|
165
|
+
return getattr(request.state, f"{_ASYNC_REQUEST_SESSION_KEY}_{self.key}")
|
|
166
|
+
except AttributeError:
|
|
167
|
+
logger.exception(
|
|
168
|
+
f"No async session with key '{self.key}' found in request, "
|
|
169
|
+
"please ensure you've setup fastapi_sqla.",
|
|
170
|
+
session_key=self.key,
|
|
171
|
+
)
|
|
172
|
+
raise
|
|
173
|
+
|
|
174
|
+
|
|
175
|
+
default_async_session_dep = AsyncSessionDependency()
|
|
176
|
+
AsyncSession = Annotated[SqlaAsyncSession, Depends(default_async_session_dep)]
|
|
@@ -1,4 +1,5 @@
|
|
|
1
1
|
from os import environ
|
|
2
|
+
from typing import Any
|
|
2
3
|
|
|
3
4
|
try:
|
|
4
5
|
import boto3
|
|
@@ -11,9 +12,10 @@ except ImportError as err:
|
|
|
11
12
|
from functools import lru_cache
|
|
12
13
|
|
|
13
14
|
from sqlalchemy import event
|
|
15
|
+
from sqlalchemy.engine import Engine
|
|
14
16
|
|
|
15
17
|
|
|
16
|
-
def setup(engine):
|
|
18
|
+
def setup(engine: Engine):
|
|
17
19
|
lc_environ = {k.lower(): v for k, v in environ.items()}
|
|
18
20
|
aws_rds_iam_enabled = lc_environ.get("fastapi_sqla_aws_rds_iam_enabled") == "true"
|
|
19
21
|
|
|
@@ -30,13 +32,13 @@ def get_rds_client():
|
|
|
30
32
|
return session.client("rds")
|
|
31
33
|
|
|
32
34
|
|
|
33
|
-
def get_authentication_token(host, port, user):
|
|
35
|
+
def get_authentication_token(host: str, port: int, user: str):
|
|
34
36
|
client = get_rds_client()
|
|
35
37
|
token = client.generate_db_auth_token(DBHostname=host, Port=port, DBUsername=user)
|
|
36
38
|
return token
|
|
37
39
|
|
|
38
40
|
|
|
39
|
-
def set_connection_token(dialect, conn_rec, cargs, cparams):
|
|
41
|
+
def set_connection_token(dialect, conn_rec, cargs, cparams: dict[str, Any]):
|
|
40
42
|
cparams["password"] = get_authentication_token(
|
|
41
43
|
host=cparams["host"], port=cparams.get("port", 5432), user=cparams["user"]
|
|
42
44
|
)
|
fastapi_sqla/base.py
CHANGED
|
@@ -1,4 +1,6 @@
|
|
|
1
|
+
import functools
|
|
1
2
|
import os
|
|
3
|
+
import re
|
|
2
4
|
|
|
3
5
|
from fastapi import FastAPI
|
|
4
6
|
from sqlalchemy.engine import Engine
|
|
@@ -15,19 +17,41 @@ except ImportError as err: # pragma: no cover
|
|
|
15
17
|
asyncio_support_err = str(err)
|
|
16
18
|
|
|
17
19
|
|
|
18
|
-
|
|
19
|
-
engine = sqla.new_engine()
|
|
20
|
-
|
|
21
|
-
if not is_async_dialect(engine):
|
|
22
|
-
app.add_event_handler("startup", sqla.startup)
|
|
23
|
-
app.middleware("http")(sqla.add_session_to_request)
|
|
20
|
+
_ENGINE_KEYS_REGEX = re.compile(r"fastapi_sqla__(?!_)(.+)(?<!_)__(?!_).+")
|
|
24
21
|
|
|
25
|
-
has_async_config = "async_sqlalchemy_url" in os.environ or is_async_dialect(engine)
|
|
26
|
-
if has_async_config:
|
|
27
|
-
assert has_asyncio_support, asyncio_support_err
|
|
28
|
-
app.add_event_handler("startup", async_sqla.startup)
|
|
29
|
-
app.middleware("http")(async_sqla.add_session_to_request)
|
|
30
22
|
|
|
31
|
-
|
|
32
|
-
|
|
23
|
+
def setup(app: FastAPI):
|
|
24
|
+
engine_keys = _get_engine_keys()
|
|
25
|
+
engines = {key: sqla.new_engine(key) for key in engine_keys}
|
|
26
|
+
for key, engine in engines.items():
|
|
27
|
+
if not _is_async_dialect(engine):
|
|
28
|
+
app.add_event_handler("startup", functools.partial(sqla.startup, key=key))
|
|
29
|
+
app.middleware("http")(
|
|
30
|
+
functools.partial(sqla.add_session_to_request, key=key)
|
|
31
|
+
)
|
|
32
|
+
else:
|
|
33
|
+
assert has_asyncio_support, asyncio_support_err
|
|
34
|
+
app.add_event_handler(
|
|
35
|
+
"startup", functools.partial(async_sqla.startup, key=key)
|
|
36
|
+
)
|
|
37
|
+
app.middleware("http")(
|
|
38
|
+
functools.partial(async_sqla.add_session_to_request, key=key)
|
|
39
|
+
)
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
def _get_engine_keys() -> set[str]:
|
|
43
|
+
keys = {sqla._DEFAULT_SESSION_KEY}
|
|
44
|
+
|
|
45
|
+
lowercase_environ = {k.lower(): v for k, v in os.environ.items()}
|
|
46
|
+
for env_var in lowercase_environ:
|
|
47
|
+
match = _ENGINE_KEYS_REGEX.search(env_var)
|
|
48
|
+
if not match:
|
|
49
|
+
continue
|
|
50
|
+
|
|
51
|
+
keys.add(match.group(1))
|
|
52
|
+
|
|
53
|
+
return keys
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
def _is_async_dialect(engine: Engine):
|
|
33
57
|
return engine.dialect.is_async if hasattr(engine.dialect, "is_async") else False
|
fastapi_sqla/pagination.py
CHANGED
|
@@ -1,24 +1,24 @@
|
|
|
1
1
|
import math
|
|
2
2
|
from collections.abc import Callable
|
|
3
3
|
from functools import singledispatch
|
|
4
|
-
from typing import Iterator, Optional, Union, cast
|
|
4
|
+
from typing import Annotated, Iterator, Optional, Union, cast
|
|
5
5
|
|
|
6
6
|
from fastapi import Depends, Query
|
|
7
7
|
from sqlalchemy.orm import Query as LegacyQuery
|
|
8
8
|
from sqlalchemy.sql import Select, func, select
|
|
9
9
|
|
|
10
10
|
from fastapi_sqla.models import Page
|
|
11
|
-
from fastapi_sqla.sqla import
|
|
11
|
+
from fastapi_sqla.sqla import _DEFAULT_SESSION_KEY, SessionDependency, SqlaSession
|
|
12
12
|
|
|
13
13
|
DbQuery = Union[LegacyQuery, Select]
|
|
14
14
|
QueryCountDependency = Callable[..., int]
|
|
15
15
|
PaginateSignature = Callable[[DbQuery, Optional[bool]], Page]
|
|
16
|
-
DefaultDependency = Callable[[
|
|
17
|
-
WithQueryCountDependency = Callable[[
|
|
16
|
+
DefaultDependency = Callable[[SqlaSession, int, int], PaginateSignature]
|
|
17
|
+
WithQueryCountDependency = Callable[[SqlaSession, int, int, int], PaginateSignature]
|
|
18
18
|
PaginateDependency = Union[DefaultDependency, WithQueryCountDependency]
|
|
19
19
|
|
|
20
20
|
|
|
21
|
-
def default_query_count(session:
|
|
21
|
+
def default_query_count(session: SqlaSession, query: DbQuery) -> int:
|
|
22
22
|
"""Default function used to count items returned by a query.
|
|
23
23
|
|
|
24
24
|
It is slower than a manually written query could be: It runs the query in a
|
|
@@ -46,7 +46,7 @@ def default_query_count(session: Session, query: DbQuery) -> int:
|
|
|
46
46
|
@singledispatch
|
|
47
47
|
def paginate_query(
|
|
48
48
|
query: DbQuery,
|
|
49
|
-
session:
|
|
49
|
+
session: SqlaSession,
|
|
50
50
|
total_items: int,
|
|
51
51
|
offset: int,
|
|
52
52
|
limit: int,
|
|
@@ -59,7 +59,7 @@ def paginate_query(
|
|
|
59
59
|
@paginate_query.register
|
|
60
60
|
def _paginate_legacy(
|
|
61
61
|
query: LegacyQuery,
|
|
62
|
-
session:
|
|
62
|
+
session: SqlaSession,
|
|
63
63
|
total_items: int,
|
|
64
64
|
offset: int,
|
|
65
65
|
limit: int,
|
|
@@ -81,7 +81,7 @@ def _paginate_legacy(
|
|
|
81
81
|
@paginate_query.register
|
|
82
82
|
def _paginate(
|
|
83
83
|
query: Select,
|
|
84
|
-
session:
|
|
84
|
+
session: SqlaSession,
|
|
85
85
|
total_items: int,
|
|
86
86
|
offset: int,
|
|
87
87
|
limit: int,
|
|
@@ -107,12 +107,13 @@ def _paginate(
|
|
|
107
107
|
|
|
108
108
|
|
|
109
109
|
def Pagination(
|
|
110
|
+
session_key: str = _DEFAULT_SESSION_KEY,
|
|
110
111
|
min_page_size: int = 10,
|
|
111
112
|
max_page_size: int = 100,
|
|
112
113
|
query_count: Union[QueryCountDependency, None] = None,
|
|
113
114
|
) -> PaginateDependency:
|
|
114
115
|
def default_dependency(
|
|
115
|
-
session:
|
|
116
|
+
session: SqlaSession = Depends(SessionDependency(key=session_key)),
|
|
116
117
|
offset: int = Query(0, ge=0),
|
|
117
118
|
limit: int = Query(min_page_size, ge=1, le=max_page_size),
|
|
118
119
|
) -> PaginateSignature:
|
|
@@ -125,7 +126,7 @@ def Pagination(
|
|
|
125
126
|
return paginate
|
|
126
127
|
|
|
127
128
|
def with_query_count_dependency(
|
|
128
|
-
session:
|
|
129
|
+
session: SqlaSession = Depends(SessionDependency(key=session_key)),
|
|
129
130
|
offset: int = Query(0, ge=0),
|
|
130
131
|
limit: int = Query(min_page_size, ge=1, le=max_page_size),
|
|
131
132
|
total_items: int = Depends(query_count),
|
|
@@ -143,4 +144,4 @@ def Pagination(
|
|
|
143
144
|
return default_dependency
|
|
144
145
|
|
|
145
146
|
|
|
146
|
-
Paginate
|
|
147
|
+
Paginate = Annotated[PaginateSignature, Depends(Pagination())]
|
fastapi_sqla/sqla.py
CHANGED
|
@@ -2,10 +2,10 @@ import asyncio
|
|
|
2
2
|
import os
|
|
3
3
|
from collections.abc import Generator
|
|
4
4
|
from contextlib import contextmanager
|
|
5
|
-
from typing import
|
|
5
|
+
from typing import Annotated
|
|
6
6
|
|
|
7
7
|
import structlog
|
|
8
|
-
from fastapi import Request
|
|
8
|
+
from fastapi import Depends, Request
|
|
9
9
|
from fastapi.concurrency import contextmanager_in_threadpool
|
|
10
10
|
from fastapi.responses import PlainTextResponse
|
|
11
11
|
from sqlalchemy import engine_from_config, text
|
|
@@ -18,75 +18,62 @@ from fastapi_sqla.models import Base
|
|
|
18
18
|
|
|
19
19
|
logger = structlog.get_logger(__name__)
|
|
20
20
|
|
|
21
|
-
|
|
21
|
+
_DEFAULT_SESSION_KEY = "default"
|
|
22
|
+
_REQUEST_SESSION_KEY = "fastapi_sqla_session"
|
|
23
|
+
_session_factories: dict[str, sessionmaker] = {}
|
|
22
24
|
|
|
23
|
-
_Session = sessionmaker()
|
|
24
25
|
|
|
26
|
+
def new_engine(key: str = _DEFAULT_SESSION_KEY) -> Engine:
|
|
27
|
+
envvar_prefix = "sqlalchemy_"
|
|
28
|
+
if key != _DEFAULT_SESSION_KEY:
|
|
29
|
+
envvar_prefix = f"fastapi_sqla__{key}__{envvar_prefix}"
|
|
25
30
|
|
|
26
|
-
|
|
27
|
-
envvar_prefix
|
|
28
|
-
lowercase_environ = {
|
|
29
|
-
k.lower(): v for k, v in os.environ.items() if k.lower() != "sqlalchemy_warn_20"
|
|
30
|
-
}
|
|
31
|
+
lowercase_environ = {k.lower(): v for k, v in os.environ.items()}
|
|
32
|
+
lowercase_environ.pop(f"{envvar_prefix}warn_20", None)
|
|
31
33
|
return engine_from_config(lowercase_environ, prefix=envvar_prefix)
|
|
32
34
|
|
|
33
35
|
|
|
34
|
-
def startup():
|
|
35
|
-
engine = new_engine()
|
|
36
|
+
def startup(key: str = _DEFAULT_SESSION_KEY):
|
|
37
|
+
engine = new_engine(key)
|
|
36
38
|
aws_rds_iam_support.setup(engine.engine)
|
|
37
39
|
aws_aurora_support.setup(engine.engine)
|
|
38
40
|
|
|
39
|
-
# Fail early
|
|
41
|
+
# Fail early
|
|
40
42
|
try:
|
|
41
43
|
with engine.connect() as connection:
|
|
42
44
|
connection.execute(text("select 'OK'"))
|
|
43
45
|
except Exception:
|
|
44
46
|
logger.critical(
|
|
45
|
-
"
|
|
47
|
+
f"Failed querying db for key '{key}': "
|
|
48
|
+
"are the the environment variables correctly configured for this key?"
|
|
46
49
|
)
|
|
47
50
|
raise
|
|
48
51
|
|
|
49
52
|
Base.prepare(engine)
|
|
50
|
-
|
|
51
|
-
logger.info("startup", engine=engine)
|
|
53
|
+
_session_factories[key] = sessionmaker(bind=engine)
|
|
52
54
|
|
|
53
|
-
|
|
54
|
-
class Session(SqlaSession):
|
|
55
|
-
def __new__(cls, request: Request):
|
|
56
|
-
"""Yield the sqlalchmey session for that request.
|
|
57
|
-
|
|
58
|
-
It is meant to be used as a FastAPI dependency::
|
|
59
|
-
|
|
60
|
-
from fastapi import APIRouter, Depends
|
|
61
|
-
from fastapi_sqla import Session
|
|
62
|
-
|
|
63
|
-
router = APIRouter()
|
|
64
|
-
|
|
65
|
-
@router.get("/users")
|
|
66
|
-
def get_users(session: Session = Depends()):
|
|
67
|
-
pass
|
|
68
|
-
"""
|
|
69
|
-
try:
|
|
70
|
-
return request.scope[_SESSION_KEY]
|
|
71
|
-
except KeyError: # pragma: no cover
|
|
72
|
-
raise Exception(
|
|
73
|
-
"No session found in request, please ensure you've setup fastapi_sqla."
|
|
74
|
-
)
|
|
55
|
+
logger.info("engine startup", engine_key=key, engine=engine)
|
|
75
56
|
|
|
76
57
|
|
|
77
58
|
@contextmanager
|
|
78
|
-
def open_session() -> Generator[SqlaSession, None, None]:
|
|
59
|
+
def open_session(key: str = _DEFAULT_SESSION_KEY) -> Generator[SqlaSession, None, None]:
|
|
79
60
|
"""Context manager that opens a session and properly closes session when exiting.
|
|
80
61
|
|
|
81
62
|
If no exception is raised before exiting context, session is committed when exiting
|
|
82
63
|
context. If an exception is raised, session is rollbacked.
|
|
83
64
|
"""
|
|
84
|
-
|
|
65
|
+
try:
|
|
66
|
+
session: SqlaSession = _session_factories[key]()
|
|
67
|
+
except KeyError as exc:
|
|
68
|
+
raise KeyError(
|
|
69
|
+
f"No session with key '{key}' found, "
|
|
70
|
+
"please ensure you've configured the environment variables for this key."
|
|
71
|
+
) from exc
|
|
72
|
+
|
|
85
73
|
logger.bind(db_session=session)
|
|
86
74
|
|
|
87
75
|
try:
|
|
88
76
|
yield session
|
|
89
|
-
|
|
90
77
|
except Exception:
|
|
91
78
|
logger.warning("context failed, rolling back", exc_info=True)
|
|
92
79
|
session.rollback()
|
|
@@ -104,7 +91,9 @@ def open_session() -> Generator[SqlaSession, None, None]:
|
|
|
104
91
|
session.close()
|
|
105
92
|
|
|
106
93
|
|
|
107
|
-
async def add_session_to_request(
|
|
94
|
+
async def add_session_to_request(
|
|
95
|
+
request: Request, call_next, key: str = _DEFAULT_SESSION_KEY
|
|
96
|
+
):
|
|
108
97
|
"""Middleware which injects a new sqla session into every request.
|
|
109
98
|
|
|
110
99
|
Handles creation of session, as well as commit, rollback, and closing of session.
|
|
@@ -119,11 +108,11 @@ async def add_session_to_request(request: Request, call_next):
|
|
|
119
108
|
fastapi_sqla.setup(app) # includes middleware
|
|
120
109
|
|
|
121
110
|
@app.get("/users")
|
|
122
|
-
def get_users(session: fastapi_sqla.Session
|
|
111
|
+
def get_users(session: fastapi_sqla.Session):
|
|
123
112
|
return session.execute(...) # use your session here
|
|
124
113
|
"""
|
|
125
|
-
async with contextmanager_in_threadpool(open_session()) as session:
|
|
126
|
-
request.
|
|
114
|
+
async with contextmanager_in_threadpool(open_session(key)) as session:
|
|
115
|
+
setattr(request.state, f"{_REQUEST_SESSION_KEY}_{key}", session)
|
|
127
116
|
|
|
128
117
|
response = await call_next(request)
|
|
129
118
|
|
|
@@ -155,3 +144,36 @@ async def add_session_to_request(request: Request, call_next):
|
|
|
155
144
|
await loop.run_in_executor(None, session.rollback)
|
|
156
145
|
|
|
157
146
|
return response
|
|
147
|
+
|
|
148
|
+
|
|
149
|
+
class SessionDependency:
|
|
150
|
+
def __init__(self, key: str = _DEFAULT_SESSION_KEY) -> None:
|
|
151
|
+
self.key = key
|
|
152
|
+
|
|
153
|
+
def __call__(self, request: Request) -> SqlaSession:
|
|
154
|
+
"""Yield the sqlalchemy session for that request.
|
|
155
|
+
|
|
156
|
+
It is meant to be used as a FastAPI dependency::
|
|
157
|
+
|
|
158
|
+
from fastapi import APIRouter, Depends
|
|
159
|
+
from fastapi_sqla import SqlaSession, SessionDependency
|
|
160
|
+
|
|
161
|
+
router = APIRouter()
|
|
162
|
+
|
|
163
|
+
@router.get("/users")
|
|
164
|
+
def get_users(session: SqlaSession = Depends(SessionDependency())):
|
|
165
|
+
pass
|
|
166
|
+
"""
|
|
167
|
+
try:
|
|
168
|
+
return getattr(request.state, f"{_REQUEST_SESSION_KEY}_{self.key}")
|
|
169
|
+
except AttributeError:
|
|
170
|
+
logger.exception(
|
|
171
|
+
f"No session with key '{self.key}' found in request, "
|
|
172
|
+
"please ensure you've setup fastapi_sqla.",
|
|
173
|
+
session_key=self.key,
|
|
174
|
+
)
|
|
175
|
+
raise
|
|
176
|
+
|
|
177
|
+
|
|
178
|
+
default_session_dep = SessionDependency()
|
|
179
|
+
Session = Annotated[SqlaSession, Depends(default_session_dep)]
|
|
@@ -1,13 +1,13 @@
|
|
|
1
1
|
Metadata-Version: 2.1
|
|
2
2
|
Name: fastapi-sqla
|
|
3
|
-
Version:
|
|
3
|
+
Version: 3.0.0
|
|
4
4
|
Summary: SQLAlchemy extension for FastAPI with support for pagination, asyncio, and pytest, ready for production.
|
|
5
5
|
Home-page: https://github.com/dialoguemd/fastapi-sqla
|
|
6
6
|
License: MIT
|
|
7
7
|
Keywords: FastAPI,SQLAlchemy,asyncio,pytest,alembic
|
|
8
8
|
Author: Hadrien David
|
|
9
9
|
Author-email: hadrien.david@dialogue.co
|
|
10
|
-
Requires-Python: >=3.
|
|
10
|
+
Requires-Python: >=3.9,<4.0
|
|
11
11
|
Classifier: Development Status :: 5 - Production/Stable
|
|
12
12
|
Classifier: Environment :: Web Environment
|
|
13
13
|
Classifier: Framework :: AsyncIO
|
|
@@ -19,8 +19,6 @@ Classifier: License :: OSI Approved :: MIT License
|
|
|
19
19
|
Classifier: Operating System :: OS Independent
|
|
20
20
|
Classifier: Programming Language :: Python
|
|
21
21
|
Classifier: Programming Language :: Python :: 3
|
|
22
|
-
Classifier: Programming Language :: Python :: 3.7
|
|
23
|
-
Classifier: Programming Language :: Python :: 3.8
|
|
24
22
|
Classifier: Programming Language :: Python :: 3.9
|
|
25
23
|
Classifier: Programming Language :: Python :: 3.10
|
|
26
24
|
Classifier: Programming Language :: Python :: 3.11
|
|
@@ -43,7 +41,7 @@ Requires-Dist: asgi_lifespan (>=1.0.1,<2.0.0) ; extra == "tests"
|
|
|
43
41
|
Requires-Dist: asyncpg (>=0.28.0,<0.29.0) ; extra == "asyncpg"
|
|
44
42
|
Requires-Dist: black (>=22.8.0,<23.0.0) ; extra == "tests"
|
|
45
43
|
Requires-Dist: boto3 (>=1.24.74,<2.0.0) ; extra == "aws-rds-iam"
|
|
46
|
-
Requires-Dist: fastapi (>=0.
|
|
44
|
+
Requires-Dist: fastapi (>=0.95.1)
|
|
47
45
|
Requires-Dist: greenlet (>=1.1.3,<2.0.0) ; extra == "tests"
|
|
48
46
|
Requires-Dist: httpx (>=0.23.0,<0.24.0) ; extra == "tests"
|
|
49
47
|
Requires-Dist: isort (>=5.5.3,<6.0.0) ; extra == "tests"
|
|
@@ -90,7 +88,7 @@ unique `email`:
|
|
|
90
88
|
|
|
91
89
|
```python
|
|
92
90
|
# main.py
|
|
93
|
-
from fastapi import
|
|
91
|
+
from fastapi import FastAPI, HTTPException
|
|
94
92
|
from fastapi_sqla import Base, Item, Page, Paginate, Session, setup
|
|
95
93
|
from pydantic import BaseModel, EmailStr
|
|
96
94
|
from sqlalchemy import select
|
|
@@ -118,12 +116,12 @@ class UserModel(UserIn):
|
|
|
118
116
|
|
|
119
117
|
|
|
120
118
|
@app.get("/users", response_model=Page[UserModel])
|
|
121
|
-
def list_users(paginate: Paginate
|
|
119
|
+
def list_users(paginate: Paginate):
|
|
122
120
|
return paginate(select(User))
|
|
123
121
|
|
|
124
122
|
|
|
125
123
|
@app.get("/users/{user_id}", response_model=Item[UserModel])
|
|
126
|
-
def get_user(user_id: int, session: Session
|
|
124
|
+
def get_user(user_id: int, session: Session):
|
|
127
125
|
user = session.get(User, user_id)
|
|
128
126
|
if user is None:
|
|
129
127
|
raise HTTPException(404)
|
|
@@ -131,7 +129,7 @@ def get_user(user_id: int, session: Session = Depends()):
|
|
|
131
129
|
|
|
132
130
|
|
|
133
131
|
@app.post("/users", response_model=Item[UserModel])
|
|
134
|
-
def create_user(new_user: UserIn, session: Session
|
|
132
|
+
def create_user(new_user: UserIn, session: Session):
|
|
135
133
|
user = User(**new_user.model_dump())
|
|
136
134
|
session.add(user)
|
|
137
135
|
try:
|
|
@@ -173,6 +171,21 @@ The only required key is `sqlalchemy_url`, which provides the database URL, exam
|
|
|
173
171
|
export sqlalchemy_url=postgresql://postgres@localhost
|
|
174
172
|
```
|
|
175
173
|
|
|
174
|
+
### Multi-session support
|
|
175
|
+
|
|
176
|
+
In order to configure multiple sessions for the application,
|
|
177
|
+
set the environment variables with this prefix format: `fastapi_sqla__MY_KEY__`.
|
|
178
|
+
|
|
179
|
+
As with the default session, each matching key (after the prefix is stripped)
|
|
180
|
+
is treated as though it were the corresponding keyword argument to [`sqlalchemy.create_engine`]
|
|
181
|
+
call.
|
|
182
|
+
|
|
183
|
+
For example, to configure a session with the `read_only` key:
|
|
184
|
+
|
|
185
|
+
```bash
|
|
186
|
+
export fastapi_sqla__read_only__sqlalchemy_url=postgresql://postgres@localhost
|
|
187
|
+
```
|
|
188
|
+
|
|
176
189
|
### `asyncio` support using [`asyncpg`]
|
|
177
190
|
|
|
178
191
|
SQLAlchemy `>= 1.4` supports `asyncio`.
|
|
@@ -182,7 +195,7 @@ To enable `asyncio` support against a Postgres DB, install `asyncpg`:
|
|
|
182
195
|
pip install asyncpg
|
|
183
196
|
```
|
|
184
197
|
|
|
185
|
-
And define environment variable `sqlalchemy_url` with `postgres+asyncpg` scheme:
|
|
198
|
+
And define the environment variable `sqlalchemy_url` with `postgres+asyncpg` scheme:
|
|
186
199
|
|
|
187
200
|
```bash
|
|
188
201
|
export sqlalchemy_url=postgresql+asyncpg://postgres@localhost
|
|
@@ -216,23 +229,71 @@ class Entity(Base):
|
|
|
216
229
|
|
|
217
230
|
Use [FastAPI dependency injection] to get a session as a parameter of a path operation
|
|
218
231
|
function.
|
|
219
|
-
|
|
232
|
+
|
|
233
|
+
The SQLAlchemy session is committed before the response is returned or rollbacked if any
|
|
220
234
|
exception occurred:
|
|
221
235
|
|
|
222
236
|
```python
|
|
237
|
+
from fastapi import APIRouter
|
|
238
|
+
from fastapi_sqla import AsyncSession, Session
|
|
239
|
+
|
|
240
|
+
router = APIRouter()
|
|
241
|
+
|
|
242
|
+
|
|
243
|
+
@router.get("/example")
|
|
244
|
+
def example(session: Session):
|
|
245
|
+
return session.execute("SELECT now()").scalar()
|
|
246
|
+
|
|
247
|
+
|
|
248
|
+
@router.get("/async_example")
|
|
249
|
+
async def async_example(session: AsyncSession):
|
|
250
|
+
return await session.scalar("SELECT now()")
|
|
251
|
+
```
|
|
252
|
+
|
|
253
|
+
In order to get a session configured with a custom key:
|
|
254
|
+
|
|
255
|
+
```python
|
|
256
|
+
from typing import Annotated
|
|
257
|
+
|
|
223
258
|
from fastapi import APIRouter, Depends
|
|
224
|
-
from fastapi_sqla import
|
|
259
|
+
from fastapi_sqla import (
|
|
260
|
+
AsyncSessionDependency,
|
|
261
|
+
SessionDependency,
|
|
262
|
+
SqlaAsyncSession,
|
|
263
|
+
SqlaSession,
|
|
264
|
+
)
|
|
225
265
|
|
|
226
266
|
router = APIRouter()
|
|
227
267
|
|
|
228
268
|
|
|
269
|
+
# Preferred
|
|
270
|
+
|
|
271
|
+
ReadOnlySession = Annotated[SqlaSession, Depends(SessionDependency(key="read_only"))]
|
|
272
|
+
AsyncReadOnlySession = Annotated[
|
|
273
|
+
SqlaAsyncSession, Depends(AsyncSessionDependency(key="read_only"))
|
|
274
|
+
]
|
|
275
|
+
|
|
229
276
|
@router.get("/example")
|
|
230
|
-
def example(session:
|
|
277
|
+
def example(session: ReadOnlySession):
|
|
231
278
|
return session.execute("SELECT now()").scalar()
|
|
232
279
|
|
|
233
280
|
|
|
234
281
|
@router.get("/async_example")
|
|
235
|
-
async def async_example(session:
|
|
282
|
+
async def async_example(session: AsyncReadOnlySession):
|
|
283
|
+
return await session.scalar("SELECT now()")
|
|
284
|
+
|
|
285
|
+
|
|
286
|
+
# Alternative
|
|
287
|
+
|
|
288
|
+
@router.get("/example/alt")
|
|
289
|
+
def example_alt(session: SqlaSession = Depends(SessionDependency(key="read_only"))):
|
|
290
|
+
return session.execute("SELECT now()").scalar()
|
|
291
|
+
|
|
292
|
+
|
|
293
|
+
@router.get("/async_example/alt")
|
|
294
|
+
async def async_example_alt(
|
|
295
|
+
session: SqlaAsyncSession = Depends(AsyncSessionDependency(key="read_only")),
|
|
296
|
+
):
|
|
236
297
|
return await session.scalar("SELECT now()")
|
|
237
298
|
```
|
|
238
299
|
|
|
@@ -240,12 +301,12 @@ async def async_example(session: AsyncSession = Depends()):
|
|
|
240
301
|
|
|
241
302
|
When needing a session outside of a path operation, like when using
|
|
242
303
|
[FastAPI background tasks], use `fastapi_sqla.open_session` context manager.
|
|
243
|
-
SQLAlchemy session is committed when exiting context or rollbacked if any exception
|
|
304
|
+
The SQLAlchemy session is committed when exiting context or rollbacked if any exception
|
|
244
305
|
occurred:
|
|
245
306
|
|
|
246
307
|
```python
|
|
247
308
|
from fastapi import APIRouter, BackgroundTasks
|
|
248
|
-
from fastapi_sqla import
|
|
309
|
+
from fastapi_sqla import open_async_session, open_session
|
|
249
310
|
|
|
250
311
|
router = APIRouter()
|
|
251
312
|
|
|
@@ -260,16 +321,23 @@ def run_bg():
|
|
|
260
321
|
with open_session() as session:
|
|
261
322
|
session.execute("SELECT now()").scalar()
|
|
262
323
|
|
|
324
|
+
def run_bg_with_key():
|
|
325
|
+
with open_session(key="read_only") as session:
|
|
326
|
+
session.execute("SELECT now()").scalar()
|
|
263
327
|
|
|
264
328
|
async def run_async_bg():
|
|
265
329
|
async with open_async_session() as session:
|
|
266
330
|
await session.scalar("SELECT now()")
|
|
331
|
+
|
|
332
|
+
async def run_async_bg_with_key():
|
|
333
|
+
async with open_async_session(key="read_only") as session:
|
|
334
|
+
await session.scalar("SELECT now()")
|
|
267
335
|
```
|
|
268
336
|
|
|
269
337
|
## Pagination
|
|
270
338
|
|
|
271
339
|
```python
|
|
272
|
-
from fastapi import APIRouter
|
|
340
|
+
from fastapi import APIRouter
|
|
273
341
|
from fastapi_sqla import Base, Page, Paginate
|
|
274
342
|
from pydantic import BaseModel
|
|
275
343
|
from sqlalchemy import select
|
|
@@ -290,7 +358,7 @@ class UserModel(BaseModel):
|
|
|
290
358
|
|
|
291
359
|
|
|
292
360
|
@router.get("/users", response_model=Page[UserModel])
|
|
293
|
-
def all_users(paginate: Paginate
|
|
361
|
+
def all_users(paginate: Paginate):
|
|
294
362
|
return paginate(select(User))
|
|
295
363
|
```
|
|
296
364
|
|
|
@@ -327,7 +395,7 @@ To paginate a query which doesn't return [scalars], specify `scalars=False` when
|
|
|
327
395
|
`paginate`:
|
|
328
396
|
|
|
329
397
|
```python
|
|
330
|
-
from fastapi import APIRouter
|
|
398
|
+
from fastapi import APIRouter
|
|
331
399
|
from fastapi_sqla import Base, Page, Paginate
|
|
332
400
|
from pydantic import BaseModel
|
|
333
401
|
from sqlalchemy import func, select
|
|
@@ -352,7 +420,7 @@ class UserModel(BaseModel):
|
|
|
352
420
|
|
|
353
421
|
|
|
354
422
|
@router.get("/users", response_model=Page[UserModel])
|
|
355
|
-
def all_users(paginate: Paginate
|
|
423
|
+
def all_users(paginate: Paginate):
|
|
356
424
|
query = (
|
|
357
425
|
select(User.id, User.name, func.count(Note.id).label("notes_count"))
|
|
358
426
|
.join(Note)
|
|
@@ -388,15 +456,15 @@ class UserModel(BaseModel):
|
|
|
388
456
|
name: str
|
|
389
457
|
|
|
390
458
|
|
|
391
|
-
def query_count(session: Session
|
|
459
|
+
def query_count(session: Session) -> int:
|
|
392
460
|
return session.execute(select(func.count()).select_from(User)).scalar()
|
|
393
461
|
|
|
394
462
|
|
|
395
|
-
|
|
463
|
+
CustomPaginate = Pagination(min_page_size=5, max_page_size=500, query_count=query_count)
|
|
396
464
|
|
|
397
465
|
|
|
398
466
|
@router.get("/users", response_model=Page[UserModel])
|
|
399
|
-
def all_users(paginate:
|
|
467
|
+
def all_users(paginate: CustomPaginate = Depends()):
|
|
400
468
|
return paginate(select(User))
|
|
401
469
|
```
|
|
402
470
|
|
|
@@ -405,7 +473,7 @@ def all_users(paginate: Paginate = Depends()):
|
|
|
405
473
|
When using the asyncio support, use the `AsyncPaginate` dependency:
|
|
406
474
|
|
|
407
475
|
```python
|
|
408
|
-
from fastapi import APIRouter
|
|
476
|
+
from fastapi import APIRouter
|
|
409
477
|
from fastapi_sqla import Base, Page, AsyncPaginate
|
|
410
478
|
from pydantic import BaseModel
|
|
411
479
|
from sqlalchemy import select
|
|
@@ -426,7 +494,7 @@ class UserModel(BaseModel):
|
|
|
426
494
|
|
|
427
495
|
|
|
428
496
|
@router.get("/users", response_model=Page[UserModel])
|
|
429
|
-
async def all_users(paginate: AsyncPaginate
|
|
497
|
+
async def all_users(paginate: AsyncPaginate):
|
|
430
498
|
return await paginate(select(User))
|
|
431
499
|
```
|
|
432
500
|
|
|
@@ -450,16 +518,85 @@ class UserModel(BaseModel):
|
|
|
450
518
|
name: str
|
|
451
519
|
|
|
452
520
|
|
|
453
|
-
async def query_count(session: AsyncSession
|
|
521
|
+
async def query_count(session: AsyncSession) -> int:
|
|
454
522
|
result = await session.execute(select(func.count()).select_from(User))
|
|
455
523
|
return result.scalar()
|
|
456
524
|
|
|
457
525
|
|
|
458
|
-
|
|
526
|
+
CustomPaginate = AsyncPagination(min_page_size=5, max_page_size=500, query_count=query_count)
|
|
459
527
|
|
|
460
528
|
|
|
461
529
|
@router.get("/users", response_model=Page[UserModel])
|
|
462
|
-
def all_users(paginate: CustomPaginate = Depends()):
|
|
530
|
+
async def all_users(paginate: CustomPaginate = Depends()):
|
|
531
|
+
return await paginate(select(User))
|
|
532
|
+
```
|
|
533
|
+
|
|
534
|
+
### Multi-session support
|
|
535
|
+
|
|
536
|
+
Pagination supports multiple sessions as well. To paginate using a session
|
|
537
|
+
configured with a custom key:
|
|
538
|
+
|
|
539
|
+
```python
|
|
540
|
+
from typing import Annotated
|
|
541
|
+
|
|
542
|
+
from fastapi import APIRouter, Depends
|
|
543
|
+
from fastapi_sqla import (
|
|
544
|
+
AsyncPaginateSignature,
|
|
545
|
+
AsyncPagination,
|
|
546
|
+
Base,
|
|
547
|
+
Page,
|
|
548
|
+
PaginateSignature,
|
|
549
|
+
Pagination,
|
|
550
|
+
)
|
|
551
|
+
from pydantic import BaseModel
|
|
552
|
+
from sqlalchemy import func, select
|
|
553
|
+
|
|
554
|
+
router = APIRouter()
|
|
555
|
+
|
|
556
|
+
|
|
557
|
+
class User(Base):
|
|
558
|
+
__tablename__ = "user"
|
|
559
|
+
|
|
560
|
+
|
|
561
|
+
class UserModel(BaseModel):
|
|
562
|
+
id: int
|
|
563
|
+
name: str
|
|
564
|
+
|
|
565
|
+
|
|
566
|
+
# Preferred
|
|
567
|
+
|
|
568
|
+
ReadOnlyPaginate = Annotated[
|
|
569
|
+
PaginateSignature, Depends(Pagination(session_key="read_only"))
|
|
570
|
+
]
|
|
571
|
+
AsyncReadOnlyPaginate = Annotated[
|
|
572
|
+
AsyncPaginateSignature, Depends(AsyncPagination(session_key="read_only"))
|
|
573
|
+
]
|
|
574
|
+
|
|
575
|
+
@router.get("/users", response_model=Page[UserModel])
|
|
576
|
+
def all_users(paginate: ReadOnlyPaginate):
|
|
577
|
+
return paginate(select(User))
|
|
578
|
+
|
|
579
|
+
@router.get("/async_users", response_model=Page[UserModel])
|
|
580
|
+
async def async_all_users(paginate: AsyncReadOnlyPaginate):
|
|
581
|
+
return await paginate(select(User))
|
|
582
|
+
|
|
583
|
+
|
|
584
|
+
# Alternative
|
|
585
|
+
|
|
586
|
+
@router.get("/users/alt", response_model=Page[UserModel])
|
|
587
|
+
def all_users_alt(
|
|
588
|
+
paginate: PaginateSignature = Depends(
|
|
589
|
+
Pagination(session_key="read_only")
|
|
590
|
+
),
|
|
591
|
+
):
|
|
592
|
+
return paginate(select(User))
|
|
593
|
+
|
|
594
|
+
@router.get("/async_users/alt", response_model=Page[UserModel])
|
|
595
|
+
async def async_all_users_alt(
|
|
596
|
+
paginate: AsyncPaginateSignature = Depends(
|
|
597
|
+
AsyncPagination(session_key="read_only")
|
|
598
|
+
),
|
|
599
|
+
):
|
|
463
600
|
return await paginate(select(User))
|
|
464
601
|
```
|
|
465
602
|
|
|
@@ -0,0 +1,16 @@
|
|
|
1
|
+
fastapi_sqla/__init__.py,sha256=4-MHiYF6TclnQpgs2Z1KDwYe2OPHLHGEu5oJzV3494s,1128
|
|
2
|
+
fastapi_sqla/_pytest_plugin.py,sha256=I7wmRmJHaws-wOHDBo2N50RUkmKd-zutMMpymI9Tg1w,5769
|
|
3
|
+
fastapi_sqla/async_pagination.py,sha256=UA3KxkTa48utBFDHeAkOFxTGUIvtNEWKm4uH-bqzQH4,3114
|
|
4
|
+
fastapi_sqla/async_sqla.py,sha256=BERggjeIjBW9hPqVxE7ry3iQyjBFc9KLvDGwV_LOfK8,5919
|
|
5
|
+
fastapi_sqla/aws_aurora_support.py,sha256=4dxLKOqDccgLwFqlz81L6f4HzrOXMZkY7Zuf4t_310U,838
|
|
6
|
+
fastapi_sqla/aws_rds_iam_support.py,sha256=YSJNhrxmhGN-GVk9PLMTmQSWTKZBvuorKkhc_XaoL44,1189
|
|
7
|
+
fastapi_sqla/base.py,sha256=0X7Gbt49rBHPiSFmNy5S2PT0dA4UBNnwrAesYSkaHBc,1606
|
|
8
|
+
fastapi_sqla/models.py,sha256=QhnPCX-Gz5exAZfWyCRyYSaZ7SM_8QY0Eir6b4_4oI8,1432
|
|
9
|
+
fastapi_sqla/pagination.py,sha256=NsI4ZeOkgbiNNDBjtqZL1rF6j1ya--jjXiyf0GlLaXU,4459
|
|
10
|
+
fastapi_sqla/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
11
|
+
fastapi_sqla/sqla.py,sha256=ih4JBAQzn-lT86jK5ZOX3tewT1UOrw2NUJauOso0mNg,5939
|
|
12
|
+
fastapi_sqla-3.0.0.dist-info/LICENSE,sha256=8G0-nWLqi3xRYRrtRlTE8n1mkYJcnCRoZGUhv6ZE29c,1064
|
|
13
|
+
fastapi_sqla-3.0.0.dist-info/METADATA,sha256=8l-HvdC7q5crdm7NJrSMg_bhf43jg73gx_gO0RM1eEc,19809
|
|
14
|
+
fastapi_sqla-3.0.0.dist-info/WHEEL,sha256=Zb28QaM1gQi8f4VCBhsUklF61CTlNYfs9YAZn-TOGFk,88
|
|
15
|
+
fastapi_sqla-3.0.0.dist-info/entry_points.txt,sha256=haa0EueKcRo8-AlJTpHBMn08wMBiULNGA53nkvaDWj0,53
|
|
16
|
+
fastapi_sqla-3.0.0.dist-info/RECORD,,
|
|
@@ -1,16 +0,0 @@
|
|
|
1
|
-
fastapi_sqla/__init__.py,sha256=QKvBNGd3irKqWp9PQMZ_QpqoqpTzl8jI7sa-zbmOdn8,824
|
|
2
|
-
fastapi_sqla/_pytest_plugin.py,sha256=K_kJRped_7-DhOoQgZVG7znA7Bb5zaDgbdKgDe4rthc,5534
|
|
3
|
-
fastapi_sqla/async_pagination.py,sha256=YUNF6ORaFTiUisLNCKz996ptisr355UlEB6ZPJ06W48,2801
|
|
4
|
-
fastapi_sqla/async_sqla.py,sha256=vfWeLJcZj8L3Z8uKuIQOF1yOEsCqo8rUYdZjfuTSu4M,3967
|
|
5
|
-
fastapi_sqla/aws_aurora_support.py,sha256=4dxLKOqDccgLwFqlz81L6f4HzrOXMZkY7Zuf4t_310U,838
|
|
6
|
-
fastapi_sqla/aws_rds_iam_support.py,sha256=hGs8erX4e7zYCqwandiokw9LHx_jlW83_fJBIYAb7c4,1090
|
|
7
|
-
fastapi_sqla/base.py,sha256=jiL_4n0r_HhyAxkPhu9rwMfh110y2GLrqWDNgDmJk7E,933
|
|
8
|
-
fastapi_sqla/models.py,sha256=QhnPCX-Gz5exAZfWyCRyYSaZ7SM_8QY0Eir6b4_4oI8,1432
|
|
9
|
-
fastapi_sqla/pagination.py,sha256=wIDVTUYogE7LD1Gzt7AmR2uQU8FM2t2iMm7FP4A7Fe0,4239
|
|
10
|
-
fastapi_sqla/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
11
|
-
fastapi_sqla/sqla.py,sha256=RgR-4VPKxfJw6QFC8oISL4htBrCyQHYc2EhjL8m0H_c,4955
|
|
12
|
-
fastapi_sqla-2.10.0.dist-info/LICENSE,sha256=8G0-nWLqi3xRYRrtRlTE8n1mkYJcnCRoZGUhv6ZE29c,1064
|
|
13
|
-
fastapi_sqla-2.10.0.dist-info/METADATA,sha256=XE56F8YbFEkOHF00SOC9CE8UFqb0h5jC4hIpmMY-Ukc,16597
|
|
14
|
-
fastapi_sqla-2.10.0.dist-info/WHEEL,sha256=Zb28QaM1gQi8f4VCBhsUklF61CTlNYfs9YAZn-TOGFk,88
|
|
15
|
-
fastapi_sqla-2.10.0.dist-info/entry_points.txt,sha256=haa0EueKcRo8-AlJTpHBMn08wMBiULNGA53nkvaDWj0,53
|
|
16
|
-
fastapi_sqla-2.10.0.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|
|
File without changes
|