fastapi-async-sqlalchemy 0.6.1__tar.gz → 0.7.0.dev2__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {fastapi-async-sqlalchemy-0.6.1 → fastapi_async_sqlalchemy-0.7.0.dev2}/PKG-INFO +26 -6
- {fastapi-async-sqlalchemy-0.6.1 → fastapi_async_sqlalchemy-0.7.0.dev2}/README.md +23 -4
- {fastapi-async-sqlalchemy-0.6.1 → fastapi_async_sqlalchemy-0.7.0.dev2}/fastapi_async_sqlalchemy/__init__.py +1 -1
- fastapi_async_sqlalchemy-0.7.0.dev2/fastapi_async_sqlalchemy/middleware.py +164 -0
- {fastapi-async-sqlalchemy-0.6.1 → fastapi_async_sqlalchemy-0.7.0.dev2}/fastapi_async_sqlalchemy.egg-info/PKG-INFO +26 -6
- {fastapi-async-sqlalchemy-0.6.1 → fastapi_async_sqlalchemy-0.7.0.dev2}/setup.py +2 -1
- {fastapi-async-sqlalchemy-0.6.1 → fastapi_async_sqlalchemy-0.7.0.dev2}/tests/test_session.py +50 -0
- fastapi-async-sqlalchemy-0.6.1/fastapi_async_sqlalchemy/middleware.py +0 -100
- {fastapi-async-sqlalchemy-0.6.1 → fastapi_async_sqlalchemy-0.7.0.dev2}/LICENSE +0 -0
- {fastapi-async-sqlalchemy-0.6.1 → fastapi_async_sqlalchemy-0.7.0.dev2}/fastapi_async_sqlalchemy/exceptions.py +0 -0
- {fastapi-async-sqlalchemy-0.6.1 → fastapi_async_sqlalchemy-0.7.0.dev2}/fastapi_async_sqlalchemy/py.typed +0 -0
- {fastapi-async-sqlalchemy-0.6.1 → fastapi_async_sqlalchemy-0.7.0.dev2}/fastapi_async_sqlalchemy.egg-info/SOURCES.txt +0 -0
- {fastapi-async-sqlalchemy-0.6.1 → fastapi_async_sqlalchemy-0.7.0.dev2}/fastapi_async_sqlalchemy.egg-info/dependency_links.txt +0 -0
- {fastapi-async-sqlalchemy-0.6.1 → fastapi_async_sqlalchemy-0.7.0.dev2}/fastapi_async_sqlalchemy.egg-info/not-zip-safe +0 -0
- {fastapi-async-sqlalchemy-0.6.1 → fastapi_async_sqlalchemy-0.7.0.dev2}/fastapi_async_sqlalchemy.egg-info/requires.txt +0 -0
- {fastapi-async-sqlalchemy-0.6.1 → fastapi_async_sqlalchemy-0.7.0.dev2}/fastapi_async_sqlalchemy.egg-info/top_level.txt +0 -0
- {fastapi-async-sqlalchemy-0.6.1 → fastapi_async_sqlalchemy-0.7.0.dev2}/pyproject.toml +0 -0
- {fastapi-async-sqlalchemy-0.6.1 → fastapi_async_sqlalchemy-0.7.0.dev2}/setup.cfg +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.1
|
|
2
2
|
Name: fastapi-async-sqlalchemy
|
|
3
|
-
Version: 0.
|
|
3
|
+
Version: 0.7.0.dev2
|
|
4
4
|
Summary: SQLAlchemy middleware for FastAPI
|
|
5
5
|
Home-page: https://github.com/h0rn3t/fastapi-async-sqlalchemy.git
|
|
6
6
|
Author: Eugene Shershen
|
|
@@ -8,7 +8,7 @@ Author-email: h0rn3t.null@gmail.com
|
|
|
8
8
|
License: MIT
|
|
9
9
|
Project-URL: Code, https://github.com/h0rn3t/fastapi-async-sqlalchemy
|
|
10
10
|
Project-URL: Issue tracker, https://github.com/h0rn3t/fastapi-async-sqlalchemy/issues
|
|
11
|
-
Classifier: Development Status ::
|
|
11
|
+
Classifier: Development Status :: 5 - Production/Stable
|
|
12
12
|
Classifier: Environment :: Web Environment
|
|
13
13
|
Classifier: Framework :: AsyncIO
|
|
14
14
|
Classifier: Intended Audience :: Developers
|
|
@@ -19,6 +19,7 @@ Classifier: Programming Language :: Python :: 3.8
|
|
|
19
19
|
Classifier: Programming Language :: Python :: 3.9
|
|
20
20
|
Classifier: Programming Language :: Python :: 3.10
|
|
21
21
|
Classifier: Programming Language :: Python :: 3.11
|
|
22
|
+
Classifier: Programming Language :: Python :: 3.12
|
|
22
23
|
Classifier: Programming Language :: Python :: 3 :: Only
|
|
23
24
|
Classifier: Programming Language :: Python :: Implementation :: CPython
|
|
24
25
|
Classifier: Topic :: Internet :: WWW/HTTP :: HTTP Servers
|
|
@@ -50,8 +51,8 @@ Provides SQLAlchemy middleware for FastAPI using AsyncSession and async engine.
|
|
|
50
51
|
pip install fastapi-async-sqlalchemy
|
|
51
52
|
```
|
|
52
53
|
|
|
53
|
-
|
|
54
|
-
|
|
54
|
+
|
|
55
|
+
It also works with ```sqlmodel```
|
|
55
56
|
|
|
56
57
|
|
|
57
58
|
### Examples
|
|
@@ -159,9 +160,10 @@ app.add_middleware(
|
|
|
159
160
|
routes.py
|
|
160
161
|
|
|
161
162
|
```python
|
|
163
|
+
import asyncio
|
|
164
|
+
|
|
162
165
|
from fastapi import APIRouter
|
|
163
|
-
from sqlalchemy import column
|
|
164
|
-
from sqlalchemy import table
|
|
166
|
+
from sqlalchemy import column, table, text
|
|
165
167
|
|
|
166
168
|
from databases import first_db, second_db
|
|
167
169
|
|
|
@@ -179,4 +181,22 @@ async def get_files_from_first_db():
|
|
|
179
181
|
async def get_files_from_second_db():
|
|
180
182
|
result = await second_db.session.execute(foo.select())
|
|
181
183
|
return result.fetchall()
|
|
184
|
+
|
|
185
|
+
|
|
186
|
+
@router.get("/concurrent-queries")
|
|
187
|
+
async def parallel_select():
|
|
188
|
+
async with first_db(multi_sessions=True):
|
|
189
|
+
async def execute_query(query):
|
|
190
|
+
return await first_db.session.execute(text(query))
|
|
191
|
+
|
|
192
|
+
tasks = [
|
|
193
|
+
asyncio.create_task(execute_query("SELECT 1")),
|
|
194
|
+
asyncio.create_task(execute_query("SELECT 2")),
|
|
195
|
+
asyncio.create_task(execute_query("SELECT 3")),
|
|
196
|
+
asyncio.create_task(execute_query("SELECT 4")),
|
|
197
|
+
asyncio.create_task(execute_query("SELECT 5")),
|
|
198
|
+
asyncio.create_task(execute_query("SELECT 6")),
|
|
199
|
+
]
|
|
200
|
+
|
|
201
|
+
await asyncio.gather(*tasks)
|
|
182
202
|
```
|
|
@@ -18,8 +18,8 @@ Provides SQLAlchemy middleware for FastAPI using AsyncSession and async engine.
|
|
|
18
18
|
pip install fastapi-async-sqlalchemy
|
|
19
19
|
```
|
|
20
20
|
|
|
21
|
-
|
|
22
|
-
|
|
21
|
+
|
|
22
|
+
It also works with ```sqlmodel```
|
|
23
23
|
|
|
24
24
|
|
|
25
25
|
### Examples
|
|
@@ -127,9 +127,10 @@ app.add_middleware(
|
|
|
127
127
|
routes.py
|
|
128
128
|
|
|
129
129
|
```python
|
|
130
|
+
import asyncio
|
|
131
|
+
|
|
130
132
|
from fastapi import APIRouter
|
|
131
|
-
from sqlalchemy import column
|
|
132
|
-
from sqlalchemy import table
|
|
133
|
+
from sqlalchemy import column, table, text
|
|
133
134
|
|
|
134
135
|
from databases import first_db, second_db
|
|
135
136
|
|
|
@@ -147,4 +148,22 @@ async def get_files_from_first_db():
|
|
|
147
148
|
async def get_files_from_second_db():
|
|
148
149
|
result = await second_db.session.execute(foo.select())
|
|
149
150
|
return result.fetchall()
|
|
151
|
+
|
|
152
|
+
|
|
153
|
+
@router.get("/concurrent-queries")
|
|
154
|
+
async def parallel_select():
|
|
155
|
+
async with first_db(multi_sessions=True):
|
|
156
|
+
async def execute_query(query):
|
|
157
|
+
return await first_db.session.execute(text(query))
|
|
158
|
+
|
|
159
|
+
tasks = [
|
|
160
|
+
asyncio.create_task(execute_query("SELECT 1")),
|
|
161
|
+
asyncio.create_task(execute_query("SELECT 2")),
|
|
162
|
+
asyncio.create_task(execute_query("SELECT 3")),
|
|
163
|
+
asyncio.create_task(execute_query("SELECT 4")),
|
|
164
|
+
asyncio.create_task(execute_query("SELECT 5")),
|
|
165
|
+
asyncio.create_task(execute_query("SELECT 6")),
|
|
166
|
+
]
|
|
167
|
+
|
|
168
|
+
await asyncio.gather(*tasks)
|
|
150
169
|
```
|
|
@@ -0,0 +1,164 @@
|
|
|
1
|
+
import asyncio
|
|
2
|
+
from contextvars import ContextVar
|
|
3
|
+
from typing import Dict, Optional, Union
|
|
4
|
+
|
|
5
|
+
from sqlalchemy.engine import Engine
|
|
6
|
+
from sqlalchemy.engine.url import URL
|
|
7
|
+
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
|
|
8
|
+
from starlette.middleware.base import BaseHTTPMiddleware, RequestResponseEndpoint
|
|
9
|
+
from starlette.requests import Request
|
|
10
|
+
from starlette.types import ASGIApp
|
|
11
|
+
|
|
12
|
+
from fastapi_async_sqlalchemy.exceptions import MissingSessionError, SessionNotInitialisedError
|
|
13
|
+
|
|
14
|
+
try:
|
|
15
|
+
from sqlalchemy.ext.asyncio import async_sessionmaker # noqa: F811
|
|
16
|
+
except ImportError:
|
|
17
|
+
from sqlalchemy.orm import sessionmaker as async_sessionmaker
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
def create_middleware_and_session_proxy():
|
|
21
|
+
_Session: Optional[async_sessionmaker] = None
|
|
22
|
+
_session: ContextVar[Optional[AsyncSession]] = ContextVar("_session", default=None)
|
|
23
|
+
_multi_sessions_ctx: ContextVar[bool] = ContextVar("_multi_sessions_context", default=False)
|
|
24
|
+
_task_session_ctx: ContextVar[Optional[AsyncSession]] = ContextVar(
|
|
25
|
+
"_task_session_ctx", default=None
|
|
26
|
+
)
|
|
27
|
+
_commit_on_exit_ctx: ContextVar[bool] = ContextVar("_commit_on_exit_ctx", default=False)
|
|
28
|
+
# Usage of context vars inside closures is not recommended, since they are not properly
|
|
29
|
+
# garbage collected, but in our use case context var is created on program startup and
|
|
30
|
+
# is used throughout the whole its lifecycle.
|
|
31
|
+
|
|
32
|
+
class SQLAlchemyMiddleware(BaseHTTPMiddleware):
|
|
33
|
+
def __init__(
|
|
34
|
+
self,
|
|
35
|
+
app: ASGIApp,
|
|
36
|
+
db_url: Optional[Union[str, URL]] = None,
|
|
37
|
+
custom_engine: Optional[Engine] = None,
|
|
38
|
+
engine_args: Dict = None,
|
|
39
|
+
session_args: Dict = None,
|
|
40
|
+
commit_on_exit: bool = False,
|
|
41
|
+
):
|
|
42
|
+
super().__init__(app)
|
|
43
|
+
self.commit_on_exit = commit_on_exit
|
|
44
|
+
engine_args = engine_args or {}
|
|
45
|
+
session_args = session_args or {}
|
|
46
|
+
|
|
47
|
+
if not custom_engine and not db_url:
|
|
48
|
+
raise ValueError("You need to pass a db_url or a custom_engine parameter.")
|
|
49
|
+
if not custom_engine:
|
|
50
|
+
engine = create_async_engine(db_url, **engine_args)
|
|
51
|
+
else:
|
|
52
|
+
engine = custom_engine
|
|
53
|
+
|
|
54
|
+
nonlocal _Session
|
|
55
|
+
_Session = async_sessionmaker(
|
|
56
|
+
engine, class_=AsyncSession, expire_on_commit=False, **session_args
|
|
57
|
+
)
|
|
58
|
+
|
|
59
|
+
async def dispatch(self, request: Request, call_next: RequestResponseEndpoint):
|
|
60
|
+
async with DBSession(commit_on_exit=self.commit_on_exit):
|
|
61
|
+
return await call_next(request)
|
|
62
|
+
|
|
63
|
+
class DBSessionMeta(type):
|
|
64
|
+
@property
|
|
65
|
+
def session(self) -> AsyncSession:
|
|
66
|
+
"""Return an instance of Session local to the current async context."""
|
|
67
|
+
if _Session is None:
|
|
68
|
+
raise SessionNotInitialisedError
|
|
69
|
+
|
|
70
|
+
multi_sessions = _multi_sessions_ctx.get()
|
|
71
|
+
if multi_sessions:
|
|
72
|
+
"""In this case, we need to create a new session for each task.
|
|
73
|
+
We also need to commit the session on exit if commit_on_exit is True.
|
|
74
|
+
This is useful when we need to run multiple queries in parallel.
|
|
75
|
+
For example, when we need to run multiple queries in parallel in a route handler.
|
|
76
|
+
Example:
|
|
77
|
+
```python
|
|
78
|
+
async with db(multi_sessions=True):
|
|
79
|
+
async def execute_query(query):
|
|
80
|
+
return await db.session.execute(text(query))
|
|
81
|
+
|
|
82
|
+
tasks = [
|
|
83
|
+
asyncio.create_task(execute_query("SELECT 1")),
|
|
84
|
+
asyncio.create_task(execute_query("SELECT 2")),
|
|
85
|
+
asyncio.create_task(execute_query("SELECT 3")),
|
|
86
|
+
asyncio.create_task(execute_query("SELECT 4")),
|
|
87
|
+
asyncio.create_task(execute_query("SELECT 5")),
|
|
88
|
+
asyncio.create_task(execute_query("SELECT 6")),
|
|
89
|
+
]
|
|
90
|
+
|
|
91
|
+
await asyncio.gather(*tasks)
|
|
92
|
+
```
|
|
93
|
+
"""
|
|
94
|
+
commit_on_exit = _commit_on_exit_ctx.get()
|
|
95
|
+
session = _task_session_ctx.get()
|
|
96
|
+
if session is None:
|
|
97
|
+
session = _Session()
|
|
98
|
+
_task_session_ctx.set(session)
|
|
99
|
+
|
|
100
|
+
async def cleanup():
|
|
101
|
+
try:
|
|
102
|
+
if commit_on_exit:
|
|
103
|
+
await session.commit()
|
|
104
|
+
except Exception:
|
|
105
|
+
await session.rollback()
|
|
106
|
+
raise
|
|
107
|
+
finally:
|
|
108
|
+
await session.close()
|
|
109
|
+
_task_session_ctx.set(None)
|
|
110
|
+
|
|
111
|
+
task = asyncio.current_task()
|
|
112
|
+
if task is not None:
|
|
113
|
+
task.add_done_callback(lambda t: asyncio.create_task(cleanup()))
|
|
114
|
+
return session
|
|
115
|
+
else:
|
|
116
|
+
session = _session.get()
|
|
117
|
+
if session is None:
|
|
118
|
+
raise MissingSessionError
|
|
119
|
+
return session
|
|
120
|
+
|
|
121
|
+
class DBSession(metaclass=DBSessionMeta):
|
|
122
|
+
def __init__(
|
|
123
|
+
self,
|
|
124
|
+
session_args: Dict = None,
|
|
125
|
+
commit_on_exit: bool = False,
|
|
126
|
+
multi_sessions: bool = False,
|
|
127
|
+
):
|
|
128
|
+
self.token = None
|
|
129
|
+
self.multi_sessions_token = None
|
|
130
|
+
self.commit_on_exit_token = None
|
|
131
|
+
self.session_args = session_args or {}
|
|
132
|
+
self.commit_on_exit = commit_on_exit
|
|
133
|
+
self.multi_sessions = multi_sessions
|
|
134
|
+
|
|
135
|
+
async def __aenter__(self):
|
|
136
|
+
if not isinstance(_Session, async_sessionmaker):
|
|
137
|
+
raise SessionNotInitialisedError
|
|
138
|
+
|
|
139
|
+
if self.multi_sessions:
|
|
140
|
+
self.multi_sessions_token = _multi_sessions_ctx.set(True)
|
|
141
|
+
self.commit_on_exit_token = _commit_on_exit_ctx.set(self.commit_on_exit)
|
|
142
|
+
else:
|
|
143
|
+
self.token = _session.set(_Session(**self.session_args))
|
|
144
|
+
return type(self)
|
|
145
|
+
|
|
146
|
+
async def __aexit__(self, exc_type, exc_value, traceback):
|
|
147
|
+
if self.multi_sessions:
|
|
148
|
+
_multi_sessions_ctx.reset(self.multi_sessions_token)
|
|
149
|
+
_commit_on_exit_ctx.reset(self.commit_on_exit_token)
|
|
150
|
+
else:
|
|
151
|
+
session = _session.get()
|
|
152
|
+
try:
|
|
153
|
+
if exc_type is not None:
|
|
154
|
+
await session.rollback()
|
|
155
|
+
elif self.commit_on_exit:
|
|
156
|
+
await session.commit()
|
|
157
|
+
finally:
|
|
158
|
+
await session.close()
|
|
159
|
+
_session.reset(self.token)
|
|
160
|
+
|
|
161
|
+
return SQLAlchemyMiddleware, DBSession
|
|
162
|
+
|
|
163
|
+
|
|
164
|
+
SQLAlchemyMiddleware, db = create_middleware_and_session_proxy()
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.1
|
|
2
2
|
Name: fastapi-async-sqlalchemy
|
|
3
|
-
Version: 0.
|
|
3
|
+
Version: 0.7.0.dev2
|
|
4
4
|
Summary: SQLAlchemy middleware for FastAPI
|
|
5
5
|
Home-page: https://github.com/h0rn3t/fastapi-async-sqlalchemy.git
|
|
6
6
|
Author: Eugene Shershen
|
|
@@ -8,7 +8,7 @@ Author-email: h0rn3t.null@gmail.com
|
|
|
8
8
|
License: MIT
|
|
9
9
|
Project-URL: Code, https://github.com/h0rn3t/fastapi-async-sqlalchemy
|
|
10
10
|
Project-URL: Issue tracker, https://github.com/h0rn3t/fastapi-async-sqlalchemy/issues
|
|
11
|
-
Classifier: Development Status ::
|
|
11
|
+
Classifier: Development Status :: 5 - Production/Stable
|
|
12
12
|
Classifier: Environment :: Web Environment
|
|
13
13
|
Classifier: Framework :: AsyncIO
|
|
14
14
|
Classifier: Intended Audience :: Developers
|
|
@@ -19,6 +19,7 @@ Classifier: Programming Language :: Python :: 3.8
|
|
|
19
19
|
Classifier: Programming Language :: Python :: 3.9
|
|
20
20
|
Classifier: Programming Language :: Python :: 3.10
|
|
21
21
|
Classifier: Programming Language :: Python :: 3.11
|
|
22
|
+
Classifier: Programming Language :: Python :: 3.12
|
|
22
23
|
Classifier: Programming Language :: Python :: 3 :: Only
|
|
23
24
|
Classifier: Programming Language :: Python :: Implementation :: CPython
|
|
24
25
|
Classifier: Topic :: Internet :: WWW/HTTP :: HTTP Servers
|
|
@@ -50,8 +51,8 @@ Provides SQLAlchemy middleware for FastAPI using AsyncSession and async engine.
|
|
|
50
51
|
pip install fastapi-async-sqlalchemy
|
|
51
52
|
```
|
|
52
53
|
|
|
53
|
-
|
|
54
|
-
|
|
54
|
+
|
|
55
|
+
It also works with ```sqlmodel```
|
|
55
56
|
|
|
56
57
|
|
|
57
58
|
### Examples
|
|
@@ -159,9 +160,10 @@ app.add_middleware(
|
|
|
159
160
|
routes.py
|
|
160
161
|
|
|
161
162
|
```python
|
|
163
|
+
import asyncio
|
|
164
|
+
|
|
162
165
|
from fastapi import APIRouter
|
|
163
|
-
from sqlalchemy import column
|
|
164
|
-
from sqlalchemy import table
|
|
166
|
+
from sqlalchemy import column, table, text
|
|
165
167
|
|
|
166
168
|
from databases import first_db, second_db
|
|
167
169
|
|
|
@@ -179,4 +181,22 @@ async def get_files_from_first_db():
|
|
|
179
181
|
async def get_files_from_second_db():
|
|
180
182
|
result = await second_db.session.execute(foo.select())
|
|
181
183
|
return result.fetchall()
|
|
184
|
+
|
|
185
|
+
|
|
186
|
+
@router.get("/concurrent-queries")
|
|
187
|
+
async def parallel_select():
|
|
188
|
+
async with first_db(multi_sessions=True):
|
|
189
|
+
async def execute_query(query):
|
|
190
|
+
return await first_db.session.execute(text(query))
|
|
191
|
+
|
|
192
|
+
tasks = [
|
|
193
|
+
asyncio.create_task(execute_query("SELECT 1")),
|
|
194
|
+
asyncio.create_task(execute_query("SELECT 2")),
|
|
195
|
+
asyncio.create_task(execute_query("SELECT 3")),
|
|
196
|
+
asyncio.create_task(execute_query("SELECT 4")),
|
|
197
|
+
asyncio.create_task(execute_query("SELECT 5")),
|
|
198
|
+
asyncio.create_task(execute_query("SELECT 6")),
|
|
199
|
+
]
|
|
200
|
+
|
|
201
|
+
await asyncio.gather(*tasks)
|
|
182
202
|
```
|
|
@@ -29,7 +29,7 @@ setup(
|
|
|
29
29
|
python_requires=">=3.7",
|
|
30
30
|
install_requires=["starlette>=0.13.6", "SQLAlchemy>=1.4.19"],
|
|
31
31
|
classifiers=[
|
|
32
|
-
"Development Status ::
|
|
32
|
+
"Development Status :: 5 - Production/Stable",
|
|
33
33
|
"Environment :: Web Environment",
|
|
34
34
|
"Framework :: AsyncIO",
|
|
35
35
|
"Intended Audience :: Developers",
|
|
@@ -40,6 +40,7 @@ setup(
|
|
|
40
40
|
"Programming Language :: Python :: 3.9",
|
|
41
41
|
"Programming Language :: Python :: 3.10",
|
|
42
42
|
"Programming Language :: Python :: 3.11",
|
|
43
|
+
"Programming Language :: Python :: 3.12",
|
|
43
44
|
"Programming Language :: Python :: 3 :: Only",
|
|
44
45
|
"Programming Language :: Python :: Implementation :: CPython",
|
|
45
46
|
"Topic :: Internet :: WWW/HTTP :: HTTP Servers",
|
{fastapi-async-sqlalchemy-0.6.1 → fastapi_async_sqlalchemy-0.7.0.dev2}/tests/test_session.py
RENAMED
|
@@ -1,4 +1,7 @@
|
|
|
1
|
+
import asyncio
|
|
2
|
+
|
|
1
3
|
import pytest
|
|
4
|
+
from sqlalchemy import text
|
|
2
5
|
from sqlalchemy.exc import IntegrityError
|
|
3
6
|
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine
|
|
4
7
|
from starlette.middleware.base import BaseHTTPMiddleware
|
|
@@ -148,3 +151,50 @@ async def test_db_context_session_args(app, db, SQLAlchemyMiddleware, commit_on_
|
|
|
148
151
|
session_args = {"expire_on_commit": False}
|
|
149
152
|
async with db(session_args=session_args):
|
|
150
153
|
db.session
|
|
154
|
+
|
|
155
|
+
|
|
156
|
+
@pytest.mark.asyncio
|
|
157
|
+
async def test_multi_sessions(app, db, SQLAlchemyMiddleware):
|
|
158
|
+
app.add_middleware(SQLAlchemyMiddleware, db_url=db_url)
|
|
159
|
+
|
|
160
|
+
async with db(multi_sessions=True):
|
|
161
|
+
|
|
162
|
+
async def execute_query(query):
|
|
163
|
+
return await db.session.execute(text(query))
|
|
164
|
+
|
|
165
|
+
tasks = [
|
|
166
|
+
asyncio.create_task(execute_query("SELECT 1")),
|
|
167
|
+
asyncio.create_task(execute_query("SELECT 2")),
|
|
168
|
+
asyncio.create_task(execute_query("SELECT 3")),
|
|
169
|
+
asyncio.create_task(execute_query("SELECT 4")),
|
|
170
|
+
asyncio.create_task(execute_query("SELECT 5")),
|
|
171
|
+
asyncio.create_task(execute_query("SELECT 6")),
|
|
172
|
+
]
|
|
173
|
+
|
|
174
|
+
res = await asyncio.gather(*tasks)
|
|
175
|
+
assert len(res) == 6
|
|
176
|
+
|
|
177
|
+
|
|
178
|
+
@pytest.mark.asyncio
|
|
179
|
+
async def test_concurrent_inserts(app, db, SQLAlchemyMiddleware):
|
|
180
|
+
app.add_middleware(SQLAlchemyMiddleware, db_url=db_url)
|
|
181
|
+
|
|
182
|
+
async with db(multi_sessions=True, commit_on_exit=True):
|
|
183
|
+
await db.session.execute(
|
|
184
|
+
text("CREATE TABLE IF NOT EXISTS my_model (id INTEGER PRIMARY KEY, value TEXT)")
|
|
185
|
+
)
|
|
186
|
+
|
|
187
|
+
async def insert_data(value):
|
|
188
|
+
await db.session.execute(
|
|
189
|
+
text("INSERT INTO my_model (value) VALUES (:value)"), {"value": value}
|
|
190
|
+
)
|
|
191
|
+
await db.session.flush()
|
|
192
|
+
|
|
193
|
+
tasks = [asyncio.create_task(insert_data(f"value_{i}")) for i in range(10)]
|
|
194
|
+
|
|
195
|
+
result_ids = await asyncio.gather(*tasks)
|
|
196
|
+
assert len(result_ids) == 10
|
|
197
|
+
|
|
198
|
+
records = await db.session.execute(text("SELECT * FROM my_model"))
|
|
199
|
+
records = records.scalars().all()
|
|
200
|
+
assert len(records) == 10
|
|
@@ -1,100 +0,0 @@
|
|
|
1
|
-
from contextvars import ContextVar
|
|
2
|
-
from typing import Dict, Optional, Union
|
|
3
|
-
|
|
4
|
-
from sqlalchemy.engine import Engine
|
|
5
|
-
from sqlalchemy.engine.url import URL
|
|
6
|
-
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine
|
|
7
|
-
from starlette.middleware.base import BaseHTTPMiddleware, RequestResponseEndpoint
|
|
8
|
-
from starlette.requests import Request
|
|
9
|
-
from starlette.types import ASGIApp
|
|
10
|
-
|
|
11
|
-
from fastapi_async_sqlalchemy.exceptions import MissingSessionError, SessionNotInitialisedError
|
|
12
|
-
|
|
13
|
-
try:
|
|
14
|
-
from sqlalchemy.ext.asyncio import async_sessionmaker
|
|
15
|
-
except ImportError:
|
|
16
|
-
from sqlalchemy.orm import sessionmaker as async_sessionmaker
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
def create_middleware_and_session_proxy():
|
|
20
|
-
_Session: Optional[async_sessionmaker] = None
|
|
21
|
-
# Usage of context vars inside closures is not recommended, since they are not properly
|
|
22
|
-
# garbage collected, but in our use case context var is created on program startup and
|
|
23
|
-
# is used throughout the whole its lifecycle.
|
|
24
|
-
_session: ContextVar[Optional[AsyncSession]] = ContextVar("_session", default=None)
|
|
25
|
-
|
|
26
|
-
class SQLAlchemyMiddleware(BaseHTTPMiddleware):
|
|
27
|
-
def __init__(
|
|
28
|
-
self,
|
|
29
|
-
app: ASGIApp,
|
|
30
|
-
db_url: Optional[Union[str, URL]] = None,
|
|
31
|
-
custom_engine: Optional[Engine] = None,
|
|
32
|
-
engine_args: Dict = None,
|
|
33
|
-
session_args: Dict = None,
|
|
34
|
-
commit_on_exit: bool = False,
|
|
35
|
-
):
|
|
36
|
-
super().__init__(app)
|
|
37
|
-
self.commit_on_exit = commit_on_exit
|
|
38
|
-
engine_args = engine_args or {}
|
|
39
|
-
session_args = session_args or {}
|
|
40
|
-
|
|
41
|
-
if not custom_engine and not db_url:
|
|
42
|
-
raise ValueError("You need to pass a db_url or a custom_engine parameter.")
|
|
43
|
-
if not custom_engine:
|
|
44
|
-
engine = create_async_engine(db_url, **engine_args)
|
|
45
|
-
else:
|
|
46
|
-
engine = custom_engine
|
|
47
|
-
|
|
48
|
-
nonlocal _Session
|
|
49
|
-
_Session = async_sessionmaker(
|
|
50
|
-
engine, class_=AsyncSession, expire_on_commit=False, **session_args
|
|
51
|
-
)
|
|
52
|
-
|
|
53
|
-
async def dispatch(self, request: Request, call_next: RequestResponseEndpoint):
|
|
54
|
-
async with DBSession(commit_on_exit=self.commit_on_exit):
|
|
55
|
-
return await call_next(request)
|
|
56
|
-
|
|
57
|
-
class DBSessionMeta(type):
|
|
58
|
-
@property
|
|
59
|
-
def session(self) -> AsyncSession:
|
|
60
|
-
"""Return an instance of Session local to the current async context."""
|
|
61
|
-
if _Session is None:
|
|
62
|
-
raise SessionNotInitialisedError
|
|
63
|
-
|
|
64
|
-
session = _session.get()
|
|
65
|
-
if session is None:
|
|
66
|
-
raise MissingSessionError
|
|
67
|
-
|
|
68
|
-
return session
|
|
69
|
-
|
|
70
|
-
class DBSession(metaclass=DBSessionMeta):
|
|
71
|
-
def __init__(self, session_args: Dict = None, commit_on_exit: bool = False):
|
|
72
|
-
self.token = None
|
|
73
|
-
self.session_args = session_args or {}
|
|
74
|
-
self.commit_on_exit = commit_on_exit
|
|
75
|
-
|
|
76
|
-
async def __aenter__(self):
|
|
77
|
-
if not isinstance(_Session, async_sessionmaker):
|
|
78
|
-
raise SessionNotInitialisedError
|
|
79
|
-
|
|
80
|
-
self.token = _session.set(_Session(**self.session_args)) # type: ignore
|
|
81
|
-
return type(self)
|
|
82
|
-
|
|
83
|
-
async def __aexit__(self, exc_type, exc_value, traceback):
|
|
84
|
-
session = _session.get()
|
|
85
|
-
|
|
86
|
-
try:
|
|
87
|
-
if exc_type is not None:
|
|
88
|
-
await session.rollback()
|
|
89
|
-
elif (
|
|
90
|
-
self.commit_on_exit
|
|
91
|
-
): # Note: Changed this to elif to avoid commit after rollback
|
|
92
|
-
await session.commit()
|
|
93
|
-
finally:
|
|
94
|
-
await session.close()
|
|
95
|
-
_session.reset(self.token)
|
|
96
|
-
|
|
97
|
-
return SQLAlchemyMiddleware, DBSession
|
|
98
|
-
|
|
99
|
-
|
|
100
|
-
SQLAlchemyMiddleware, db = create_middleware_and_session_proxy()
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|