fastapi-async-sqlalchemy 0.6.1__tar.gz → 0.7.0.dev2__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (18) hide show
  1. {fastapi-async-sqlalchemy-0.6.1 → fastapi_async_sqlalchemy-0.7.0.dev2}/PKG-INFO +26 -6
  2. {fastapi-async-sqlalchemy-0.6.1 → fastapi_async_sqlalchemy-0.7.0.dev2}/README.md +23 -4
  3. {fastapi-async-sqlalchemy-0.6.1 → fastapi_async_sqlalchemy-0.7.0.dev2}/fastapi_async_sqlalchemy/__init__.py +1 -1
  4. fastapi_async_sqlalchemy-0.7.0.dev2/fastapi_async_sqlalchemy/middleware.py +164 -0
  5. {fastapi-async-sqlalchemy-0.6.1 → fastapi_async_sqlalchemy-0.7.0.dev2}/fastapi_async_sqlalchemy.egg-info/PKG-INFO +26 -6
  6. {fastapi-async-sqlalchemy-0.6.1 → fastapi_async_sqlalchemy-0.7.0.dev2}/setup.py +2 -1
  7. {fastapi-async-sqlalchemy-0.6.1 → fastapi_async_sqlalchemy-0.7.0.dev2}/tests/test_session.py +50 -0
  8. fastapi-async-sqlalchemy-0.6.1/fastapi_async_sqlalchemy/middleware.py +0 -100
  9. {fastapi-async-sqlalchemy-0.6.1 → fastapi_async_sqlalchemy-0.7.0.dev2}/LICENSE +0 -0
  10. {fastapi-async-sqlalchemy-0.6.1 → fastapi_async_sqlalchemy-0.7.0.dev2}/fastapi_async_sqlalchemy/exceptions.py +0 -0
  11. {fastapi-async-sqlalchemy-0.6.1 → fastapi_async_sqlalchemy-0.7.0.dev2}/fastapi_async_sqlalchemy/py.typed +0 -0
  12. {fastapi-async-sqlalchemy-0.6.1 → fastapi_async_sqlalchemy-0.7.0.dev2}/fastapi_async_sqlalchemy.egg-info/SOURCES.txt +0 -0
  13. {fastapi-async-sqlalchemy-0.6.1 → fastapi_async_sqlalchemy-0.7.0.dev2}/fastapi_async_sqlalchemy.egg-info/dependency_links.txt +0 -0
  14. {fastapi-async-sqlalchemy-0.6.1 → fastapi_async_sqlalchemy-0.7.0.dev2}/fastapi_async_sqlalchemy.egg-info/not-zip-safe +0 -0
  15. {fastapi-async-sqlalchemy-0.6.1 → fastapi_async_sqlalchemy-0.7.0.dev2}/fastapi_async_sqlalchemy.egg-info/requires.txt +0 -0
  16. {fastapi-async-sqlalchemy-0.6.1 → fastapi_async_sqlalchemy-0.7.0.dev2}/fastapi_async_sqlalchemy.egg-info/top_level.txt +0 -0
  17. {fastapi-async-sqlalchemy-0.6.1 → fastapi_async_sqlalchemy-0.7.0.dev2}/pyproject.toml +0 -0
  18. {fastapi-async-sqlalchemy-0.6.1 → fastapi_async_sqlalchemy-0.7.0.dev2}/setup.cfg +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: fastapi-async-sqlalchemy
3
- Version: 0.6.1
3
+ Version: 0.7.0.dev2
4
4
  Summary: SQLAlchemy middleware for FastAPI
5
5
  Home-page: https://github.com/h0rn3t/fastapi-async-sqlalchemy.git
6
6
  Author: Eugene Shershen
@@ -8,7 +8,7 @@ Author-email: h0rn3t.null@gmail.com
8
8
  License: MIT
9
9
  Project-URL: Code, https://github.com/h0rn3t/fastapi-async-sqlalchemy
10
10
  Project-URL: Issue tracker, https://github.com/h0rn3t/fastapi-async-sqlalchemy/issues
11
- Classifier: Development Status :: 4 - Beta
11
+ Classifier: Development Status :: 5 - Production/Stable
12
12
  Classifier: Environment :: Web Environment
13
13
  Classifier: Framework :: AsyncIO
14
14
  Classifier: Intended Audience :: Developers
@@ -19,6 +19,7 @@ Classifier: Programming Language :: Python :: 3.8
19
19
  Classifier: Programming Language :: Python :: 3.9
20
20
  Classifier: Programming Language :: Python :: 3.10
21
21
  Classifier: Programming Language :: Python :: 3.11
22
+ Classifier: Programming Language :: Python :: 3.12
22
23
  Classifier: Programming Language :: Python :: 3 :: Only
23
24
  Classifier: Programming Language :: Python :: Implementation :: CPython
24
25
  Classifier: Topic :: Internet :: WWW/HTTP :: HTTP Servers
@@ -50,8 +51,8 @@ Provides SQLAlchemy middleware for FastAPI using AsyncSession and async engine.
50
51
  pip install fastapi-async-sqlalchemy
51
52
  ```
52
53
 
53
- ### Important !!!
54
- If you use ```sqlmodel``` install ```sqlalchemy<=1.4.41```
54
+
55
+ It also works with ```sqlmodel```
55
56
 
56
57
 
57
58
  ### Examples
@@ -159,9 +160,10 @@ app.add_middleware(
159
160
  routes.py
160
161
 
161
162
  ```python
163
+ import asyncio
164
+
162
165
  from fastapi import APIRouter
163
- from sqlalchemy import column
164
- from sqlalchemy import table
166
+ from sqlalchemy import column, table, text
165
167
 
166
168
  from databases import first_db, second_db
167
169
 
@@ -179,4 +181,22 @@ async def get_files_from_first_db():
179
181
  async def get_files_from_second_db():
180
182
  result = await second_db.session.execute(foo.select())
181
183
  return result.fetchall()
184
+
185
+
186
+ @router.get("/concurrent-queries")
187
+ async def parallel_select():
188
+ async with first_db(multi_sessions=True):
189
+ async def execute_query(query):
190
+ return await first_db.session.execute(text(query))
191
+
192
+ tasks = [
193
+ asyncio.create_task(execute_query("SELECT 1")),
194
+ asyncio.create_task(execute_query("SELECT 2")),
195
+ asyncio.create_task(execute_query("SELECT 3")),
196
+ asyncio.create_task(execute_query("SELECT 4")),
197
+ asyncio.create_task(execute_query("SELECT 5")),
198
+ asyncio.create_task(execute_query("SELECT 6")),
199
+ ]
200
+
201
+ await asyncio.gather(*tasks)
182
202
  ```
@@ -18,8 +18,8 @@ Provides SQLAlchemy middleware for FastAPI using AsyncSession and async engine.
18
18
  pip install fastapi-async-sqlalchemy
19
19
  ```
20
20
 
21
- ### Important !!!
22
- If you use ```sqlmodel``` install ```sqlalchemy<=1.4.41```
21
+
22
+ It also works with ```sqlmodel```
23
23
 
24
24
 
25
25
  ### Examples
@@ -127,9 +127,10 @@ app.add_middleware(
127
127
  routes.py
128
128
 
129
129
  ```python
130
+ import asyncio
131
+
130
132
  from fastapi import APIRouter
131
- from sqlalchemy import column
132
- from sqlalchemy import table
133
+ from sqlalchemy import column, table, text
133
134
 
134
135
  from databases import first_db, second_db
135
136
 
@@ -147,4 +148,22 @@ async def get_files_from_first_db():
147
148
  async def get_files_from_second_db():
148
149
  result = await second_db.session.execute(foo.select())
149
150
  return result.fetchall()
151
+
152
+
153
+ @router.get("/concurrent-queries")
154
+ async def parallel_select():
155
+ async with first_db(multi_sessions=True):
156
+ async def execute_query(query):
157
+ return await first_db.session.execute(text(query))
158
+
159
+ tasks = [
160
+ asyncio.create_task(execute_query("SELECT 1")),
161
+ asyncio.create_task(execute_query("SELECT 2")),
162
+ asyncio.create_task(execute_query("SELECT 3")),
163
+ asyncio.create_task(execute_query("SELECT 4")),
164
+ asyncio.create_task(execute_query("SELECT 5")),
165
+ asyncio.create_task(execute_query("SELECT 6")),
166
+ ]
167
+
168
+ await asyncio.gather(*tasks)
150
169
  ```
@@ -2,4 +2,4 @@ from fastapi_async_sqlalchemy.middleware import SQLAlchemyMiddleware, db
2
2
 
3
3
  __all__ = ["db", "SQLAlchemyMiddleware"]
4
4
 
5
- __version__ = "0.6.1"
5
+ __version__ = "0.7.0.dev2"
@@ -0,0 +1,164 @@
1
+ import asyncio
2
+ from contextvars import ContextVar
3
+ from typing import Dict, Optional, Union
4
+
5
+ from sqlalchemy.engine import Engine
6
+ from sqlalchemy.engine.url import URL
7
+ from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
8
+ from starlette.middleware.base import BaseHTTPMiddleware, RequestResponseEndpoint
9
+ from starlette.requests import Request
10
+ from starlette.types import ASGIApp
11
+
12
+ from fastapi_async_sqlalchemy.exceptions import MissingSessionError, SessionNotInitialisedError
13
+
14
+ try:
15
+ from sqlalchemy.ext.asyncio import async_sessionmaker # noqa: F811
16
+ except ImportError:
17
+ from sqlalchemy.orm import sessionmaker as async_sessionmaker
18
+
19
+
20
+ def create_middleware_and_session_proxy():
21
+ _Session: Optional[async_sessionmaker] = None
22
+ _session: ContextVar[Optional[AsyncSession]] = ContextVar("_session", default=None)
23
+ _multi_sessions_ctx: ContextVar[bool] = ContextVar("_multi_sessions_context", default=False)
24
+ _task_session_ctx: ContextVar[Optional[AsyncSession]] = ContextVar(
25
+ "_task_session_ctx", default=None
26
+ )
27
+ _commit_on_exit_ctx: ContextVar[bool] = ContextVar("_commit_on_exit_ctx", default=False)
28
+ # Usage of context vars inside closures is not recommended, since they are not properly
29
+ # garbage collected, but in our use case context var is created on program startup and
30
+ # is used throughout the whole its lifecycle.
31
+
32
+ class SQLAlchemyMiddleware(BaseHTTPMiddleware):
33
+ def __init__(
34
+ self,
35
+ app: ASGIApp,
36
+ db_url: Optional[Union[str, URL]] = None,
37
+ custom_engine: Optional[Engine] = None,
38
+ engine_args: Dict = None,
39
+ session_args: Dict = None,
40
+ commit_on_exit: bool = False,
41
+ ):
42
+ super().__init__(app)
43
+ self.commit_on_exit = commit_on_exit
44
+ engine_args = engine_args or {}
45
+ session_args = session_args or {}
46
+
47
+ if not custom_engine and not db_url:
48
+ raise ValueError("You need to pass a db_url or a custom_engine parameter.")
49
+ if not custom_engine:
50
+ engine = create_async_engine(db_url, **engine_args)
51
+ else:
52
+ engine = custom_engine
53
+
54
+ nonlocal _Session
55
+ _Session = async_sessionmaker(
56
+ engine, class_=AsyncSession, expire_on_commit=False, **session_args
57
+ )
58
+
59
+ async def dispatch(self, request: Request, call_next: RequestResponseEndpoint):
60
+ async with DBSession(commit_on_exit=self.commit_on_exit):
61
+ return await call_next(request)
62
+
63
+ class DBSessionMeta(type):
64
+ @property
65
+ def session(self) -> AsyncSession:
66
+ """Return an instance of Session local to the current async context."""
67
+ if _Session is None:
68
+ raise SessionNotInitialisedError
69
+
70
+ multi_sessions = _multi_sessions_ctx.get()
71
+ if multi_sessions:
72
+ """In this case, we need to create a new session for each task.
73
+ We also need to commit the session on exit if commit_on_exit is True.
74
+ This is useful when we need to run multiple queries in parallel.
75
+ For example, when we need to run multiple queries in parallel in a route handler.
76
+ Example:
77
+ ```python
78
+ async with db(multi_sessions=True):
79
+ async def execute_query(query):
80
+ return await db.session.execute(text(query))
81
+
82
+ tasks = [
83
+ asyncio.create_task(execute_query("SELECT 1")),
84
+ asyncio.create_task(execute_query("SELECT 2")),
85
+ asyncio.create_task(execute_query("SELECT 3")),
86
+ asyncio.create_task(execute_query("SELECT 4")),
87
+ asyncio.create_task(execute_query("SELECT 5")),
88
+ asyncio.create_task(execute_query("SELECT 6")),
89
+ ]
90
+
91
+ await asyncio.gather(*tasks)
92
+ ```
93
+ """
94
+ commit_on_exit = _commit_on_exit_ctx.get()
95
+ session = _task_session_ctx.get()
96
+ if session is None:
97
+ session = _Session()
98
+ _task_session_ctx.set(session)
99
+
100
+ async def cleanup():
101
+ try:
102
+ if commit_on_exit:
103
+ await session.commit()
104
+ except Exception:
105
+ await session.rollback()
106
+ raise
107
+ finally:
108
+ await session.close()
109
+ _task_session_ctx.set(None)
110
+
111
+ task = asyncio.current_task()
112
+ if task is not None:
113
+ task.add_done_callback(lambda t: asyncio.create_task(cleanup()))
114
+ return session
115
+ else:
116
+ session = _session.get()
117
+ if session is None:
118
+ raise MissingSessionError
119
+ return session
120
+
121
+ class DBSession(metaclass=DBSessionMeta):
122
+ def __init__(
123
+ self,
124
+ session_args: Dict = None,
125
+ commit_on_exit: bool = False,
126
+ multi_sessions: bool = False,
127
+ ):
128
+ self.token = None
129
+ self.multi_sessions_token = None
130
+ self.commit_on_exit_token = None
131
+ self.session_args = session_args or {}
132
+ self.commit_on_exit = commit_on_exit
133
+ self.multi_sessions = multi_sessions
134
+
135
+ async def __aenter__(self):
136
+ if not isinstance(_Session, async_sessionmaker):
137
+ raise SessionNotInitialisedError
138
+
139
+ if self.multi_sessions:
140
+ self.multi_sessions_token = _multi_sessions_ctx.set(True)
141
+ self.commit_on_exit_token = _commit_on_exit_ctx.set(self.commit_on_exit)
142
+ else:
143
+ self.token = _session.set(_Session(**self.session_args))
144
+ return type(self)
145
+
146
+ async def __aexit__(self, exc_type, exc_value, traceback):
147
+ if self.multi_sessions:
148
+ _multi_sessions_ctx.reset(self.multi_sessions_token)
149
+ _commit_on_exit_ctx.reset(self.commit_on_exit_token)
150
+ else:
151
+ session = _session.get()
152
+ try:
153
+ if exc_type is not None:
154
+ await session.rollback()
155
+ elif self.commit_on_exit:
156
+ await session.commit()
157
+ finally:
158
+ await session.close()
159
+ _session.reset(self.token)
160
+
161
+ return SQLAlchemyMiddleware, DBSession
162
+
163
+
164
+ SQLAlchemyMiddleware, db = create_middleware_and_session_proxy()
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: fastapi-async-sqlalchemy
3
- Version: 0.6.1
3
+ Version: 0.7.0.dev2
4
4
  Summary: SQLAlchemy middleware for FastAPI
5
5
  Home-page: https://github.com/h0rn3t/fastapi-async-sqlalchemy.git
6
6
  Author: Eugene Shershen
@@ -8,7 +8,7 @@ Author-email: h0rn3t.null@gmail.com
8
8
  License: MIT
9
9
  Project-URL: Code, https://github.com/h0rn3t/fastapi-async-sqlalchemy
10
10
  Project-URL: Issue tracker, https://github.com/h0rn3t/fastapi-async-sqlalchemy/issues
11
- Classifier: Development Status :: 4 - Beta
11
+ Classifier: Development Status :: 5 - Production/Stable
12
12
  Classifier: Environment :: Web Environment
13
13
  Classifier: Framework :: AsyncIO
14
14
  Classifier: Intended Audience :: Developers
@@ -19,6 +19,7 @@ Classifier: Programming Language :: Python :: 3.8
19
19
  Classifier: Programming Language :: Python :: 3.9
20
20
  Classifier: Programming Language :: Python :: 3.10
21
21
  Classifier: Programming Language :: Python :: 3.11
22
+ Classifier: Programming Language :: Python :: 3.12
22
23
  Classifier: Programming Language :: Python :: 3 :: Only
23
24
  Classifier: Programming Language :: Python :: Implementation :: CPython
24
25
  Classifier: Topic :: Internet :: WWW/HTTP :: HTTP Servers
@@ -50,8 +51,8 @@ Provides SQLAlchemy middleware for FastAPI using AsyncSession and async engine.
50
51
  pip install fastapi-async-sqlalchemy
51
52
  ```
52
53
 
53
- ### Important !!!
54
- If you use ```sqlmodel``` install ```sqlalchemy<=1.4.41```
54
+
55
+ It also works with ```sqlmodel```
55
56
 
56
57
 
57
58
  ### Examples
@@ -159,9 +160,10 @@ app.add_middleware(
159
160
  routes.py
160
161
 
161
162
  ```python
163
+ import asyncio
164
+
162
165
  from fastapi import APIRouter
163
- from sqlalchemy import column
164
- from sqlalchemy import table
166
+ from sqlalchemy import column, table, text
165
167
 
166
168
  from databases import first_db, second_db
167
169
 
@@ -179,4 +181,22 @@ async def get_files_from_first_db():
179
181
  async def get_files_from_second_db():
180
182
  result = await second_db.session.execute(foo.select())
181
183
  return result.fetchall()
184
+
185
+
186
+ @router.get("/concurrent-queries")
187
+ async def parallel_select():
188
+ async with first_db(multi_sessions=True):
189
+ async def execute_query(query):
190
+ return await first_db.session.execute(text(query))
191
+
192
+ tasks = [
193
+ asyncio.create_task(execute_query("SELECT 1")),
194
+ asyncio.create_task(execute_query("SELECT 2")),
195
+ asyncio.create_task(execute_query("SELECT 3")),
196
+ asyncio.create_task(execute_query("SELECT 4")),
197
+ asyncio.create_task(execute_query("SELECT 5")),
198
+ asyncio.create_task(execute_query("SELECT 6")),
199
+ ]
200
+
201
+ await asyncio.gather(*tasks)
182
202
  ```
@@ -29,7 +29,7 @@ setup(
29
29
  python_requires=">=3.7",
30
30
  install_requires=["starlette>=0.13.6", "SQLAlchemy>=1.4.19"],
31
31
  classifiers=[
32
- "Development Status :: 4 - Beta",
32
+ "Development Status :: 5 - Production/Stable",
33
33
  "Environment :: Web Environment",
34
34
  "Framework :: AsyncIO",
35
35
  "Intended Audience :: Developers",
@@ -40,6 +40,7 @@ setup(
40
40
  "Programming Language :: Python :: 3.9",
41
41
  "Programming Language :: Python :: 3.10",
42
42
  "Programming Language :: Python :: 3.11",
43
+ "Programming Language :: Python :: 3.12",
43
44
  "Programming Language :: Python :: 3 :: Only",
44
45
  "Programming Language :: Python :: Implementation :: CPython",
45
46
  "Topic :: Internet :: WWW/HTTP :: HTTP Servers",
@@ -1,4 +1,7 @@
1
+ import asyncio
2
+
1
3
  import pytest
4
+ from sqlalchemy import text
2
5
  from sqlalchemy.exc import IntegrityError
3
6
  from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine
4
7
  from starlette.middleware.base import BaseHTTPMiddleware
@@ -148,3 +151,50 @@ async def test_db_context_session_args(app, db, SQLAlchemyMiddleware, commit_on_
148
151
  session_args = {"expire_on_commit": False}
149
152
  async with db(session_args=session_args):
150
153
  db.session
154
+
155
+
156
+ @pytest.mark.asyncio
157
+ async def test_multi_sessions(app, db, SQLAlchemyMiddleware):
158
+ app.add_middleware(SQLAlchemyMiddleware, db_url=db_url)
159
+
160
+ async with db(multi_sessions=True):
161
+
162
+ async def execute_query(query):
163
+ return await db.session.execute(text(query))
164
+
165
+ tasks = [
166
+ asyncio.create_task(execute_query("SELECT 1")),
167
+ asyncio.create_task(execute_query("SELECT 2")),
168
+ asyncio.create_task(execute_query("SELECT 3")),
169
+ asyncio.create_task(execute_query("SELECT 4")),
170
+ asyncio.create_task(execute_query("SELECT 5")),
171
+ asyncio.create_task(execute_query("SELECT 6")),
172
+ ]
173
+
174
+ res = await asyncio.gather(*tasks)
175
+ assert len(res) == 6
176
+
177
+
178
+ @pytest.mark.asyncio
179
+ async def test_concurrent_inserts(app, db, SQLAlchemyMiddleware):
180
+ app.add_middleware(SQLAlchemyMiddleware, db_url=db_url)
181
+
182
+ async with db(multi_sessions=True, commit_on_exit=True):
183
+ await db.session.execute(
184
+ text("CREATE TABLE IF NOT EXISTS my_model (id INTEGER PRIMARY KEY, value TEXT)")
185
+ )
186
+
187
+ async def insert_data(value):
188
+ await db.session.execute(
189
+ text("INSERT INTO my_model (value) VALUES (:value)"), {"value": value}
190
+ )
191
+ await db.session.flush()
192
+
193
+ tasks = [asyncio.create_task(insert_data(f"value_{i}")) for i in range(10)]
194
+
195
+ result_ids = await asyncio.gather(*tasks)
196
+ assert len(result_ids) == 10
197
+
198
+ records = await db.session.execute(text("SELECT * FROM my_model"))
199
+ records = records.scalars().all()
200
+ assert len(records) == 10
@@ -1,100 +0,0 @@
1
- from contextvars import ContextVar
2
- from typing import Dict, Optional, Union
3
-
4
- from sqlalchemy.engine import Engine
5
- from sqlalchemy.engine.url import URL
6
- from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine
7
- from starlette.middleware.base import BaseHTTPMiddleware, RequestResponseEndpoint
8
- from starlette.requests import Request
9
- from starlette.types import ASGIApp
10
-
11
- from fastapi_async_sqlalchemy.exceptions import MissingSessionError, SessionNotInitialisedError
12
-
13
- try:
14
- from sqlalchemy.ext.asyncio import async_sessionmaker
15
- except ImportError:
16
- from sqlalchemy.orm import sessionmaker as async_sessionmaker
17
-
18
-
19
- def create_middleware_and_session_proxy():
20
- _Session: Optional[async_sessionmaker] = None
21
- # Usage of context vars inside closures is not recommended, since they are not properly
22
- # garbage collected, but in our use case context var is created on program startup and
23
- # is used throughout the whole its lifecycle.
24
- _session: ContextVar[Optional[AsyncSession]] = ContextVar("_session", default=None)
25
-
26
- class SQLAlchemyMiddleware(BaseHTTPMiddleware):
27
- def __init__(
28
- self,
29
- app: ASGIApp,
30
- db_url: Optional[Union[str, URL]] = None,
31
- custom_engine: Optional[Engine] = None,
32
- engine_args: Dict = None,
33
- session_args: Dict = None,
34
- commit_on_exit: bool = False,
35
- ):
36
- super().__init__(app)
37
- self.commit_on_exit = commit_on_exit
38
- engine_args = engine_args or {}
39
- session_args = session_args or {}
40
-
41
- if not custom_engine and not db_url:
42
- raise ValueError("You need to pass a db_url or a custom_engine parameter.")
43
- if not custom_engine:
44
- engine = create_async_engine(db_url, **engine_args)
45
- else:
46
- engine = custom_engine
47
-
48
- nonlocal _Session
49
- _Session = async_sessionmaker(
50
- engine, class_=AsyncSession, expire_on_commit=False, **session_args
51
- )
52
-
53
- async def dispatch(self, request: Request, call_next: RequestResponseEndpoint):
54
- async with DBSession(commit_on_exit=self.commit_on_exit):
55
- return await call_next(request)
56
-
57
- class DBSessionMeta(type):
58
- @property
59
- def session(self) -> AsyncSession:
60
- """Return an instance of Session local to the current async context."""
61
- if _Session is None:
62
- raise SessionNotInitialisedError
63
-
64
- session = _session.get()
65
- if session is None:
66
- raise MissingSessionError
67
-
68
- return session
69
-
70
- class DBSession(metaclass=DBSessionMeta):
71
- def __init__(self, session_args: Dict = None, commit_on_exit: bool = False):
72
- self.token = None
73
- self.session_args = session_args or {}
74
- self.commit_on_exit = commit_on_exit
75
-
76
- async def __aenter__(self):
77
- if not isinstance(_Session, async_sessionmaker):
78
- raise SessionNotInitialisedError
79
-
80
- self.token = _session.set(_Session(**self.session_args)) # type: ignore
81
- return type(self)
82
-
83
- async def __aexit__(self, exc_type, exc_value, traceback):
84
- session = _session.get()
85
-
86
- try:
87
- if exc_type is not None:
88
- await session.rollback()
89
- elif (
90
- self.commit_on_exit
91
- ): # Note: Changed this to elif to avoid commit after rollback
92
- await session.commit()
93
- finally:
94
- await session.close()
95
- _session.reset(self.token)
96
-
97
- return SQLAlchemyMiddleware, DBSession
98
-
99
-
100
- SQLAlchemyMiddleware, db = create_middleware_and_session_proxy()