iceaxe 0.7.1__cp313-cp313-macosx_11_0_arm64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of iceaxe might be problematic. Click here for more details.
- iceaxe/__init__.py +20 -0
- iceaxe/__tests__/__init__.py +0 -0
- iceaxe/__tests__/benchmarks/__init__.py +0 -0
- iceaxe/__tests__/benchmarks/test_bulk_insert.py +45 -0
- iceaxe/__tests__/benchmarks/test_select.py +114 -0
- iceaxe/__tests__/conf_models.py +133 -0
- iceaxe/__tests__/conftest.py +204 -0
- iceaxe/__tests__/docker_helpers.py +208 -0
- iceaxe/__tests__/helpers.py +268 -0
- iceaxe/__tests__/migrations/__init__.py +0 -0
- iceaxe/__tests__/migrations/conftest.py +36 -0
- iceaxe/__tests__/migrations/test_action_sorter.py +237 -0
- iceaxe/__tests__/migrations/test_generator.py +140 -0
- iceaxe/__tests__/migrations/test_generics.py +91 -0
- iceaxe/__tests__/mountaineer/__init__.py +0 -0
- iceaxe/__tests__/mountaineer/dependencies/__init__.py +0 -0
- iceaxe/__tests__/mountaineer/dependencies/test_core.py +76 -0
- iceaxe/__tests__/schemas/__init__.py +0 -0
- iceaxe/__tests__/schemas/test_actions.py +1264 -0
- iceaxe/__tests__/schemas/test_cli.py +25 -0
- iceaxe/__tests__/schemas/test_db_memory_serializer.py +1525 -0
- iceaxe/__tests__/schemas/test_db_serializer.py +398 -0
- iceaxe/__tests__/schemas/test_db_stubs.py +190 -0
- iceaxe/__tests__/test_alias.py +83 -0
- iceaxe/__tests__/test_base.py +52 -0
- iceaxe/__tests__/test_comparison.py +383 -0
- iceaxe/__tests__/test_field.py +11 -0
- iceaxe/__tests__/test_helpers.py +9 -0
- iceaxe/__tests__/test_modifications.py +151 -0
- iceaxe/__tests__/test_queries.py +605 -0
- iceaxe/__tests__/test_queries_str.py +173 -0
- iceaxe/__tests__/test_session.py +1511 -0
- iceaxe/__tests__/test_text_search.py +287 -0
- iceaxe/alias_values.py +67 -0
- iceaxe/base.py +350 -0
- iceaxe/comparison.py +560 -0
- iceaxe/field.py +250 -0
- iceaxe/functions.py +906 -0
- iceaxe/generics.py +140 -0
- iceaxe/io.py +107 -0
- iceaxe/logging.py +91 -0
- iceaxe/migrations/__init__.py +5 -0
- iceaxe/migrations/action_sorter.py +98 -0
- iceaxe/migrations/cli.py +228 -0
- iceaxe/migrations/client_io.py +62 -0
- iceaxe/migrations/generator.py +404 -0
- iceaxe/migrations/migration.py +86 -0
- iceaxe/migrations/migrator.py +101 -0
- iceaxe/modifications.py +176 -0
- iceaxe/mountaineer/__init__.py +10 -0
- iceaxe/mountaineer/cli.py +74 -0
- iceaxe/mountaineer/config.py +46 -0
- iceaxe/mountaineer/dependencies/__init__.py +6 -0
- iceaxe/mountaineer/dependencies/core.py +67 -0
- iceaxe/postgres.py +133 -0
- iceaxe/py.typed +0 -0
- iceaxe/queries.py +1455 -0
- iceaxe/queries_str.py +294 -0
- iceaxe/schemas/__init__.py +0 -0
- iceaxe/schemas/actions.py +864 -0
- iceaxe/schemas/cli.py +30 -0
- iceaxe/schemas/db_memory_serializer.py +705 -0
- iceaxe/schemas/db_serializer.py +346 -0
- iceaxe/schemas/db_stubs.py +525 -0
- iceaxe/session.py +860 -0
- iceaxe/session_optimized.c +12035 -0
- iceaxe/session_optimized.cpython-313-darwin.so +0 -0
- iceaxe/session_optimized.pyx +212 -0
- iceaxe/sql_types.py +148 -0
- iceaxe/typing.py +73 -0
- iceaxe-0.7.1.dist-info/METADATA +261 -0
- iceaxe-0.7.1.dist-info/RECORD +75 -0
- iceaxe-0.7.1.dist-info/WHEEL +6 -0
- iceaxe-0.7.1.dist-info/licenses/LICENSE +21 -0
- iceaxe-0.7.1.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,1511 @@
|
|
|
1
|
+
from contextlib import asynccontextmanager
|
|
2
|
+
from enum import StrEnum
|
|
3
|
+
from typing import Any, Type
|
|
4
|
+
from unittest.mock import AsyncMock, patch
|
|
5
|
+
|
|
6
|
+
import asyncpg
|
|
7
|
+
import pytest
|
|
8
|
+
from asyncpg.connection import Connection
|
|
9
|
+
|
|
10
|
+
from iceaxe.__tests__.conf_models import (
|
|
11
|
+
ArtifactDemo,
|
|
12
|
+
ComplexDemo,
|
|
13
|
+
DemoModelA,
|
|
14
|
+
DemoModelB,
|
|
15
|
+
JsonDemo,
|
|
16
|
+
UserDemo,
|
|
17
|
+
)
|
|
18
|
+
from iceaxe.base import INTERNAL_TABLE_FIELDS, TableBase
|
|
19
|
+
from iceaxe.field import Field
|
|
20
|
+
from iceaxe.functions import func
|
|
21
|
+
from iceaxe.queries import QueryBuilder
|
|
22
|
+
from iceaxe.schemas.cli import create_all
|
|
23
|
+
from iceaxe.session import (
|
|
24
|
+
PG_MAX_PARAMETERS,
|
|
25
|
+
TYPE_CACHE,
|
|
26
|
+
DBConnection,
|
|
27
|
+
)
|
|
28
|
+
from iceaxe.typing import column
|
|
29
|
+
|
|
30
|
+
#
|
|
31
|
+
# Insert / Update / Delete with ORM objects
|
|
32
|
+
#
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
@pytest.mark.asyncio
|
|
36
|
+
async def test_db_connection_insert(db_connection: DBConnection):
|
|
37
|
+
user = UserDemo(name="John Doe", email="john@example.com")
|
|
38
|
+
await db_connection.insert([user])
|
|
39
|
+
|
|
40
|
+
result = await db_connection.conn.fetch(
|
|
41
|
+
"SELECT * FROM userdemo WHERE name = $1", "John Doe"
|
|
42
|
+
)
|
|
43
|
+
assert len(result) == 1
|
|
44
|
+
assert result[0]["id"] == user.id
|
|
45
|
+
assert result[0]["name"] == "John Doe"
|
|
46
|
+
assert result[0]["email"] == "john@example.com"
|
|
47
|
+
assert user.get_modified_attributes() == {}
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
@pytest.mark.asyncio
|
|
51
|
+
async def test_db_connection_update(db_connection: DBConnection):
|
|
52
|
+
user = UserDemo(name="John Doe", email="john@example.com")
|
|
53
|
+
await db_connection.insert([user])
|
|
54
|
+
|
|
55
|
+
user.name = "Jane Doe"
|
|
56
|
+
await db_connection.update([user])
|
|
57
|
+
|
|
58
|
+
result = await db_connection.conn.fetch(
|
|
59
|
+
"SELECT * FROM userdemo WHERE id = $1", user.id
|
|
60
|
+
)
|
|
61
|
+
assert len(result) == 1
|
|
62
|
+
assert result[0]["name"] == "Jane Doe"
|
|
63
|
+
assert user.get_modified_attributes() == {}
|
|
64
|
+
|
|
65
|
+
|
|
66
|
+
@pytest.mark.asyncio
|
|
67
|
+
async def test_db_obj_mixin_track_modifications():
|
|
68
|
+
user = UserDemo(name="John Doe", email="john@example.com")
|
|
69
|
+
assert user.get_modified_attributes() == {}
|
|
70
|
+
|
|
71
|
+
user.name = "Jane Doe"
|
|
72
|
+
assert user.get_modified_attributes() == {"name": "Jane Doe"}
|
|
73
|
+
|
|
74
|
+
user.email = "jane@example.com"
|
|
75
|
+
assert user.get_modified_attributes() == {
|
|
76
|
+
"name": "Jane Doe",
|
|
77
|
+
"email": "jane@example.com",
|
|
78
|
+
}
|
|
79
|
+
|
|
80
|
+
user.clear_modified_attributes()
|
|
81
|
+
assert user.get_modified_attributes() == {}
|
|
82
|
+
|
|
83
|
+
|
|
84
|
+
@pytest.mark.asyncio
|
|
85
|
+
async def test_db_connection_delete_query(db_connection: DBConnection):
|
|
86
|
+
userdemo = [
|
|
87
|
+
UserDemo(name="John Doe", email="john@example.com"),
|
|
88
|
+
UserDemo(name="Jane Doe", email="jane@example.com"),
|
|
89
|
+
]
|
|
90
|
+
|
|
91
|
+
await db_connection.insert(userdemo)
|
|
92
|
+
|
|
93
|
+
query = QueryBuilder().delete(UserDemo).where(UserDemo.name == "John Doe")
|
|
94
|
+
await db_connection.exec(query)
|
|
95
|
+
|
|
96
|
+
result = await db_connection.conn.fetch("SELECT * FROM userdemo")
|
|
97
|
+
assert len(result) == 1
|
|
98
|
+
|
|
99
|
+
|
|
100
|
+
@pytest.mark.asyncio
|
|
101
|
+
async def test_db_connection_insert_multiple(db_connection: DBConnection):
|
|
102
|
+
userdemo = [
|
|
103
|
+
UserDemo(name="John Doe", email="john@example.com"),
|
|
104
|
+
UserDemo(name="Jane Doe", email="jane@example.com"),
|
|
105
|
+
]
|
|
106
|
+
|
|
107
|
+
await db_connection.insert(userdemo)
|
|
108
|
+
|
|
109
|
+
result = await db_connection.conn.fetch("SELECT * FROM userdemo ORDER BY id")
|
|
110
|
+
assert len(result) == 2
|
|
111
|
+
assert result[0]["name"] == "John Doe"
|
|
112
|
+
assert result[1]["name"] == "Jane Doe"
|
|
113
|
+
assert userdemo[0].id == result[0]["id"]
|
|
114
|
+
assert userdemo[1].id == result[1]["id"]
|
|
115
|
+
assert all(user.get_modified_attributes() == {} for user in userdemo)
|
|
116
|
+
|
|
117
|
+
|
|
118
|
+
@pytest.mark.asyncio
|
|
119
|
+
async def test_db_connection_update_multiple(db_connection: DBConnection):
|
|
120
|
+
userdemo = [
|
|
121
|
+
UserDemo(name="John Doe", email="john@example.com"),
|
|
122
|
+
UserDemo(name="Jane Doe", email="jane@example.com"),
|
|
123
|
+
]
|
|
124
|
+
await db_connection.insert(userdemo)
|
|
125
|
+
|
|
126
|
+
userdemo[0].name = "Johnny Doe"
|
|
127
|
+
userdemo[1].email = "janey@example.com"
|
|
128
|
+
|
|
129
|
+
await db_connection.update(userdemo)
|
|
130
|
+
|
|
131
|
+
result = await db_connection.conn.fetch("SELECT * FROM userdemo ORDER BY id")
|
|
132
|
+
assert len(result) == 2
|
|
133
|
+
assert result[0]["name"] == "Johnny Doe"
|
|
134
|
+
assert result[1]["email"] == "janey@example.com"
|
|
135
|
+
assert all(user.get_modified_attributes() == {} for user in userdemo)
|
|
136
|
+
|
|
137
|
+
|
|
138
|
+
@pytest.mark.asyncio
|
|
139
|
+
async def test_db_connection_insert_empty_list(db_connection: DBConnection):
|
|
140
|
+
await db_connection.insert([])
|
|
141
|
+
result = await db_connection.conn.fetch("SELECT * FROM userdemo")
|
|
142
|
+
assert len(result) == 0
|
|
143
|
+
|
|
144
|
+
|
|
145
|
+
@pytest.mark.asyncio
|
|
146
|
+
async def test_db_connection_update_empty_list(db_connection: DBConnection):
|
|
147
|
+
await db_connection.update([])
|
|
148
|
+
# This test doesn't really assert anything, as an empty update shouldn't change the database
|
|
149
|
+
|
|
150
|
+
|
|
151
|
+
@pytest.mark.asyncio
|
|
152
|
+
async def test_db_connection_update_no_modifications(db_connection: DBConnection):
|
|
153
|
+
user = UserDemo(name="John Doe", email="john@example.com")
|
|
154
|
+
await db_connection.insert([user])
|
|
155
|
+
|
|
156
|
+
await db_connection.update([user])
|
|
157
|
+
|
|
158
|
+
result = await db_connection.conn.fetch(
|
|
159
|
+
"SELECT * FROM userdemo WHERE id = $1", user.id
|
|
160
|
+
)
|
|
161
|
+
assert len(result) == 1
|
|
162
|
+
assert result[0]["name"] == "John Doe"
|
|
163
|
+
assert result[0]["email"] == "john@example.com"
|
|
164
|
+
|
|
165
|
+
|
|
166
|
+
@pytest.mark.asyncio
|
|
167
|
+
async def test_delete_object(db_connection: DBConnection):
|
|
168
|
+
user = UserDemo(name="John Doe", email="john@example.com")
|
|
169
|
+
await db_connection.insert([user])
|
|
170
|
+
|
|
171
|
+
result = await db_connection.conn.fetch(
|
|
172
|
+
"SELECT * FROM userdemo WHERE id = $1", user.id
|
|
173
|
+
)
|
|
174
|
+
assert len(result) == 1
|
|
175
|
+
|
|
176
|
+
await db_connection.delete([user])
|
|
177
|
+
|
|
178
|
+
result = await db_connection.conn.fetch(
|
|
179
|
+
"SELECT * FROM userdemo WHERE id = $1", user.id
|
|
180
|
+
)
|
|
181
|
+
assert len(result) == 0
|
|
182
|
+
|
|
183
|
+
|
|
184
|
+
#
|
|
185
|
+
# Select into ORM objects
|
|
186
|
+
#
|
|
187
|
+
|
|
188
|
+
|
|
189
|
+
@pytest.mark.asyncio
|
|
190
|
+
async def test_select(db_connection: DBConnection):
|
|
191
|
+
user = UserDemo(name="John Doe", email="john@example.com")
|
|
192
|
+
await db_connection.insert([user])
|
|
193
|
+
|
|
194
|
+
# Table selection
|
|
195
|
+
result_1 = await db_connection.exec(QueryBuilder().select(UserDemo))
|
|
196
|
+
assert result_1 == [UserDemo(id=user.id, name="John Doe", email="john@example.com")]
|
|
197
|
+
|
|
198
|
+
# Single column selection
|
|
199
|
+
result_2 = await db_connection.exec(QueryBuilder().select(UserDemo.email))
|
|
200
|
+
assert result_2 == ["john@example.com"]
|
|
201
|
+
|
|
202
|
+
# Multiple column selection
|
|
203
|
+
result_3 = await db_connection.exec(
|
|
204
|
+
QueryBuilder().select((UserDemo.name, UserDemo.email))
|
|
205
|
+
)
|
|
206
|
+
assert result_3 == [("John Doe", "john@example.com")]
|
|
207
|
+
|
|
208
|
+
# Table and column selection
|
|
209
|
+
result_4 = await db_connection.exec(
|
|
210
|
+
QueryBuilder().select((UserDemo, UserDemo.email))
|
|
211
|
+
)
|
|
212
|
+
assert result_4 == [
|
|
213
|
+
(
|
|
214
|
+
UserDemo(id=user.id, name="John Doe", email="john@example.com"),
|
|
215
|
+
"john@example.com",
|
|
216
|
+
)
|
|
217
|
+
]
|
|
218
|
+
|
|
219
|
+
|
|
220
|
+
@pytest.mark.asyncio
|
|
221
|
+
async def test_is_null(db_connection: DBConnection):
|
|
222
|
+
user = UserDemo(name="John Doe", email="john@example.com")
|
|
223
|
+
await db_connection.insert([user])
|
|
224
|
+
|
|
225
|
+
# Table selection
|
|
226
|
+
result_1 = await db_connection.exec(
|
|
227
|
+
QueryBuilder()
|
|
228
|
+
.select(UserDemo)
|
|
229
|
+
.where(
|
|
230
|
+
UserDemo.id == None, # noqa: E711
|
|
231
|
+
)
|
|
232
|
+
)
|
|
233
|
+
assert result_1 == []
|
|
234
|
+
|
|
235
|
+
# Single column selection
|
|
236
|
+
result_2 = await db_connection.exec(
|
|
237
|
+
QueryBuilder()
|
|
238
|
+
.select(UserDemo)
|
|
239
|
+
.where(
|
|
240
|
+
UserDemo.id != None, # noqa: E711
|
|
241
|
+
)
|
|
242
|
+
)
|
|
243
|
+
assert result_2 == [UserDemo(id=user.id, name="John Doe", email="john@example.com")]
|
|
244
|
+
|
|
245
|
+
|
|
246
|
+
@pytest.mark.asyncio
|
|
247
|
+
async def test_select_complex(db_connection: DBConnection):
|
|
248
|
+
"""
|
|
249
|
+
Ensure that we can serialize the complex types.
|
|
250
|
+
|
|
251
|
+
"""
|
|
252
|
+
complex_obj = ComplexDemo(id=1, string_list=["a", "b", "c"], json_data={"a": "a"})
|
|
253
|
+
await db_connection.insert([complex_obj])
|
|
254
|
+
|
|
255
|
+
# Table selection
|
|
256
|
+
result = await db_connection.exec(QueryBuilder().select(ComplexDemo))
|
|
257
|
+
assert result == [
|
|
258
|
+
ComplexDemo(id=1, string_list=["a", "b", "c"], json_data={"a": "a"})
|
|
259
|
+
]
|
|
260
|
+
|
|
261
|
+
|
|
262
|
+
@pytest.mark.asyncio
|
|
263
|
+
async def test_select_where(db_connection: DBConnection):
|
|
264
|
+
user = UserDemo(name="John Doe", email="john@example.com")
|
|
265
|
+
await db_connection.insert([user])
|
|
266
|
+
|
|
267
|
+
new_query = QueryBuilder().select(UserDemo).where(UserDemo.name == "John Doe")
|
|
268
|
+
result = await db_connection.exec(new_query)
|
|
269
|
+
assert result == [
|
|
270
|
+
UserDemo(id=user.id, name="John Doe", email="john@example.com"),
|
|
271
|
+
]
|
|
272
|
+
|
|
273
|
+
|
|
274
|
+
@pytest.mark.asyncio
|
|
275
|
+
async def test_select_join(db_connection: DBConnection):
|
|
276
|
+
user = UserDemo(name="John Doe", email="john@example.com")
|
|
277
|
+
await db_connection.insert([user])
|
|
278
|
+
assert user.id is not None
|
|
279
|
+
|
|
280
|
+
artifact = ArtifactDemo(title="Artifact 1", user_id=user.id)
|
|
281
|
+
await db_connection.insert([artifact])
|
|
282
|
+
|
|
283
|
+
new_query = (
|
|
284
|
+
QueryBuilder()
|
|
285
|
+
.select((ArtifactDemo, UserDemo.email))
|
|
286
|
+
.join(UserDemo, UserDemo.id == ArtifactDemo.user_id)
|
|
287
|
+
.where(UserDemo.name == "John Doe")
|
|
288
|
+
)
|
|
289
|
+
result = await db_connection.exec(new_query)
|
|
290
|
+
assert result == [
|
|
291
|
+
(
|
|
292
|
+
ArtifactDemo(id=artifact.id, title="Artifact 1", user_id=user.id),
|
|
293
|
+
"john@example.com",
|
|
294
|
+
)
|
|
295
|
+
]
|
|
296
|
+
|
|
297
|
+
|
|
298
|
+
@pytest.mark.asyncio
|
|
299
|
+
async def test_select_join_multiple_tables(db_connection: DBConnection):
|
|
300
|
+
user = UserDemo(name="John Doe", email="john@example.com")
|
|
301
|
+
await db_connection.insert([user])
|
|
302
|
+
assert user.id is not None
|
|
303
|
+
|
|
304
|
+
artifact = ArtifactDemo(title="Artifact 1", user_id=user.id)
|
|
305
|
+
await db_connection.insert([artifact])
|
|
306
|
+
|
|
307
|
+
new_query = (
|
|
308
|
+
QueryBuilder()
|
|
309
|
+
.select((ArtifactDemo, UserDemo))
|
|
310
|
+
.join(UserDemo, UserDemo.id == ArtifactDemo.user_id)
|
|
311
|
+
.where(UserDemo.name == "John Doe")
|
|
312
|
+
)
|
|
313
|
+
result = await db_connection.exec(new_query)
|
|
314
|
+
assert result == [
|
|
315
|
+
(
|
|
316
|
+
ArtifactDemo(id=artifact.id, title="Artifact 1", user_id=user.id),
|
|
317
|
+
UserDemo(id=user.id, name="John Doe", email="john@example.com"),
|
|
318
|
+
)
|
|
319
|
+
]
|
|
320
|
+
|
|
321
|
+
|
|
322
|
+
@pytest.mark.asyncio
|
|
323
|
+
async def test_select_with_limit_and_offset(db_connection: DBConnection):
|
|
324
|
+
users = [
|
|
325
|
+
UserDemo(name="User 1", email="user1@example.com"),
|
|
326
|
+
UserDemo(name="User 2", email="user2@example.com"),
|
|
327
|
+
UserDemo(name="User 3", email="user3@example.com"),
|
|
328
|
+
UserDemo(name="User 4", email="user4@example.com"),
|
|
329
|
+
UserDemo(name="User 5", email="user5@example.com"),
|
|
330
|
+
]
|
|
331
|
+
await db_connection.insert(users)
|
|
332
|
+
|
|
333
|
+
query = (
|
|
334
|
+
QueryBuilder().select(UserDemo).order_by(UserDemo.id, "ASC").limit(2).offset(1)
|
|
335
|
+
)
|
|
336
|
+
result = await db_connection.exec(query)
|
|
337
|
+
assert len(result) == 2
|
|
338
|
+
assert result[0].name == "User 2"
|
|
339
|
+
assert result[1].name == "User 3"
|
|
340
|
+
|
|
341
|
+
|
|
342
|
+
@pytest.mark.asyncio
|
|
343
|
+
async def test_select_with_multiple_where_conditions(db_connection: DBConnection):
|
|
344
|
+
users = [
|
|
345
|
+
UserDemo(name="John Doe", email="john@example.com"),
|
|
346
|
+
UserDemo(name="Jane Doe", email="jane@example.com"),
|
|
347
|
+
UserDemo(name="Bob Smith", email="bob@example.com"),
|
|
348
|
+
]
|
|
349
|
+
await db_connection.insert(users)
|
|
350
|
+
|
|
351
|
+
query = (
|
|
352
|
+
QueryBuilder()
|
|
353
|
+
.select(UserDemo)
|
|
354
|
+
.where(
|
|
355
|
+
column(UserDemo.name).like("%Doe%"), UserDemo.email != "john@example.com"
|
|
356
|
+
)
|
|
357
|
+
)
|
|
358
|
+
result = await db_connection.exec(query)
|
|
359
|
+
assert len(result) == 1
|
|
360
|
+
assert result[0].name == "Jane Doe"
|
|
361
|
+
|
|
362
|
+
|
|
363
|
+
@pytest.mark.asyncio
|
|
364
|
+
async def test_select_with_list_filter(db_connection: DBConnection):
|
|
365
|
+
user = UserDemo(name="John Doe", email="john@example.com")
|
|
366
|
+
await db_connection.insert([user])
|
|
367
|
+
|
|
368
|
+
result = await db_connection.exec(
|
|
369
|
+
QueryBuilder()
|
|
370
|
+
.select(UserDemo)
|
|
371
|
+
.where(
|
|
372
|
+
column(UserDemo.name).in_(["John Doe"]),
|
|
373
|
+
)
|
|
374
|
+
)
|
|
375
|
+
assert result == [UserDemo(id=user.id, name="John Doe", email="john@example.com")]
|
|
376
|
+
|
|
377
|
+
result = await db_connection.exec(
|
|
378
|
+
QueryBuilder()
|
|
379
|
+
.select(UserDemo)
|
|
380
|
+
.where(
|
|
381
|
+
column(UserDemo.name).not_in(["John A"]),
|
|
382
|
+
)
|
|
383
|
+
)
|
|
384
|
+
assert result == [UserDemo(id=user.id, name="John Doe", email="john@example.com")]
|
|
385
|
+
|
|
386
|
+
|
|
387
|
+
@pytest.mark.asyncio
|
|
388
|
+
async def test_select_with_order_by_multiple_columns(db_connection: DBConnection):
|
|
389
|
+
users = [
|
|
390
|
+
UserDemo(name="Alice", email="alice@example.com"),
|
|
391
|
+
UserDemo(name="Bob", email="bob@example.com"),
|
|
392
|
+
UserDemo(name="Charlie", email="charlie@example.com"),
|
|
393
|
+
UserDemo(name="Alice", email="alice2@example.com"),
|
|
394
|
+
]
|
|
395
|
+
await db_connection.insert(users)
|
|
396
|
+
|
|
397
|
+
query = (
|
|
398
|
+
QueryBuilder()
|
|
399
|
+
.select(UserDemo)
|
|
400
|
+
.order_by(UserDemo.name, "ASC")
|
|
401
|
+
.order_by(UserDemo.email, "ASC")
|
|
402
|
+
)
|
|
403
|
+
result = await db_connection.exec(query)
|
|
404
|
+
assert len(result) == 4
|
|
405
|
+
assert result[0].name == "Alice" and result[0].email == "alice2@example.com"
|
|
406
|
+
assert result[1].name == "Alice" and result[1].email == "alice@example.com"
|
|
407
|
+
assert result[2].name == "Bob"
|
|
408
|
+
assert result[3].name == "Charlie"
|
|
409
|
+
|
|
410
|
+
|
|
411
|
+
@pytest.mark.asyncio
|
|
412
|
+
async def test_select_with_group_by_and_having(db_connection: DBConnection):
|
|
413
|
+
users = [
|
|
414
|
+
UserDemo(name="John", email="john@example.com"),
|
|
415
|
+
UserDemo(name="Jane", email="jane@example.com"),
|
|
416
|
+
UserDemo(name="John", email="john2@example.com"),
|
|
417
|
+
UserDemo(name="Bob", email="bob@example.com"),
|
|
418
|
+
]
|
|
419
|
+
await db_connection.insert(users)
|
|
420
|
+
|
|
421
|
+
query = (
|
|
422
|
+
QueryBuilder()
|
|
423
|
+
.select((UserDemo.name, func.count(UserDemo.id)))
|
|
424
|
+
.group_by(UserDemo.name)
|
|
425
|
+
.having(func.count(UserDemo.id) > 1)
|
|
426
|
+
)
|
|
427
|
+
result = await db_connection.exec(query)
|
|
428
|
+
assert len(result) == 1
|
|
429
|
+
assert result[0] == ("John", 2)
|
|
430
|
+
|
|
431
|
+
|
|
432
|
+
@pytest.mark.asyncio
|
|
433
|
+
async def test_select_with_left_join(db_connection: DBConnection):
|
|
434
|
+
users = [
|
|
435
|
+
UserDemo(name="John", email="john@example.com"),
|
|
436
|
+
UserDemo(name="Jane", email="jane@example.com"),
|
|
437
|
+
]
|
|
438
|
+
await db_connection.insert(users)
|
|
439
|
+
|
|
440
|
+
posts = [
|
|
441
|
+
ArtifactDemo(title="John's Post", user_id=users[0].id),
|
|
442
|
+
ArtifactDemo(title="Another Post", user_id=users[0].id),
|
|
443
|
+
]
|
|
444
|
+
await db_connection.insert(posts)
|
|
445
|
+
|
|
446
|
+
query = (
|
|
447
|
+
QueryBuilder()
|
|
448
|
+
.select((UserDemo.name, func.count(ArtifactDemo.id)))
|
|
449
|
+
.join(ArtifactDemo, UserDemo.id == ArtifactDemo.user_id, "LEFT")
|
|
450
|
+
.group_by(UserDemo.name)
|
|
451
|
+
.order_by(UserDemo.name, "ASC")
|
|
452
|
+
)
|
|
453
|
+
result = await db_connection.exec(query)
|
|
454
|
+
assert len(result) == 2
|
|
455
|
+
assert result[0] == ("Jane", 0)
|
|
456
|
+
assert result[1] == ("John", 2)
|
|
457
|
+
|
|
458
|
+
|
|
459
|
+
@pytest.mark.asyncio
|
|
460
|
+
async def test_select_with_left_join_object(db_connection: DBConnection):
|
|
461
|
+
users = [
|
|
462
|
+
UserDemo(name="John", email="john@example.com"),
|
|
463
|
+
UserDemo(name="Jane", email="jane@example.com"),
|
|
464
|
+
]
|
|
465
|
+
await db_connection.insert(users)
|
|
466
|
+
|
|
467
|
+
posts = [
|
|
468
|
+
ArtifactDemo(title="John's Post", user_id=users[0].id),
|
|
469
|
+
ArtifactDemo(title="Another Post", user_id=users[0].id),
|
|
470
|
+
]
|
|
471
|
+
await db_connection.insert(posts)
|
|
472
|
+
|
|
473
|
+
query = (
|
|
474
|
+
QueryBuilder()
|
|
475
|
+
.select((UserDemo, ArtifactDemo))
|
|
476
|
+
.join(ArtifactDemo, UserDemo.id == ArtifactDemo.user_id, "LEFT")
|
|
477
|
+
)
|
|
478
|
+
result = await db_connection.exec(query)
|
|
479
|
+
assert len(result) == 3
|
|
480
|
+
assert result[0] == (users[0], posts[0])
|
|
481
|
+
assert result[1] == (users[0], posts[1])
|
|
482
|
+
assert result[2] == (users[1], None)
|
|
483
|
+
|
|
484
|
+
|
|
485
|
+
# @pytest.mark.asyncio
|
|
486
|
+
# async def test_select_with_subquery(db_connection: DBConnection):
|
|
487
|
+
# users = [
|
|
488
|
+
# UserDemo(name="John", email="john@example.com"),
|
|
489
|
+
# UserDemo(name="Jane", email="jane@example.com"),
|
|
490
|
+
# UserDemo(name="Bob", email="bob@example.com"),
|
|
491
|
+
# ]
|
|
492
|
+
# await db_connection.insert(users)
|
|
493
|
+
|
|
494
|
+
# posts = [
|
|
495
|
+
# ArtifactDemo(title="John's Post", content="Hello", user_id=users[0].id),
|
|
496
|
+
# ArtifactDemo(title="Jane's Post", content="World", user_id=users[1].id),
|
|
497
|
+
# ArtifactDemo(title="John's Second Post", content="!", user_id=users[0].id),
|
|
498
|
+
# ]
|
|
499
|
+
# await db_connection.insert(posts)
|
|
500
|
+
|
|
501
|
+
# subquery = QueryBuilder().select(ArtifactDemo.user_id).where(func.count(ArtifactDemo.id) > 1).group_by(PostDemo.user_id)
|
|
502
|
+
# query = QueryBuilder().select(UserDemo).where(is_column(UserDemo.id).in_(subquery))
|
|
503
|
+
# result = await db_connection.exec(query)
|
|
504
|
+
# assert len(result) == 1
|
|
505
|
+
# assert result[0].name == "John"
|
|
506
|
+
|
|
507
|
+
|
|
508
|
+
@pytest.mark.asyncio
|
|
509
|
+
async def test_select_with_distinct(db_connection: DBConnection):
|
|
510
|
+
users = [
|
|
511
|
+
UserDemo(name="John", email="john@example.com"),
|
|
512
|
+
UserDemo(name="Jane", email="jane@example.com"),
|
|
513
|
+
UserDemo(name="John", email="john2@example.com"),
|
|
514
|
+
]
|
|
515
|
+
await db_connection.insert(users)
|
|
516
|
+
|
|
517
|
+
query = (
|
|
518
|
+
QueryBuilder()
|
|
519
|
+
.select(func.distinct(UserDemo.name))
|
|
520
|
+
.order_by(UserDemo.name, "ASC")
|
|
521
|
+
)
|
|
522
|
+
result = await db_connection.exec(query)
|
|
523
|
+
assert result == ["Jane", "John"]
|
|
524
|
+
|
|
525
|
+
|
|
526
|
+
@pytest.mark.asyncio
|
|
527
|
+
async def test_refresh(db_connection: DBConnection):
|
|
528
|
+
user = UserDemo(name="John Doe", email="john@example.com")
|
|
529
|
+
await db_connection.insert([user])
|
|
530
|
+
|
|
531
|
+
# Update the user with a manual SQL query to simulate another process
|
|
532
|
+
# doing an update
|
|
533
|
+
await db_connection.conn.execute(
|
|
534
|
+
"UPDATE userdemo SET name = 'Jane Doe' WHERE id = $1", user.id
|
|
535
|
+
)
|
|
536
|
+
|
|
537
|
+
# The user object in memory should still have the old name
|
|
538
|
+
assert user.name == "John Doe"
|
|
539
|
+
|
|
540
|
+
# Refreshing the user object from the database should pull the
|
|
541
|
+
# new attributes
|
|
542
|
+
await db_connection.refresh([user])
|
|
543
|
+
assert user.name == "Jane Doe"
|
|
544
|
+
|
|
545
|
+
|
|
546
|
+
@pytest.mark.asyncio
|
|
547
|
+
async def test_get(db_connection: DBConnection):
|
|
548
|
+
"""
|
|
549
|
+
Test retrieving a single record by primary key using the get method.
|
|
550
|
+
"""
|
|
551
|
+
# Create a test user
|
|
552
|
+
user = UserDemo(name="John Doe", email="john@example.com")
|
|
553
|
+
await db_connection.insert([user])
|
|
554
|
+
assert user.id is not None
|
|
555
|
+
|
|
556
|
+
# Test successful get
|
|
557
|
+
retrieved_user = await db_connection.get(UserDemo, user.id)
|
|
558
|
+
assert retrieved_user is not None
|
|
559
|
+
assert retrieved_user.id == user.id
|
|
560
|
+
assert retrieved_user.name == "John Doe"
|
|
561
|
+
assert retrieved_user.email == "john@example.com"
|
|
562
|
+
|
|
563
|
+
# Test get with non-existent ID
|
|
564
|
+
non_existent = await db_connection.get(UserDemo, 9999)
|
|
565
|
+
assert non_existent is None
|
|
566
|
+
|
|
567
|
+
|
|
568
|
+
@pytest.mark.asyncio
|
|
569
|
+
async def test_db_connection_insert_update_enum(db_connection: DBConnection):
|
|
570
|
+
"""
|
|
571
|
+
Test that casting enum types with is working for both insert and updates.
|
|
572
|
+
|
|
573
|
+
"""
|
|
574
|
+
|
|
575
|
+
class EnumValue(StrEnum):
|
|
576
|
+
A = "a"
|
|
577
|
+
B = "b"
|
|
578
|
+
|
|
579
|
+
class EnumDemo(TableBase):
|
|
580
|
+
id: int | None = Field(default=None, primary_key=True)
|
|
581
|
+
value: EnumValue
|
|
582
|
+
|
|
583
|
+
# Clear out previous tables
|
|
584
|
+
await db_connection.conn.execute("DROP TABLE IF EXISTS enumdemo")
|
|
585
|
+
await db_connection.conn.execute("DROP TYPE IF EXISTS enumvalue")
|
|
586
|
+
await create_all(db_connection, [EnumDemo])
|
|
587
|
+
|
|
588
|
+
userdemo = EnumDemo(value=EnumValue.A)
|
|
589
|
+
await db_connection.insert([userdemo])
|
|
590
|
+
|
|
591
|
+
result = await db_connection.conn.fetch("SELECT * FROM enumdemo")
|
|
592
|
+
assert len(result) == 1
|
|
593
|
+
assert result[0]["value"] == "a"
|
|
594
|
+
|
|
595
|
+
userdemo.value = EnumValue.B
|
|
596
|
+
await db_connection.update([userdemo])
|
|
597
|
+
|
|
598
|
+
result = await db_connection.conn.fetch("SELECT * FROM enumdemo")
|
|
599
|
+
assert len(result) == 1
|
|
600
|
+
assert result[0]["value"] == "b"
|
|
601
|
+
|
|
602
|
+
|
|
603
|
+
#
|
|
604
|
+
# Upsert
|
|
605
|
+
#
|
|
606
|
+
|
|
607
|
+
|
|
608
|
+
@pytest.mark.asyncio
|
|
609
|
+
async def test_upsert_basic_insert(db_connection: DBConnection):
|
|
610
|
+
"""
|
|
611
|
+
Test basic insert when no conflict exists
|
|
612
|
+
|
|
613
|
+
"""
|
|
614
|
+
await db_connection.conn.execute(
|
|
615
|
+
"""
|
|
616
|
+
ALTER TABLE userdemo
|
|
617
|
+
ADD CONSTRAINT email_unique UNIQUE (email)
|
|
618
|
+
"""
|
|
619
|
+
)
|
|
620
|
+
|
|
621
|
+
user = UserDemo(name="John Doe", email="john@example.com")
|
|
622
|
+
result = await db_connection.upsert(
|
|
623
|
+
[user],
|
|
624
|
+
conflict_fields=(UserDemo.email,),
|
|
625
|
+
returning_fields=(UserDemo.id, UserDemo.name, UserDemo.email),
|
|
626
|
+
)
|
|
627
|
+
|
|
628
|
+
assert result is not None
|
|
629
|
+
assert len(result) == 1
|
|
630
|
+
assert result[0][1] == "John Doe"
|
|
631
|
+
assert result[0][2] == "john@example.com"
|
|
632
|
+
|
|
633
|
+
# Verify in database
|
|
634
|
+
db_result = await db_connection.conn.fetch("SELECT * FROM userdemo")
|
|
635
|
+
assert len(db_result) == 1
|
|
636
|
+
assert db_result[0][1] == "John Doe"
|
|
637
|
+
|
|
638
|
+
|
|
639
|
+
@pytest.mark.asyncio
|
|
640
|
+
async def test_upsert_update_on_conflict(db_connection: DBConnection):
|
|
641
|
+
"""
|
|
642
|
+
Test update when conflict exists
|
|
643
|
+
|
|
644
|
+
"""
|
|
645
|
+
await db_connection.conn.execute(
|
|
646
|
+
"""
|
|
647
|
+
ALTER TABLE userdemo
|
|
648
|
+
ADD CONSTRAINT email_unique UNIQUE (email)
|
|
649
|
+
"""
|
|
650
|
+
)
|
|
651
|
+
|
|
652
|
+
# First insert
|
|
653
|
+
user = UserDemo(name="John Doe", email="john@example.com")
|
|
654
|
+
await db_connection.insert([user])
|
|
655
|
+
|
|
656
|
+
# Attempt upsert with same email but different name
|
|
657
|
+
new_user = UserDemo(name="Johnny Doe", email="john@example.com")
|
|
658
|
+
result = await db_connection.upsert(
|
|
659
|
+
[new_user],
|
|
660
|
+
conflict_fields=(UserDemo.email,),
|
|
661
|
+
update_fields=(UserDemo.name,),
|
|
662
|
+
returning_fields=(UserDemo.id, UserDemo.name, UserDemo.email),
|
|
663
|
+
)
|
|
664
|
+
|
|
665
|
+
assert result is not None
|
|
666
|
+
assert len(result) == 1
|
|
667
|
+
assert result[0][1] == "Johnny Doe"
|
|
668
|
+
|
|
669
|
+
# Verify only one record exists
|
|
670
|
+
db_result = await db_connection.conn.fetch("SELECT * FROM userdemo")
|
|
671
|
+
assert len(db_result) == 1
|
|
672
|
+
assert db_result[0]["name"] == "Johnny Doe"
|
|
673
|
+
|
|
674
|
+
|
|
675
|
+
@pytest.mark.asyncio
|
|
676
|
+
async def test_upsert_do_nothing_on_conflict(db_connection: DBConnection):
|
|
677
|
+
"""
|
|
678
|
+
Test DO NOTHING behavior when no update_fields specified
|
|
679
|
+
|
|
680
|
+
"""
|
|
681
|
+
await db_connection.conn.execute(
|
|
682
|
+
"""
|
|
683
|
+
ALTER TABLE userdemo
|
|
684
|
+
ADD CONSTRAINT email_unique UNIQUE (email)
|
|
685
|
+
"""
|
|
686
|
+
)
|
|
687
|
+
|
|
688
|
+
# First insert
|
|
689
|
+
user = UserDemo(name="John Doe", email="john@example.com")
|
|
690
|
+
await db_connection.insert([user])
|
|
691
|
+
|
|
692
|
+
# Attempt upsert with same email but different name
|
|
693
|
+
new_user = UserDemo(name="Johnny Doe", email="john@example.com")
|
|
694
|
+
result = await db_connection.upsert(
|
|
695
|
+
[new_user],
|
|
696
|
+
conflict_fields=(UserDemo.email,),
|
|
697
|
+
returning_fields=(UserDemo.id, UserDemo.name, UserDemo.email),
|
|
698
|
+
)
|
|
699
|
+
|
|
700
|
+
# Should return empty list as no update was performed
|
|
701
|
+
assert result == []
|
|
702
|
+
|
|
703
|
+
# Verify original record unchanged
|
|
704
|
+
db_result = await db_connection.conn.fetch("SELECT * FROM userdemo")
|
|
705
|
+
assert len(db_result) == 1
|
|
706
|
+
assert db_result[0][1] == "John Doe"
|
|
707
|
+
|
|
708
|
+
|
|
709
|
+
@pytest.mark.asyncio
|
|
710
|
+
async def test_upsert_multiple_objects(db_connection: DBConnection):
|
|
711
|
+
"""
|
|
712
|
+
Test upserting multiple objects at once
|
|
713
|
+
|
|
714
|
+
"""
|
|
715
|
+
await db_connection.conn.execute(
|
|
716
|
+
"""
|
|
717
|
+
ALTER TABLE userdemo
|
|
718
|
+
ADD CONSTRAINT email_unique UNIQUE (email)
|
|
719
|
+
"""
|
|
720
|
+
)
|
|
721
|
+
|
|
722
|
+
users = [
|
|
723
|
+
UserDemo(name="John Doe", email="john@example.com"),
|
|
724
|
+
UserDemo(name="Jane Doe", email="jane@example.com"),
|
|
725
|
+
]
|
|
726
|
+
result = await db_connection.upsert(
|
|
727
|
+
users,
|
|
728
|
+
conflict_fields=(UserDemo.email,),
|
|
729
|
+
returning_fields=(UserDemo.name, UserDemo.email),
|
|
730
|
+
)
|
|
731
|
+
|
|
732
|
+
assert result is not None
|
|
733
|
+
assert len(result) == 2
|
|
734
|
+
assert {r[1] for r in result} == {"john@example.com", "jane@example.com"}
|
|
735
|
+
|
|
736
|
+
|
|
737
|
+
@pytest.mark.asyncio
|
|
738
|
+
async def test_upsert_empty_list(db_connection: DBConnection):
|
|
739
|
+
await db_connection.conn.execute(
|
|
740
|
+
"""
|
|
741
|
+
ALTER TABLE userdemo
|
|
742
|
+
ADD CONSTRAINT email_unique UNIQUE (email)
|
|
743
|
+
"""
|
|
744
|
+
)
|
|
745
|
+
|
|
746
|
+
"""Test upserting an empty list"""
|
|
747
|
+
result = await db_connection.upsert(
|
|
748
|
+
[], conflict_fields=(UserDemo.email,), returning_fields=(UserDemo.id,)
|
|
749
|
+
)
|
|
750
|
+
assert result is None
|
|
751
|
+
|
|
752
|
+
|
|
753
|
+
@pytest.mark.asyncio
|
|
754
|
+
async def test_upsert_multiple_conflict_fields(db_connection: DBConnection):
|
|
755
|
+
"""
|
|
756
|
+
Test upserting with multiple conflict fields
|
|
757
|
+
|
|
758
|
+
"""
|
|
759
|
+
await db_connection.conn.execute(
|
|
760
|
+
"""
|
|
761
|
+
ALTER TABLE userdemo
|
|
762
|
+
ADD CONSTRAINT email_unique UNIQUE (name, email)
|
|
763
|
+
"""
|
|
764
|
+
)
|
|
765
|
+
|
|
766
|
+
users = [
|
|
767
|
+
UserDemo(name="John Doe", email="john@example.com"),
|
|
768
|
+
UserDemo(name="John Doe", email="john@example.com"),
|
|
769
|
+
UserDemo(name="Jane Doe", email="jane@example.com"),
|
|
770
|
+
]
|
|
771
|
+
result = await db_connection.upsert(
|
|
772
|
+
users,
|
|
773
|
+
conflict_fields=(UserDemo.name, UserDemo.email),
|
|
774
|
+
returning_fields=(UserDemo.name, UserDemo.email),
|
|
775
|
+
)
|
|
776
|
+
|
|
777
|
+
assert result is not None
|
|
778
|
+
assert len(result) == 2
|
|
779
|
+
assert {r[1] for r in result} == {"john@example.com", "jane@example.com"}
|
|
780
|
+
|
|
781
|
+
|
|
782
|
+
@pytest.mark.asyncio
|
|
783
|
+
async def test_for_update_prevents_concurrent_modification(
|
|
784
|
+
db_connection: DBConnection, docker_postgres
|
|
785
|
+
):
|
|
786
|
+
"""
|
|
787
|
+
Test that FOR UPDATE actually locks the row for concurrent modifications.
|
|
788
|
+
"""
|
|
789
|
+
# Create initial user
|
|
790
|
+
user = UserDemo(name="John Doe", email="john@example.com")
|
|
791
|
+
await db_connection.insert([user])
|
|
792
|
+
|
|
793
|
+
async with db_connection.transaction():
|
|
794
|
+
# Lock the row with FOR UPDATE
|
|
795
|
+
[locked_user] = await db_connection.exec(
|
|
796
|
+
QueryBuilder().select(UserDemo).where(UserDemo.id == user.id).for_update()
|
|
797
|
+
)
|
|
798
|
+
assert locked_user.name == "John Doe"
|
|
799
|
+
|
|
800
|
+
# Try to update from another connection - this should block
|
|
801
|
+
# until our transaction is done
|
|
802
|
+
other_conn = DBConnection(
|
|
803
|
+
await asyncpg.connect(
|
|
804
|
+
host=docker_postgres["host"],
|
|
805
|
+
port=docker_postgres["port"],
|
|
806
|
+
user=docker_postgres["user"],
|
|
807
|
+
password=docker_postgres["password"],
|
|
808
|
+
database=docker_postgres["database"],
|
|
809
|
+
)
|
|
810
|
+
)
|
|
811
|
+
try:
|
|
812
|
+
with pytest.raises(asyncpg.exceptions.LockNotAvailableError):
|
|
813
|
+
# This should raise an error since we're using NOWAIT
|
|
814
|
+
await other_conn.exec(
|
|
815
|
+
QueryBuilder()
|
|
816
|
+
.select(UserDemo)
|
|
817
|
+
.where(UserDemo.id == user.id)
|
|
818
|
+
.for_update(nowait=True)
|
|
819
|
+
)
|
|
820
|
+
finally:
|
|
821
|
+
await other_conn.conn.close()
|
|
822
|
+
|
|
823
|
+
|
|
824
|
+
@pytest.mark.asyncio
|
|
825
|
+
async def test_for_update_skip_locked(db_connection: DBConnection, docker_postgres):
|
|
826
|
+
"""
|
|
827
|
+
Test that SKIP LOCKED works as expected.
|
|
828
|
+
"""
|
|
829
|
+
# Create test users
|
|
830
|
+
users = [
|
|
831
|
+
UserDemo(name="User 1", email="user1@example.com"),
|
|
832
|
+
UserDemo(name="User 2", email="user2@example.com"),
|
|
833
|
+
]
|
|
834
|
+
await db_connection.insert(users)
|
|
835
|
+
|
|
836
|
+
async with db_connection.transaction():
|
|
837
|
+
# Lock the first user
|
|
838
|
+
[locked_user] = await db_connection.exec(
|
|
839
|
+
QueryBuilder()
|
|
840
|
+
.select(UserDemo)
|
|
841
|
+
.where(UserDemo.id == users[0].id)
|
|
842
|
+
.for_update()
|
|
843
|
+
)
|
|
844
|
+
assert locked_user.name == "User 1"
|
|
845
|
+
|
|
846
|
+
# From another connection, try to select both users with SKIP LOCKED
|
|
847
|
+
other_conn = DBConnection(
|
|
848
|
+
await asyncpg.connect(
|
|
849
|
+
host=docker_postgres["host"],
|
|
850
|
+
port=docker_postgres["port"],
|
|
851
|
+
user=docker_postgres["user"],
|
|
852
|
+
password=docker_postgres["password"],
|
|
853
|
+
database=docker_postgres["database"],
|
|
854
|
+
)
|
|
855
|
+
)
|
|
856
|
+
try:
|
|
857
|
+
# This should only return User 2 since User 1 is locked
|
|
858
|
+
result = await other_conn.exec(
|
|
859
|
+
QueryBuilder()
|
|
860
|
+
.select(UserDemo)
|
|
861
|
+
.order_by(UserDemo.id, "ASC")
|
|
862
|
+
.for_update(skip_locked=True)
|
|
863
|
+
)
|
|
864
|
+
assert len(result) == 1
|
|
865
|
+
assert result[0].name == "User 2"
|
|
866
|
+
finally:
|
|
867
|
+
await other_conn.conn.close()
|
|
868
|
+
|
|
869
|
+
|
|
870
|
+
@pytest.mark.asyncio
|
|
871
|
+
async def test_for_update_of_with_join(db_connection: DBConnection, docker_postgres):
|
|
872
|
+
"""
|
|
873
|
+
Test FOR UPDATE OF with JOINed tables.
|
|
874
|
+
"""
|
|
875
|
+
# Create test data
|
|
876
|
+
user = UserDemo(name="John Doe", email="john@example.com")
|
|
877
|
+
await db_connection.insert([user])
|
|
878
|
+
|
|
879
|
+
artifact = ArtifactDemo(title="Test Artifact", user_id=user.id)
|
|
880
|
+
await db_connection.insert([artifact])
|
|
881
|
+
|
|
882
|
+
async with db_connection.transaction():
|
|
883
|
+
# Lock only the artifacts table in a join query
|
|
884
|
+
[(selected_artifact, selected_user)] = await db_connection.exec(
|
|
885
|
+
QueryBuilder()
|
|
886
|
+
.select((ArtifactDemo, UserDemo))
|
|
887
|
+
.join(UserDemo, UserDemo.id == ArtifactDemo.user_id)
|
|
888
|
+
.for_update(of=(ArtifactDemo,))
|
|
889
|
+
)
|
|
890
|
+
assert selected_artifact.title == "Test Artifact"
|
|
891
|
+
assert selected_user.name == "John Doe"
|
|
892
|
+
|
|
893
|
+
# In another connection, we should be able to lock the user
|
|
894
|
+
# but not the artifact
|
|
895
|
+
other_conn = DBConnection(
|
|
896
|
+
await asyncpg.connect(
|
|
897
|
+
host=docker_postgres["host"],
|
|
898
|
+
port=docker_postgres["port"],
|
|
899
|
+
user=docker_postgres["user"],
|
|
900
|
+
password=docker_postgres["password"],
|
|
901
|
+
database=docker_postgres["database"],
|
|
902
|
+
)
|
|
903
|
+
)
|
|
904
|
+
try:
|
|
905
|
+
# Should succeed since user table isn't locked
|
|
906
|
+
[other_user] = await other_conn.exec(
|
|
907
|
+
QueryBuilder()
|
|
908
|
+
.select(UserDemo)
|
|
909
|
+
.where(UserDemo.id == user.id)
|
|
910
|
+
.for_update(nowait=True)
|
|
911
|
+
)
|
|
912
|
+
assert other_user.name == "John Doe"
|
|
913
|
+
|
|
914
|
+
# Should fail since artifact table is locked
|
|
915
|
+
with pytest.raises(asyncpg.exceptions.LockNotAvailableError):
|
|
916
|
+
await other_conn.exec(
|
|
917
|
+
QueryBuilder()
|
|
918
|
+
.select(ArtifactDemo)
|
|
919
|
+
.where(ArtifactDemo.id == artifact.id)
|
|
920
|
+
.for_update(nowait=True)
|
|
921
|
+
)
|
|
922
|
+
pytest.fail("Should have raised an error")
|
|
923
|
+
finally:
|
|
924
|
+
await other_conn.conn.close()
|
|
925
|
+
|
|
926
|
+
|
|
927
|
+
@pytest.mark.asyncio
|
|
928
|
+
async def test_select_same_column_name_from_different_tables(
|
|
929
|
+
db_connection: DBConnection,
|
|
930
|
+
):
|
|
931
|
+
"""
|
|
932
|
+
Test that we can correctly select and distinguish between columns with the same name
|
|
933
|
+
from different tables. Both tables have a 'name' column to verify proper disambiguation.
|
|
934
|
+
"""
|
|
935
|
+
# Create tables first
|
|
936
|
+
await db_connection.conn.execute("DROP TABLE IF EXISTS demomodela")
|
|
937
|
+
await db_connection.conn.execute("DROP TABLE IF EXISTS demomodelb")
|
|
938
|
+
await create_all(db_connection, [DemoModelA, DemoModelB])
|
|
939
|
+
|
|
940
|
+
# Create test data
|
|
941
|
+
model_a = DemoModelA(name="Name from A", description="Description A", code="ABC123")
|
|
942
|
+
model_b = DemoModelB(
|
|
943
|
+
name="Name from B",
|
|
944
|
+
category="Category B",
|
|
945
|
+
code="ABC123", # Same code to join on
|
|
946
|
+
)
|
|
947
|
+
await db_connection.insert([model_a, model_b])
|
|
948
|
+
|
|
949
|
+
# Select both name columns and verify they are correctly distinguished
|
|
950
|
+
query = (
|
|
951
|
+
QueryBuilder()
|
|
952
|
+
.select((DemoModelA.name, DemoModelB.name))
|
|
953
|
+
.join(DemoModelB, DemoModelA.code == DemoModelB.code)
|
|
954
|
+
)
|
|
955
|
+
result = await db_connection.exec(query)
|
|
956
|
+
|
|
957
|
+
# The first column should be DemoModelA's name, and the second should be DemoModelB's name
|
|
958
|
+
assert len(result) == 1
|
|
959
|
+
assert result[0] == ("Name from A", "Name from B")
|
|
960
|
+
|
|
961
|
+
# Verify the order is maintained when selecting in reverse
|
|
962
|
+
query_reversed = (
|
|
963
|
+
QueryBuilder()
|
|
964
|
+
.select((DemoModelB.name, DemoModelA.name))
|
|
965
|
+
.join(DemoModelA, DemoModelA.code == DemoModelB.code)
|
|
966
|
+
)
|
|
967
|
+
result_reversed = await db_connection.exec(query_reversed)
|
|
968
|
+
|
|
969
|
+
assert len(result_reversed) == 1
|
|
970
|
+
assert result_reversed[0] == ("Name from B", "Name from A")
|
|
971
|
+
|
|
972
|
+
|
|
973
|
+
@pytest.mark.asyncio
|
|
974
|
+
async def test_select_with_order_by_func_count(db_connection: DBConnection):
|
|
975
|
+
# Create users with different numbers of artifacts
|
|
976
|
+
users = [
|
|
977
|
+
UserDemo(name="John", email="john@example.com"),
|
|
978
|
+
UserDemo(name="Jane", email="jane@example.com"),
|
|
979
|
+
UserDemo(name="Bob", email="bob@example.com"),
|
|
980
|
+
]
|
|
981
|
+
await db_connection.insert(users)
|
|
982
|
+
|
|
983
|
+
# Create artifacts with different counts per user
|
|
984
|
+
artifacts = [
|
|
985
|
+
ArtifactDemo(title="John's Post 1", user_id=users[0].id),
|
|
986
|
+
ArtifactDemo(title="John's Post 2", user_id=users[0].id),
|
|
987
|
+
ArtifactDemo(title="Jane's Post", user_id=users[1].id),
|
|
988
|
+
# Bob has no posts
|
|
989
|
+
]
|
|
990
|
+
await db_connection.insert(artifacts)
|
|
991
|
+
|
|
992
|
+
query = (
|
|
993
|
+
QueryBuilder()
|
|
994
|
+
.select((UserDemo.name, func.count(ArtifactDemo.id)))
|
|
995
|
+
.join(ArtifactDemo, UserDemo.id == ArtifactDemo.user_id, "LEFT")
|
|
996
|
+
.group_by(UserDemo.name)
|
|
997
|
+
.order_by(func.count(ArtifactDemo.id), "DESC")
|
|
998
|
+
)
|
|
999
|
+
result = await db_connection.exec(query)
|
|
1000
|
+
|
|
1001
|
+
assert len(result) == 3
|
|
1002
|
+
# John has 2 posts
|
|
1003
|
+
assert result[0] == ("John", 2)
|
|
1004
|
+
# Jane has 1 post
|
|
1005
|
+
assert result[1] == ("Jane", 1)
|
|
1006
|
+
# Bob has 0 posts
|
|
1007
|
+
assert result[2] == ("Bob", 0)
|
|
1008
|
+
|
|
1009
|
+
|
|
1010
|
+
@pytest.mark.asyncio
|
|
1011
|
+
async def test_json_update(db_connection: DBConnection):
|
|
1012
|
+
"""
|
|
1013
|
+
Test that JSON fields are correctly serialized during updates.
|
|
1014
|
+
"""
|
|
1015
|
+
# Create the table first
|
|
1016
|
+
await db_connection.conn.execute("DROP TABLE IF EXISTS jsondemo")
|
|
1017
|
+
await create_all(db_connection, [JsonDemo])
|
|
1018
|
+
|
|
1019
|
+
# Create initial object with JSON data
|
|
1020
|
+
demo = JsonDemo(
|
|
1021
|
+
settings={"theme": "dark", "notifications": True},
|
|
1022
|
+
metadata={"version": 1},
|
|
1023
|
+
unique_val="1",
|
|
1024
|
+
)
|
|
1025
|
+
await db_connection.insert([demo])
|
|
1026
|
+
|
|
1027
|
+
# Update JSON fields
|
|
1028
|
+
demo.settings = {"theme": "light", "notifications": False}
|
|
1029
|
+
demo.metadata = {"version": 2, "last_updated": "2024-01-01"}
|
|
1030
|
+
await db_connection.update([demo])
|
|
1031
|
+
|
|
1032
|
+
# Verify the update through a fresh select
|
|
1033
|
+
result = await db_connection.exec(
|
|
1034
|
+
QueryBuilder().select(JsonDemo).where(JsonDemo.id == demo.id)
|
|
1035
|
+
)
|
|
1036
|
+
assert len(result) == 1
|
|
1037
|
+
assert result[0].settings == {"theme": "light", "notifications": False}
|
|
1038
|
+
assert result[0].metadata == {"version": 2, "last_updated": "2024-01-01"}
|
|
1039
|
+
|
|
1040
|
+
|
|
1041
|
+
@pytest.mark.asyncio
|
|
1042
|
+
async def test_json_upsert(db_connection: DBConnection):
|
|
1043
|
+
"""
|
|
1044
|
+
Test that JSON fields are correctly serialized during upsert operations.
|
|
1045
|
+
"""
|
|
1046
|
+
# Create the table first
|
|
1047
|
+
await db_connection.conn.execute("DROP TABLE IF EXISTS jsondemo")
|
|
1048
|
+
await create_all(db_connection, [JsonDemo])
|
|
1049
|
+
|
|
1050
|
+
# Initial insert via upsert
|
|
1051
|
+
demo = JsonDemo(
|
|
1052
|
+
settings={"theme": "dark", "notifications": True},
|
|
1053
|
+
metadata={"version": 1},
|
|
1054
|
+
unique_val="1",
|
|
1055
|
+
)
|
|
1056
|
+
result = await db_connection.upsert(
|
|
1057
|
+
[demo],
|
|
1058
|
+
conflict_fields=(JsonDemo.unique_val,),
|
|
1059
|
+
update_fields=(JsonDemo.metadata,),
|
|
1060
|
+
returning_fields=(JsonDemo.unique_val, JsonDemo.metadata),
|
|
1061
|
+
)
|
|
1062
|
+
|
|
1063
|
+
assert result is not None
|
|
1064
|
+
assert len(result) == 1
|
|
1065
|
+
assert result[0][0] == "1"
|
|
1066
|
+
assert result[0][1] == {"version": 1}
|
|
1067
|
+
|
|
1068
|
+
# Update via upsert
|
|
1069
|
+
demo2 = JsonDemo(
|
|
1070
|
+
settings={"theme": "dark", "notifications": True},
|
|
1071
|
+
metadata={"version": 2, "last_updated": "2024-01-01"}, # New metadata
|
|
1072
|
+
unique_val="1", # Same value to trigger update
|
|
1073
|
+
)
|
|
1074
|
+
result = await db_connection.upsert(
|
|
1075
|
+
[demo2],
|
|
1076
|
+
conflict_fields=(JsonDemo.unique_val,),
|
|
1077
|
+
update_fields=(JsonDemo.metadata,),
|
|
1078
|
+
returning_fields=(JsonDemo.unique_val, JsonDemo.metadata),
|
|
1079
|
+
)
|
|
1080
|
+
|
|
1081
|
+
assert result is not None
|
|
1082
|
+
assert len(result) == 1
|
|
1083
|
+
assert result[0][0] == "1"
|
|
1084
|
+
assert result[0][1] == {"version": 2, "last_updated": "2024-01-01"}
|
|
1085
|
+
|
|
1086
|
+
# Verify through a fresh select
|
|
1087
|
+
result = await db_connection.exec(QueryBuilder().select(JsonDemo))
|
|
1088
|
+
assert len(result) == 1
|
|
1089
|
+
assert result[0].settings == {"theme": "dark", "notifications": True}
|
|
1090
|
+
assert result[0].metadata == {"version": 2, "last_updated": "2024-01-01"}
|
|
1091
|
+
|
|
1092
|
+
|
|
1093
|
+
@pytest.mark.asyncio
|
|
1094
|
+
async def test_db_connection_update_batched(db_connection: DBConnection):
|
|
1095
|
+
"""Test that updates are properly batched when dealing with many objects and different field combinations."""
|
|
1096
|
+
# Create test data with different update patterns
|
|
1097
|
+
users_group1 = [
|
|
1098
|
+
UserDemo(name=f"User{i}", email=f"user{i}@example.com") for i in range(10)
|
|
1099
|
+
]
|
|
1100
|
+
users_group2 = [
|
|
1101
|
+
UserDemo(name=f"User{i}", email=f"user{i}@example.com") for i in range(10, 20)
|
|
1102
|
+
]
|
|
1103
|
+
users_group3 = [
|
|
1104
|
+
UserDemo(name=f"User{i}", email=f"user{i}@example.com") for i in range(20, 30)
|
|
1105
|
+
]
|
|
1106
|
+
all_users = users_group1 + users_group2 + users_group3
|
|
1107
|
+
await db_connection.insert(all_users)
|
|
1108
|
+
|
|
1109
|
+
# Modify different fields for different groups to test batching by modified fields
|
|
1110
|
+
for user in users_group1:
|
|
1111
|
+
user.name = f"Updated{user.name}" # Only name modified
|
|
1112
|
+
|
|
1113
|
+
for user in users_group2:
|
|
1114
|
+
user.email = f"updated_{user.email}" # Only email modified
|
|
1115
|
+
|
|
1116
|
+
for user in users_group3:
|
|
1117
|
+
user.name = f"Updated{user.name}" # Both fields modified
|
|
1118
|
+
user.email = f"updated_{user.email}"
|
|
1119
|
+
|
|
1120
|
+
await db_connection.update(all_users)
|
|
1121
|
+
|
|
1122
|
+
# Verify all updates were applied correctly
|
|
1123
|
+
result = await db_connection.conn.fetch("SELECT * FROM userdemo ORDER BY id")
|
|
1124
|
+
assert len(result) == 30
|
|
1125
|
+
|
|
1126
|
+
# Check group 1 (only names updated)
|
|
1127
|
+
for i, row in enumerate(result[:10]):
|
|
1128
|
+
assert row["name"] == f"UpdatedUser{i}"
|
|
1129
|
+
assert row["email"] == f"user{i}@example.com"
|
|
1130
|
+
|
|
1131
|
+
# Check group 2 (only emails updated)
|
|
1132
|
+
for i, row in enumerate(result[10:20]):
|
|
1133
|
+
assert row["name"] == f"User{i + 10}"
|
|
1134
|
+
assert row["email"] == f"updated_user{i + 10}@example.com"
|
|
1135
|
+
|
|
1136
|
+
# Check group 3 (both fields updated)
|
|
1137
|
+
for i, row in enumerate(result[20:30]):
|
|
1138
|
+
assert row["name"] == f"UpdatedUser{i + 20}"
|
|
1139
|
+
assert row["email"] == f"updated_user{i + 20}@example.com"
|
|
1140
|
+
|
|
1141
|
+
# Verify all modifications were cleared
|
|
1142
|
+
assert all(user.get_modified_attributes() == {} for user in all_users)
|
|
1143
|
+
|
|
1144
|
+
|
|
1145
|
+
#
|
|
1146
|
+
# Batch query construction
|
|
1147
|
+
#
|
|
1148
|
+
|
|
1149
|
+
|
|
1150
|
+
def assert_expected_user_fields(user: Type[UserDemo]):
|
|
1151
|
+
# Verify UserDemo structure hasn't changed - if this fails, update the parameter calculations below
|
|
1152
|
+
assert {
|
|
1153
|
+
key for key in UserDemo.model_fields.keys() if key not in INTERNAL_TABLE_FIELDS
|
|
1154
|
+
} == {"id", "name", "email"}
|
|
1155
|
+
assert UserDemo.model_fields["id"].primary_key
|
|
1156
|
+
assert UserDemo.model_fields["id"].default is None
|
|
1157
|
+
return True
|
|
1158
|
+
|
|
1159
|
+
|
|
1160
|
+
@asynccontextmanager
|
|
1161
|
+
async def mock_transaction():
|
|
1162
|
+
yield
|
|
1163
|
+
|
|
1164
|
+
|
|
1165
|
+
@pytest.mark.asyncio
|
|
1166
|
+
async def test_batch_insert_exceeds_parameters():
|
|
1167
|
+
"""
|
|
1168
|
+
Test that insert() correctly batches operations when we exceed Postgres parameter limits.
|
|
1169
|
+
We'll create enough objects with enough fields that a single query would exceed PG_MAX_PARAMETERS.
|
|
1170
|
+
"""
|
|
1171
|
+
assert assert_expected_user_fields(UserDemo)
|
|
1172
|
+
|
|
1173
|
+
# Mock the connection
|
|
1174
|
+
mock_conn = AsyncMock()
|
|
1175
|
+
mock_conn.fetchmany = AsyncMock(return_value=[{"id": i} for i in range(1000)])
|
|
1176
|
+
mock_conn.executemany = AsyncMock()
|
|
1177
|
+
mock_conn.transaction = mock_transaction
|
|
1178
|
+
|
|
1179
|
+
db = DBConnection(mock_conn)
|
|
1180
|
+
|
|
1181
|
+
# Calculate how many objects we need to exceed the parameter limit
|
|
1182
|
+
# Each object has 2 fields (name, email) in UserDemo
|
|
1183
|
+
# So each object uses 2 parameters
|
|
1184
|
+
objects_needed = (PG_MAX_PARAMETERS // 2) + 1
|
|
1185
|
+
users = [
|
|
1186
|
+
UserDemo(name=f"User {i}", email=f"user{i}@example.com")
|
|
1187
|
+
for i in range(objects_needed)
|
|
1188
|
+
]
|
|
1189
|
+
|
|
1190
|
+
# Insert the objects
|
|
1191
|
+
await db.insert(users)
|
|
1192
|
+
|
|
1193
|
+
# We should have made at least 2 calls to fetchmany since we exceeded the parameter limit
|
|
1194
|
+
assert len(mock_conn.fetchmany.mock_calls) >= 2
|
|
1195
|
+
|
|
1196
|
+
# Verify the structure of the first call
|
|
1197
|
+
first_call = mock_conn.fetchmany.mock_calls[0]
|
|
1198
|
+
assert "INSERT INTO" in first_call.args[0]
|
|
1199
|
+
assert '"name"' in first_call.args[0]
|
|
1200
|
+
assert '"email"' in first_call.args[0]
|
|
1201
|
+
assert "RETURNING" in first_call.args[0]
|
|
1202
|
+
|
|
1203
|
+
|
|
1204
|
+
@pytest.mark.asyncio
|
|
1205
|
+
async def test_batch_update_exceeds_parameters():
|
|
1206
|
+
"""
|
|
1207
|
+
Test that update() correctly batches operations when we exceed Postgres parameter limits.
|
|
1208
|
+
We'll create enough objects with enough modified fields that a single query would exceed PG_MAX_PARAMETERS.
|
|
1209
|
+
"""
|
|
1210
|
+
assert assert_expected_user_fields(UserDemo)
|
|
1211
|
+
|
|
1212
|
+
# Mock the connection
|
|
1213
|
+
mock_conn = AsyncMock()
|
|
1214
|
+
mock_conn.executemany = AsyncMock()
|
|
1215
|
+
mock_conn.transaction = mock_transaction
|
|
1216
|
+
|
|
1217
|
+
db = DBConnection(mock_conn)
|
|
1218
|
+
|
|
1219
|
+
# Calculate how many objects we need to exceed the parameter limit
|
|
1220
|
+
# Each UPDATE row needs:
|
|
1221
|
+
# - 1 parameter for WHERE clause (id)
|
|
1222
|
+
# - 2 parameters for SET clause (name, email)
|
|
1223
|
+
# So each object uses 3 parameters
|
|
1224
|
+
objects_needed = (PG_MAX_PARAMETERS // 3) + 1
|
|
1225
|
+
users: list[UserDemo] = []
|
|
1226
|
+
|
|
1227
|
+
# Create objects and mark all fields as modified
|
|
1228
|
+
for i in range(objects_needed):
|
|
1229
|
+
user = UserDemo(id=i, name=f"User {i}", email=f"user{i}@example.com")
|
|
1230
|
+
user.clear_modified_attributes()
|
|
1231
|
+
|
|
1232
|
+
# Simulate modifications to both fields
|
|
1233
|
+
user.name = f"New User {i}"
|
|
1234
|
+
user.email = f"newuser{i}@example.com"
|
|
1235
|
+
|
|
1236
|
+
users.append(user)
|
|
1237
|
+
|
|
1238
|
+
# Update the objects
|
|
1239
|
+
await db.update(users)
|
|
1240
|
+
|
|
1241
|
+
# We should have made at least 2 calls to executemany since we exceeded the parameter limit
|
|
1242
|
+
assert len(mock_conn.executemany.mock_calls) >= 2
|
|
1243
|
+
|
|
1244
|
+
# Verify the structure of the first call
|
|
1245
|
+
first_call = mock_conn.executemany.mock_calls[0]
|
|
1246
|
+
assert "UPDATE" in first_call.args[0]
|
|
1247
|
+
assert "SET" in first_call.args[0]
|
|
1248
|
+
assert "WHERE" in first_call.args[0]
|
|
1249
|
+
assert '"id"' in first_call.args[0]
|
|
1250
|
+
|
|
1251
|
+
|
|
1252
|
+
@pytest.mark.asyncio
|
|
1253
|
+
async def test_batch_upsert_exceeds_parameters():
|
|
1254
|
+
"""
|
|
1255
|
+
Test that upsert() correctly batches operations when we exceed Postgres parameter limits.
|
|
1256
|
+
We'll create enough objects with enough fields that a single query would exceed PG_MAX_PARAMETERS.
|
|
1257
|
+
"""
|
|
1258
|
+
assert assert_expected_user_fields(UserDemo)
|
|
1259
|
+
|
|
1260
|
+
# Calculate how many objects we need to exceed the parameter limit
|
|
1261
|
+
# Each object has 2 fields (name, email) in UserDemo
|
|
1262
|
+
# So each object uses 2 parameters
|
|
1263
|
+
objects_needed = (PG_MAX_PARAMETERS // 2) + 1
|
|
1264
|
+
users = [
|
|
1265
|
+
UserDemo(name=f"User {i}", email=f"user{i}@example.com")
|
|
1266
|
+
for i in range(objects_needed)
|
|
1267
|
+
]
|
|
1268
|
+
|
|
1269
|
+
# Mock the connection with dynamic results based on input
|
|
1270
|
+
mock_conn = AsyncMock()
|
|
1271
|
+
mock_conn.fetchmany = AsyncMock(
|
|
1272
|
+
side_effect=lambda query, values_list: [
|
|
1273
|
+
{"id": i, "name": f"User {i}", "email": f"user{i}@example.com"}
|
|
1274
|
+
for i in range(len(values_list))
|
|
1275
|
+
]
|
|
1276
|
+
)
|
|
1277
|
+
mock_conn.executemany = AsyncMock()
|
|
1278
|
+
mock_conn.transaction = mock_transaction
|
|
1279
|
+
|
|
1280
|
+
db = DBConnection(mock_conn)
|
|
1281
|
+
|
|
1282
|
+
# Upsert the objects with all possible kwargs
|
|
1283
|
+
result = await db.upsert(
|
|
1284
|
+
users,
|
|
1285
|
+
conflict_fields=(UserDemo.email,),
|
|
1286
|
+
update_fields=(UserDemo.name,),
|
|
1287
|
+
returning_fields=(UserDemo.id, UserDemo.name, UserDemo.email),
|
|
1288
|
+
)
|
|
1289
|
+
|
|
1290
|
+
# We should have made at least 2 calls to fetchmany since we exceeded the parameter limit
|
|
1291
|
+
assert len(mock_conn.fetchmany.mock_calls) >= 2
|
|
1292
|
+
|
|
1293
|
+
# Verify the structure of the first call
|
|
1294
|
+
first_call = mock_conn.fetchmany.mock_calls[0]
|
|
1295
|
+
assert "INSERT INTO" in first_call.args[0]
|
|
1296
|
+
assert "ON CONFLICT" in first_call.args[0]
|
|
1297
|
+
assert "DO UPDATE SET" in first_call.args[0]
|
|
1298
|
+
assert "RETURNING" in first_call.args[0]
|
|
1299
|
+
|
|
1300
|
+
# Verify we got back the expected number of results
|
|
1301
|
+
assert result is not None
|
|
1302
|
+
assert len(result) == objects_needed
|
|
1303
|
+
assert all(len(r) == 3 for r in result) # Each result should have id, name, email
|
|
1304
|
+
|
|
1305
|
+
|
|
1306
|
+
@pytest.mark.asyncio
|
|
1307
|
+
async def test_batch_upsert_multiple_with_real_db(db_connection: DBConnection):
|
|
1308
|
+
"""
|
|
1309
|
+
Integration test for upserting multiple objects at once with a real database connection.
|
|
1310
|
+
Tests both insert and update scenarios in the same batch.
|
|
1311
|
+
"""
|
|
1312
|
+
await db_connection.conn.execute(
|
|
1313
|
+
"""
|
|
1314
|
+
ALTER TABLE userdemo
|
|
1315
|
+
ADD CONSTRAINT email_unique UNIQUE (email)
|
|
1316
|
+
"""
|
|
1317
|
+
)
|
|
1318
|
+
|
|
1319
|
+
# Create initial set of users
|
|
1320
|
+
initial_users = [
|
|
1321
|
+
UserDemo(name="User 1", email="user1@example.com"),
|
|
1322
|
+
UserDemo(name="User 2", email="user2@example.com"),
|
|
1323
|
+
]
|
|
1324
|
+
await db_connection.insert(initial_users)
|
|
1325
|
+
|
|
1326
|
+
# Create a mix of new and existing users for upsert
|
|
1327
|
+
users_to_upsert = [
|
|
1328
|
+
# These should update
|
|
1329
|
+
UserDemo(name="Updated User 1", email="user1@example.com"),
|
|
1330
|
+
UserDemo(name="Updated User 2", email="user2@example.com"),
|
|
1331
|
+
# These should insert
|
|
1332
|
+
UserDemo(name="User 3", email="user3@example.com"),
|
|
1333
|
+
UserDemo(name="User 4", email="user4@example.com"),
|
|
1334
|
+
]
|
|
1335
|
+
|
|
1336
|
+
result = await db_connection.upsert(
|
|
1337
|
+
users_to_upsert,
|
|
1338
|
+
conflict_fields=(UserDemo.email,),
|
|
1339
|
+
update_fields=(UserDemo.name,),
|
|
1340
|
+
returning_fields=(UserDemo.name, UserDemo.email),
|
|
1341
|
+
)
|
|
1342
|
+
|
|
1343
|
+
# Verify we got all results back
|
|
1344
|
+
assert result is not None
|
|
1345
|
+
assert len(result) == 4
|
|
1346
|
+
|
|
1347
|
+
# Verify the database state
|
|
1348
|
+
db_result = await db_connection.conn.fetch("SELECT * FROM userdemo ORDER BY email")
|
|
1349
|
+
assert len(db_result) == 4
|
|
1350
|
+
|
|
1351
|
+
# Check that updates worked
|
|
1352
|
+
assert db_result[0]["name"] == "Updated User 1"
|
|
1353
|
+
assert db_result[1]["name"] == "Updated User 2"
|
|
1354
|
+
|
|
1355
|
+
# Check that inserts worked
|
|
1356
|
+
assert db_result[2]["name"] == "User 3"
|
|
1357
|
+
assert db_result[3]["name"] == "User 4"
|
|
1358
|
+
|
|
1359
|
+
|
|
1360
|
+
@pytest.mark.asyncio
|
|
1361
|
+
async def test_initialize_types_caching(docker_postgres):
|
|
1362
|
+
# Clear the global cache for isolation.
|
|
1363
|
+
TYPE_CACHE.clear()
|
|
1364
|
+
|
|
1365
|
+
# Define a sample enum and model that require type introspection.
|
|
1366
|
+
class StatusEnum(StrEnum):
|
|
1367
|
+
ACTIVE = "active"
|
|
1368
|
+
INACTIVE = "inactive"
|
|
1369
|
+
PENDING = "pending"
|
|
1370
|
+
|
|
1371
|
+
class ComplexTypeDemo(TableBase):
|
|
1372
|
+
id: int = Field(primary_key=True)
|
|
1373
|
+
status: StatusEnum
|
|
1374
|
+
tags: list[str]
|
|
1375
|
+
metadata: dict[Any, Any] = Field(is_json=True)
|
|
1376
|
+
|
|
1377
|
+
# Establish the first connection.
|
|
1378
|
+
conn1 = await asyncpg.connect(
|
|
1379
|
+
host=docker_postgres["host"],
|
|
1380
|
+
port=docker_postgres["port"],
|
|
1381
|
+
user=docker_postgres["user"],
|
|
1382
|
+
password=docker_postgres["password"],
|
|
1383
|
+
database=docker_postgres["database"],
|
|
1384
|
+
)
|
|
1385
|
+
db1 = DBConnection(conn1)
|
|
1386
|
+
|
|
1387
|
+
# Prepare the database schema.
|
|
1388
|
+
await db1.conn.execute("DROP TYPE IF EXISTS statusenum CASCADE")
|
|
1389
|
+
await db1.conn.execute("DROP TABLE IF EXISTS complextypedemo")
|
|
1390
|
+
await create_all(db1, [ComplexTypeDemo])
|
|
1391
|
+
|
|
1392
|
+
# Save the original method.
|
|
1393
|
+
original_introspect = Connection._introspect_types
|
|
1394
|
+
|
|
1395
|
+
# Default value
|
|
1396
|
+
introspect_wrapper_call_count = 0
|
|
1397
|
+
|
|
1398
|
+
# Define a wrapper that counts calls and then calls through.
|
|
1399
|
+
async def introspect_wrapper(self, types_with_missing_codecs, timeout):
|
|
1400
|
+
nonlocal introspect_wrapper_call_count
|
|
1401
|
+
introspect_wrapper_call_count += 1
|
|
1402
|
+
return await original_introspect(self, types_with_missing_codecs, timeout)
|
|
1403
|
+
|
|
1404
|
+
# Patch the _introspect_types method on the Connection class.
|
|
1405
|
+
with patch.object(Connection, "_introspect_types", new=introspect_wrapper):
|
|
1406
|
+
# For the first connection, initialize types.
|
|
1407
|
+
await db1.initialize_types()
|
|
1408
|
+
# Verify that introspection was called.
|
|
1409
|
+
assert introspect_wrapper_call_count == 1
|
|
1410
|
+
|
|
1411
|
+
# Insert test data via the first connection.
|
|
1412
|
+
demo1 = ComplexTypeDemo(
|
|
1413
|
+
id=1,
|
|
1414
|
+
status=StatusEnum.ACTIVE,
|
|
1415
|
+
tags=["test", "demo"],
|
|
1416
|
+
metadata={"version": 1},
|
|
1417
|
+
)
|
|
1418
|
+
await db1.insert([demo1])
|
|
1419
|
+
|
|
1420
|
+
# Create a second connection to the same database.
|
|
1421
|
+
conn2 = await asyncpg.connect(
|
|
1422
|
+
host=docker_postgres["host"],
|
|
1423
|
+
port=docker_postgres["port"],
|
|
1424
|
+
user=docker_postgres["user"],
|
|
1425
|
+
password=docker_postgres["password"],
|
|
1426
|
+
database=docker_postgres["database"],
|
|
1427
|
+
)
|
|
1428
|
+
db2 = DBConnection(conn2)
|
|
1429
|
+
|
|
1430
|
+
# For the second connection, initializing types should use the cache.
|
|
1431
|
+
await db2.initialize_types()
|
|
1432
|
+
|
|
1433
|
+
# The call count should remain unchanged.
|
|
1434
|
+
assert introspect_wrapper_call_count == 1
|
|
1435
|
+
|
|
1436
|
+
# Verify that we can query the inserted record via the second connection.
|
|
1437
|
+
results = await db2.exec(
|
|
1438
|
+
QueryBuilder().select(ComplexTypeDemo).order_by(ComplexTypeDemo.id, "ASC")
|
|
1439
|
+
)
|
|
1440
|
+
assert len(results) == 1
|
|
1441
|
+
assert results[0].status == StatusEnum.ACTIVE
|
|
1442
|
+
|
|
1443
|
+
# Insert additional data via the second connection.
|
|
1444
|
+
demo2 = ComplexTypeDemo(
|
|
1445
|
+
id=2,
|
|
1446
|
+
status=StatusEnum.PENDING,
|
|
1447
|
+
tags=["test2", "demo2"],
|
|
1448
|
+
metadata={"version": 2},
|
|
1449
|
+
)
|
|
1450
|
+
await db2.insert([demo2])
|
|
1451
|
+
|
|
1452
|
+
# Retrieve and verify data from both connections.
|
|
1453
|
+
result1 = await db1.exec(
|
|
1454
|
+
QueryBuilder().select(ComplexTypeDemo).order_by(ComplexTypeDemo.id, "ASC")
|
|
1455
|
+
)
|
|
1456
|
+
result2 = await db2.exec(
|
|
1457
|
+
QueryBuilder().select(ComplexTypeDemo).order_by(ComplexTypeDemo.id, "ASC")
|
|
1458
|
+
)
|
|
1459
|
+
|
|
1460
|
+
assert len(result1) == 2
|
|
1461
|
+
assert len(result2) == 2
|
|
1462
|
+
assert result1[0].status == StatusEnum.ACTIVE
|
|
1463
|
+
assert result1[1].status == StatusEnum.PENDING
|
|
1464
|
+
assert result2[0].tags == ["test", "demo"]
|
|
1465
|
+
assert result2[1].tags == ["test2", "demo2"]
|
|
1466
|
+
|
|
1467
|
+
await conn2.close()
|
|
1468
|
+
await conn1.close()
|
|
1469
|
+
|
|
1470
|
+
|
|
1471
|
+
@pytest.mark.asyncio
|
|
1472
|
+
async def test_get_dsn(db_connection: DBConnection):
|
|
1473
|
+
"""
|
|
1474
|
+
Test that get_dsn correctly formats the connection parameters into a DSN string.
|
|
1475
|
+
"""
|
|
1476
|
+
dsn = db_connection.get_dsn()
|
|
1477
|
+
assert dsn.startswith("postgresql://")
|
|
1478
|
+
assert "iceaxe" in dsn
|
|
1479
|
+
assert "localhost" in dsn
|
|
1480
|
+
assert ":" in dsn # Just verify there is a port
|
|
1481
|
+
assert "iceaxe_test_db" in dsn
|
|
1482
|
+
|
|
1483
|
+
|
|
1484
|
+
@pytest.mark.asyncio
|
|
1485
|
+
async def test_nested_transactions(db_connection):
|
|
1486
|
+
"""
|
|
1487
|
+
Test that nested transactions raise an error by default, but work with ensure=True.
|
|
1488
|
+
"""
|
|
1489
|
+
# Start an outer transaction
|
|
1490
|
+
async with db_connection.transaction():
|
|
1491
|
+
# This should work fine
|
|
1492
|
+
assert db_connection.in_transaction is True
|
|
1493
|
+
|
|
1494
|
+
# Nested transaction with ensure=True should work
|
|
1495
|
+
async with db_connection.transaction(ensure=True):
|
|
1496
|
+
assert db_connection.in_transaction is True
|
|
1497
|
+
|
|
1498
|
+
# Nested transaction without ensure should fail
|
|
1499
|
+
with pytest.raises(
|
|
1500
|
+
RuntimeError,
|
|
1501
|
+
match="Cannot start a new transaction while already in a transaction",
|
|
1502
|
+
):
|
|
1503
|
+
async with db_connection.transaction():
|
|
1504
|
+
pass # Should not reach here
|
|
1505
|
+
|
|
1506
|
+
# After outer transaction ends, we should be out of transaction
|
|
1507
|
+
assert db_connection.in_transaction is False
|
|
1508
|
+
|
|
1509
|
+
# Now a new transaction should start without error
|
|
1510
|
+
async with db_connection.transaction():
|
|
1511
|
+
assert db_connection.in_transaction is True
|