fastapi-async-sqlalchemy 0.6.0__tar.gz → 0.7.0.dev1__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {fastapi-async-sqlalchemy-0.6.0 → fastapi_async_sqlalchemy-0.7.0.dev1}/PKG-INFO +29 -7
- {fastapi-async-sqlalchemy-0.6.0 → fastapi_async_sqlalchemy-0.7.0.dev1}/README.md +25 -6
- {fastapi-async-sqlalchemy-0.6.0 → fastapi_async_sqlalchemy-0.7.0.dev1}/fastapi_async_sqlalchemy/__init__.py +1 -1
- fastapi_async_sqlalchemy-0.7.0.dev1/fastapi_async_sqlalchemy/middleware.py +163 -0
- {fastapi-async-sqlalchemy-0.6.0 → fastapi_async_sqlalchemy-0.7.0.dev1}/fastapi_async_sqlalchemy.egg-info/PKG-INFO +29 -7
- {fastapi-async-sqlalchemy-0.6.0 → fastapi_async_sqlalchemy-0.7.0.dev1}/fastapi_async_sqlalchemy.egg-info/SOURCES.txt +2 -1
- {fastapi-async-sqlalchemy-0.6.0 → fastapi_async_sqlalchemy-0.7.0.dev1}/setup.py +2 -1
- fastapi_async_sqlalchemy-0.7.0.dev1/tests/test_session.py +175 -0
- fastapi-async-sqlalchemy-0.6.0/fastapi_async_sqlalchemy/middleware.py +0 -100
- {fastapi-async-sqlalchemy-0.6.0 → fastapi_async_sqlalchemy-0.7.0.dev1}/LICENSE +0 -0
- {fastapi-async-sqlalchemy-0.6.0 → fastapi_async_sqlalchemy-0.7.0.dev1}/fastapi_async_sqlalchemy/exceptions.py +0 -0
- {fastapi-async-sqlalchemy-0.6.0 → fastapi_async_sqlalchemy-0.7.0.dev1}/fastapi_async_sqlalchemy/py.typed +0 -0
- {fastapi-async-sqlalchemy-0.6.0 → fastapi_async_sqlalchemy-0.7.0.dev1}/fastapi_async_sqlalchemy.egg-info/dependency_links.txt +0 -0
- {fastapi-async-sqlalchemy-0.6.0 → fastapi_async_sqlalchemy-0.7.0.dev1}/fastapi_async_sqlalchemy.egg-info/not-zip-safe +0 -0
- {fastapi-async-sqlalchemy-0.6.0 → fastapi_async_sqlalchemy-0.7.0.dev1}/fastapi_async_sqlalchemy.egg-info/requires.txt +0 -0
- {fastapi-async-sqlalchemy-0.6.0 → fastapi_async_sqlalchemy-0.7.0.dev1}/fastapi_async_sqlalchemy.egg-info/top_level.txt +0 -0
- {fastapi-async-sqlalchemy-0.6.0 → fastapi_async_sqlalchemy-0.7.0.dev1}/pyproject.toml +0 -0
- {fastapi-async-sqlalchemy-0.6.0 → fastapi_async_sqlalchemy-0.7.0.dev1}/setup.cfg +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.1
|
|
2
2
|
Name: fastapi-async-sqlalchemy
|
|
3
|
-
Version: 0.
|
|
3
|
+
Version: 0.7.0.dev1
|
|
4
4
|
Summary: SQLAlchemy middleware for FastAPI
|
|
5
5
|
Home-page: https://github.com/h0rn3t/fastapi-async-sqlalchemy.git
|
|
6
6
|
Author: Eugene Shershen
|
|
@@ -8,7 +8,7 @@ Author-email: h0rn3t.null@gmail.com
|
|
|
8
8
|
License: MIT
|
|
9
9
|
Project-URL: Code, https://github.com/h0rn3t/fastapi-async-sqlalchemy
|
|
10
10
|
Project-URL: Issue tracker, https://github.com/h0rn3t/fastapi-async-sqlalchemy/issues
|
|
11
|
-
Classifier: Development Status ::
|
|
11
|
+
Classifier: Development Status :: 5 - Production/Stable
|
|
12
12
|
Classifier: Environment :: Web Environment
|
|
13
13
|
Classifier: Framework :: AsyncIO
|
|
14
14
|
Classifier: Intended Audience :: Developers
|
|
@@ -19,6 +19,7 @@ Classifier: Programming Language :: Python :: 3.8
|
|
|
19
19
|
Classifier: Programming Language :: Python :: 3.9
|
|
20
20
|
Classifier: Programming Language :: Python :: 3.10
|
|
21
21
|
Classifier: Programming Language :: Python :: 3.11
|
|
22
|
+
Classifier: Programming Language :: Python :: 3.12
|
|
22
23
|
Classifier: Programming Language :: Python :: 3 :: Only
|
|
23
24
|
Classifier: Programming Language :: Python :: Implementation :: CPython
|
|
24
25
|
Classifier: Topic :: Internet :: WWW/HTTP :: HTTP Servers
|
|
@@ -27,6 +28,8 @@ Classifier: Topic :: Software Development :: Libraries :: Python Modules
|
|
|
27
28
|
Requires-Python: >=3.7
|
|
28
29
|
Description-Content-Type: text/markdown
|
|
29
30
|
License-File: LICENSE
|
|
31
|
+
Requires-Dist: starlette>=0.13.6
|
|
32
|
+
Requires-Dist: SQLAlchemy>=1.4.19
|
|
30
33
|
|
|
31
34
|
# SQLAlchemy FastAPI middleware
|
|
32
35
|
|
|
@@ -35,7 +38,7 @@ License-File: LICENSE
|
|
|
35
38
|
[](https://codecov.io/gh/h0rn3t/fastapi-async-sqlalchemy)
|
|
36
39
|
[](https://opensource.org/licenses/MIT)
|
|
37
40
|
[](https://pypi.org/project/fastapi-async-sqlalchemy/)
|
|
38
|
-
[](https://pepy.tech/project/fastapi-async-sqlalchemy)
|
|
41
|
+
[](https://pepy.tech/project/fastapi-async-sqlalchemy)
|
|
39
42
|
[](https://pyup.io/repos/github/h0rn3t/fastapi-async-sqlalchemy/)
|
|
40
43
|
|
|
41
44
|
### Description
|
|
@@ -48,8 +51,8 @@ Provides SQLAlchemy middleware for FastAPI using AsyncSession and async engine.
|
|
|
48
51
|
pip install fastapi-async-sqlalchemy
|
|
49
52
|
```
|
|
50
53
|
|
|
51
|
-
|
|
52
|
-
|
|
54
|
+
|
|
55
|
+
It also works with ```sqlmodel```
|
|
53
56
|
|
|
54
57
|
|
|
55
58
|
### Examples
|
|
@@ -157,9 +160,10 @@ app.add_middleware(
|
|
|
157
160
|
routes.py
|
|
158
161
|
|
|
159
162
|
```python
|
|
163
|
+
import asyncio
|
|
164
|
+
|
|
160
165
|
from fastapi import APIRouter
|
|
161
|
-
from sqlalchemy import column
|
|
162
|
-
from sqlalchemy import table
|
|
166
|
+
from sqlalchemy import column, table, text
|
|
163
167
|
|
|
164
168
|
from databases import first_db, second_db
|
|
165
169
|
|
|
@@ -177,4 +181,22 @@ async def get_files_from_first_db():
|
|
|
177
181
|
async def get_files_from_second_db():
|
|
178
182
|
result = await second_db.session.execute(foo.select())
|
|
179
183
|
return result.fetchall()
|
|
184
|
+
|
|
185
|
+
|
|
186
|
+
@router.get("/concurrent-queries")
|
|
187
|
+
async def parallel_select():
|
|
188
|
+
async with first_db(multi_sessions=True):
|
|
189
|
+
async def execute_query(query):
|
|
190
|
+
return await first_db.session.execute(text(query))
|
|
191
|
+
|
|
192
|
+
tasks = [
|
|
193
|
+
asyncio.create_task(execute_query("SELECT 1")),
|
|
194
|
+
asyncio.create_task(execute_query("SELECT 2")),
|
|
195
|
+
asyncio.create_task(execute_query("SELECT 3")),
|
|
196
|
+
asyncio.create_task(execute_query("SELECT 4")),
|
|
197
|
+
asyncio.create_task(execute_query("SELECT 5")),
|
|
198
|
+
asyncio.create_task(execute_query("SELECT 6")),
|
|
199
|
+
]
|
|
200
|
+
|
|
201
|
+
await asyncio.gather(*tasks)
|
|
180
202
|
```
|
|
@@ -5,7 +5,7 @@
|
|
|
5
5
|
[](https://codecov.io/gh/h0rn3t/fastapi-async-sqlalchemy)
|
|
6
6
|
[](https://opensource.org/licenses/MIT)
|
|
7
7
|
[](https://pypi.org/project/fastapi-async-sqlalchemy/)
|
|
8
|
-
[](https://pepy.tech/project/fastapi-async-sqlalchemy)
|
|
8
|
+
[](https://pepy.tech/project/fastapi-async-sqlalchemy)
|
|
9
9
|
[](https://pyup.io/repos/github/h0rn3t/fastapi-async-sqlalchemy/)
|
|
10
10
|
|
|
11
11
|
### Description
|
|
@@ -18,8 +18,8 @@ Provides SQLAlchemy middleware for FastAPI using AsyncSession and async engine.
|
|
|
18
18
|
pip install fastapi-async-sqlalchemy
|
|
19
19
|
```
|
|
20
20
|
|
|
21
|
-
|
|
22
|
-
|
|
21
|
+
|
|
22
|
+
It also works with ```sqlmodel```
|
|
23
23
|
|
|
24
24
|
|
|
25
25
|
### Examples
|
|
@@ -127,9 +127,10 @@ app.add_middleware(
|
|
|
127
127
|
routes.py
|
|
128
128
|
|
|
129
129
|
```python
|
|
130
|
+
import asyncio
|
|
131
|
+
|
|
130
132
|
from fastapi import APIRouter
|
|
131
|
-
from sqlalchemy import column
|
|
132
|
-
from sqlalchemy import table
|
|
133
|
+
from sqlalchemy import column, table, text
|
|
133
134
|
|
|
134
135
|
from databases import first_db, second_db
|
|
135
136
|
|
|
@@ -147,4 +148,22 @@ async def get_files_from_first_db():
|
|
|
147
148
|
async def get_files_from_second_db():
|
|
148
149
|
result = await second_db.session.execute(foo.select())
|
|
149
150
|
return result.fetchall()
|
|
150
|
-
|
|
151
|
+
|
|
152
|
+
|
|
153
|
+
@router.get("/concurrent-queries")
|
|
154
|
+
async def parallel_select():
|
|
155
|
+
async with first_db(multi_sessions=True):
|
|
156
|
+
async def execute_query(query):
|
|
157
|
+
return await first_db.session.execute(text(query))
|
|
158
|
+
|
|
159
|
+
tasks = [
|
|
160
|
+
asyncio.create_task(execute_query("SELECT 1")),
|
|
161
|
+
asyncio.create_task(execute_query("SELECT 2")),
|
|
162
|
+
asyncio.create_task(execute_query("SELECT 3")),
|
|
163
|
+
asyncio.create_task(execute_query("SELECT 4")),
|
|
164
|
+
asyncio.create_task(execute_query("SELECT 5")),
|
|
165
|
+
asyncio.create_task(execute_query("SELECT 6")),
|
|
166
|
+
]
|
|
167
|
+
|
|
168
|
+
await asyncio.gather(*tasks)
|
|
169
|
+
```
|
|
@@ -0,0 +1,163 @@
|
|
|
1
|
+
import asyncio
|
|
2
|
+
from asyncio import Task
|
|
3
|
+
from contextvars import ContextVar
|
|
4
|
+
from typing import Dict, Optional, Union
|
|
5
|
+
|
|
6
|
+
from sqlalchemy.engine import Engine
|
|
7
|
+
from sqlalchemy.engine.url import URL
|
|
8
|
+
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
|
|
9
|
+
from starlette.middleware.base import BaseHTTPMiddleware, RequestResponseEndpoint
|
|
10
|
+
from starlette.requests import Request
|
|
11
|
+
from starlette.types import ASGIApp
|
|
12
|
+
|
|
13
|
+
from fastapi_async_sqlalchemy.exceptions import MissingSessionError, SessionNotInitialisedError
|
|
14
|
+
|
|
15
|
+
try:
|
|
16
|
+
from sqlalchemy.ext.asyncio import async_sessionmaker # noqa: F811
|
|
17
|
+
except ImportError:
|
|
18
|
+
from sqlalchemy.orm import sessionmaker as async_sessionmaker
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
def create_middleware_and_session_proxy():
|
|
22
|
+
_Session: Optional[async_sessionmaker] = None
|
|
23
|
+
_session: ContextVar[Optional[AsyncSession]] = ContextVar("_session", default=None)
|
|
24
|
+
_multi_sessions_ctx: ContextVar[bool] = ContextVar("_multi_sessions_context", default=False)
|
|
25
|
+
_commit_on_exit_ctx: ContextVar[bool] = ContextVar("_commit_on_exit_ctx", default=False)
|
|
26
|
+
# Usage of context vars inside closures is not recommended, since they are not properly
|
|
27
|
+
# garbage collected, but in our use case context var is created on program startup and
|
|
28
|
+
# is used throughout the whole its lifecycle.
|
|
29
|
+
|
|
30
|
+
class SQLAlchemyMiddleware(BaseHTTPMiddleware):
|
|
31
|
+
def __init__(
|
|
32
|
+
self,
|
|
33
|
+
app: ASGIApp,
|
|
34
|
+
db_url: Optional[Union[str, URL]] = None,
|
|
35
|
+
custom_engine: Optional[Engine] = None,
|
|
36
|
+
engine_args: Dict = None,
|
|
37
|
+
session_args: Dict = None,
|
|
38
|
+
commit_on_exit: bool = False,
|
|
39
|
+
):
|
|
40
|
+
super().__init__(app)
|
|
41
|
+
self.commit_on_exit = commit_on_exit
|
|
42
|
+
engine_args = engine_args or {}
|
|
43
|
+
session_args = session_args or {}
|
|
44
|
+
|
|
45
|
+
if not custom_engine and not db_url:
|
|
46
|
+
raise ValueError("You need to pass a db_url or a custom_engine parameter.")
|
|
47
|
+
if not custom_engine:
|
|
48
|
+
engine = create_async_engine(db_url, **engine_args)
|
|
49
|
+
else:
|
|
50
|
+
engine = custom_engine
|
|
51
|
+
|
|
52
|
+
nonlocal _Session
|
|
53
|
+
_Session = async_sessionmaker(
|
|
54
|
+
engine, class_=AsyncSession, expire_on_commit=False, **session_args
|
|
55
|
+
)
|
|
56
|
+
|
|
57
|
+
async def dispatch(self, request: Request, call_next: RequestResponseEndpoint):
|
|
58
|
+
async with DBSession(commit_on_exit=self.commit_on_exit):
|
|
59
|
+
return await call_next(request)
|
|
60
|
+
|
|
61
|
+
class DBSessionMeta(type):
|
|
62
|
+
@property
|
|
63
|
+
def session(self) -> AsyncSession:
|
|
64
|
+
"""Return an instance of Session local to the current async context."""
|
|
65
|
+
if _Session is None:
|
|
66
|
+
raise SessionNotInitialisedError
|
|
67
|
+
|
|
68
|
+
multi_sessions = _multi_sessions_ctx.get()
|
|
69
|
+
if multi_sessions:
|
|
70
|
+
"""In this case, we need to create a new session for each task.
|
|
71
|
+
We also need to commit the session on exit if commit_on_exit is True.
|
|
72
|
+
This is useful when we need to run multiple queries in parallel.
|
|
73
|
+
For example, when we need to run multiple queries in parallel in a route handler.
|
|
74
|
+
Example:
|
|
75
|
+
```python
|
|
76
|
+
async with db(multi_sessions=True):
|
|
77
|
+
async def execute_query(query):
|
|
78
|
+
return await db.session.execute(text(query))
|
|
79
|
+
|
|
80
|
+
tasks = [
|
|
81
|
+
asyncio.create_task(execute_query("SELECT 1")),
|
|
82
|
+
asyncio.create_task(execute_query("SELECT 2")),
|
|
83
|
+
asyncio.create_task(execute_query("SELECT 3")),
|
|
84
|
+
asyncio.create_task(execute_query("SELECT 4")),
|
|
85
|
+
asyncio.create_task(execute_query("SELECT 5")),
|
|
86
|
+
asyncio.create_task(execute_query("SELECT 6")),
|
|
87
|
+
]
|
|
88
|
+
|
|
89
|
+
await asyncio.gather(*tasks)
|
|
90
|
+
```
|
|
91
|
+
"""
|
|
92
|
+
commit_on_exit = _commit_on_exit_ctx.get()
|
|
93
|
+
task: Task = asyncio.current_task() # type: ignore
|
|
94
|
+
if not hasattr(task, "_db_session"):
|
|
95
|
+
task._db_session = _Session() # type: ignore
|
|
96
|
+
|
|
97
|
+
def cleanup(future):
|
|
98
|
+
session = getattr(task, "_db_session", None)
|
|
99
|
+
if session:
|
|
100
|
+
|
|
101
|
+
async def do_cleanup():
|
|
102
|
+
try:
|
|
103
|
+
if future.exception():
|
|
104
|
+
await session.rollback()
|
|
105
|
+
else:
|
|
106
|
+
if commit_on_exit:
|
|
107
|
+
await session.commit()
|
|
108
|
+
finally:
|
|
109
|
+
await session.close()
|
|
110
|
+
|
|
111
|
+
asyncio.create_task(do_cleanup())
|
|
112
|
+
|
|
113
|
+
task.add_done_callback(cleanup)
|
|
114
|
+
return task._db_session # type: ignore
|
|
115
|
+
else:
|
|
116
|
+
session = _session.get()
|
|
117
|
+
if session is None:
|
|
118
|
+
raise MissingSessionError
|
|
119
|
+
return session
|
|
120
|
+
|
|
121
|
+
class DBSession(metaclass=DBSessionMeta):
|
|
122
|
+
def __init__(
|
|
123
|
+
self,
|
|
124
|
+
session_args: Dict = None,
|
|
125
|
+
commit_on_exit: bool = False,
|
|
126
|
+
multi_sessions: bool = False,
|
|
127
|
+
):
|
|
128
|
+
self.token = None
|
|
129
|
+
self.multi_sessions_token = None
|
|
130
|
+
self.commit_on_exit_token = None
|
|
131
|
+
self.session_args = session_args or {}
|
|
132
|
+
self.commit_on_exit = commit_on_exit
|
|
133
|
+
self.multi_sessions = multi_sessions
|
|
134
|
+
|
|
135
|
+
async def __aenter__(self):
|
|
136
|
+
if not isinstance(_Session, async_sessionmaker):
|
|
137
|
+
raise SessionNotInitialisedError
|
|
138
|
+
|
|
139
|
+
if self.multi_sessions:
|
|
140
|
+
self.multi_sessions_token = _multi_sessions_ctx.set(True)
|
|
141
|
+
self.commit_on_exit_token = _commit_on_exit_ctx.set(self.commit_on_exit)
|
|
142
|
+
|
|
143
|
+
self.token = _session.set(_Session(**self.session_args))
|
|
144
|
+
return type(self)
|
|
145
|
+
|
|
146
|
+
async def __aexit__(self, exc_type, exc_value, traceback):
|
|
147
|
+
session = _session.get()
|
|
148
|
+
try:
|
|
149
|
+
if exc_type is not None:
|
|
150
|
+
await session.rollback()
|
|
151
|
+
elif self.commit_on_exit:
|
|
152
|
+
await session.commit()
|
|
153
|
+
finally:
|
|
154
|
+
await session.close()
|
|
155
|
+
_session.reset(self.token)
|
|
156
|
+
if self.multi_sessions_token is not None:
|
|
157
|
+
_multi_sessions_ctx.reset(self.multi_sessions_token)
|
|
158
|
+
_commit_on_exit_ctx.reset(self.commit_on_exit_token)
|
|
159
|
+
|
|
160
|
+
return SQLAlchemyMiddleware, DBSession
|
|
161
|
+
|
|
162
|
+
|
|
163
|
+
SQLAlchemyMiddleware, db = create_middleware_and_session_proxy()
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.1
|
|
2
2
|
Name: fastapi-async-sqlalchemy
|
|
3
|
-
Version: 0.
|
|
3
|
+
Version: 0.7.0.dev1
|
|
4
4
|
Summary: SQLAlchemy middleware for FastAPI
|
|
5
5
|
Home-page: https://github.com/h0rn3t/fastapi-async-sqlalchemy.git
|
|
6
6
|
Author: Eugene Shershen
|
|
@@ -8,7 +8,7 @@ Author-email: h0rn3t.null@gmail.com
|
|
|
8
8
|
License: MIT
|
|
9
9
|
Project-URL: Code, https://github.com/h0rn3t/fastapi-async-sqlalchemy
|
|
10
10
|
Project-URL: Issue tracker, https://github.com/h0rn3t/fastapi-async-sqlalchemy/issues
|
|
11
|
-
Classifier: Development Status ::
|
|
11
|
+
Classifier: Development Status :: 5 - Production/Stable
|
|
12
12
|
Classifier: Environment :: Web Environment
|
|
13
13
|
Classifier: Framework :: AsyncIO
|
|
14
14
|
Classifier: Intended Audience :: Developers
|
|
@@ -19,6 +19,7 @@ Classifier: Programming Language :: Python :: 3.8
|
|
|
19
19
|
Classifier: Programming Language :: Python :: 3.9
|
|
20
20
|
Classifier: Programming Language :: Python :: 3.10
|
|
21
21
|
Classifier: Programming Language :: Python :: 3.11
|
|
22
|
+
Classifier: Programming Language :: Python :: 3.12
|
|
22
23
|
Classifier: Programming Language :: Python :: 3 :: Only
|
|
23
24
|
Classifier: Programming Language :: Python :: Implementation :: CPython
|
|
24
25
|
Classifier: Topic :: Internet :: WWW/HTTP :: HTTP Servers
|
|
@@ -27,6 +28,8 @@ Classifier: Topic :: Software Development :: Libraries :: Python Modules
|
|
|
27
28
|
Requires-Python: >=3.7
|
|
28
29
|
Description-Content-Type: text/markdown
|
|
29
30
|
License-File: LICENSE
|
|
31
|
+
Requires-Dist: starlette>=0.13.6
|
|
32
|
+
Requires-Dist: SQLAlchemy>=1.4.19
|
|
30
33
|
|
|
31
34
|
# SQLAlchemy FastAPI middleware
|
|
32
35
|
|
|
@@ -35,7 +38,7 @@ License-File: LICENSE
|
|
|
35
38
|
[](https://codecov.io/gh/h0rn3t/fastapi-async-sqlalchemy)
|
|
36
39
|
[](https://opensource.org/licenses/MIT)
|
|
37
40
|
[](https://pypi.org/project/fastapi-async-sqlalchemy/)
|
|
38
|
-
[](https://pepy.tech/project/fastapi-async-sqlalchemy)
|
|
41
|
+
[](https://pepy.tech/project/fastapi-async-sqlalchemy)
|
|
39
42
|
[](https://pyup.io/repos/github/h0rn3t/fastapi-async-sqlalchemy/)
|
|
40
43
|
|
|
41
44
|
### Description
|
|
@@ -48,8 +51,8 @@ Provides SQLAlchemy middleware for FastAPI using AsyncSession and async engine.
|
|
|
48
51
|
pip install fastapi-async-sqlalchemy
|
|
49
52
|
```
|
|
50
53
|
|
|
51
|
-
|
|
52
|
-
|
|
54
|
+
|
|
55
|
+
It also works with ```sqlmodel```
|
|
53
56
|
|
|
54
57
|
|
|
55
58
|
### Examples
|
|
@@ -157,9 +160,10 @@ app.add_middleware(
|
|
|
157
160
|
routes.py
|
|
158
161
|
|
|
159
162
|
```python
|
|
163
|
+
import asyncio
|
|
164
|
+
|
|
160
165
|
from fastapi import APIRouter
|
|
161
|
-
from sqlalchemy import column
|
|
162
|
-
from sqlalchemy import table
|
|
166
|
+
from sqlalchemy import column, table, text
|
|
163
167
|
|
|
164
168
|
from databases import first_db, second_db
|
|
165
169
|
|
|
@@ -177,4 +181,22 @@ async def get_files_from_first_db():
|
|
|
177
181
|
async def get_files_from_second_db():
|
|
178
182
|
result = await second_db.session.execute(foo.select())
|
|
179
183
|
return result.fetchall()
|
|
184
|
+
|
|
185
|
+
|
|
186
|
+
@router.get("/concurrent-queries")
|
|
187
|
+
async def parallel_select():
|
|
188
|
+
async with first_db(multi_sessions=True):
|
|
189
|
+
async def execute_query(query):
|
|
190
|
+
return await first_db.session.execute(text(query))
|
|
191
|
+
|
|
192
|
+
tasks = [
|
|
193
|
+
asyncio.create_task(execute_query("SELECT 1")),
|
|
194
|
+
asyncio.create_task(execute_query("SELECT 2")),
|
|
195
|
+
asyncio.create_task(execute_query("SELECT 3")),
|
|
196
|
+
asyncio.create_task(execute_query("SELECT 4")),
|
|
197
|
+
asyncio.create_task(execute_query("SELECT 5")),
|
|
198
|
+
asyncio.create_task(execute_query("SELECT 6")),
|
|
199
|
+
]
|
|
200
|
+
|
|
201
|
+
await asyncio.gather(*tasks)
|
|
180
202
|
```
|
|
@@ -11,4 +11,5 @@ fastapi_async_sqlalchemy.egg-info/SOURCES.txt
|
|
|
11
11
|
fastapi_async_sqlalchemy.egg-info/dependency_links.txt
|
|
12
12
|
fastapi_async_sqlalchemy.egg-info/not-zip-safe
|
|
13
13
|
fastapi_async_sqlalchemy.egg-info/requires.txt
|
|
14
|
-
fastapi_async_sqlalchemy.egg-info/top_level.txt
|
|
14
|
+
fastapi_async_sqlalchemy.egg-info/top_level.txt
|
|
15
|
+
tests/test_session.py
|
|
@@ -29,7 +29,7 @@ setup(
|
|
|
29
29
|
python_requires=">=3.7",
|
|
30
30
|
install_requires=["starlette>=0.13.6", "SQLAlchemy>=1.4.19"],
|
|
31
31
|
classifiers=[
|
|
32
|
-
"Development Status ::
|
|
32
|
+
"Development Status :: 5 - Production/Stable",
|
|
33
33
|
"Environment :: Web Environment",
|
|
34
34
|
"Framework :: AsyncIO",
|
|
35
35
|
"Intended Audience :: Developers",
|
|
@@ -40,6 +40,7 @@ setup(
|
|
|
40
40
|
"Programming Language :: Python :: 3.9",
|
|
41
41
|
"Programming Language :: Python :: 3.10",
|
|
42
42
|
"Programming Language :: Python :: 3.11",
|
|
43
|
+
"Programming Language :: Python :: 3.12",
|
|
43
44
|
"Programming Language :: Python :: 3 :: Only",
|
|
44
45
|
"Programming Language :: Python :: Implementation :: CPython",
|
|
45
46
|
"Topic :: Internet :: WWW/HTTP :: HTTP Servers",
|
|
@@ -0,0 +1,175 @@
|
|
|
1
|
+
import asyncio
|
|
2
|
+
|
|
3
|
+
import pytest
|
|
4
|
+
from sqlalchemy import text
|
|
5
|
+
from sqlalchemy.exc import IntegrityError
|
|
6
|
+
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine
|
|
7
|
+
from starlette.middleware.base import BaseHTTPMiddleware
|
|
8
|
+
|
|
9
|
+
from fastapi_async_sqlalchemy.exceptions import MissingSessionError, SessionNotInitialisedError
|
|
10
|
+
|
|
11
|
+
db_url = "sqlite+aiosqlite://"
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
@pytest.mark.asyncio
|
|
15
|
+
async def test_init(app, SQLAlchemyMiddleware):
|
|
16
|
+
mw = SQLAlchemyMiddleware(app, db_url=db_url)
|
|
17
|
+
assert isinstance(mw, BaseHTTPMiddleware)
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
@pytest.mark.asyncio
|
|
21
|
+
async def test_init_required_args(app, SQLAlchemyMiddleware):
|
|
22
|
+
with pytest.raises(ValueError) as exc_info:
|
|
23
|
+
SQLAlchemyMiddleware(app)
|
|
24
|
+
|
|
25
|
+
assert exc_info.value.args[0] == "You need to pass a db_url or a custom_engine parameter."
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
@pytest.mark.asyncio
|
|
29
|
+
async def test_init_required_args_custom_engine(app, db, SQLAlchemyMiddleware):
|
|
30
|
+
custom_engine = create_async_engine(db_url)
|
|
31
|
+
SQLAlchemyMiddleware(app, custom_engine=custom_engine)
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
@pytest.mark.asyncio
|
|
35
|
+
async def test_init_correct_optional_args(app, db, SQLAlchemyMiddleware):
|
|
36
|
+
engine_args = {"echo": True}
|
|
37
|
+
# session_args = {"expire_on_commit": False}
|
|
38
|
+
|
|
39
|
+
SQLAlchemyMiddleware(app, db_url, engine_args=engine_args, session_args={})
|
|
40
|
+
|
|
41
|
+
async with db():
|
|
42
|
+
# assert not db.session.expire_on_commit
|
|
43
|
+
engine = db.session.bind
|
|
44
|
+
assert engine.echo
|
|
45
|
+
|
|
46
|
+
async with db() as db_ctx:
|
|
47
|
+
engine = db_ctx.session.bind
|
|
48
|
+
assert engine.echo
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
@pytest.mark.asyncio
|
|
52
|
+
async def test_init_incorrect_optional_args(app, SQLAlchemyMiddleware):
|
|
53
|
+
with pytest.raises(TypeError) as exc_info:
|
|
54
|
+
SQLAlchemyMiddleware(app, db_url=db_url, invalid_args="test")
|
|
55
|
+
|
|
56
|
+
assert "__init__() got an unexpected keyword argument 'invalid_args'" in exc_info.value.args[0]
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
@pytest.mark.asyncio
|
|
60
|
+
async def test_inside_route(app, client, db, SQLAlchemyMiddleware):
|
|
61
|
+
app.add_middleware(SQLAlchemyMiddleware, db_url=db_url)
|
|
62
|
+
|
|
63
|
+
@app.get("/")
|
|
64
|
+
def test_get():
|
|
65
|
+
assert isinstance(db.session, AsyncSession)
|
|
66
|
+
|
|
67
|
+
client.get("/")
|
|
68
|
+
|
|
69
|
+
|
|
70
|
+
@pytest.mark.asyncio
|
|
71
|
+
async def test_inside_route_without_middleware_fails(app, client, db):
|
|
72
|
+
@app.get("/")
|
|
73
|
+
def test_get():
|
|
74
|
+
with pytest.raises(SessionNotInitialisedError):
|
|
75
|
+
db.session
|
|
76
|
+
|
|
77
|
+
client.get("/")
|
|
78
|
+
|
|
79
|
+
|
|
80
|
+
@pytest.mark.asyncio
|
|
81
|
+
async def test_outside_of_route(app, db, SQLAlchemyMiddleware):
|
|
82
|
+
app.add_middleware(SQLAlchemyMiddleware, db_url=db_url)
|
|
83
|
+
|
|
84
|
+
async with db():
|
|
85
|
+
assert isinstance(db.session, AsyncSession)
|
|
86
|
+
|
|
87
|
+
|
|
88
|
+
@pytest.mark.asyncio
|
|
89
|
+
async def test_outside_of_route_without_middleware_fails(db):
|
|
90
|
+
with pytest.raises(SessionNotInitialisedError):
|
|
91
|
+
db.session
|
|
92
|
+
|
|
93
|
+
with pytest.raises(SessionNotInitialisedError):
|
|
94
|
+
async with db():
|
|
95
|
+
pass
|
|
96
|
+
|
|
97
|
+
|
|
98
|
+
@pytest.mark.asyncio
|
|
99
|
+
async def test_outside_of_route_without_context_fails(app, db, SQLAlchemyMiddleware):
|
|
100
|
+
app.add_middleware(SQLAlchemyMiddleware, db_url=db_url)
|
|
101
|
+
|
|
102
|
+
with pytest.raises(MissingSessionError):
|
|
103
|
+
db.session
|
|
104
|
+
|
|
105
|
+
|
|
106
|
+
@pytest.mark.asyncio
|
|
107
|
+
async def test_init_session(app, db, SQLAlchemyMiddleware):
|
|
108
|
+
app.add_middleware(SQLAlchemyMiddleware, db_url=db_url)
|
|
109
|
+
|
|
110
|
+
async with db():
|
|
111
|
+
assert isinstance(db.session, AsyncSession)
|
|
112
|
+
|
|
113
|
+
|
|
114
|
+
@pytest.mark.asyncio
|
|
115
|
+
async def test_db_session_commit_fail(app, db, SQLAlchemyMiddleware):
|
|
116
|
+
app.add_middleware(SQLAlchemyMiddleware, db_url=db_url, commit_on_exit=True)
|
|
117
|
+
|
|
118
|
+
with pytest.raises(IntegrityError):
|
|
119
|
+
async with db():
|
|
120
|
+
raise IntegrityError("test", "test", "test")
|
|
121
|
+
db.session.close.assert_called_once()
|
|
122
|
+
|
|
123
|
+
async with db():
|
|
124
|
+
assert db.session
|
|
125
|
+
|
|
126
|
+
|
|
127
|
+
@pytest.mark.asyncio
|
|
128
|
+
async def test_rollback(app, db, SQLAlchemyMiddleware):
|
|
129
|
+
# pytest-cov shows that the line in db.__exit__() rolling back the db session
|
|
130
|
+
# when there is an Exception is run correctly. However, it would be much better
|
|
131
|
+
# if we could demonstrate somehow that db.session.rollback() was called e.g. once
|
|
132
|
+
app.add_middleware(SQLAlchemyMiddleware, db_url=db_url)
|
|
133
|
+
|
|
134
|
+
with pytest.raises(Exception):
|
|
135
|
+
async with db():
|
|
136
|
+
raise Exception
|
|
137
|
+
|
|
138
|
+
db.session.rollback.assert_called_once()
|
|
139
|
+
|
|
140
|
+
|
|
141
|
+
@pytest.mark.parametrize("commit_on_exit", [True, False])
|
|
142
|
+
@pytest.mark.asyncio
|
|
143
|
+
async def test_db_context_session_args(app, db, SQLAlchemyMiddleware, commit_on_exit):
|
|
144
|
+
app.add_middleware(SQLAlchemyMiddleware, db_url=db_url, commit_on_exit=commit_on_exit)
|
|
145
|
+
|
|
146
|
+
session_args = {}
|
|
147
|
+
|
|
148
|
+
async with db(session_args=session_args, commit_on_exit=True):
|
|
149
|
+
assert isinstance(db.session, AsyncSession)
|
|
150
|
+
|
|
151
|
+
session_args = {"expire_on_commit": False}
|
|
152
|
+
async with db(session_args=session_args):
|
|
153
|
+
db.session
|
|
154
|
+
|
|
155
|
+
|
|
156
|
+
@pytest.mark.asyncio
|
|
157
|
+
async def test_multi_sessions(app, db, SQLAlchemyMiddleware):
|
|
158
|
+
app.add_middleware(SQLAlchemyMiddleware, db_url=db_url)
|
|
159
|
+
|
|
160
|
+
async with db(multi_sessions=True):
|
|
161
|
+
|
|
162
|
+
async def execute_query(query):
|
|
163
|
+
return await db.session.execute(text(query))
|
|
164
|
+
|
|
165
|
+
tasks = [
|
|
166
|
+
asyncio.create_task(execute_query("SELECT 1")),
|
|
167
|
+
asyncio.create_task(execute_query("SELECT 2")),
|
|
168
|
+
asyncio.create_task(execute_query("SELECT 3")),
|
|
169
|
+
asyncio.create_task(execute_query("SELECT 4")),
|
|
170
|
+
asyncio.create_task(execute_query("SELECT 5")),
|
|
171
|
+
asyncio.create_task(execute_query("SELECT 6")),
|
|
172
|
+
]
|
|
173
|
+
|
|
174
|
+
res = await asyncio.gather(*tasks)
|
|
175
|
+
assert len(res) == 6
|
|
@@ -1,100 +0,0 @@
|
|
|
1
|
-
from contextvars import ContextVar
|
|
2
|
-
from typing import Dict, Optional, Union
|
|
3
|
-
|
|
4
|
-
from sqlalchemy.engine import Engine
|
|
5
|
-
from sqlalchemy.engine.url import URL
|
|
6
|
-
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine
|
|
7
|
-
from starlette.middleware.base import BaseHTTPMiddleware, RequestResponseEndpoint
|
|
8
|
-
from starlette.requests import Request
|
|
9
|
-
from starlette.types import ASGIApp
|
|
10
|
-
|
|
11
|
-
from fastapi_async_sqlalchemy.exceptions import MissingSessionError, SessionNotInitialisedError
|
|
12
|
-
|
|
13
|
-
try:
|
|
14
|
-
from sqlalchemy.ext.asyncio import async_sessionmaker
|
|
15
|
-
except ImportError:
|
|
16
|
-
from sqlalchemy.orm import sessionmaker as async_sessionmaker
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
def create_middleware_and_session_proxy():
|
|
20
|
-
_Session: Optional[async_sessionmaker] = None
|
|
21
|
-
# Usage of context vars inside closures is not recommended, since they are not properly
|
|
22
|
-
# garbage collected, but in our use case context var is created on program startup and
|
|
23
|
-
# is used throughout the whole its lifecycle.
|
|
24
|
-
_session: ContextVar[Optional[AsyncSession]] = ContextVar("_session", default=None)
|
|
25
|
-
|
|
26
|
-
class SQLAlchemyMiddleware(BaseHTTPMiddleware):
|
|
27
|
-
def __init__(
|
|
28
|
-
self,
|
|
29
|
-
app: ASGIApp,
|
|
30
|
-
db_url: Optional[Union[str, URL]] = None,
|
|
31
|
-
custom_engine: Optional[Engine] = None,
|
|
32
|
-
engine_args: Dict = None,
|
|
33
|
-
session_args: Dict = None,
|
|
34
|
-
commit_on_exit: bool = False,
|
|
35
|
-
):
|
|
36
|
-
super().__init__(app)
|
|
37
|
-
self.commit_on_exit = commit_on_exit
|
|
38
|
-
engine_args = engine_args or {}
|
|
39
|
-
session_args = session_args or {}
|
|
40
|
-
|
|
41
|
-
if not custom_engine and not db_url:
|
|
42
|
-
raise ValueError("You need to pass a db_url or a custom_engine parameter.")
|
|
43
|
-
if not custom_engine:
|
|
44
|
-
engine = create_async_engine(db_url, **engine_args)
|
|
45
|
-
else:
|
|
46
|
-
engine = custom_engine
|
|
47
|
-
|
|
48
|
-
nonlocal _Session
|
|
49
|
-
_Session = async_sessionmaker(
|
|
50
|
-
engine, class_=AsyncSession, expire_on_commit=False, **session_args
|
|
51
|
-
)
|
|
52
|
-
|
|
53
|
-
async def dispatch(self, request: Request, call_next: RequestResponseEndpoint):
|
|
54
|
-
async with db(commit_on_exit=self.commit_on_exit):
|
|
55
|
-
return await call_next(request)
|
|
56
|
-
|
|
57
|
-
class DBSessionMeta(type):
|
|
58
|
-
@property
|
|
59
|
-
def session(self) -> AsyncSession:
|
|
60
|
-
"""Return an instance of Session local to the current async context."""
|
|
61
|
-
if _Session is None:
|
|
62
|
-
raise SessionNotInitialisedError
|
|
63
|
-
|
|
64
|
-
session = _session.get()
|
|
65
|
-
if session is None:
|
|
66
|
-
raise MissingSessionError
|
|
67
|
-
|
|
68
|
-
return session
|
|
69
|
-
|
|
70
|
-
class DBSession(metaclass=DBSessionMeta):
|
|
71
|
-
def __init__(self, session_args: Dict = None, commit_on_exit: bool = False):
|
|
72
|
-
self.token = None
|
|
73
|
-
self.session_args = session_args or {}
|
|
74
|
-
self.commit_on_exit = commit_on_exit
|
|
75
|
-
|
|
76
|
-
async def __aenter__(self):
|
|
77
|
-
if not isinstance(_Session, async_sessionmaker):
|
|
78
|
-
raise SessionNotInitialisedError
|
|
79
|
-
|
|
80
|
-
self.token = _session.set(_Session(**self.session_args)) # type: ignore
|
|
81
|
-
return type(self)
|
|
82
|
-
|
|
83
|
-
async def __aexit__(self, exc_type, exc_value, traceback):
|
|
84
|
-
session = _session.get()
|
|
85
|
-
|
|
86
|
-
try:
|
|
87
|
-
if exc_type is not None:
|
|
88
|
-
await session.rollback()
|
|
89
|
-
elif (
|
|
90
|
-
self.commit_on_exit
|
|
91
|
-
): # Note: Changed this to elif to avoid commit after rollback
|
|
92
|
-
await session.commit()
|
|
93
|
-
finally:
|
|
94
|
-
await session.close()
|
|
95
|
-
_session.reset(self.token)
|
|
96
|
-
|
|
97
|
-
return SQLAlchemyMiddleware, DBSession
|
|
98
|
-
|
|
99
|
-
|
|
100
|
-
SQLAlchemyMiddleware, db = create_middleware_and_session_proxy()
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|