fastapi-async-sqlalchemy 0.7.1.post3__tar.gz → 0.7.2a0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {fastapi_async_sqlalchemy-0.7.1.post3 → fastapi_async_sqlalchemy-0.7.2a0}/PKG-INFO +1 -1
- {fastapi_async_sqlalchemy-0.7.1.post3 → fastapi_async_sqlalchemy-0.7.2a0}/fastapi_async_sqlalchemy/__init__.py +1 -1
- fastapi_async_sqlalchemy-0.7.2a0/fastapi_async_sqlalchemy/middleware.py +484 -0
- {fastapi_async_sqlalchemy-0.7.1.post3 → fastapi_async_sqlalchemy-0.7.2a0}/fastapi_async_sqlalchemy.egg-info/PKG-INFO +1 -1
- {fastapi_async_sqlalchemy-0.7.1.post3 → fastapi_async_sqlalchemy-0.7.2a0}/fastapi_async_sqlalchemy.egg-info/SOURCES.txt +1 -0
- {fastapi_async_sqlalchemy-0.7.1.post3 → fastapi_async_sqlalchemy-0.7.2a0}/tests/test_coverage_improvements.py +4 -2
- {fastapi_async_sqlalchemy-0.7.1.post3 → fastapi_async_sqlalchemy-0.7.2a0}/tests/test_edge_cases_coverage.py +49 -50
- {fastapi_async_sqlalchemy-0.7.1.post3 → fastapi_async_sqlalchemy-0.7.2a0}/tests/test_maximum_coverage.py +4 -11
- fastapi_async_sqlalchemy-0.7.2a0/tests/test_pool_throttling.py +625 -0
- {fastapi_async_sqlalchemy-0.7.1.post3 → fastapi_async_sqlalchemy-0.7.2a0}/tests/test_single_session_no_gather.py +3 -1
- {fastapi_async_sqlalchemy-0.7.1.post3 → fastapi_async_sqlalchemy-0.7.2a0}/tests/test_sqlmodel.py +2 -4
- fastapi_async_sqlalchemy-0.7.1.post3/fastapi_async_sqlalchemy/middleware.py +0 -146
- {fastapi_async_sqlalchemy-0.7.1.post3 → fastapi_async_sqlalchemy-0.7.2a0}/LICENSE +0 -0
- {fastapi_async_sqlalchemy-0.7.1.post3 → fastapi_async_sqlalchemy-0.7.2a0}/README.md +0 -0
- {fastapi_async_sqlalchemy-0.7.1.post3 → fastapi_async_sqlalchemy-0.7.2a0}/fastapi_async_sqlalchemy/exceptions.py +0 -0
- {fastapi_async_sqlalchemy-0.7.1.post3 → fastapi_async_sqlalchemy-0.7.2a0}/fastapi_async_sqlalchemy/py.typed +0 -0
- {fastapi_async_sqlalchemy-0.7.1.post3 → fastapi_async_sqlalchemy-0.7.2a0}/fastapi_async_sqlalchemy.egg-info/dependency_links.txt +0 -0
- {fastapi_async_sqlalchemy-0.7.1.post3 → fastapi_async_sqlalchemy-0.7.2a0}/fastapi_async_sqlalchemy.egg-info/not-zip-safe +0 -0
- {fastapi_async_sqlalchemy-0.7.1.post3 → fastapi_async_sqlalchemy-0.7.2a0}/fastapi_async_sqlalchemy.egg-info/requires.txt +0 -0
- {fastapi_async_sqlalchemy-0.7.1.post3 → fastapi_async_sqlalchemy-0.7.2a0}/fastapi_async_sqlalchemy.egg-info/top_level.txt +0 -0
- {fastapi_async_sqlalchemy-0.7.1.post3 → fastapi_async_sqlalchemy-0.7.2a0}/pyproject.toml +0 -0
- {fastapi_async_sqlalchemy-0.7.1.post3 → fastapi_async_sqlalchemy-0.7.2a0}/setup.cfg +0 -0
- {fastapi_async_sqlalchemy-0.7.1.post3 → fastapi_async_sqlalchemy-0.7.2a0}/setup.py +0 -0
- {fastapi_async_sqlalchemy-0.7.1.post3 → fastapi_async_sqlalchemy-0.7.2a0}/tests/test_additional_coverage.py +0 -0
- {fastapi_async_sqlalchemy-0.7.1.post3 → fastapi_async_sqlalchemy-0.7.2a0}/tests/test_backward_compat_gather.py +0 -0
- {fastapi_async_sqlalchemy-0.7.1.post3 → fastapi_async_sqlalchemy-0.7.2a0}/tests/test_concurrent_queries.py +0 -0
- {fastapi_async_sqlalchemy-0.7.1.post3 → fastapi_async_sqlalchemy-0.7.2a0}/tests/test_coverage_boost.py +0 -0
- {fastapi_async_sqlalchemy-0.7.1.post3 → fastapi_async_sqlalchemy-0.7.2a0}/tests/test_custom_engine_branch.py +0 -0
- {fastapi_async_sqlalchemy-0.7.1.post3 → fastapi_async_sqlalchemy-0.7.2a0}/tests/test_import_fallback_simulation.py +0 -0
- {fastapi_async_sqlalchemy-0.7.1.post3 → fastapi_async_sqlalchemy-0.7.2a0}/tests/test_import_fallbacks.py +0 -0
- {fastapi_async_sqlalchemy-0.7.1.post3 → fastapi_async_sqlalchemy-0.7.2a0}/tests/test_import_without_sqlmodel.py +0 -0
- {fastapi_async_sqlalchemy-0.7.1.post3 → fastapi_async_sqlalchemy-0.7.2a0}/tests/test_session.py +0 -0
- {fastapi_async_sqlalchemy-0.7.1.post3 → fastapi_async_sqlalchemy-0.7.2a0}/tests/test_type_hints_compatibility.py +0 -0
|
@@ -0,0 +1,484 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import asyncio
|
|
4
|
+
import warnings
|
|
5
|
+
from contextvars import ContextVar
|
|
6
|
+
from dataclasses import dataclass, field
|
|
7
|
+
|
|
8
|
+
from sqlalchemy.engine.url import URL
|
|
9
|
+
from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession, create_async_engine
|
|
10
|
+
from starlette.middleware.base import BaseHTTPMiddleware, RequestResponseEndpoint
|
|
11
|
+
from starlette.requests import Request
|
|
12
|
+
from starlette.types import ASGIApp
|
|
13
|
+
|
|
14
|
+
from fastapi_async_sqlalchemy.exceptions import (
|
|
15
|
+
MissingSessionError,
|
|
16
|
+
SessionNotInitialisedError,
|
|
17
|
+
)
|
|
18
|
+
|
|
19
|
+
try:
|
|
20
|
+
from sqlalchemy.ext.asyncio import async_sessionmaker
|
|
21
|
+
except ImportError:
|
|
22
|
+
from sqlalchemy.orm import sessionmaker as async_sessionmaker # type: ignore
|
|
23
|
+
|
|
24
|
+
try:
|
|
25
|
+
from sqlmodel.ext.asyncio.session import AsyncSession as SQLModelAsyncSession
|
|
26
|
+
|
|
27
|
+
DefaultAsyncSession: type[AsyncSession] = SQLModelAsyncSession # type: ignore
|
|
28
|
+
except ImportError:
|
|
29
|
+
DefaultAsyncSession: type[AsyncSession] = AsyncSession # type: ignore
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
def create_middleware_and_session_proxy() -> tuple:
|
|
33
|
+
_Session: async_sessionmaker | None = None
|
|
34
|
+
_session: ContextVar[AsyncSession | None] = ContextVar("_session", default=None)
|
|
35
|
+
_multi_sessions_ctx: ContextVar[bool] = ContextVar("_multi_sessions_context", default=False)
|
|
36
|
+
_multi_state: ContextVar[_MultiSessionState | None] = ContextVar(
|
|
37
|
+
"_multi_sessions_state",
|
|
38
|
+
default=None,
|
|
39
|
+
)
|
|
40
|
+
|
|
41
|
+
@dataclass
|
|
42
|
+
class _MultiSessionState:
|
|
43
|
+
tracked: set[AsyncSession] = field(default_factory=set)
|
|
44
|
+
task_sessions: dict[int, AsyncSession] = field(default_factory=dict)
|
|
45
|
+
cleanup_tasks: list[asyncio.Task] = field(default_factory=list)
|
|
46
|
+
parent_task_id: int | None = None
|
|
47
|
+
commit_on_exit: bool = False
|
|
48
|
+
session_args: dict = field(default_factory=dict)
|
|
49
|
+
semaphore: asyncio.Semaphore | None = None
|
|
50
|
+
slot_holders: set[int] = field(default_factory=set)
|
|
51
|
+
|
|
52
|
+
def _cleanup_error(error: BaseException) -> str:
|
|
53
|
+
return f"{type(error).__name__}: {error}"
|
|
54
|
+
|
|
55
|
+
def _raise_cleanup_errors(errors: list[BaseException]) -> None:
|
|
56
|
+
if not errors:
|
|
57
|
+
return
|
|
58
|
+
if len(errors) == 1:
|
|
59
|
+
raise errors[0]
|
|
60
|
+
|
|
61
|
+
details = "; ".join(_cleanup_error(error) for error in errors)
|
|
62
|
+
raise RuntimeError(f"Session cleanup failed with {len(errors)} errors: {details}")
|
|
63
|
+
|
|
64
|
+
async def _finalize_session(
|
|
65
|
+
session: AsyncSession,
|
|
66
|
+
commit_on_exit: bool,
|
|
67
|
+
exc: BaseException | None,
|
|
68
|
+
) -> None:
|
|
69
|
+
errors: list[BaseException] = []
|
|
70
|
+
|
|
71
|
+
# Rollback/commit must surface errors to caller; otherwise writes can be lost silently.
|
|
72
|
+
if session.is_active:
|
|
73
|
+
if exc is not None:
|
|
74
|
+
try:
|
|
75
|
+
await session.rollback()
|
|
76
|
+
except BaseException as rollback_error:
|
|
77
|
+
errors.append(rollback_error)
|
|
78
|
+
elif commit_on_exit:
|
|
79
|
+
try:
|
|
80
|
+
await session.commit()
|
|
81
|
+
except BaseException as commit_error:
|
|
82
|
+
errors.append(commit_error)
|
|
83
|
+
try:
|
|
84
|
+
await session.rollback()
|
|
85
|
+
except BaseException as rollback_error:
|
|
86
|
+
errors.append(rollback_error)
|
|
87
|
+
|
|
88
|
+
try:
|
|
89
|
+
await session.close()
|
|
90
|
+
except BaseException as close_error:
|
|
91
|
+
errors.append(close_error)
|
|
92
|
+
|
|
93
|
+
_raise_cleanup_errors(errors)
|
|
94
|
+
|
|
95
|
+
class _ConnectionContextManager:
|
|
96
|
+
"""Async context manager for throttled session access in multi_sessions mode.
|
|
97
|
+
|
|
98
|
+
When used within ``db(multi_sessions=True, max_concurrent=N)``, this
|
|
99
|
+
context manager awaits a semaphore slot before creating a session.
|
|
100
|
+
When the block exits, the session is finalized (committed or rolled back)
|
|
101
|
+
and the slot is released, allowing the next waiting task to proceed.
|
|
102
|
+
|
|
103
|
+
This prevents ``TimeoutError: QueuePool limit ... reached`` by ensuring
|
|
104
|
+
no more than *max_concurrent* sessions hold connections simultaneously.
|
|
105
|
+
"""
|
|
106
|
+
|
|
107
|
+
__slots__ = ("_session", "_state", "_semaphore", "_owns_session", "_acquired_slot")
|
|
108
|
+
|
|
109
|
+
def __init__(self) -> None:
|
|
110
|
+
self._session: AsyncSession | None = None
|
|
111
|
+
self._state: _MultiSessionState | None = None
|
|
112
|
+
self._semaphore: asyncio.Semaphore | None = None
|
|
113
|
+
self._owns_session: bool = False
|
|
114
|
+
self._acquired_slot: bool = False
|
|
115
|
+
|
|
116
|
+
async def __aenter__(self) -> AsyncSession:
|
|
117
|
+
if _Session is None:
|
|
118
|
+
raise SessionNotInitialisedError
|
|
119
|
+
|
|
120
|
+
self._state = _multi_state.get()
|
|
121
|
+
self._semaphore = self._state.semaphore if self._state else None
|
|
122
|
+
|
|
123
|
+
multi_sessions = _multi_sessions_ctx.get()
|
|
124
|
+
if multi_sessions and self._state is not None:
|
|
125
|
+
task = asyncio.current_task()
|
|
126
|
+
task_id = id(task) if task is not None else None
|
|
127
|
+
|
|
128
|
+
# Reuse existing session for this task
|
|
129
|
+
if task_id is not None and task_id in self._state.task_sessions:
|
|
130
|
+
self._session = self._state.task_sessions[task_id]
|
|
131
|
+
self._owns_session = False
|
|
132
|
+
return self._session
|
|
133
|
+
|
|
134
|
+
# Acquire pool slot only when this context creates a new session.
|
|
135
|
+
if self._semaphore:
|
|
136
|
+
await self._semaphore.acquire()
|
|
137
|
+
self._acquired_slot = True
|
|
138
|
+
if task_id is not None:
|
|
139
|
+
self._state.slot_holders.add(task_id)
|
|
140
|
+
|
|
141
|
+
# Create new session — we own it and will close it in __aexit__
|
|
142
|
+
try:
|
|
143
|
+
session = _Session(**self._state.session_args)
|
|
144
|
+
except BaseException:
|
|
145
|
+
if self._acquired_slot and self._semaphore:
|
|
146
|
+
self._semaphore.release()
|
|
147
|
+
self._acquired_slot = False
|
|
148
|
+
if task_id is not None:
|
|
149
|
+
self._state.slot_holders.discard(task_id)
|
|
150
|
+
raise
|
|
151
|
+
self._state.tracked.add(session)
|
|
152
|
+
if task_id is not None:
|
|
153
|
+
self._state.task_sessions[task_id] = session
|
|
154
|
+
self._session = session
|
|
155
|
+
self._owns_session = True
|
|
156
|
+
return session
|
|
157
|
+
else:
|
|
158
|
+
# Not in multi_sessions mode — return the context session
|
|
159
|
+
session = _session.get()
|
|
160
|
+
if session is None:
|
|
161
|
+
raise MissingSessionError
|
|
162
|
+
self._session = session
|
|
163
|
+
self._owns_session = False
|
|
164
|
+
return session
|
|
165
|
+
|
|
166
|
+
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
|
167
|
+
task = asyncio.current_task()
|
|
168
|
+
task_id = id(task) if task else None
|
|
169
|
+
try:
|
|
170
|
+
if self._owns_session and self._session is not None and self._state is not None:
|
|
171
|
+
# Remove from state to prevent double-cleanup by done_callback
|
|
172
|
+
if task_id is not None:
|
|
173
|
+
self._state.task_sessions.pop(task_id, None)
|
|
174
|
+
self._state.tracked.discard(self._session)
|
|
175
|
+
|
|
176
|
+
await _finalize_session(
|
|
177
|
+
self._session,
|
|
178
|
+
commit_on_exit=self._state.commit_on_exit,
|
|
179
|
+
exc=exc_val if exc_type is not None else None,
|
|
180
|
+
)
|
|
181
|
+
finally:
|
|
182
|
+
if task_id is not None and self._state is not None:
|
|
183
|
+
self._state.slot_holders.discard(task_id)
|
|
184
|
+
if self._acquired_slot and self._semaphore:
|
|
185
|
+
self._semaphore.release()
|
|
186
|
+
|
|
187
|
+
class _SQLAlchemyMiddleware(BaseHTTPMiddleware):
|
|
188
|
+
__test__ = False
|
|
189
|
+
|
|
190
|
+
def __init__(
|
|
191
|
+
self,
|
|
192
|
+
app: ASGIApp,
|
|
193
|
+
db_url: str | URL | None = None,
|
|
194
|
+
custom_engine: AsyncEngine | None = None,
|
|
195
|
+
engine_args: dict | None = None,
|
|
196
|
+
session_args: dict | None = None,
|
|
197
|
+
commit_on_exit: bool = False,
|
|
198
|
+
):
|
|
199
|
+
super().__init__(app)
|
|
200
|
+
self.commit_on_exit = commit_on_exit
|
|
201
|
+
engine_args = engine_args or {}
|
|
202
|
+
session_args = session_args or {}
|
|
203
|
+
|
|
204
|
+
if not custom_engine and not db_url:
|
|
205
|
+
raise ValueError("You need to pass a db_url or a custom_engine parameter.")
|
|
206
|
+
if custom_engine:
|
|
207
|
+
engine = custom_engine
|
|
208
|
+
else:
|
|
209
|
+
engine = create_async_engine(db_url, **engine_args)
|
|
210
|
+
|
|
211
|
+
nonlocal _Session
|
|
212
|
+
_Session = async_sessionmaker(
|
|
213
|
+
engine,
|
|
214
|
+
class_=DefaultAsyncSession,
|
|
215
|
+
expire_on_commit=False,
|
|
216
|
+
**session_args,
|
|
217
|
+
)
|
|
218
|
+
|
|
219
|
+
async def dispatch(self, request: Request, call_next: RequestResponseEndpoint):
|
|
220
|
+
async with DBSession(commit_on_exit=self.commit_on_exit):
|
|
221
|
+
return await call_next(request)
|
|
222
|
+
|
|
223
|
+
class DBSessionMeta(type):
|
|
224
|
+
@property
|
|
225
|
+
def session(self) -> AsyncSession:
|
|
226
|
+
"""Return an instance of Session local to the current async context."""
|
|
227
|
+
if _Session is None:
|
|
228
|
+
raise SessionNotInitialisedError
|
|
229
|
+
|
|
230
|
+
multi_sessions = _multi_sessions_ctx.get()
|
|
231
|
+
if multi_sessions:
|
|
232
|
+
state = _multi_state.get()
|
|
233
|
+
if state is None:
|
|
234
|
+
raise RuntimeError("Multi-session state is not initialized")
|
|
235
|
+
|
|
236
|
+
task = asyncio.current_task()
|
|
237
|
+
task_id = id(task) if task is not None else None
|
|
238
|
+
|
|
239
|
+
if task_id is not None and task_id in state.task_sessions:
|
|
240
|
+
return state.task_sessions[task_id]
|
|
241
|
+
|
|
242
|
+
if (
|
|
243
|
+
state.semaphore is not None
|
|
244
|
+
and task_id is not None
|
|
245
|
+
and task_id != state.parent_task_id
|
|
246
|
+
and task_id not in state.slot_holders
|
|
247
|
+
):
|
|
248
|
+
raise RuntimeError(
|
|
249
|
+
"When `max_concurrent` is set, child tasks must access DB via "
|
|
250
|
+
"`db.connection()` or `db.gather()`; direct `db.session` access "
|
|
251
|
+
"from child tasks is not throttled."
|
|
252
|
+
)
|
|
253
|
+
|
|
254
|
+
session = _Session(**state.session_args)
|
|
255
|
+
state.tracked.add(session)
|
|
256
|
+
if task_id is not None:
|
|
257
|
+
state.task_sessions[task_id] = session
|
|
258
|
+
|
|
259
|
+
# Capture loop from current context
|
|
260
|
+
try:
|
|
261
|
+
current_loop = asyncio.get_running_loop()
|
|
262
|
+
except RuntimeError:
|
|
263
|
+
current_loop = None
|
|
264
|
+
|
|
265
|
+
def cleanup_callback(finished_task: asyncio.Task) -> None:
|
|
266
|
+
async def cleanup() -> None:
|
|
267
|
+
task_exception: BaseException | None
|
|
268
|
+
try:
|
|
269
|
+
task_exception = finished_task.exception()
|
|
270
|
+
except (asyncio.CancelledError, GeneratorExit) as e:
|
|
271
|
+
task_exception = e
|
|
272
|
+
except BaseException as error:
|
|
273
|
+
task_exception = error
|
|
274
|
+
|
|
275
|
+
if task_exception is None:
|
|
276
|
+
# Close session on success to return connection to pool
|
|
277
|
+
try:
|
|
278
|
+
if session in state.tracked:
|
|
279
|
+
state.tracked.discard(session)
|
|
280
|
+
await _finalize_session(
|
|
281
|
+
session,
|
|
282
|
+
commit_on_exit=state.commit_on_exit,
|
|
283
|
+
exc=None,
|
|
284
|
+
)
|
|
285
|
+
finally:
|
|
286
|
+
if task_id is not None:
|
|
287
|
+
state.task_sessions.pop(task_id, None)
|
|
288
|
+
state.slot_holders.discard(task_id)
|
|
289
|
+
return
|
|
290
|
+
|
|
291
|
+
try:
|
|
292
|
+
if session not in state.tracked:
|
|
293
|
+
return
|
|
294
|
+
state.tracked.discard(session)
|
|
295
|
+
await _finalize_session(
|
|
296
|
+
session,
|
|
297
|
+
commit_on_exit=state.commit_on_exit,
|
|
298
|
+
exc=task_exception,
|
|
299
|
+
)
|
|
300
|
+
finally:
|
|
301
|
+
if task_id is not None:
|
|
302
|
+
state.task_sessions.pop(task_id, None)
|
|
303
|
+
state.slot_holders.discard(task_id)
|
|
304
|
+
|
|
305
|
+
if current_loop and not current_loop.is_closed():
|
|
306
|
+
|
|
307
|
+
def schedule_cleanup() -> None:
|
|
308
|
+
cleanup_task = current_loop.create_task(cleanup())
|
|
309
|
+
state.cleanup_tasks.append(cleanup_task)
|
|
310
|
+
|
|
311
|
+
current_loop.call_soon(schedule_cleanup)
|
|
312
|
+
else:
|
|
313
|
+
warnings.warn("No running event loop during cleanup", stacklevel=2)
|
|
314
|
+
|
|
315
|
+
if task is not None and task_id != state.parent_task_id:
|
|
316
|
+
task.add_done_callback(cleanup_callback)
|
|
317
|
+
return session
|
|
318
|
+
else:
|
|
319
|
+
session = _session.get()
|
|
320
|
+
if session is None:
|
|
321
|
+
raise MissingSessionError
|
|
322
|
+
return session
|
|
323
|
+
|
|
324
|
+
def connection(cls) -> _ConnectionContextManager:
|
|
325
|
+
"""Return an async context manager that respects pool throttling.
|
|
326
|
+
|
|
327
|
+
When ``max_concurrent`` is set on the enclosing ``db(...)`` context,
|
|
328
|
+
``connection()`` waits for a free semaphore slot before creating a
|
|
329
|
+
session. The session is automatically closed when the block exits,
|
|
330
|
+
releasing the slot for the next waiting task.
|
|
331
|
+
|
|
332
|
+
Usage::
|
|
333
|
+
|
|
334
|
+
async with db(multi_sessions=True, max_concurrent=10):
|
|
335
|
+
async def work(n):
|
|
336
|
+
async with db.connection() as session:
|
|
337
|
+
return await session.execute(text(f"SELECT {n}"))
|
|
338
|
+
tasks = [asyncio.create_task(work(i)) for i in range(100)]
|
|
339
|
+
results = await asyncio.gather(*tasks)
|
|
340
|
+
|
|
341
|
+
Without ``max_concurrent`` the method still works — it simply
|
|
342
|
+
creates a session without throttling and cleans it up on exit.
|
|
343
|
+
"""
|
|
344
|
+
return _ConnectionContextManager()
|
|
345
|
+
|
|
346
|
+
async def gather(cls, *coros_or_futures, return_exceptions: bool = False):
|
|
347
|
+
"""Drop-in replacement for ``asyncio.gather`` with pool throttling.
|
|
348
|
+
|
|
349
|
+
Each coroutine is wrapped so that it acquires a semaphore slot
|
|
350
|
+
(and thus a session) before running, and releases it after.
|
|
351
|
+
This guarantees that no more than ``max_concurrent`` connections
|
|
352
|
+
are in use at any time.
|
|
353
|
+
|
|
354
|
+
Usage::
|
|
355
|
+
|
|
356
|
+
async with db(multi_sessions=True, max_concurrent=10):
|
|
357
|
+
results = await db.gather(
|
|
358
|
+
do_work(1), do_work(2), ..., do_work(100),
|
|
359
|
+
)
|
|
360
|
+
|
|
361
|
+
When ``max_concurrent`` is not set, delegates directly to
|
|
362
|
+
``asyncio.gather`` with no extra overhead.
|
|
363
|
+
"""
|
|
364
|
+
state = _multi_state.get()
|
|
365
|
+
semaphore = state.semaphore if state else None
|
|
366
|
+
|
|
367
|
+
if semaphore is None:
|
|
368
|
+
return await asyncio.gather(
|
|
369
|
+
*coros_or_futures,
|
|
370
|
+
return_exceptions=return_exceptions,
|
|
371
|
+
)
|
|
372
|
+
|
|
373
|
+
async def _throttled(coro):
|
|
374
|
+
async with _ConnectionContextManager():
|
|
375
|
+
return await coro
|
|
376
|
+
|
|
377
|
+
return await asyncio.gather(
|
|
378
|
+
*[_throttled(c) for c in coros_or_futures],
|
|
379
|
+
return_exceptions=return_exceptions,
|
|
380
|
+
)
|
|
381
|
+
|
|
382
|
+
class DBSession(metaclass=DBSessionMeta):
|
|
383
|
+
def __init__(
|
|
384
|
+
self,
|
|
385
|
+
session_args: dict | None = None,
|
|
386
|
+
commit_on_exit: bool = False,
|
|
387
|
+
multi_sessions: bool = False,
|
|
388
|
+
max_concurrent: int | None = None,
|
|
389
|
+
):
|
|
390
|
+
if max_concurrent is not None and max_concurrent < 1:
|
|
391
|
+
raise ValueError("`max_concurrent` must be greater than 0.")
|
|
392
|
+
|
|
393
|
+
self.token = None
|
|
394
|
+
self.multi_sessions_token = None
|
|
395
|
+
self.multi_state_token = None
|
|
396
|
+
self.session_args = session_args or {}
|
|
397
|
+
self.commit_on_exit = commit_on_exit
|
|
398
|
+
self.multi_sessions = multi_sessions
|
|
399
|
+
self.max_concurrent = max_concurrent
|
|
400
|
+
|
|
401
|
+
async def __aenter__(self):
|
|
402
|
+
if not isinstance(_Session, async_sessionmaker):
|
|
403
|
+
raise SessionNotInitialisedError
|
|
404
|
+
|
|
405
|
+
if self.multi_sessions:
|
|
406
|
+
self.multi_sessions_token = _multi_sessions_ctx.set(True)
|
|
407
|
+
parent_task = asyncio.current_task()
|
|
408
|
+
semaphore = (
|
|
409
|
+
asyncio.Semaphore(self.max_concurrent)
|
|
410
|
+
if self.max_concurrent is not None
|
|
411
|
+
else None
|
|
412
|
+
)
|
|
413
|
+
self.multi_state_token = _multi_state.set(
|
|
414
|
+
_MultiSessionState(
|
|
415
|
+
parent_task_id=id(parent_task) if parent_task else None,
|
|
416
|
+
commit_on_exit=self.commit_on_exit,
|
|
417
|
+
session_args=self.session_args,
|
|
418
|
+
semaphore=semaphore,
|
|
419
|
+
)
|
|
420
|
+
)
|
|
421
|
+
else:
|
|
422
|
+
self.token = _session.set(_Session(**self.session_args))
|
|
423
|
+
return type(self)
|
|
424
|
+
|
|
425
|
+
async def __aexit__(self, exc_type, exc_value, traceback):
|
|
426
|
+
if self.multi_sessions:
|
|
427
|
+
_multi_sessions_ctx.reset(self.multi_sessions_token)
|
|
428
|
+
state = _multi_state.get()
|
|
429
|
+
cleanup_errors: list[BaseException] = []
|
|
430
|
+
if state is not None:
|
|
431
|
+
if state.cleanup_tasks:
|
|
432
|
+
cleanup_results = await asyncio.gather(
|
|
433
|
+
*state.cleanup_tasks,
|
|
434
|
+
return_exceptions=True,
|
|
435
|
+
)
|
|
436
|
+
cleanup_errors.extend(
|
|
437
|
+
result
|
|
438
|
+
for result in cleanup_results
|
|
439
|
+
if isinstance(result, BaseException)
|
|
440
|
+
)
|
|
441
|
+
exc = exc_value if exc_type is not None else None
|
|
442
|
+
# Claim all remaining sessions to prevent concurrent cleanup
|
|
443
|
+
sessions_to_finalize = list(state.tracked)
|
|
444
|
+
state.tracked.clear()
|
|
445
|
+
|
|
446
|
+
for session in sessions_to_finalize:
|
|
447
|
+
try:
|
|
448
|
+
await _finalize_session(
|
|
449
|
+
session,
|
|
450
|
+
commit_on_exit=state.commit_on_exit,
|
|
451
|
+
exc=exc,
|
|
452
|
+
)
|
|
453
|
+
except BaseException as caught_cleanup_error:
|
|
454
|
+
cleanup_errors.append(caught_cleanup_error)
|
|
455
|
+
_multi_state.reset(self.multi_state_token)
|
|
456
|
+
|
|
457
|
+
if cleanup_errors:
|
|
458
|
+
if exc_type is None:
|
|
459
|
+
_raise_cleanup_errors(cleanup_errors)
|
|
460
|
+
for cleanup_exc in cleanup_errors:
|
|
461
|
+
warnings.warn(
|
|
462
|
+
"Suppressed session cleanup error because another exception is already "
|
|
463
|
+
f"being raised: {_cleanup_error(cleanup_exc)}",
|
|
464
|
+
stacklevel=2,
|
|
465
|
+
)
|
|
466
|
+
else:
|
|
467
|
+
session = _session.get()
|
|
468
|
+
try:
|
|
469
|
+
if exc_type is not None:
|
|
470
|
+
await session.rollback()
|
|
471
|
+
elif self.commit_on_exit:
|
|
472
|
+
try:
|
|
473
|
+
await session.commit()
|
|
474
|
+
except Exception:
|
|
475
|
+
await session.rollback()
|
|
476
|
+
raise
|
|
477
|
+
finally:
|
|
478
|
+
await session.close()
|
|
479
|
+
_session.reset(self.token)
|
|
480
|
+
|
|
481
|
+
return _SQLAlchemyMiddleware, DBSession
|
|
482
|
+
|
|
483
|
+
|
|
484
|
+
SQLAlchemyMiddleware, db = create_middleware_and_session_proxy()
|
|
@@ -23,6 +23,7 @@ tests/test_import_fallback_simulation.py
|
|
|
23
23
|
tests/test_import_fallbacks.py
|
|
24
24
|
tests/test_import_without_sqlmodel.py
|
|
25
25
|
tests/test_maximum_coverage.py
|
|
26
|
+
tests/test_pool_throttling.py
|
|
26
27
|
tests/test_session.py
|
|
27
28
|
tests/test_single_session_no_gather.py
|
|
28
29
|
tests/test_sqlmodel.py
|
|
@@ -242,7 +242,8 @@ async def test_cleanup_callback_with_mocked_closed_loop():
|
|
|
242
242
|
return {"done": True}
|
|
243
243
|
|
|
244
244
|
client = TestClient(app)
|
|
245
|
-
|
|
245
|
+
with pytest.warns(UserWarning, match="No running event loop during cleanup"):
|
|
246
|
+
response = client.get("/test_mock_closed")
|
|
246
247
|
assert response.status_code == 200
|
|
247
248
|
|
|
248
249
|
|
|
@@ -281,5 +282,6 @@ async def test_cleanup_callback_with_runtime_error():
|
|
|
281
282
|
return {"done": True}
|
|
282
283
|
|
|
283
284
|
client = TestClient(app)
|
|
284
|
-
|
|
285
|
+
with pytest.warns(UserWarning, match="No running event loop during cleanup"):
|
|
286
|
+
response = client.get("/test_runtime_error")
|
|
285
287
|
assert response.status_code == 200
|