iceaxe 0.7.1__cp313-cp313-macosx_11_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of iceaxe might be problematic. Click here for more details.

Files changed (75) hide show
  1. iceaxe/__init__.py +20 -0
  2. iceaxe/__tests__/__init__.py +0 -0
  3. iceaxe/__tests__/benchmarks/__init__.py +0 -0
  4. iceaxe/__tests__/benchmarks/test_bulk_insert.py +45 -0
  5. iceaxe/__tests__/benchmarks/test_select.py +114 -0
  6. iceaxe/__tests__/conf_models.py +133 -0
  7. iceaxe/__tests__/conftest.py +204 -0
  8. iceaxe/__tests__/docker_helpers.py +208 -0
  9. iceaxe/__tests__/helpers.py +268 -0
  10. iceaxe/__tests__/migrations/__init__.py +0 -0
  11. iceaxe/__tests__/migrations/conftest.py +36 -0
  12. iceaxe/__tests__/migrations/test_action_sorter.py +237 -0
  13. iceaxe/__tests__/migrations/test_generator.py +140 -0
  14. iceaxe/__tests__/migrations/test_generics.py +91 -0
  15. iceaxe/__tests__/mountaineer/__init__.py +0 -0
  16. iceaxe/__tests__/mountaineer/dependencies/__init__.py +0 -0
  17. iceaxe/__tests__/mountaineer/dependencies/test_core.py +76 -0
  18. iceaxe/__tests__/schemas/__init__.py +0 -0
  19. iceaxe/__tests__/schemas/test_actions.py +1264 -0
  20. iceaxe/__tests__/schemas/test_cli.py +25 -0
  21. iceaxe/__tests__/schemas/test_db_memory_serializer.py +1525 -0
  22. iceaxe/__tests__/schemas/test_db_serializer.py +398 -0
  23. iceaxe/__tests__/schemas/test_db_stubs.py +190 -0
  24. iceaxe/__tests__/test_alias.py +83 -0
  25. iceaxe/__tests__/test_base.py +52 -0
  26. iceaxe/__tests__/test_comparison.py +383 -0
  27. iceaxe/__tests__/test_field.py +11 -0
  28. iceaxe/__tests__/test_helpers.py +9 -0
  29. iceaxe/__tests__/test_modifications.py +151 -0
  30. iceaxe/__tests__/test_queries.py +605 -0
  31. iceaxe/__tests__/test_queries_str.py +173 -0
  32. iceaxe/__tests__/test_session.py +1511 -0
  33. iceaxe/__tests__/test_text_search.py +287 -0
  34. iceaxe/alias_values.py +67 -0
  35. iceaxe/base.py +350 -0
  36. iceaxe/comparison.py +560 -0
  37. iceaxe/field.py +250 -0
  38. iceaxe/functions.py +906 -0
  39. iceaxe/generics.py +140 -0
  40. iceaxe/io.py +107 -0
  41. iceaxe/logging.py +91 -0
  42. iceaxe/migrations/__init__.py +5 -0
  43. iceaxe/migrations/action_sorter.py +98 -0
  44. iceaxe/migrations/cli.py +228 -0
  45. iceaxe/migrations/client_io.py +62 -0
  46. iceaxe/migrations/generator.py +404 -0
  47. iceaxe/migrations/migration.py +86 -0
  48. iceaxe/migrations/migrator.py +101 -0
  49. iceaxe/modifications.py +176 -0
  50. iceaxe/mountaineer/__init__.py +10 -0
  51. iceaxe/mountaineer/cli.py +74 -0
  52. iceaxe/mountaineer/config.py +46 -0
  53. iceaxe/mountaineer/dependencies/__init__.py +6 -0
  54. iceaxe/mountaineer/dependencies/core.py +67 -0
  55. iceaxe/postgres.py +133 -0
  56. iceaxe/py.typed +0 -0
  57. iceaxe/queries.py +1455 -0
  58. iceaxe/queries_str.py +294 -0
  59. iceaxe/schemas/__init__.py +0 -0
  60. iceaxe/schemas/actions.py +864 -0
  61. iceaxe/schemas/cli.py +30 -0
  62. iceaxe/schemas/db_memory_serializer.py +705 -0
  63. iceaxe/schemas/db_serializer.py +346 -0
  64. iceaxe/schemas/db_stubs.py +525 -0
  65. iceaxe/session.py +860 -0
  66. iceaxe/session_optimized.c +12035 -0
  67. iceaxe/session_optimized.cpython-313-darwin.so +0 -0
  68. iceaxe/session_optimized.pyx +212 -0
  69. iceaxe/sql_types.py +148 -0
  70. iceaxe/typing.py +73 -0
  71. iceaxe-0.7.1.dist-info/METADATA +261 -0
  72. iceaxe-0.7.1.dist-info/RECORD +75 -0
  73. iceaxe-0.7.1.dist-info/WHEEL +6 -0
  74. iceaxe-0.7.1.dist-info/licenses/LICENSE +21 -0
  75. iceaxe-0.7.1.dist-info/top_level.txt +1 -0
@@ -0,0 +1,1511 @@
1
+ from contextlib import asynccontextmanager
2
+ from enum import StrEnum
3
+ from typing import Any, Type
4
+ from unittest.mock import AsyncMock, patch
5
+
6
+ import asyncpg
7
+ import pytest
8
+ from asyncpg.connection import Connection
9
+
10
+ from iceaxe.__tests__.conf_models import (
11
+ ArtifactDemo,
12
+ ComplexDemo,
13
+ DemoModelA,
14
+ DemoModelB,
15
+ JsonDemo,
16
+ UserDemo,
17
+ )
18
+ from iceaxe.base import INTERNAL_TABLE_FIELDS, TableBase
19
+ from iceaxe.field import Field
20
+ from iceaxe.functions import func
21
+ from iceaxe.queries import QueryBuilder
22
+ from iceaxe.schemas.cli import create_all
23
+ from iceaxe.session import (
24
+ PG_MAX_PARAMETERS,
25
+ TYPE_CACHE,
26
+ DBConnection,
27
+ )
28
+ from iceaxe.typing import column
29
+
30
+ #
31
+ # Insert / Update / Delete with ORM objects
32
+ #
33
+
34
+
35
+ @pytest.mark.asyncio
36
+ async def test_db_connection_insert(db_connection: DBConnection):
37
+ user = UserDemo(name="John Doe", email="john@example.com")
38
+ await db_connection.insert([user])
39
+
40
+ result = await db_connection.conn.fetch(
41
+ "SELECT * FROM userdemo WHERE name = $1", "John Doe"
42
+ )
43
+ assert len(result) == 1
44
+ assert result[0]["id"] == user.id
45
+ assert result[0]["name"] == "John Doe"
46
+ assert result[0]["email"] == "john@example.com"
47
+ assert user.get_modified_attributes() == {}
48
+
49
+
50
+ @pytest.mark.asyncio
51
+ async def test_db_connection_update(db_connection: DBConnection):
52
+ user = UserDemo(name="John Doe", email="john@example.com")
53
+ await db_connection.insert([user])
54
+
55
+ user.name = "Jane Doe"
56
+ await db_connection.update([user])
57
+
58
+ result = await db_connection.conn.fetch(
59
+ "SELECT * FROM userdemo WHERE id = $1", user.id
60
+ )
61
+ assert len(result) == 1
62
+ assert result[0]["name"] == "Jane Doe"
63
+ assert user.get_modified_attributes() == {}
64
+
65
+
66
+ @pytest.mark.asyncio
67
+ async def test_db_obj_mixin_track_modifications():
68
+ user = UserDemo(name="John Doe", email="john@example.com")
69
+ assert user.get_modified_attributes() == {}
70
+
71
+ user.name = "Jane Doe"
72
+ assert user.get_modified_attributes() == {"name": "Jane Doe"}
73
+
74
+ user.email = "jane@example.com"
75
+ assert user.get_modified_attributes() == {
76
+ "name": "Jane Doe",
77
+ "email": "jane@example.com",
78
+ }
79
+
80
+ user.clear_modified_attributes()
81
+ assert user.get_modified_attributes() == {}
82
+
83
+
84
+ @pytest.mark.asyncio
85
+ async def test_db_connection_delete_query(db_connection: DBConnection):
86
+ userdemo = [
87
+ UserDemo(name="John Doe", email="john@example.com"),
88
+ UserDemo(name="Jane Doe", email="jane@example.com"),
89
+ ]
90
+
91
+ await db_connection.insert(userdemo)
92
+
93
+ query = QueryBuilder().delete(UserDemo).where(UserDemo.name == "John Doe")
94
+ await db_connection.exec(query)
95
+
96
+ result = await db_connection.conn.fetch("SELECT * FROM userdemo")
97
+ assert len(result) == 1
98
+
99
+
100
+ @pytest.mark.asyncio
101
+ async def test_db_connection_insert_multiple(db_connection: DBConnection):
102
+ userdemo = [
103
+ UserDemo(name="John Doe", email="john@example.com"),
104
+ UserDemo(name="Jane Doe", email="jane@example.com"),
105
+ ]
106
+
107
+ await db_connection.insert(userdemo)
108
+
109
+ result = await db_connection.conn.fetch("SELECT * FROM userdemo ORDER BY id")
110
+ assert len(result) == 2
111
+ assert result[0]["name"] == "John Doe"
112
+ assert result[1]["name"] == "Jane Doe"
113
+ assert userdemo[0].id == result[0]["id"]
114
+ assert userdemo[1].id == result[1]["id"]
115
+ assert all(user.get_modified_attributes() == {} for user in userdemo)
116
+
117
+
118
+ @pytest.mark.asyncio
119
+ async def test_db_connection_update_multiple(db_connection: DBConnection):
120
+ userdemo = [
121
+ UserDemo(name="John Doe", email="john@example.com"),
122
+ UserDemo(name="Jane Doe", email="jane@example.com"),
123
+ ]
124
+ await db_connection.insert(userdemo)
125
+
126
+ userdemo[0].name = "Johnny Doe"
127
+ userdemo[1].email = "janey@example.com"
128
+
129
+ await db_connection.update(userdemo)
130
+
131
+ result = await db_connection.conn.fetch("SELECT * FROM userdemo ORDER BY id")
132
+ assert len(result) == 2
133
+ assert result[0]["name"] == "Johnny Doe"
134
+ assert result[1]["email"] == "janey@example.com"
135
+ assert all(user.get_modified_attributes() == {} for user in userdemo)
136
+
137
+
138
+ @pytest.mark.asyncio
139
+ async def test_db_connection_insert_empty_list(db_connection: DBConnection):
140
+ await db_connection.insert([])
141
+ result = await db_connection.conn.fetch("SELECT * FROM userdemo")
142
+ assert len(result) == 0
143
+
144
+
145
+ @pytest.mark.asyncio
146
+ async def test_db_connection_update_empty_list(db_connection: DBConnection):
147
+ await db_connection.update([])
148
+ # This test doesn't really assert anything, as an empty update shouldn't change the database
149
+
150
+
151
+ @pytest.mark.asyncio
152
+ async def test_db_connection_update_no_modifications(db_connection: DBConnection):
153
+ user = UserDemo(name="John Doe", email="john@example.com")
154
+ await db_connection.insert([user])
155
+
156
+ await db_connection.update([user])
157
+
158
+ result = await db_connection.conn.fetch(
159
+ "SELECT * FROM userdemo WHERE id = $1", user.id
160
+ )
161
+ assert len(result) == 1
162
+ assert result[0]["name"] == "John Doe"
163
+ assert result[0]["email"] == "john@example.com"
164
+
165
+
166
+ @pytest.mark.asyncio
167
+ async def test_delete_object(db_connection: DBConnection):
168
+ user = UserDemo(name="John Doe", email="john@example.com")
169
+ await db_connection.insert([user])
170
+
171
+ result = await db_connection.conn.fetch(
172
+ "SELECT * FROM userdemo WHERE id = $1", user.id
173
+ )
174
+ assert len(result) == 1
175
+
176
+ await db_connection.delete([user])
177
+
178
+ result = await db_connection.conn.fetch(
179
+ "SELECT * FROM userdemo WHERE id = $1", user.id
180
+ )
181
+ assert len(result) == 0
182
+
183
+
184
+ #
185
+ # Select into ORM objects
186
+ #
187
+
188
+
189
+ @pytest.mark.asyncio
190
+ async def test_select(db_connection: DBConnection):
191
+ user = UserDemo(name="John Doe", email="john@example.com")
192
+ await db_connection.insert([user])
193
+
194
+ # Table selection
195
+ result_1 = await db_connection.exec(QueryBuilder().select(UserDemo))
196
+ assert result_1 == [UserDemo(id=user.id, name="John Doe", email="john@example.com")]
197
+
198
+ # Single column selection
199
+ result_2 = await db_connection.exec(QueryBuilder().select(UserDemo.email))
200
+ assert result_2 == ["john@example.com"]
201
+
202
+ # Multiple column selection
203
+ result_3 = await db_connection.exec(
204
+ QueryBuilder().select((UserDemo.name, UserDemo.email))
205
+ )
206
+ assert result_3 == [("John Doe", "john@example.com")]
207
+
208
+ # Table and column selection
209
+ result_4 = await db_connection.exec(
210
+ QueryBuilder().select((UserDemo, UserDemo.email))
211
+ )
212
+ assert result_4 == [
213
+ (
214
+ UserDemo(id=user.id, name="John Doe", email="john@example.com"),
215
+ "john@example.com",
216
+ )
217
+ ]
218
+
219
+
220
+ @pytest.mark.asyncio
221
+ async def test_is_null(db_connection: DBConnection):
222
+ user = UserDemo(name="John Doe", email="john@example.com")
223
+ await db_connection.insert([user])
224
+
225
+ # Table selection
226
+ result_1 = await db_connection.exec(
227
+ QueryBuilder()
228
+ .select(UserDemo)
229
+ .where(
230
+ UserDemo.id == None, # noqa: E711
231
+ )
232
+ )
233
+ assert result_1 == []
234
+
235
+ # Single column selection
236
+ result_2 = await db_connection.exec(
237
+ QueryBuilder()
238
+ .select(UserDemo)
239
+ .where(
240
+ UserDemo.id != None, # noqa: E711
241
+ )
242
+ )
243
+ assert result_2 == [UserDemo(id=user.id, name="John Doe", email="john@example.com")]
244
+
245
+
246
+ @pytest.mark.asyncio
247
+ async def test_select_complex(db_connection: DBConnection):
248
+ """
249
+ Ensure that we can serialize the complex types.
250
+
251
+ """
252
+ complex_obj = ComplexDemo(id=1, string_list=["a", "b", "c"], json_data={"a": "a"})
253
+ await db_connection.insert([complex_obj])
254
+
255
+ # Table selection
256
+ result = await db_connection.exec(QueryBuilder().select(ComplexDemo))
257
+ assert result == [
258
+ ComplexDemo(id=1, string_list=["a", "b", "c"], json_data={"a": "a"})
259
+ ]
260
+
261
+
262
+ @pytest.mark.asyncio
263
+ async def test_select_where(db_connection: DBConnection):
264
+ user = UserDemo(name="John Doe", email="john@example.com")
265
+ await db_connection.insert([user])
266
+
267
+ new_query = QueryBuilder().select(UserDemo).where(UserDemo.name == "John Doe")
268
+ result = await db_connection.exec(new_query)
269
+ assert result == [
270
+ UserDemo(id=user.id, name="John Doe", email="john@example.com"),
271
+ ]
272
+
273
+
274
+ @pytest.mark.asyncio
275
+ async def test_select_join(db_connection: DBConnection):
276
+ user = UserDemo(name="John Doe", email="john@example.com")
277
+ await db_connection.insert([user])
278
+ assert user.id is not None
279
+
280
+ artifact = ArtifactDemo(title="Artifact 1", user_id=user.id)
281
+ await db_connection.insert([artifact])
282
+
283
+ new_query = (
284
+ QueryBuilder()
285
+ .select((ArtifactDemo, UserDemo.email))
286
+ .join(UserDemo, UserDemo.id == ArtifactDemo.user_id)
287
+ .where(UserDemo.name == "John Doe")
288
+ )
289
+ result = await db_connection.exec(new_query)
290
+ assert result == [
291
+ (
292
+ ArtifactDemo(id=artifact.id, title="Artifact 1", user_id=user.id),
293
+ "john@example.com",
294
+ )
295
+ ]
296
+
297
+
298
+ @pytest.mark.asyncio
299
+ async def test_select_join_multiple_tables(db_connection: DBConnection):
300
+ user = UserDemo(name="John Doe", email="john@example.com")
301
+ await db_connection.insert([user])
302
+ assert user.id is not None
303
+
304
+ artifact = ArtifactDemo(title="Artifact 1", user_id=user.id)
305
+ await db_connection.insert([artifact])
306
+
307
+ new_query = (
308
+ QueryBuilder()
309
+ .select((ArtifactDemo, UserDemo))
310
+ .join(UserDemo, UserDemo.id == ArtifactDemo.user_id)
311
+ .where(UserDemo.name == "John Doe")
312
+ )
313
+ result = await db_connection.exec(new_query)
314
+ assert result == [
315
+ (
316
+ ArtifactDemo(id=artifact.id, title="Artifact 1", user_id=user.id),
317
+ UserDemo(id=user.id, name="John Doe", email="john@example.com"),
318
+ )
319
+ ]
320
+
321
+
322
+ @pytest.mark.asyncio
323
+ async def test_select_with_limit_and_offset(db_connection: DBConnection):
324
+ users = [
325
+ UserDemo(name="User 1", email="user1@example.com"),
326
+ UserDemo(name="User 2", email="user2@example.com"),
327
+ UserDemo(name="User 3", email="user3@example.com"),
328
+ UserDemo(name="User 4", email="user4@example.com"),
329
+ UserDemo(name="User 5", email="user5@example.com"),
330
+ ]
331
+ await db_connection.insert(users)
332
+
333
+ query = (
334
+ QueryBuilder().select(UserDemo).order_by(UserDemo.id, "ASC").limit(2).offset(1)
335
+ )
336
+ result = await db_connection.exec(query)
337
+ assert len(result) == 2
338
+ assert result[0].name == "User 2"
339
+ assert result[1].name == "User 3"
340
+
341
+
342
+ @pytest.mark.asyncio
343
+ async def test_select_with_multiple_where_conditions(db_connection: DBConnection):
344
+ users = [
345
+ UserDemo(name="John Doe", email="john@example.com"),
346
+ UserDemo(name="Jane Doe", email="jane@example.com"),
347
+ UserDemo(name="Bob Smith", email="bob@example.com"),
348
+ ]
349
+ await db_connection.insert(users)
350
+
351
+ query = (
352
+ QueryBuilder()
353
+ .select(UserDemo)
354
+ .where(
355
+ column(UserDemo.name).like("%Doe%"), UserDemo.email != "john@example.com"
356
+ )
357
+ )
358
+ result = await db_connection.exec(query)
359
+ assert len(result) == 1
360
+ assert result[0].name == "Jane Doe"
361
+
362
+
363
+ @pytest.mark.asyncio
364
+ async def test_select_with_list_filter(db_connection: DBConnection):
365
+ user = UserDemo(name="John Doe", email="john@example.com")
366
+ await db_connection.insert([user])
367
+
368
+ result = await db_connection.exec(
369
+ QueryBuilder()
370
+ .select(UserDemo)
371
+ .where(
372
+ column(UserDemo.name).in_(["John Doe"]),
373
+ )
374
+ )
375
+ assert result == [UserDemo(id=user.id, name="John Doe", email="john@example.com")]
376
+
377
+ result = await db_connection.exec(
378
+ QueryBuilder()
379
+ .select(UserDemo)
380
+ .where(
381
+ column(UserDemo.name).not_in(["John A"]),
382
+ )
383
+ )
384
+ assert result == [UserDemo(id=user.id, name="John Doe", email="john@example.com")]
385
+
386
+
387
+ @pytest.mark.asyncio
388
+ async def test_select_with_order_by_multiple_columns(db_connection: DBConnection):
389
+ users = [
390
+ UserDemo(name="Alice", email="alice@example.com"),
391
+ UserDemo(name="Bob", email="bob@example.com"),
392
+ UserDemo(name="Charlie", email="charlie@example.com"),
393
+ UserDemo(name="Alice", email="alice2@example.com"),
394
+ ]
395
+ await db_connection.insert(users)
396
+
397
+ query = (
398
+ QueryBuilder()
399
+ .select(UserDemo)
400
+ .order_by(UserDemo.name, "ASC")
401
+ .order_by(UserDemo.email, "ASC")
402
+ )
403
+ result = await db_connection.exec(query)
404
+ assert len(result) == 4
405
+ assert result[0].name == "Alice" and result[0].email == "alice2@example.com"
406
+ assert result[1].name == "Alice" and result[1].email == "alice@example.com"
407
+ assert result[2].name == "Bob"
408
+ assert result[3].name == "Charlie"
409
+
410
+
411
+ @pytest.mark.asyncio
412
+ async def test_select_with_group_by_and_having(db_connection: DBConnection):
413
+ users = [
414
+ UserDemo(name="John", email="john@example.com"),
415
+ UserDemo(name="Jane", email="jane@example.com"),
416
+ UserDemo(name="John", email="john2@example.com"),
417
+ UserDemo(name="Bob", email="bob@example.com"),
418
+ ]
419
+ await db_connection.insert(users)
420
+
421
+ query = (
422
+ QueryBuilder()
423
+ .select((UserDemo.name, func.count(UserDemo.id)))
424
+ .group_by(UserDemo.name)
425
+ .having(func.count(UserDemo.id) > 1)
426
+ )
427
+ result = await db_connection.exec(query)
428
+ assert len(result) == 1
429
+ assert result[0] == ("John", 2)
430
+
431
+
432
+ @pytest.mark.asyncio
433
+ async def test_select_with_left_join(db_connection: DBConnection):
434
+ users = [
435
+ UserDemo(name="John", email="john@example.com"),
436
+ UserDemo(name="Jane", email="jane@example.com"),
437
+ ]
438
+ await db_connection.insert(users)
439
+
440
+ posts = [
441
+ ArtifactDemo(title="John's Post", user_id=users[0].id),
442
+ ArtifactDemo(title="Another Post", user_id=users[0].id),
443
+ ]
444
+ await db_connection.insert(posts)
445
+
446
+ query = (
447
+ QueryBuilder()
448
+ .select((UserDemo.name, func.count(ArtifactDemo.id)))
449
+ .join(ArtifactDemo, UserDemo.id == ArtifactDemo.user_id, "LEFT")
450
+ .group_by(UserDemo.name)
451
+ .order_by(UserDemo.name, "ASC")
452
+ )
453
+ result = await db_connection.exec(query)
454
+ assert len(result) == 2
455
+ assert result[0] == ("Jane", 0)
456
+ assert result[1] == ("John", 2)
457
+
458
+
459
+ @pytest.mark.asyncio
460
+ async def test_select_with_left_join_object(db_connection: DBConnection):
461
+ users = [
462
+ UserDemo(name="John", email="john@example.com"),
463
+ UserDemo(name="Jane", email="jane@example.com"),
464
+ ]
465
+ await db_connection.insert(users)
466
+
467
+ posts = [
468
+ ArtifactDemo(title="John's Post", user_id=users[0].id),
469
+ ArtifactDemo(title="Another Post", user_id=users[0].id),
470
+ ]
471
+ await db_connection.insert(posts)
472
+
473
+ query = (
474
+ QueryBuilder()
475
+ .select((UserDemo, ArtifactDemo))
476
+ .join(ArtifactDemo, UserDemo.id == ArtifactDemo.user_id, "LEFT")
477
+ )
478
+ result = await db_connection.exec(query)
479
+ assert len(result) == 3
480
+ assert result[0] == (users[0], posts[0])
481
+ assert result[1] == (users[0], posts[1])
482
+ assert result[2] == (users[1], None)
483
+
484
+
485
+ # @pytest.mark.asyncio
486
+ # async def test_select_with_subquery(db_connection: DBConnection):
487
+ # users = [
488
+ # UserDemo(name="John", email="john@example.com"),
489
+ # UserDemo(name="Jane", email="jane@example.com"),
490
+ # UserDemo(name="Bob", email="bob@example.com"),
491
+ # ]
492
+ # await db_connection.insert(users)
493
+
494
+ # posts = [
495
+ # ArtifactDemo(title="John's Post", content="Hello", user_id=users[0].id),
496
+ # ArtifactDemo(title="Jane's Post", content="World", user_id=users[1].id),
497
+ # ArtifactDemo(title="John's Second Post", content="!", user_id=users[0].id),
498
+ # ]
499
+ # await db_connection.insert(posts)
500
+
501
+ # subquery = QueryBuilder().select(ArtifactDemo.user_id).where(func.count(ArtifactDemo.id) > 1).group_by(PostDemo.user_id)
502
+ # query = QueryBuilder().select(UserDemo).where(is_column(UserDemo.id).in_(subquery))
503
+ # result = await db_connection.exec(query)
504
+ # assert len(result) == 1
505
+ # assert result[0].name == "John"
506
+
507
+
508
+ @pytest.mark.asyncio
509
+ async def test_select_with_distinct(db_connection: DBConnection):
510
+ users = [
511
+ UserDemo(name="John", email="john@example.com"),
512
+ UserDemo(name="Jane", email="jane@example.com"),
513
+ UserDemo(name="John", email="john2@example.com"),
514
+ ]
515
+ await db_connection.insert(users)
516
+
517
+ query = (
518
+ QueryBuilder()
519
+ .select(func.distinct(UserDemo.name))
520
+ .order_by(UserDemo.name, "ASC")
521
+ )
522
+ result = await db_connection.exec(query)
523
+ assert result == ["Jane", "John"]
524
+
525
+
526
+ @pytest.mark.asyncio
527
+ async def test_refresh(db_connection: DBConnection):
528
+ user = UserDemo(name="John Doe", email="john@example.com")
529
+ await db_connection.insert([user])
530
+
531
+ # Update the user with a manual SQL query to simulate another process
532
+ # doing an update
533
+ await db_connection.conn.execute(
534
+ "UPDATE userdemo SET name = 'Jane Doe' WHERE id = $1", user.id
535
+ )
536
+
537
+ # The user object in memory should still have the old name
538
+ assert user.name == "John Doe"
539
+
540
+ # Refreshing the user object from the database should pull the
541
+ # new attributes
542
+ await db_connection.refresh([user])
543
+ assert user.name == "Jane Doe"
544
+
545
+
546
+ @pytest.mark.asyncio
547
+ async def test_get(db_connection: DBConnection):
548
+ """
549
+ Test retrieving a single record by primary key using the get method.
550
+ """
551
+ # Create a test user
552
+ user = UserDemo(name="John Doe", email="john@example.com")
553
+ await db_connection.insert([user])
554
+ assert user.id is not None
555
+
556
+ # Test successful get
557
+ retrieved_user = await db_connection.get(UserDemo, user.id)
558
+ assert retrieved_user is not None
559
+ assert retrieved_user.id == user.id
560
+ assert retrieved_user.name == "John Doe"
561
+ assert retrieved_user.email == "john@example.com"
562
+
563
+ # Test get with non-existent ID
564
+ non_existent = await db_connection.get(UserDemo, 9999)
565
+ assert non_existent is None
566
+
567
+
568
+ @pytest.mark.asyncio
569
+ async def test_db_connection_insert_update_enum(db_connection: DBConnection):
570
+ """
571
+ Test that casting enum types with is working for both insert and updates.
572
+
573
+ """
574
+
575
+ class EnumValue(StrEnum):
576
+ A = "a"
577
+ B = "b"
578
+
579
+ class EnumDemo(TableBase):
580
+ id: int | None = Field(default=None, primary_key=True)
581
+ value: EnumValue
582
+
583
+ # Clear out previous tables
584
+ await db_connection.conn.execute("DROP TABLE IF EXISTS enumdemo")
585
+ await db_connection.conn.execute("DROP TYPE IF EXISTS enumvalue")
586
+ await create_all(db_connection, [EnumDemo])
587
+
588
+ userdemo = EnumDemo(value=EnumValue.A)
589
+ await db_connection.insert([userdemo])
590
+
591
+ result = await db_connection.conn.fetch("SELECT * FROM enumdemo")
592
+ assert len(result) == 1
593
+ assert result[0]["value"] == "a"
594
+
595
+ userdemo.value = EnumValue.B
596
+ await db_connection.update([userdemo])
597
+
598
+ result = await db_connection.conn.fetch("SELECT * FROM enumdemo")
599
+ assert len(result) == 1
600
+ assert result[0]["value"] == "b"
601
+
602
+
603
+ #
604
+ # Upsert
605
+ #
606
+
607
+
608
+ @pytest.mark.asyncio
609
+ async def test_upsert_basic_insert(db_connection: DBConnection):
610
+ """
611
+ Test basic insert when no conflict exists
612
+
613
+ """
614
+ await db_connection.conn.execute(
615
+ """
616
+ ALTER TABLE userdemo
617
+ ADD CONSTRAINT email_unique UNIQUE (email)
618
+ """
619
+ )
620
+
621
+ user = UserDemo(name="John Doe", email="john@example.com")
622
+ result = await db_connection.upsert(
623
+ [user],
624
+ conflict_fields=(UserDemo.email,),
625
+ returning_fields=(UserDemo.id, UserDemo.name, UserDemo.email),
626
+ )
627
+
628
+ assert result is not None
629
+ assert len(result) == 1
630
+ assert result[0][1] == "John Doe"
631
+ assert result[0][2] == "john@example.com"
632
+
633
+ # Verify in database
634
+ db_result = await db_connection.conn.fetch("SELECT * FROM userdemo")
635
+ assert len(db_result) == 1
636
+ assert db_result[0][1] == "John Doe"
637
+
638
+
639
+ @pytest.mark.asyncio
640
+ async def test_upsert_update_on_conflict(db_connection: DBConnection):
641
+ """
642
+ Test update when conflict exists
643
+
644
+ """
645
+ await db_connection.conn.execute(
646
+ """
647
+ ALTER TABLE userdemo
648
+ ADD CONSTRAINT email_unique UNIQUE (email)
649
+ """
650
+ )
651
+
652
+ # First insert
653
+ user = UserDemo(name="John Doe", email="john@example.com")
654
+ await db_connection.insert([user])
655
+
656
+ # Attempt upsert with same email but different name
657
+ new_user = UserDemo(name="Johnny Doe", email="john@example.com")
658
+ result = await db_connection.upsert(
659
+ [new_user],
660
+ conflict_fields=(UserDemo.email,),
661
+ update_fields=(UserDemo.name,),
662
+ returning_fields=(UserDemo.id, UserDemo.name, UserDemo.email),
663
+ )
664
+
665
+ assert result is not None
666
+ assert len(result) == 1
667
+ assert result[0][1] == "Johnny Doe"
668
+
669
+ # Verify only one record exists
670
+ db_result = await db_connection.conn.fetch("SELECT * FROM userdemo")
671
+ assert len(db_result) == 1
672
+ assert db_result[0]["name"] == "Johnny Doe"
673
+
674
+
675
+ @pytest.mark.asyncio
676
+ async def test_upsert_do_nothing_on_conflict(db_connection: DBConnection):
677
+ """
678
+ Test DO NOTHING behavior when no update_fields specified
679
+
680
+ """
681
+ await db_connection.conn.execute(
682
+ """
683
+ ALTER TABLE userdemo
684
+ ADD CONSTRAINT email_unique UNIQUE (email)
685
+ """
686
+ )
687
+
688
+ # First insert
689
+ user = UserDemo(name="John Doe", email="john@example.com")
690
+ await db_connection.insert([user])
691
+
692
+ # Attempt upsert with same email but different name
693
+ new_user = UserDemo(name="Johnny Doe", email="john@example.com")
694
+ result = await db_connection.upsert(
695
+ [new_user],
696
+ conflict_fields=(UserDemo.email,),
697
+ returning_fields=(UserDemo.id, UserDemo.name, UserDemo.email),
698
+ )
699
+
700
+ # Should return empty list as no update was performed
701
+ assert result == []
702
+
703
+ # Verify original record unchanged
704
+ db_result = await db_connection.conn.fetch("SELECT * FROM userdemo")
705
+ assert len(db_result) == 1
706
+ assert db_result[0][1] == "John Doe"
707
+
708
+
709
+ @pytest.mark.asyncio
710
+ async def test_upsert_multiple_objects(db_connection: DBConnection):
711
+ """
712
+ Test upserting multiple objects at once
713
+
714
+ """
715
+ await db_connection.conn.execute(
716
+ """
717
+ ALTER TABLE userdemo
718
+ ADD CONSTRAINT email_unique UNIQUE (email)
719
+ """
720
+ )
721
+
722
+ users = [
723
+ UserDemo(name="John Doe", email="john@example.com"),
724
+ UserDemo(name="Jane Doe", email="jane@example.com"),
725
+ ]
726
+ result = await db_connection.upsert(
727
+ users,
728
+ conflict_fields=(UserDemo.email,),
729
+ returning_fields=(UserDemo.name, UserDemo.email),
730
+ )
731
+
732
+ assert result is not None
733
+ assert len(result) == 2
734
+ assert {r[1] for r in result} == {"john@example.com", "jane@example.com"}
735
+
736
+
737
+ @pytest.mark.asyncio
738
+ async def test_upsert_empty_list(db_connection: DBConnection):
739
+ await db_connection.conn.execute(
740
+ """
741
+ ALTER TABLE userdemo
742
+ ADD CONSTRAINT email_unique UNIQUE (email)
743
+ """
744
+ )
745
+
746
+ """Test upserting an empty list"""
747
+ result = await db_connection.upsert(
748
+ [], conflict_fields=(UserDemo.email,), returning_fields=(UserDemo.id,)
749
+ )
750
+ assert result is None
751
+
752
+
753
+ @pytest.mark.asyncio
754
+ async def test_upsert_multiple_conflict_fields(db_connection: DBConnection):
755
+ """
756
+ Test upserting with multiple conflict fields
757
+
758
+ """
759
+ await db_connection.conn.execute(
760
+ """
761
+ ALTER TABLE userdemo
762
+ ADD CONSTRAINT email_unique UNIQUE (name, email)
763
+ """
764
+ )
765
+
766
+ users = [
767
+ UserDemo(name="John Doe", email="john@example.com"),
768
+ UserDemo(name="John Doe", email="john@example.com"),
769
+ UserDemo(name="Jane Doe", email="jane@example.com"),
770
+ ]
771
+ result = await db_connection.upsert(
772
+ users,
773
+ conflict_fields=(UserDemo.name, UserDemo.email),
774
+ returning_fields=(UserDemo.name, UserDemo.email),
775
+ )
776
+
777
+ assert result is not None
778
+ assert len(result) == 2
779
+ assert {r[1] for r in result} == {"john@example.com", "jane@example.com"}
780
+
781
+
782
+ @pytest.mark.asyncio
783
+ async def test_for_update_prevents_concurrent_modification(
784
+ db_connection: DBConnection, docker_postgres
785
+ ):
786
+ """
787
+ Test that FOR UPDATE actually locks the row for concurrent modifications.
788
+ """
789
+ # Create initial user
790
+ user = UserDemo(name="John Doe", email="john@example.com")
791
+ await db_connection.insert([user])
792
+
793
+ async with db_connection.transaction():
794
+ # Lock the row with FOR UPDATE
795
+ [locked_user] = await db_connection.exec(
796
+ QueryBuilder().select(UserDemo).where(UserDemo.id == user.id).for_update()
797
+ )
798
+ assert locked_user.name == "John Doe"
799
+
800
+ # Try to update from another connection - this should block
801
+ # until our transaction is done
802
+ other_conn = DBConnection(
803
+ await asyncpg.connect(
804
+ host=docker_postgres["host"],
805
+ port=docker_postgres["port"],
806
+ user=docker_postgres["user"],
807
+ password=docker_postgres["password"],
808
+ database=docker_postgres["database"],
809
+ )
810
+ )
811
+ try:
812
+ with pytest.raises(asyncpg.exceptions.LockNotAvailableError):
813
+ # This should raise an error since we're using NOWAIT
814
+ await other_conn.exec(
815
+ QueryBuilder()
816
+ .select(UserDemo)
817
+ .where(UserDemo.id == user.id)
818
+ .for_update(nowait=True)
819
+ )
820
+ finally:
821
+ await other_conn.conn.close()
822
+
823
+
824
+ @pytest.mark.asyncio
825
+ async def test_for_update_skip_locked(db_connection: DBConnection, docker_postgres):
826
+ """
827
+ Test that SKIP LOCKED works as expected.
828
+ """
829
+ # Create test users
830
+ users = [
831
+ UserDemo(name="User 1", email="user1@example.com"),
832
+ UserDemo(name="User 2", email="user2@example.com"),
833
+ ]
834
+ await db_connection.insert(users)
835
+
836
+ async with db_connection.transaction():
837
+ # Lock the first user
838
+ [locked_user] = await db_connection.exec(
839
+ QueryBuilder()
840
+ .select(UserDemo)
841
+ .where(UserDemo.id == users[0].id)
842
+ .for_update()
843
+ )
844
+ assert locked_user.name == "User 1"
845
+
846
+ # From another connection, try to select both users with SKIP LOCKED
847
+ other_conn = DBConnection(
848
+ await asyncpg.connect(
849
+ host=docker_postgres["host"],
850
+ port=docker_postgres["port"],
851
+ user=docker_postgres["user"],
852
+ password=docker_postgres["password"],
853
+ database=docker_postgres["database"],
854
+ )
855
+ )
856
+ try:
857
+ # This should only return User 2 since User 1 is locked
858
+ result = await other_conn.exec(
859
+ QueryBuilder()
860
+ .select(UserDemo)
861
+ .order_by(UserDemo.id, "ASC")
862
+ .for_update(skip_locked=True)
863
+ )
864
+ assert len(result) == 1
865
+ assert result[0].name == "User 2"
866
+ finally:
867
+ await other_conn.conn.close()
868
+
869
+
870
+ @pytest.mark.asyncio
871
+ async def test_for_update_of_with_join(db_connection: DBConnection, docker_postgres):
872
+ """
873
+ Test FOR UPDATE OF with JOINed tables.
874
+ """
875
+ # Create test data
876
+ user = UserDemo(name="John Doe", email="john@example.com")
877
+ await db_connection.insert([user])
878
+
879
+ artifact = ArtifactDemo(title="Test Artifact", user_id=user.id)
880
+ await db_connection.insert([artifact])
881
+
882
+ async with db_connection.transaction():
883
+ # Lock only the artifacts table in a join query
884
+ [(selected_artifact, selected_user)] = await db_connection.exec(
885
+ QueryBuilder()
886
+ .select((ArtifactDemo, UserDemo))
887
+ .join(UserDemo, UserDemo.id == ArtifactDemo.user_id)
888
+ .for_update(of=(ArtifactDemo,))
889
+ )
890
+ assert selected_artifact.title == "Test Artifact"
891
+ assert selected_user.name == "John Doe"
892
+
893
+ # In another connection, we should be able to lock the user
894
+ # but not the artifact
895
+ other_conn = DBConnection(
896
+ await asyncpg.connect(
897
+ host=docker_postgres["host"],
898
+ port=docker_postgres["port"],
899
+ user=docker_postgres["user"],
900
+ password=docker_postgres["password"],
901
+ database=docker_postgres["database"],
902
+ )
903
+ )
904
+ try:
905
+ # Should succeed since user table isn't locked
906
+ [other_user] = await other_conn.exec(
907
+ QueryBuilder()
908
+ .select(UserDemo)
909
+ .where(UserDemo.id == user.id)
910
+ .for_update(nowait=True)
911
+ )
912
+ assert other_user.name == "John Doe"
913
+
914
+ # Should fail since artifact table is locked
915
+ with pytest.raises(asyncpg.exceptions.LockNotAvailableError):
916
+ await other_conn.exec(
917
+ QueryBuilder()
918
+ .select(ArtifactDemo)
919
+ .where(ArtifactDemo.id == artifact.id)
920
+ .for_update(nowait=True)
921
+ )
922
+ pytest.fail("Should have raised an error")
923
+ finally:
924
+ await other_conn.conn.close()
925
+
926
+
927
+ @pytest.mark.asyncio
928
+ async def test_select_same_column_name_from_different_tables(
929
+ db_connection: DBConnection,
930
+ ):
931
+ """
932
+ Test that we can correctly select and distinguish between columns with the same name
933
+ from different tables. Both tables have a 'name' column to verify proper disambiguation.
934
+ """
935
+ # Create tables first
936
+ await db_connection.conn.execute("DROP TABLE IF EXISTS demomodela")
937
+ await db_connection.conn.execute("DROP TABLE IF EXISTS demomodelb")
938
+ await create_all(db_connection, [DemoModelA, DemoModelB])
939
+
940
+ # Create test data
941
+ model_a = DemoModelA(name="Name from A", description="Description A", code="ABC123")
942
+ model_b = DemoModelB(
943
+ name="Name from B",
944
+ category="Category B",
945
+ code="ABC123", # Same code to join on
946
+ )
947
+ await db_connection.insert([model_a, model_b])
948
+
949
+ # Select both name columns and verify they are correctly distinguished
950
+ query = (
951
+ QueryBuilder()
952
+ .select((DemoModelA.name, DemoModelB.name))
953
+ .join(DemoModelB, DemoModelA.code == DemoModelB.code)
954
+ )
955
+ result = await db_connection.exec(query)
956
+
957
+ # The first column should be DemoModelA's name, and the second should be DemoModelB's name
958
+ assert len(result) == 1
959
+ assert result[0] == ("Name from A", "Name from B")
960
+
961
+ # Verify the order is maintained when selecting in reverse
962
+ query_reversed = (
963
+ QueryBuilder()
964
+ .select((DemoModelB.name, DemoModelA.name))
965
+ .join(DemoModelA, DemoModelA.code == DemoModelB.code)
966
+ )
967
+ result_reversed = await db_connection.exec(query_reversed)
968
+
969
+ assert len(result_reversed) == 1
970
+ assert result_reversed[0] == ("Name from B", "Name from A")
971
+
972
+
973
+ @pytest.mark.asyncio
974
+ async def test_select_with_order_by_func_count(db_connection: DBConnection):
975
+ # Create users with different numbers of artifacts
976
+ users = [
977
+ UserDemo(name="John", email="john@example.com"),
978
+ UserDemo(name="Jane", email="jane@example.com"),
979
+ UserDemo(name="Bob", email="bob@example.com"),
980
+ ]
981
+ await db_connection.insert(users)
982
+
983
+ # Create artifacts with different counts per user
984
+ artifacts = [
985
+ ArtifactDemo(title="John's Post 1", user_id=users[0].id),
986
+ ArtifactDemo(title="John's Post 2", user_id=users[0].id),
987
+ ArtifactDemo(title="Jane's Post", user_id=users[1].id),
988
+ # Bob has no posts
989
+ ]
990
+ await db_connection.insert(artifacts)
991
+
992
+ query = (
993
+ QueryBuilder()
994
+ .select((UserDemo.name, func.count(ArtifactDemo.id)))
995
+ .join(ArtifactDemo, UserDemo.id == ArtifactDemo.user_id, "LEFT")
996
+ .group_by(UserDemo.name)
997
+ .order_by(func.count(ArtifactDemo.id), "DESC")
998
+ )
999
+ result = await db_connection.exec(query)
1000
+
1001
+ assert len(result) == 3
1002
+ # John has 2 posts
1003
+ assert result[0] == ("John", 2)
1004
+ # Jane has 1 post
1005
+ assert result[1] == ("Jane", 1)
1006
+ # Bob has 0 posts
1007
+ assert result[2] == ("Bob", 0)
1008
+
1009
+
1010
+ @pytest.mark.asyncio
1011
+ async def test_json_update(db_connection: DBConnection):
1012
+ """
1013
+ Test that JSON fields are correctly serialized during updates.
1014
+ """
1015
+ # Create the table first
1016
+ await db_connection.conn.execute("DROP TABLE IF EXISTS jsondemo")
1017
+ await create_all(db_connection, [JsonDemo])
1018
+
1019
+ # Create initial object with JSON data
1020
+ demo = JsonDemo(
1021
+ settings={"theme": "dark", "notifications": True},
1022
+ metadata={"version": 1},
1023
+ unique_val="1",
1024
+ )
1025
+ await db_connection.insert([demo])
1026
+
1027
+ # Update JSON fields
1028
+ demo.settings = {"theme": "light", "notifications": False}
1029
+ demo.metadata = {"version": 2, "last_updated": "2024-01-01"}
1030
+ await db_connection.update([demo])
1031
+
1032
+ # Verify the update through a fresh select
1033
+ result = await db_connection.exec(
1034
+ QueryBuilder().select(JsonDemo).where(JsonDemo.id == demo.id)
1035
+ )
1036
+ assert len(result) == 1
1037
+ assert result[0].settings == {"theme": "light", "notifications": False}
1038
+ assert result[0].metadata == {"version": 2, "last_updated": "2024-01-01"}
1039
+
1040
+
1041
+ @pytest.mark.asyncio
1042
+ async def test_json_upsert(db_connection: DBConnection):
1043
+ """
1044
+ Test that JSON fields are correctly serialized during upsert operations.
1045
+ """
1046
+ # Create the table first
1047
+ await db_connection.conn.execute("DROP TABLE IF EXISTS jsondemo")
1048
+ await create_all(db_connection, [JsonDemo])
1049
+
1050
+ # Initial insert via upsert
1051
+ demo = JsonDemo(
1052
+ settings={"theme": "dark", "notifications": True},
1053
+ metadata={"version": 1},
1054
+ unique_val="1",
1055
+ )
1056
+ result = await db_connection.upsert(
1057
+ [demo],
1058
+ conflict_fields=(JsonDemo.unique_val,),
1059
+ update_fields=(JsonDemo.metadata,),
1060
+ returning_fields=(JsonDemo.unique_val, JsonDemo.metadata),
1061
+ )
1062
+
1063
+ assert result is not None
1064
+ assert len(result) == 1
1065
+ assert result[0][0] == "1"
1066
+ assert result[0][1] == {"version": 1}
1067
+
1068
+ # Update via upsert
1069
+ demo2 = JsonDemo(
1070
+ settings={"theme": "dark", "notifications": True},
1071
+ metadata={"version": 2, "last_updated": "2024-01-01"}, # New metadata
1072
+ unique_val="1", # Same value to trigger update
1073
+ )
1074
+ result = await db_connection.upsert(
1075
+ [demo2],
1076
+ conflict_fields=(JsonDemo.unique_val,),
1077
+ update_fields=(JsonDemo.metadata,),
1078
+ returning_fields=(JsonDemo.unique_val, JsonDemo.metadata),
1079
+ )
1080
+
1081
+ assert result is not None
1082
+ assert len(result) == 1
1083
+ assert result[0][0] == "1"
1084
+ assert result[0][1] == {"version": 2, "last_updated": "2024-01-01"}
1085
+
1086
+ # Verify through a fresh select
1087
+ result = await db_connection.exec(QueryBuilder().select(JsonDemo))
1088
+ assert len(result) == 1
1089
+ assert result[0].settings == {"theme": "dark", "notifications": True}
1090
+ assert result[0].metadata == {"version": 2, "last_updated": "2024-01-01"}
1091
+
1092
+
1093
+ @pytest.mark.asyncio
1094
+ async def test_db_connection_update_batched(db_connection: DBConnection):
1095
+ """Test that updates are properly batched when dealing with many objects and different field combinations."""
1096
+ # Create test data with different update patterns
1097
+ users_group1 = [
1098
+ UserDemo(name=f"User{i}", email=f"user{i}@example.com") for i in range(10)
1099
+ ]
1100
+ users_group2 = [
1101
+ UserDemo(name=f"User{i}", email=f"user{i}@example.com") for i in range(10, 20)
1102
+ ]
1103
+ users_group3 = [
1104
+ UserDemo(name=f"User{i}", email=f"user{i}@example.com") for i in range(20, 30)
1105
+ ]
1106
+ all_users = users_group1 + users_group2 + users_group3
1107
+ await db_connection.insert(all_users)
1108
+
1109
+ # Modify different fields for different groups to test batching by modified fields
1110
+ for user in users_group1:
1111
+ user.name = f"Updated{user.name}" # Only name modified
1112
+
1113
+ for user in users_group2:
1114
+ user.email = f"updated_{user.email}" # Only email modified
1115
+
1116
+ for user in users_group3:
1117
+ user.name = f"Updated{user.name}" # Both fields modified
1118
+ user.email = f"updated_{user.email}"
1119
+
1120
+ await db_connection.update(all_users)
1121
+
1122
+ # Verify all updates were applied correctly
1123
+ result = await db_connection.conn.fetch("SELECT * FROM userdemo ORDER BY id")
1124
+ assert len(result) == 30
1125
+
1126
+ # Check group 1 (only names updated)
1127
+ for i, row in enumerate(result[:10]):
1128
+ assert row["name"] == f"UpdatedUser{i}"
1129
+ assert row["email"] == f"user{i}@example.com"
1130
+
1131
+ # Check group 2 (only emails updated)
1132
+ for i, row in enumerate(result[10:20]):
1133
+ assert row["name"] == f"User{i + 10}"
1134
+ assert row["email"] == f"updated_user{i + 10}@example.com"
1135
+
1136
+ # Check group 3 (both fields updated)
1137
+ for i, row in enumerate(result[20:30]):
1138
+ assert row["name"] == f"UpdatedUser{i + 20}"
1139
+ assert row["email"] == f"updated_user{i + 20}@example.com"
1140
+
1141
+ # Verify all modifications were cleared
1142
+ assert all(user.get_modified_attributes() == {} for user in all_users)
1143
+
1144
+
1145
+ #
1146
+ # Batch query construction
1147
+ #
1148
+
1149
+
1150
+ def assert_expected_user_fields(user: Type[UserDemo]):
1151
+ # Verify UserDemo structure hasn't changed - if this fails, update the parameter calculations below
1152
+ assert {
1153
+ key for key in UserDemo.model_fields.keys() if key not in INTERNAL_TABLE_FIELDS
1154
+ } == {"id", "name", "email"}
1155
+ assert UserDemo.model_fields["id"].primary_key
1156
+ assert UserDemo.model_fields["id"].default is None
1157
+ return True
1158
+
1159
+
1160
+ @asynccontextmanager
1161
+ async def mock_transaction():
1162
+ yield
1163
+
1164
+
1165
+ @pytest.mark.asyncio
1166
+ async def test_batch_insert_exceeds_parameters():
1167
+ """
1168
+ Test that insert() correctly batches operations when we exceed Postgres parameter limits.
1169
+ We'll create enough objects with enough fields that a single query would exceed PG_MAX_PARAMETERS.
1170
+ """
1171
+ assert assert_expected_user_fields(UserDemo)
1172
+
1173
+ # Mock the connection
1174
+ mock_conn = AsyncMock()
1175
+ mock_conn.fetchmany = AsyncMock(return_value=[{"id": i} for i in range(1000)])
1176
+ mock_conn.executemany = AsyncMock()
1177
+ mock_conn.transaction = mock_transaction
1178
+
1179
+ db = DBConnection(mock_conn)
1180
+
1181
+ # Calculate how many objects we need to exceed the parameter limit
1182
+ # Each object has 2 fields (name, email) in UserDemo
1183
+ # So each object uses 2 parameters
1184
+ objects_needed = (PG_MAX_PARAMETERS // 2) + 1
1185
+ users = [
1186
+ UserDemo(name=f"User {i}", email=f"user{i}@example.com")
1187
+ for i in range(objects_needed)
1188
+ ]
1189
+
1190
+ # Insert the objects
1191
+ await db.insert(users)
1192
+
1193
+ # We should have made at least 2 calls to fetchmany since we exceeded the parameter limit
1194
+ assert len(mock_conn.fetchmany.mock_calls) >= 2
1195
+
1196
+ # Verify the structure of the first call
1197
+ first_call = mock_conn.fetchmany.mock_calls[0]
1198
+ assert "INSERT INTO" in first_call.args[0]
1199
+ assert '"name"' in first_call.args[0]
1200
+ assert '"email"' in first_call.args[0]
1201
+ assert "RETURNING" in first_call.args[0]
1202
+
1203
+
1204
+ @pytest.mark.asyncio
1205
+ async def test_batch_update_exceeds_parameters():
1206
+ """
1207
+ Test that update() correctly batches operations when we exceed Postgres parameter limits.
1208
+ We'll create enough objects with enough modified fields that a single query would exceed PG_MAX_PARAMETERS.
1209
+ """
1210
+ assert assert_expected_user_fields(UserDemo)
1211
+
1212
+ # Mock the connection
1213
+ mock_conn = AsyncMock()
1214
+ mock_conn.executemany = AsyncMock()
1215
+ mock_conn.transaction = mock_transaction
1216
+
1217
+ db = DBConnection(mock_conn)
1218
+
1219
+ # Calculate how many objects we need to exceed the parameter limit
1220
+ # Each UPDATE row needs:
1221
+ # - 1 parameter for WHERE clause (id)
1222
+ # - 2 parameters for SET clause (name, email)
1223
+ # So each object uses 3 parameters
1224
+ objects_needed = (PG_MAX_PARAMETERS // 3) + 1
1225
+ users: list[UserDemo] = []
1226
+
1227
+ # Create objects and mark all fields as modified
1228
+ for i in range(objects_needed):
1229
+ user = UserDemo(id=i, name=f"User {i}", email=f"user{i}@example.com")
1230
+ user.clear_modified_attributes()
1231
+
1232
+ # Simulate modifications to both fields
1233
+ user.name = f"New User {i}"
1234
+ user.email = f"newuser{i}@example.com"
1235
+
1236
+ users.append(user)
1237
+
1238
+ # Update the objects
1239
+ await db.update(users)
1240
+
1241
+ # We should have made at least 2 calls to executemany since we exceeded the parameter limit
1242
+ assert len(mock_conn.executemany.mock_calls) >= 2
1243
+
1244
+ # Verify the structure of the first call
1245
+ first_call = mock_conn.executemany.mock_calls[0]
1246
+ assert "UPDATE" in first_call.args[0]
1247
+ assert "SET" in first_call.args[0]
1248
+ assert "WHERE" in first_call.args[0]
1249
+ assert '"id"' in first_call.args[0]
1250
+
1251
+
1252
+ @pytest.mark.asyncio
1253
+ async def test_batch_upsert_exceeds_parameters():
1254
+ """
1255
+ Test that upsert() correctly batches operations when we exceed Postgres parameter limits.
1256
+ We'll create enough objects with enough fields that a single query would exceed PG_MAX_PARAMETERS.
1257
+ """
1258
+ assert assert_expected_user_fields(UserDemo)
1259
+
1260
+ # Calculate how many objects we need to exceed the parameter limit
1261
+ # Each object has 2 fields (name, email) in UserDemo
1262
+ # So each object uses 2 parameters
1263
+ objects_needed = (PG_MAX_PARAMETERS // 2) + 1
1264
+ users = [
1265
+ UserDemo(name=f"User {i}", email=f"user{i}@example.com")
1266
+ for i in range(objects_needed)
1267
+ ]
1268
+
1269
+ # Mock the connection with dynamic results based on input
1270
+ mock_conn = AsyncMock()
1271
+ mock_conn.fetchmany = AsyncMock(
1272
+ side_effect=lambda query, values_list: [
1273
+ {"id": i, "name": f"User {i}", "email": f"user{i}@example.com"}
1274
+ for i in range(len(values_list))
1275
+ ]
1276
+ )
1277
+ mock_conn.executemany = AsyncMock()
1278
+ mock_conn.transaction = mock_transaction
1279
+
1280
+ db = DBConnection(mock_conn)
1281
+
1282
+ # Upsert the objects with all possible kwargs
1283
+ result = await db.upsert(
1284
+ users,
1285
+ conflict_fields=(UserDemo.email,),
1286
+ update_fields=(UserDemo.name,),
1287
+ returning_fields=(UserDemo.id, UserDemo.name, UserDemo.email),
1288
+ )
1289
+
1290
+ # We should have made at least 2 calls to fetchmany since we exceeded the parameter limit
1291
+ assert len(mock_conn.fetchmany.mock_calls) >= 2
1292
+
1293
+ # Verify the structure of the first call
1294
+ first_call = mock_conn.fetchmany.mock_calls[0]
1295
+ assert "INSERT INTO" in first_call.args[0]
1296
+ assert "ON CONFLICT" in first_call.args[0]
1297
+ assert "DO UPDATE SET" in first_call.args[0]
1298
+ assert "RETURNING" in first_call.args[0]
1299
+
1300
+ # Verify we got back the expected number of results
1301
+ assert result is not None
1302
+ assert len(result) == objects_needed
1303
+ assert all(len(r) == 3 for r in result) # Each result should have id, name, email
1304
+
1305
+
1306
+ @pytest.mark.asyncio
1307
+ async def test_batch_upsert_multiple_with_real_db(db_connection: DBConnection):
1308
+ """
1309
+ Integration test for upserting multiple objects at once with a real database connection.
1310
+ Tests both insert and update scenarios in the same batch.
1311
+ """
1312
+ await db_connection.conn.execute(
1313
+ """
1314
+ ALTER TABLE userdemo
1315
+ ADD CONSTRAINT email_unique UNIQUE (email)
1316
+ """
1317
+ )
1318
+
1319
+ # Create initial set of users
1320
+ initial_users = [
1321
+ UserDemo(name="User 1", email="user1@example.com"),
1322
+ UserDemo(name="User 2", email="user2@example.com"),
1323
+ ]
1324
+ await db_connection.insert(initial_users)
1325
+
1326
+ # Create a mix of new and existing users for upsert
1327
+ users_to_upsert = [
1328
+ # These should update
1329
+ UserDemo(name="Updated User 1", email="user1@example.com"),
1330
+ UserDemo(name="Updated User 2", email="user2@example.com"),
1331
+ # These should insert
1332
+ UserDemo(name="User 3", email="user3@example.com"),
1333
+ UserDemo(name="User 4", email="user4@example.com"),
1334
+ ]
1335
+
1336
+ result = await db_connection.upsert(
1337
+ users_to_upsert,
1338
+ conflict_fields=(UserDemo.email,),
1339
+ update_fields=(UserDemo.name,),
1340
+ returning_fields=(UserDemo.name, UserDemo.email),
1341
+ )
1342
+
1343
+ # Verify we got all results back
1344
+ assert result is not None
1345
+ assert len(result) == 4
1346
+
1347
+ # Verify the database state
1348
+ db_result = await db_connection.conn.fetch("SELECT * FROM userdemo ORDER BY email")
1349
+ assert len(db_result) == 4
1350
+
1351
+ # Check that updates worked
1352
+ assert db_result[0]["name"] == "Updated User 1"
1353
+ assert db_result[1]["name"] == "Updated User 2"
1354
+
1355
+ # Check that inserts worked
1356
+ assert db_result[2]["name"] == "User 3"
1357
+ assert db_result[3]["name"] == "User 4"
1358
+
1359
+
1360
+ @pytest.mark.asyncio
1361
+ async def test_initialize_types_caching(docker_postgres):
1362
+ # Clear the global cache for isolation.
1363
+ TYPE_CACHE.clear()
1364
+
1365
+ # Define a sample enum and model that require type introspection.
1366
+ class StatusEnum(StrEnum):
1367
+ ACTIVE = "active"
1368
+ INACTIVE = "inactive"
1369
+ PENDING = "pending"
1370
+
1371
+ class ComplexTypeDemo(TableBase):
1372
+ id: int = Field(primary_key=True)
1373
+ status: StatusEnum
1374
+ tags: list[str]
1375
+ metadata: dict[Any, Any] = Field(is_json=True)
1376
+
1377
+ # Establish the first connection.
1378
+ conn1 = await asyncpg.connect(
1379
+ host=docker_postgres["host"],
1380
+ port=docker_postgres["port"],
1381
+ user=docker_postgres["user"],
1382
+ password=docker_postgres["password"],
1383
+ database=docker_postgres["database"],
1384
+ )
1385
+ db1 = DBConnection(conn1)
1386
+
1387
+ # Prepare the database schema.
1388
+ await db1.conn.execute("DROP TYPE IF EXISTS statusenum CASCADE")
1389
+ await db1.conn.execute("DROP TABLE IF EXISTS complextypedemo")
1390
+ await create_all(db1, [ComplexTypeDemo])
1391
+
1392
+ # Save the original method.
1393
+ original_introspect = Connection._introspect_types
1394
+
1395
+ # Default value
1396
+ introspect_wrapper_call_count = 0
1397
+
1398
+ # Define a wrapper that counts calls and then calls through.
1399
+ async def introspect_wrapper(self, types_with_missing_codecs, timeout):
1400
+ nonlocal introspect_wrapper_call_count
1401
+ introspect_wrapper_call_count += 1
1402
+ return await original_introspect(self, types_with_missing_codecs, timeout)
1403
+
1404
+ # Patch the _introspect_types method on the Connection class.
1405
+ with patch.object(Connection, "_introspect_types", new=introspect_wrapper):
1406
+ # For the first connection, initialize types.
1407
+ await db1.initialize_types()
1408
+ # Verify that introspection was called.
1409
+ assert introspect_wrapper_call_count == 1
1410
+
1411
+ # Insert test data via the first connection.
1412
+ demo1 = ComplexTypeDemo(
1413
+ id=1,
1414
+ status=StatusEnum.ACTIVE,
1415
+ tags=["test", "demo"],
1416
+ metadata={"version": 1},
1417
+ )
1418
+ await db1.insert([demo1])
1419
+
1420
+ # Create a second connection to the same database.
1421
+ conn2 = await asyncpg.connect(
1422
+ host=docker_postgres["host"],
1423
+ port=docker_postgres["port"],
1424
+ user=docker_postgres["user"],
1425
+ password=docker_postgres["password"],
1426
+ database=docker_postgres["database"],
1427
+ )
1428
+ db2 = DBConnection(conn2)
1429
+
1430
+ # For the second connection, initializing types should use the cache.
1431
+ await db2.initialize_types()
1432
+
1433
+ # The call count should remain unchanged.
1434
+ assert introspect_wrapper_call_count == 1
1435
+
1436
+ # Verify that we can query the inserted record via the second connection.
1437
+ results = await db2.exec(
1438
+ QueryBuilder().select(ComplexTypeDemo).order_by(ComplexTypeDemo.id, "ASC")
1439
+ )
1440
+ assert len(results) == 1
1441
+ assert results[0].status == StatusEnum.ACTIVE
1442
+
1443
+ # Insert additional data via the second connection.
1444
+ demo2 = ComplexTypeDemo(
1445
+ id=2,
1446
+ status=StatusEnum.PENDING,
1447
+ tags=["test2", "demo2"],
1448
+ metadata={"version": 2},
1449
+ )
1450
+ await db2.insert([demo2])
1451
+
1452
+ # Retrieve and verify data from both connections.
1453
+ result1 = await db1.exec(
1454
+ QueryBuilder().select(ComplexTypeDemo).order_by(ComplexTypeDemo.id, "ASC")
1455
+ )
1456
+ result2 = await db2.exec(
1457
+ QueryBuilder().select(ComplexTypeDemo).order_by(ComplexTypeDemo.id, "ASC")
1458
+ )
1459
+
1460
+ assert len(result1) == 2
1461
+ assert len(result2) == 2
1462
+ assert result1[0].status == StatusEnum.ACTIVE
1463
+ assert result1[1].status == StatusEnum.PENDING
1464
+ assert result2[0].tags == ["test", "demo"]
1465
+ assert result2[1].tags == ["test2", "demo2"]
1466
+
1467
+ await conn2.close()
1468
+ await conn1.close()
1469
+
1470
+
1471
+ @pytest.mark.asyncio
1472
+ async def test_get_dsn(db_connection: DBConnection):
1473
+ """
1474
+ Test that get_dsn correctly formats the connection parameters into a DSN string.
1475
+ """
1476
+ dsn = db_connection.get_dsn()
1477
+ assert dsn.startswith("postgresql://")
1478
+ assert "iceaxe" in dsn
1479
+ assert "localhost" in dsn
1480
+ assert ":" in dsn # Just verify there is a port
1481
+ assert "iceaxe_test_db" in dsn
1482
+
1483
+
1484
+ @pytest.mark.asyncio
1485
+ async def test_nested_transactions(db_connection):
1486
+ """
1487
+ Test that nested transactions raise an error by default, but work with ensure=True.
1488
+ """
1489
+ # Start an outer transaction
1490
+ async with db_connection.transaction():
1491
+ # This should work fine
1492
+ assert db_connection.in_transaction is True
1493
+
1494
+ # Nested transaction with ensure=True should work
1495
+ async with db_connection.transaction(ensure=True):
1496
+ assert db_connection.in_transaction is True
1497
+
1498
+ # Nested transaction without ensure should fail
1499
+ with pytest.raises(
1500
+ RuntimeError,
1501
+ match="Cannot start a new transaction while already in a transaction",
1502
+ ):
1503
+ async with db_connection.transaction():
1504
+ pass # Should not reach here
1505
+
1506
+ # After outer transaction ends, we should be out of transaction
1507
+ assert db_connection.in_transaction is False
1508
+
1509
+ # Now a new transaction should start without error
1510
+ async with db_connection.transaction():
1511
+ assert db_connection.in_transaction is True