fastapi-async-sqlalchemy 0.7.1.post1__tar.gz → 0.7.1.post3__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {fastapi_async_sqlalchemy-0.7.1.post1 → fastapi_async_sqlalchemy-0.7.1.post3}/PKG-INFO +1 -1
- {fastapi_async_sqlalchemy-0.7.1.post1 → fastapi_async_sqlalchemy-0.7.1.post3}/fastapi_async_sqlalchemy/__init__.py +1 -1
- {fastapi_async_sqlalchemy-0.7.1.post1 → fastapi_async_sqlalchemy-0.7.1.post3}/fastapi_async_sqlalchemy/middleware.py +26 -91
- {fastapi_async_sqlalchemy-0.7.1.post1 → fastapi_async_sqlalchemy-0.7.1.post3}/fastapi_async_sqlalchemy.egg-info/PKG-INFO +1 -1
- {fastapi_async_sqlalchemy-0.7.1.post1 → fastapi_async_sqlalchemy-0.7.1.post3}/fastapi_async_sqlalchemy.egg-info/SOURCES.txt +4 -2
- fastapi_async_sqlalchemy-0.7.1.post3/tests/test_backward_compat_gather.py +213 -0
- fastapi_async_sqlalchemy-0.7.1.post3/tests/test_concurrent_queries.py +480 -0
- {fastapi_async_sqlalchemy-0.7.1.post1 → fastapi_async_sqlalchemy-0.7.1.post3}/tests/test_coverage_improvements.py +10 -10
- fastapi_async_sqlalchemy-0.7.1.post3/tests/test_import_without_sqlmodel.py +57 -0
- {fastapi_async_sqlalchemy-0.7.1.post1 → fastapi_async_sqlalchemy-0.7.1.post3}/tests/test_maximum_coverage.py +4 -6
- fastapi_async_sqlalchemy-0.7.1.post3/tests/test_single_session_no_gather.py +104 -0
- fastapi_async_sqlalchemy-0.7.1.post1/tests/test_multi_sessions_cleanup.py +0 -89
- fastapi_async_sqlalchemy-0.7.1.post1/tests/test_multisession_pool.py +0 -82
- {fastapi_async_sqlalchemy-0.7.1.post1 → fastapi_async_sqlalchemy-0.7.1.post3}/LICENSE +0 -0
- {fastapi_async_sqlalchemy-0.7.1.post1 → fastapi_async_sqlalchemy-0.7.1.post3}/README.md +0 -0
- {fastapi_async_sqlalchemy-0.7.1.post1 → fastapi_async_sqlalchemy-0.7.1.post3}/fastapi_async_sqlalchemy/exceptions.py +0 -0
- {fastapi_async_sqlalchemy-0.7.1.post1 → fastapi_async_sqlalchemy-0.7.1.post3}/fastapi_async_sqlalchemy/py.typed +0 -0
- {fastapi_async_sqlalchemy-0.7.1.post1 → fastapi_async_sqlalchemy-0.7.1.post3}/fastapi_async_sqlalchemy.egg-info/dependency_links.txt +0 -0
- {fastapi_async_sqlalchemy-0.7.1.post1 → fastapi_async_sqlalchemy-0.7.1.post3}/fastapi_async_sqlalchemy.egg-info/not-zip-safe +0 -0
- {fastapi_async_sqlalchemy-0.7.1.post1 → fastapi_async_sqlalchemy-0.7.1.post3}/fastapi_async_sqlalchemy.egg-info/requires.txt +0 -0
- {fastapi_async_sqlalchemy-0.7.1.post1 → fastapi_async_sqlalchemy-0.7.1.post3}/fastapi_async_sqlalchemy.egg-info/top_level.txt +0 -0
- {fastapi_async_sqlalchemy-0.7.1.post1 → fastapi_async_sqlalchemy-0.7.1.post3}/pyproject.toml +0 -0
- {fastapi_async_sqlalchemy-0.7.1.post1 → fastapi_async_sqlalchemy-0.7.1.post3}/setup.cfg +0 -0
- {fastapi_async_sqlalchemy-0.7.1.post1 → fastapi_async_sqlalchemy-0.7.1.post3}/setup.py +0 -0
- {fastapi_async_sqlalchemy-0.7.1.post1 → fastapi_async_sqlalchemy-0.7.1.post3}/tests/test_additional_coverage.py +0 -0
- {fastapi_async_sqlalchemy-0.7.1.post1 → fastapi_async_sqlalchemy-0.7.1.post3}/tests/test_coverage_boost.py +0 -0
- {fastapi_async_sqlalchemy-0.7.1.post1 → fastapi_async_sqlalchemy-0.7.1.post3}/tests/test_custom_engine_branch.py +0 -0
- {fastapi_async_sqlalchemy-0.7.1.post1 → fastapi_async_sqlalchemy-0.7.1.post3}/tests/test_edge_cases_coverage.py +0 -0
- {fastapi_async_sqlalchemy-0.7.1.post1 → fastapi_async_sqlalchemy-0.7.1.post3}/tests/test_import_fallback_simulation.py +0 -0
- {fastapi_async_sqlalchemy-0.7.1.post1 → fastapi_async_sqlalchemy-0.7.1.post3}/tests/test_import_fallbacks.py +0 -0
- {fastapi_async_sqlalchemy-0.7.1.post1 → fastapi_async_sqlalchemy-0.7.1.post3}/tests/test_session.py +0 -0
- {fastapi_async_sqlalchemy-0.7.1.post1 → fastapi_async_sqlalchemy-0.7.1.post3}/tests/test_sqlmodel.py +0 -0
- {fastapi_async_sqlalchemy-0.7.1.post1 → fastapi_async_sqlalchemy-0.7.1.post3}/tests/test_type_hints_compatibility.py +0 -0
|
@@ -1,6 +1,5 @@
|
|
|
1
1
|
import asyncio
|
|
2
2
|
from contextvars import ContextVar
|
|
3
|
-
from dataclasses import dataclass, field
|
|
4
3
|
from typing import Optional, Union
|
|
5
4
|
|
|
6
5
|
from sqlalchemy.engine.url import URL
|
|
@@ -27,24 +26,14 @@ except ImportError:
|
|
|
27
26
|
DefaultAsyncSession: type[AsyncSession] = AsyncSession # type: ignore
|
|
28
27
|
|
|
29
28
|
|
|
30
|
-
@dataclass(slots=True)
|
|
31
|
-
class MultiSessionState:
|
|
32
|
-
"""State for multi_sessions mode."""
|
|
33
|
-
|
|
34
|
-
tracked: set[AsyncSession] = field(default_factory=set)
|
|
35
|
-
task_sessions: dict[int, AsyncSession] = field(default_factory=dict)
|
|
36
|
-
cleanup_tasks: list[asyncio.Task] = field(default_factory=list)
|
|
37
|
-
parent_task_id: int = 0
|
|
38
|
-
commit_on_exit: bool = False
|
|
39
|
-
|
|
40
|
-
|
|
41
29
|
def create_middleware_and_session_proxy() -> tuple:
|
|
42
30
|
_Session: Optional[async_sessionmaker] = None
|
|
43
31
|
_session: ContextVar[Optional[AsyncSession]] = ContextVar("_session", default=None)
|
|
44
|
-
|
|
32
|
+
_multi_sessions_ctx: ContextVar[bool] = ContextVar("_multi_sessions_context", default=False)
|
|
33
|
+
_commit_on_exit_ctx: ContextVar[bool] = ContextVar("_commit_on_exit_ctx", default=False)
|
|
45
34
|
|
|
46
35
|
class _SQLAlchemyMiddleware(BaseHTTPMiddleware):
|
|
47
|
-
__test__ = False
|
|
36
|
+
__test__ = False
|
|
48
37
|
|
|
49
38
|
def __init__(
|
|
50
39
|
self,
|
|
@@ -86,48 +75,24 @@ def create_middleware_and_session_proxy() -> tuple:
|
|
|
86
75
|
if _Session is None:
|
|
87
76
|
raise SessionNotInitialisedError
|
|
88
77
|
|
|
89
|
-
|
|
90
|
-
if
|
|
91
|
-
|
|
92
|
-
if task is None:
|
|
93
|
-
raise RuntimeError("Cannot get current task")
|
|
94
|
-
task_id = id(task)
|
|
95
|
-
|
|
96
|
-
if task_id in state.task_sessions:
|
|
97
|
-
return state.task_sessions[task_id]
|
|
98
|
-
|
|
78
|
+
multi_sessions = _multi_sessions_ctx.get()
|
|
79
|
+
if multi_sessions:
|
|
80
|
+
commit_on_exit = _commit_on_exit_ctx.get()
|
|
99
81
|
session = _Session()
|
|
100
|
-
state.task_sessions[task_id] = session
|
|
101
|
-
state.tracked.add(session)
|
|
102
|
-
|
|
103
|
-
# Add cleanup callback only for child tasks
|
|
104
|
-
if task_id != state.parent_task_id:
|
|
105
|
-
|
|
106
|
-
def cleanup_callback(_task):
|
|
107
|
-
try:
|
|
108
|
-
loop = asyncio.get_running_loop()
|
|
109
|
-
if loop.is_closed():
|
|
110
|
-
return
|
|
111
|
-
except RuntimeError:
|
|
112
|
-
return
|
|
113
|
-
|
|
114
|
-
async def cleanup():
|
|
115
|
-
try:
|
|
116
|
-
if state.commit_on_exit:
|
|
117
|
-
try:
|
|
118
|
-
await session.commit()
|
|
119
|
-
except Exception:
|
|
120
|
-
await session.rollback()
|
|
121
|
-
finally:
|
|
122
|
-
await session.close()
|
|
123
|
-
state.tracked.discard(session)
|
|
124
|
-
state.task_sessions.pop(task_id, None)
|
|
125
|
-
|
|
126
|
-
t = loop.create_task(cleanup())
|
|
127
|
-
state.cleanup_tasks.append(t)
|
|
128
|
-
|
|
129
|
-
task.add_done_callback(cleanup_callback)
|
|
130
82
|
|
|
83
|
+
async def cleanup():
|
|
84
|
+
try:
|
|
85
|
+
if commit_on_exit:
|
|
86
|
+
await session.commit()
|
|
87
|
+
except Exception:
|
|
88
|
+
await session.rollback()
|
|
89
|
+
raise
|
|
90
|
+
finally:
|
|
91
|
+
await session.close()
|
|
92
|
+
|
|
93
|
+
task = asyncio.current_task()
|
|
94
|
+
if task is not None:
|
|
95
|
+
task.add_done_callback(lambda t: asyncio.create_task(cleanup()))
|
|
131
96
|
return session
|
|
132
97
|
else:
|
|
133
98
|
session = _session.get()
|
|
@@ -143,7 +108,8 @@ def create_middleware_and_session_proxy() -> tuple:
|
|
|
143
108
|
multi_sessions: bool = False,
|
|
144
109
|
):
|
|
145
110
|
self.token = None
|
|
146
|
-
self.
|
|
111
|
+
self.multi_sessions_token = None
|
|
112
|
+
self.commit_on_exit_token = None
|
|
147
113
|
self.session_args = session_args or {}
|
|
148
114
|
self.commit_on_exit = commit_on_exit
|
|
149
115
|
self.multi_sessions = multi_sessions
|
|
@@ -153,54 +119,23 @@ def create_middleware_and_session_proxy() -> tuple:
|
|
|
153
119
|
raise SessionNotInitialisedError
|
|
154
120
|
|
|
155
121
|
if self.multi_sessions:
|
|
156
|
-
|
|
157
|
-
|
|
158
|
-
commit_on_exit=self.commit_on_exit,
|
|
159
|
-
)
|
|
160
|
-
self.multi_state_token = _multi_state.set(state)
|
|
122
|
+
self.multi_sessions_token = _multi_sessions_ctx.set(True)
|
|
123
|
+
self.commit_on_exit_token = _commit_on_exit_ctx.set(self.commit_on_exit)
|
|
161
124
|
else:
|
|
162
125
|
self.token = _session.set(_Session(**self.session_args))
|
|
163
126
|
return type(self)
|
|
164
127
|
|
|
165
128
|
async def __aexit__(self, exc_type, exc_value, traceback):
|
|
166
129
|
if self.multi_sessions:
|
|
167
|
-
|
|
168
|
-
|
|
169
|
-
# Wait for cleanup tasks
|
|
170
|
-
if state.cleanup_tasks:
|
|
171
|
-
await asyncio.sleep(0)
|
|
172
|
-
await asyncio.gather(*state.cleanup_tasks, return_exceptions=True)
|
|
173
|
-
|
|
174
|
-
# Clean up remaining sessions
|
|
175
|
-
for session in list(state.tracked):
|
|
176
|
-
try:
|
|
177
|
-
if exc_type is not None:
|
|
178
|
-
await session.rollback()
|
|
179
|
-
elif self.commit_on_exit:
|
|
180
|
-
try:
|
|
181
|
-
await session.commit()
|
|
182
|
-
except Exception:
|
|
183
|
-
await session.rollback()
|
|
184
|
-
except Exception:
|
|
185
|
-
pass
|
|
186
|
-
finally:
|
|
187
|
-
try:
|
|
188
|
-
await session.close()
|
|
189
|
-
except Exception:
|
|
190
|
-
pass
|
|
191
|
-
|
|
192
|
-
_multi_state.reset(self.multi_state_token)
|
|
130
|
+
_multi_sessions_ctx.reset(self.multi_sessions_token)
|
|
131
|
+
_commit_on_exit_ctx.reset(self.commit_on_exit_token)
|
|
193
132
|
else:
|
|
194
133
|
session = _session.get()
|
|
195
134
|
try:
|
|
196
135
|
if exc_type is not None:
|
|
197
136
|
await session.rollback()
|
|
198
137
|
elif self.commit_on_exit:
|
|
199
|
-
|
|
200
|
-
await session.commit()
|
|
201
|
-
except Exception:
|
|
202
|
-
await session.rollback()
|
|
203
|
-
raise
|
|
138
|
+
await session.commit()
|
|
204
139
|
finally:
|
|
205
140
|
await session.close()
|
|
206
141
|
_session.reset(self.token)
|
|
@@ -13,15 +13,17 @@ fastapi_async_sqlalchemy.egg-info/not-zip-safe
|
|
|
13
13
|
fastapi_async_sqlalchemy.egg-info/requires.txt
|
|
14
14
|
fastapi_async_sqlalchemy.egg-info/top_level.txt
|
|
15
15
|
tests/test_additional_coverage.py
|
|
16
|
+
tests/test_backward_compat_gather.py
|
|
17
|
+
tests/test_concurrent_queries.py
|
|
16
18
|
tests/test_coverage_boost.py
|
|
17
19
|
tests/test_coverage_improvements.py
|
|
18
20
|
tests/test_custom_engine_branch.py
|
|
19
21
|
tests/test_edge_cases_coverage.py
|
|
20
22
|
tests/test_import_fallback_simulation.py
|
|
21
23
|
tests/test_import_fallbacks.py
|
|
24
|
+
tests/test_import_without_sqlmodel.py
|
|
22
25
|
tests/test_maximum_coverage.py
|
|
23
|
-
tests/test_multi_sessions_cleanup.py
|
|
24
|
-
tests/test_multisession_pool.py
|
|
25
26
|
tests/test_session.py
|
|
27
|
+
tests/test_single_session_no_gather.py
|
|
26
28
|
tests/test_sqlmodel.py
|
|
27
29
|
tests/test_type_hints_compatibility.py
|
|
@@ -0,0 +1,213 @@
|
|
|
1
|
+
"""Test backward compatibility for asyncio.gather() without multi_sessions flag.
|
|
2
|
+
|
|
3
|
+
This test verifies that after the fix, the old code pattern works without
|
|
4
|
+
requiring multi_sessions=True explicitly.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
import asyncio
|
|
8
|
+
|
|
9
|
+
import pytest
|
|
10
|
+
from sqlalchemy import text
|
|
11
|
+
|
|
12
|
+
db_url = "sqlite+aiosqlite://"
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
@pytest.mark.asyncio
|
|
16
|
+
async def test_gather_works_without_multi_sessions_flag(app, db, SQLAlchemyMiddleware):
|
|
17
|
+
"""
|
|
18
|
+
Verify that asyncio.gather() works in normal mode (without multi_sessions=True).
|
|
19
|
+
|
|
20
|
+
This is the backward compatibility fix - users shouldn't need to change their code.
|
|
21
|
+
"""
|
|
22
|
+
app.add_middleware(SQLAlchemyMiddleware, db_url=db_url)
|
|
23
|
+
|
|
24
|
+
async with db(commit_on_exit=True):
|
|
25
|
+
await db.session.execute(
|
|
26
|
+
text("CREATE TABLE IF NOT EXISTS compat_test (id INTEGER PRIMARY KEY, value TEXT)")
|
|
27
|
+
)
|
|
28
|
+
for i in range(20):
|
|
29
|
+
await db.session.execute(
|
|
30
|
+
text("INSERT INTO compat_test (value) VALUES (:value)"),
|
|
31
|
+
{"value": f"value_{i}"},
|
|
32
|
+
)
|
|
33
|
+
|
|
34
|
+
# OLD CODE PATTERN - should work now without multi_sessions=True
|
|
35
|
+
async with db():
|
|
36
|
+
count_stmt = text("SELECT COUNT(*) FROM compat_test")
|
|
37
|
+
data_stmt = text("SELECT * FROM compat_test LIMIT 5")
|
|
38
|
+
|
|
39
|
+
# This should work! Each parallel query gets its own session
|
|
40
|
+
count_result, data_result = await asyncio.gather(
|
|
41
|
+
db.session.execute(count_stmt),
|
|
42
|
+
db.session.execute(data_stmt),
|
|
43
|
+
)
|
|
44
|
+
|
|
45
|
+
count = count_result.scalar()
|
|
46
|
+
data = data_result.fetchall()
|
|
47
|
+
|
|
48
|
+
assert count == 20
|
|
49
|
+
assert len(data) == 5
|
|
50
|
+
|
|
51
|
+
print("✅ Backward compatibility preserved: asyncio.gather() works without multi_sessions!")
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
@pytest.mark.asyncio
|
|
55
|
+
async def test_gather_multiple_queries_parallel(app, db, SQLAlchemyMiddleware):
|
|
56
|
+
"""
|
|
57
|
+
Test that multiple parallel queries work correctly.
|
|
58
|
+
"""
|
|
59
|
+
app.add_middleware(SQLAlchemyMiddleware, db_url=db_url)
|
|
60
|
+
|
|
61
|
+
async with db(commit_on_exit=True):
|
|
62
|
+
await db.session.execute(
|
|
63
|
+
text("CREATE TABLE IF NOT EXISTS parallel_test (id INTEGER PRIMARY KEY, status TEXT)")
|
|
64
|
+
)
|
|
65
|
+
for i in range(100):
|
|
66
|
+
await db.session.execute(
|
|
67
|
+
text("INSERT INTO parallel_test (status) VALUES (:status)"),
|
|
68
|
+
{"status": "active" if i % 3 == 0 else "inactive"},
|
|
69
|
+
)
|
|
70
|
+
|
|
71
|
+
# Multiple parallel queries without multi_sessions=True
|
|
72
|
+
async with db():
|
|
73
|
+
stmt1 = text("SELECT COUNT(*) FROM parallel_test WHERE status = 'active'")
|
|
74
|
+
stmt2 = text("SELECT COUNT(*) FROM parallel_test WHERE status = 'inactive'")
|
|
75
|
+
stmt3 = text("SELECT * FROM parallel_test LIMIT 10")
|
|
76
|
+
|
|
77
|
+
r1, r2, r3 = await asyncio.gather(
|
|
78
|
+
db.session.execute(stmt1),
|
|
79
|
+
db.session.execute(stmt2),
|
|
80
|
+
db.session.execute(stmt3),
|
|
81
|
+
)
|
|
82
|
+
|
|
83
|
+
active_count = r1.scalar()
|
|
84
|
+
inactive_count = r2.scalar()
|
|
85
|
+
data = r3.fetchall()
|
|
86
|
+
|
|
87
|
+
assert active_count == 34 # 100 / 3 rounded up
|
|
88
|
+
assert inactive_count == 66
|
|
89
|
+
assert len(data) == 10
|
|
90
|
+
|
|
91
|
+
|
|
92
|
+
@pytest.mark.asyncio
|
|
93
|
+
async def test_production_pattern_without_changes(app, db, SQLAlchemyMiddleware):
|
|
94
|
+
"""
|
|
95
|
+
Verify the EXACT production pattern from the error report works.
|
|
96
|
+
|
|
97
|
+
This is the pattern from /app/api/repository/routes.py:186
|
|
98
|
+
"""
|
|
99
|
+
app.add_middleware(SQLAlchemyMiddleware, db_url=db_url)
|
|
100
|
+
|
|
101
|
+
async with db(commit_on_exit=True):
|
|
102
|
+
await db.session.execute(
|
|
103
|
+
text("""
|
|
104
|
+
CREATE TABLE IF NOT EXISTS processes (
|
|
105
|
+
id INTEGER PRIMARY KEY,
|
|
106
|
+
name TEXT NOT NULL,
|
|
107
|
+
status TEXT,
|
|
108
|
+
created_at TEXT
|
|
109
|
+
)
|
|
110
|
+
""")
|
|
111
|
+
)
|
|
112
|
+
for i in range(100):
|
|
113
|
+
await db.session.execute(
|
|
114
|
+
text(
|
|
115
|
+
"""INSERT INTO
|
|
116
|
+
processes (name, status, created_at)
|
|
117
|
+
VALUES (:name, :status, :created_at)
|
|
118
|
+
"""
|
|
119
|
+
),
|
|
120
|
+
{
|
|
121
|
+
"name": f"process_{i}",
|
|
122
|
+
"status": "running" if i % 2 == 0 else "stopped",
|
|
123
|
+
"created_at": "2025-01-01T00:00:00",
|
|
124
|
+
},
|
|
125
|
+
)
|
|
126
|
+
|
|
127
|
+
# EXACT PRODUCTION CODE - should work now!
|
|
128
|
+
async with db():
|
|
129
|
+
count_stmt = text("SELECT COUNT(*) FROM processes WHERE status = :status")
|
|
130
|
+
processes_stmt = text(
|
|
131
|
+
"SELECT * FROM processes WHERE status = :status "
|
|
132
|
+
"ORDER BY created_at DESC LIMIT :limit OFFSET :offset"
|
|
133
|
+
)
|
|
134
|
+
|
|
135
|
+
count_stmt = count_stmt.bindparams(status="running")
|
|
136
|
+
processes_stmt = processes_stmt.bindparams(status="running", limit=10, offset=0)
|
|
137
|
+
|
|
138
|
+
# This is line 186 from production - should work without any changes!
|
|
139
|
+
total_result, processes_result = await asyncio.gather(
|
|
140
|
+
db.session.execute(count_stmt),
|
|
141
|
+
db.session.execute(processes_stmt),
|
|
142
|
+
)
|
|
143
|
+
|
|
144
|
+
total = total_result.scalar()
|
|
145
|
+
processes = processes_result.fetchall()
|
|
146
|
+
|
|
147
|
+
assert total == 50
|
|
148
|
+
assert len(processes) == 10
|
|
149
|
+
|
|
150
|
+
print("✅ Production code pattern works without any changes!")
|
|
151
|
+
|
|
152
|
+
|
|
153
|
+
@pytest.mark.asyncio
|
|
154
|
+
async def test_commit_on_exit_with_parallel_queries(app, db, SQLAlchemyMiddleware):
|
|
155
|
+
"""
|
|
156
|
+
Verify that commit_on_exit works correctly with parallel queries.
|
|
157
|
+
"""
|
|
158
|
+
app.add_middleware(SQLAlchemyMiddleware, db_url=db_url)
|
|
159
|
+
|
|
160
|
+
# Create table first
|
|
161
|
+
async with db(commit_on_exit=True):
|
|
162
|
+
await db.session.execute(
|
|
163
|
+
text("CREATE TABLE IF NOT EXISTS commit_test (id INTEGER PRIMARY KEY, value TEXT)")
|
|
164
|
+
)
|
|
165
|
+
|
|
166
|
+
# Insert data with parallel queries and commit_on_exit
|
|
167
|
+
async with db(commit_on_exit=True):
|
|
168
|
+
# These should all be committed automatically
|
|
169
|
+
await asyncio.gather(
|
|
170
|
+
db.session.execute(text("INSERT INTO commit_test (value) VALUES ('a')")),
|
|
171
|
+
db.session.execute(text("INSERT INTO commit_test (value) VALUES ('b')")),
|
|
172
|
+
db.session.execute(text("INSERT INTO commit_test (value) VALUES ('c')")),
|
|
173
|
+
)
|
|
174
|
+
|
|
175
|
+
# Verify data was committed
|
|
176
|
+
async with db():
|
|
177
|
+
result = await db.session.execute(text("SELECT COUNT(*) FROM commit_test"))
|
|
178
|
+
count = result.scalar()
|
|
179
|
+
assert count == 3
|
|
180
|
+
|
|
181
|
+
print("✅ commit_on_exit works correctly with parallel queries!")
|
|
182
|
+
|
|
183
|
+
|
|
184
|
+
@pytest.mark.asyncio
|
|
185
|
+
async def test_rollback_on_error_with_parallel_queries(app, db, SQLAlchemyMiddleware):
|
|
186
|
+
"""
|
|
187
|
+
Verify that rollback works correctly when error occurs in parallel queries.
|
|
188
|
+
"""
|
|
189
|
+
app.add_middleware(SQLAlchemyMiddleware, db_url=db_url)
|
|
190
|
+
|
|
191
|
+
async with db(commit_on_exit=True):
|
|
192
|
+
await db.session.execute(
|
|
193
|
+
text("CREATE TABLE IF NOT EXISTS rollback_test (id INTEGER PRIMARY KEY, value TEXT)")
|
|
194
|
+
)
|
|
195
|
+
|
|
196
|
+
# Try to insert with error - should rollback all
|
|
197
|
+
try:
|
|
198
|
+
async with db(commit_on_exit=True):
|
|
199
|
+
await db.session.execute(
|
|
200
|
+
text("INSERT INTO rollback_test (value) VALUES ('should_rollback')")
|
|
201
|
+
)
|
|
202
|
+
# Force an error
|
|
203
|
+
raise RuntimeError("Simulated error")
|
|
204
|
+
except RuntimeError:
|
|
205
|
+
pass
|
|
206
|
+
|
|
207
|
+
# Verify data was rolled back
|
|
208
|
+
async with db():
|
|
209
|
+
result = await db.session.execute(text("SELECT COUNT(*) FROM rollback_test"))
|
|
210
|
+
count = result.scalar()
|
|
211
|
+
assert count == 0
|
|
212
|
+
|
|
213
|
+
print("✅ Rollback works correctly on error!")
|