iceaxe 0.8.3__cp313-cp313-macosx_11_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of iceaxe might be problematic. Click here for more details.

Files changed (75) hide show
  1. iceaxe/__init__.py +20 -0
  2. iceaxe/__tests__/__init__.py +0 -0
  3. iceaxe/__tests__/benchmarks/__init__.py +0 -0
  4. iceaxe/__tests__/benchmarks/test_bulk_insert.py +45 -0
  5. iceaxe/__tests__/benchmarks/test_select.py +114 -0
  6. iceaxe/__tests__/conf_models.py +133 -0
  7. iceaxe/__tests__/conftest.py +204 -0
  8. iceaxe/__tests__/docker_helpers.py +208 -0
  9. iceaxe/__tests__/helpers.py +268 -0
  10. iceaxe/__tests__/migrations/__init__.py +0 -0
  11. iceaxe/__tests__/migrations/conftest.py +36 -0
  12. iceaxe/__tests__/migrations/test_action_sorter.py +237 -0
  13. iceaxe/__tests__/migrations/test_generator.py +140 -0
  14. iceaxe/__tests__/migrations/test_generics.py +91 -0
  15. iceaxe/__tests__/mountaineer/__init__.py +0 -0
  16. iceaxe/__tests__/mountaineer/dependencies/__init__.py +0 -0
  17. iceaxe/__tests__/mountaineer/dependencies/test_core.py +76 -0
  18. iceaxe/__tests__/schemas/__init__.py +0 -0
  19. iceaxe/__tests__/schemas/test_actions.py +1265 -0
  20. iceaxe/__tests__/schemas/test_cli.py +25 -0
  21. iceaxe/__tests__/schemas/test_db_memory_serializer.py +1571 -0
  22. iceaxe/__tests__/schemas/test_db_serializer.py +435 -0
  23. iceaxe/__tests__/schemas/test_db_stubs.py +190 -0
  24. iceaxe/__tests__/test_alias.py +83 -0
  25. iceaxe/__tests__/test_base.py +52 -0
  26. iceaxe/__tests__/test_comparison.py +383 -0
  27. iceaxe/__tests__/test_field.py +11 -0
  28. iceaxe/__tests__/test_helpers.py +9 -0
  29. iceaxe/__tests__/test_modifications.py +151 -0
  30. iceaxe/__tests__/test_queries.py +764 -0
  31. iceaxe/__tests__/test_queries_str.py +173 -0
  32. iceaxe/__tests__/test_session.py +1511 -0
  33. iceaxe/__tests__/test_text_search.py +287 -0
  34. iceaxe/alias_values.py +67 -0
  35. iceaxe/base.py +351 -0
  36. iceaxe/comparison.py +560 -0
  37. iceaxe/field.py +263 -0
  38. iceaxe/functions.py +1432 -0
  39. iceaxe/generics.py +140 -0
  40. iceaxe/io.py +107 -0
  41. iceaxe/logging.py +91 -0
  42. iceaxe/migrations/__init__.py +5 -0
  43. iceaxe/migrations/action_sorter.py +98 -0
  44. iceaxe/migrations/cli.py +228 -0
  45. iceaxe/migrations/client_io.py +62 -0
  46. iceaxe/migrations/generator.py +404 -0
  47. iceaxe/migrations/migration.py +86 -0
  48. iceaxe/migrations/migrator.py +101 -0
  49. iceaxe/modifications.py +176 -0
  50. iceaxe/mountaineer/__init__.py +10 -0
  51. iceaxe/mountaineer/cli.py +74 -0
  52. iceaxe/mountaineer/config.py +46 -0
  53. iceaxe/mountaineer/dependencies/__init__.py +6 -0
  54. iceaxe/mountaineer/dependencies/core.py +67 -0
  55. iceaxe/postgres.py +133 -0
  56. iceaxe/py.typed +0 -0
  57. iceaxe/queries.py +1459 -0
  58. iceaxe/queries_str.py +294 -0
  59. iceaxe/schemas/__init__.py +0 -0
  60. iceaxe/schemas/actions.py +864 -0
  61. iceaxe/schemas/cli.py +30 -0
  62. iceaxe/schemas/db_memory_serializer.py +711 -0
  63. iceaxe/schemas/db_serializer.py +347 -0
  64. iceaxe/schemas/db_stubs.py +529 -0
  65. iceaxe/session.py +860 -0
  66. iceaxe/session_optimized.c +12207 -0
  67. iceaxe/session_optimized.cpython-313-darwin.so +0 -0
  68. iceaxe/session_optimized.pyx +212 -0
  69. iceaxe/sql_types.py +149 -0
  70. iceaxe/typing.py +73 -0
  71. iceaxe-0.8.3.dist-info/METADATA +262 -0
  72. iceaxe-0.8.3.dist-info/RECORD +75 -0
  73. iceaxe-0.8.3.dist-info/WHEEL +6 -0
  74. iceaxe-0.8.3.dist-info/licenses/LICENSE +21 -0
  75. iceaxe-0.8.3.dist-info/top_level.txt +1 -0
@@ -0,0 +1,1571 @@
1
+ import warnings
2
+ from datetime import date, datetime, time, timedelta
3
+ from enum import Enum, IntEnum, StrEnum
4
+ from typing import Generic, Sequence, TypeVar
5
+ from unittest.mock import ANY
6
+ from uuid import UUID
7
+
8
+ import pytest
9
+ from pydantic import create_model
10
+ from pydantic.fields import FieldInfo
11
+
12
+ from iceaxe import Field, TableBase
13
+ from iceaxe.base import IndexConstraint, UniqueConstraint
14
+ from iceaxe.field import DBFieldInfo
15
+ from iceaxe.postgres import PostgresDateTime, PostgresForeignKey, PostgresTime
16
+ from iceaxe.schemas.actions import (
17
+ ColumnType,
18
+ ConstraintType,
19
+ DatabaseActions,
20
+ DryRunAction,
21
+ DryRunComment,
22
+ )
23
+ from iceaxe.schemas.db_memory_serializer import (
24
+ CompositePrimaryKeyConstraintError,
25
+ DatabaseHandler,
26
+ DatabaseMemorySerializer,
27
+ )
28
+ from iceaxe.schemas.db_stubs import (
29
+ DBColumn,
30
+ DBConstraint,
31
+ DBObject,
32
+ DBObjectPointer,
33
+ DBTable,
34
+ DBType,
35
+ DBTypePointer,
36
+ )
37
+
38
+
39
+ def compare_db_objects(
40
+ calculated: Sequence[tuple[DBObject, Sequence[DBObject | DBObjectPointer]]],
41
+ expected: Sequence[tuple[DBObject, Sequence[DBObject | DBObjectPointer]]],
42
+ ):
43
+ """
44
+ Helper function to compare lists of DBObjects. The order doesn't actually matter
45
+ for downstream uses, but we can't do a simple equality check with a set because the
46
+ dependencies list is un-hashable.
47
+
48
+ """
49
+ assert sorted(calculated, key=lambda x: x[0].representation()) == sorted(
50
+ expected, key=lambda x: x[0].representation()
51
+ )
52
+
53
+
54
+ @pytest.mark.asyncio
55
+ async def test_from_scratch_migration():
56
+ """
57
+ Test a migration from scratch.
58
+
59
+ """
60
+
61
+ class OldValues(Enum):
62
+ A = "A"
63
+
64
+ class ModelA(TableBase):
65
+ id: int = Field(primary_key=True)
66
+ animal: OldValues
67
+ was_nullable: str | None
68
+
69
+ migrator = DatabaseMemorySerializer()
70
+
71
+ db_objects = list(migrator.delegate([ModelA]))
72
+ next_ordering = migrator.order_db_objects(db_objects)
73
+
74
+ actor = DatabaseActions()
75
+ actions = await migrator.build_actions(
76
+ actor, [], {}, [obj for obj, _ in db_objects], next_ordering
77
+ )
78
+
79
+ assert actions == [
80
+ DryRunAction(
81
+ fn=actor.add_type,
82
+ kwargs={
83
+ "type_name": "oldvalues",
84
+ "values": [
85
+ "A",
86
+ ],
87
+ },
88
+ ),
89
+ DryRunComment(
90
+ text="\nNEW TABLE: modela\n",
91
+ previous_line=False,
92
+ ),
93
+ DryRunAction(
94
+ fn=actor.add_table,
95
+ kwargs={
96
+ "table_name": "modela",
97
+ },
98
+ ),
99
+ DryRunAction(
100
+ fn=actor.add_column,
101
+ kwargs={
102
+ "column_name": "id",
103
+ "custom_data_type": None,
104
+ "explicit_data_is_list": False,
105
+ "explicit_data_type": ColumnType.INTEGER,
106
+ "table_name": "modela",
107
+ },
108
+ ),
109
+ DryRunAction(
110
+ fn=actor.add_not_null,
111
+ kwargs={
112
+ "column_name": "id",
113
+ "table_name": "modela",
114
+ },
115
+ ),
116
+ DryRunAction(
117
+ fn=actor.add_column,
118
+ kwargs={
119
+ "column_name": "animal",
120
+ "custom_data_type": "oldvalues",
121
+ "explicit_data_is_list": False,
122
+ "explicit_data_type": None,
123
+ "table_name": "modela",
124
+ },
125
+ ),
126
+ DryRunAction(
127
+ fn=actor.add_not_null,
128
+ kwargs={
129
+ "column_name": "animal",
130
+ "table_name": "modela",
131
+ },
132
+ ),
133
+ DryRunAction(
134
+ fn=actor.add_column,
135
+ kwargs={
136
+ "column_name": "was_nullable",
137
+ "custom_data_type": None,
138
+ "explicit_data_is_list": False,
139
+ "explicit_data_type": ColumnType.VARCHAR,
140
+ "table_name": "modela",
141
+ },
142
+ ),
143
+ DryRunAction(
144
+ fn=actor.add_constraint,
145
+ kwargs={
146
+ "columns": [
147
+ "id",
148
+ ],
149
+ "constraint": ConstraintType.PRIMARY_KEY,
150
+ "constraint_args": None,
151
+ "constraint_name": "modela_pkey",
152
+ "table_name": "modela",
153
+ },
154
+ ),
155
+ ]
156
+
157
+
158
+ @pytest.mark.asyncio
159
+ async def test_diff_migration():
160
+ """
161
+ Test the diff migration between two schemas.
162
+
163
+ """
164
+
165
+ class OldValues(Enum):
166
+ A = "A"
167
+
168
+ class NewValues(Enum):
169
+ A = "A"
170
+ B = "B"
171
+
172
+ class ModelA(TableBase):
173
+ id: int = Field(primary_key=True)
174
+ animal: OldValues
175
+ was_nullable: str | None
176
+
177
+ class ModelANew(TableBase):
178
+ table_name = "modela"
179
+ id: int = Field(primary_key=True)
180
+ name: str
181
+ animal: NewValues
182
+ was_nullable: str
183
+
184
+ actor = DatabaseActions()
185
+ migrator = DatabaseMemorySerializer()
186
+
187
+ db_objects = list(migrator.delegate([ModelA]))
188
+ db_objects_previous = [obj for obj, _ in db_objects]
189
+ previous_ordering = migrator.order_db_objects(db_objects)
190
+
191
+ db_objects_new = list(migrator.delegate([ModelANew]))
192
+ db_objects_next = [obj for obj, _ in db_objects_new]
193
+ next_ordering = migrator.order_db_objects(db_objects_new)
194
+
195
+ actor = DatabaseActions()
196
+ actions = await migrator.build_actions(
197
+ actor, db_objects_previous, previous_ordering, db_objects_next, next_ordering
198
+ )
199
+ assert actions == [
200
+ DryRunAction(
201
+ fn=actor.add_type,
202
+ kwargs={
203
+ "type_name": "newvalues",
204
+ "values": [
205
+ "A",
206
+ "B",
207
+ ],
208
+ },
209
+ ),
210
+ DryRunAction(
211
+ fn=actor.add_column,
212
+ kwargs={
213
+ "column_name": "name",
214
+ "custom_data_type": None,
215
+ "explicit_data_is_list": False,
216
+ "explicit_data_type": ColumnType.VARCHAR,
217
+ "table_name": "modela",
218
+ },
219
+ ),
220
+ DryRunAction(
221
+ fn=actor.add_not_null,
222
+ kwargs={
223
+ "column_name": "name",
224
+ "table_name": "modela",
225
+ },
226
+ ),
227
+ DryRunAction(
228
+ fn=actor.modify_column_type,
229
+ kwargs={
230
+ "column_name": "animal",
231
+ "custom_data_type": "newvalues",
232
+ "explicit_data_is_list": False,
233
+ "explicit_data_type": None,
234
+ "table_name": "modela",
235
+ "autocast": True,
236
+ },
237
+ ),
238
+ DryRunComment(
239
+ text="TODO: Perform a migration of values across types",
240
+ previous_line=True,
241
+ ),
242
+ DryRunAction(
243
+ fn=actor.add_not_null,
244
+ kwargs={
245
+ "column_name": "was_nullable",
246
+ "table_name": "modela",
247
+ },
248
+ ),
249
+ DryRunAction(
250
+ fn=actor.drop_type,
251
+ kwargs={
252
+ "type_name": "oldvalues",
253
+ },
254
+ ),
255
+ ]
256
+
257
+
258
+ @pytest.mark.asyncio
259
+ async def test_duplicate_enum_migration():
260
+ """
261
+ Test that the shared reference to an enum across multiple tables results in only
262
+ one migration action to define the type.
263
+
264
+ """
265
+
266
+ class EnumValues(Enum):
267
+ A = "A"
268
+ B = "B"
269
+
270
+ class Model1(TableBase):
271
+ id: int = Field(primary_key=True)
272
+ value: EnumValues
273
+
274
+ class Model2(TableBase):
275
+ id: int = Field(primary_key=True)
276
+ value: EnumValues
277
+
278
+ migrator = DatabaseMemorySerializer()
279
+
280
+ db_objects = list(migrator.delegate([Model1, Model2]))
281
+ next_ordering = migrator.order_db_objects(db_objects)
282
+
283
+ actor = DatabaseActions()
284
+ actions = await migrator.build_actions(
285
+ actor, [], {}, [obj for obj, _ in db_objects], next_ordering
286
+ )
287
+
288
+ assert actions == [
289
+ DryRunAction(
290
+ fn=actor.add_type,
291
+ kwargs={
292
+ "type_name": "enumvalues",
293
+ "values": [
294
+ "A",
295
+ "B",
296
+ ],
297
+ },
298
+ ),
299
+ DryRunComment(
300
+ text="\nNEW TABLE: model1\n",
301
+ previous_line=False,
302
+ ),
303
+ DryRunAction(
304
+ fn=actor.add_table,
305
+ kwargs={
306
+ "table_name": "model1",
307
+ },
308
+ ),
309
+ DryRunAction(
310
+ fn=actor.add_column,
311
+ kwargs={
312
+ "column_name": "id",
313
+ "custom_data_type": None,
314
+ "explicit_data_is_list": False,
315
+ "explicit_data_type": ColumnType.INTEGER,
316
+ "table_name": "model1",
317
+ },
318
+ ),
319
+ DryRunAction(
320
+ fn=actor.add_not_null,
321
+ kwargs={
322
+ "column_name": "id",
323
+ "table_name": "model1",
324
+ },
325
+ ),
326
+ DryRunAction(
327
+ fn=actor.add_column,
328
+ kwargs={
329
+ "column_name": "value",
330
+ "custom_data_type": "enumvalues",
331
+ "explicit_data_is_list": False,
332
+ "explicit_data_type": None,
333
+ "table_name": "model1",
334
+ },
335
+ ),
336
+ DryRunAction(
337
+ fn=actor.add_not_null,
338
+ kwargs={
339
+ "column_name": "value",
340
+ "table_name": "model1",
341
+ },
342
+ ),
343
+ DryRunAction(
344
+ fn=actor.add_constraint,
345
+ kwargs={
346
+ "columns": [
347
+ "id",
348
+ ],
349
+ "constraint": ConstraintType.PRIMARY_KEY,
350
+ "constraint_args": None,
351
+ "constraint_name": "model1_pkey",
352
+ "table_name": "model1",
353
+ },
354
+ ),
355
+ DryRunComment(
356
+ text="\nNEW TABLE: model2\n",
357
+ previous_line=False,
358
+ ),
359
+ DryRunAction(
360
+ fn=actor.add_table,
361
+ kwargs={
362
+ "table_name": "model2",
363
+ },
364
+ ),
365
+ DryRunAction(
366
+ fn=actor.add_column,
367
+ kwargs={
368
+ "column_name": "id",
369
+ "custom_data_type": None,
370
+ "explicit_data_is_list": False,
371
+ "explicit_data_type": ColumnType.INTEGER,
372
+ "table_name": "model2",
373
+ },
374
+ ),
375
+ DryRunAction(
376
+ fn=actor.add_not_null,
377
+ kwargs={
378
+ "column_name": "id",
379
+ "table_name": "model2",
380
+ },
381
+ ),
382
+ DryRunAction(
383
+ fn=actor.add_column,
384
+ kwargs={
385
+ "column_name": "value",
386
+ "custom_data_type": "enumvalues",
387
+ "explicit_data_is_list": False,
388
+ "explicit_data_type": None,
389
+ "table_name": "model2",
390
+ },
391
+ ),
392
+ DryRunAction(
393
+ fn=actor.add_not_null,
394
+ kwargs={
395
+ "column_name": "value",
396
+ "table_name": "model2",
397
+ },
398
+ ),
399
+ DryRunAction(
400
+ fn=actor.add_constraint,
401
+ kwargs={
402
+ "columns": [
403
+ "id",
404
+ ],
405
+ "constraint": ConstraintType.PRIMARY_KEY,
406
+ "constraint_args": None,
407
+ "constraint_name": "model2_pkey",
408
+ "table_name": "model2",
409
+ },
410
+ ),
411
+ ]
412
+
413
+
414
+ @pytest.mark.asyncio
415
+ async def test_required_db_default():
416
+ """
417
+ Even if we have a default value in Python, we should still force the content
418
+ to have a value at the db level.
419
+
420
+ """
421
+
422
+ class Model1(TableBase):
423
+ id: int = Field(primary_key=True)
424
+ value: str = "ABC"
425
+ value2: str = Field(default="ABC")
426
+
427
+ migrator = DatabaseMemorySerializer()
428
+
429
+ db_objects = list(migrator.delegate([Model1]))
430
+ next_ordering = migrator.order_db_objects(db_objects)
431
+
432
+ actor = DatabaseActions()
433
+ actions = await migrator.build_actions(
434
+ actor, [], {}, [obj for obj, _ in db_objects], next_ordering
435
+ )
436
+
437
+ assert actions == [
438
+ DryRunComment(text="\nNEW TABLE: model1\n"),
439
+ DryRunAction(fn=actor.add_table, kwargs={"table_name": "model1"}),
440
+ DryRunAction(
441
+ fn=actor.add_column,
442
+ kwargs={
443
+ "column_name": "id",
444
+ "custom_data_type": None,
445
+ "explicit_data_is_list": False,
446
+ "explicit_data_type": ColumnType.INTEGER,
447
+ "table_name": "model1",
448
+ },
449
+ ),
450
+ DryRunAction(
451
+ fn=actor.add_not_null, kwargs={"column_name": "id", "table_name": "model1"}
452
+ ),
453
+ DryRunAction(
454
+ fn=actor.add_column,
455
+ kwargs={
456
+ "column_name": "value",
457
+ "custom_data_type": None,
458
+ "explicit_data_is_list": False,
459
+ "explicit_data_type": ColumnType.VARCHAR,
460
+ "table_name": "model1",
461
+ },
462
+ ),
463
+ DryRunAction(
464
+ fn=actor.add_not_null,
465
+ kwargs={"column_name": "value", "table_name": "model1"},
466
+ ),
467
+ DryRunAction(
468
+ fn=actor.add_column,
469
+ kwargs={
470
+ "column_name": "value2",
471
+ "custom_data_type": None,
472
+ "explicit_data_is_list": False,
473
+ "explicit_data_type": ColumnType.VARCHAR,
474
+ "table_name": "model1",
475
+ },
476
+ ),
477
+ DryRunAction(
478
+ fn=actor.add_not_null,
479
+ kwargs={"column_name": "value2", "table_name": "model1"},
480
+ ),
481
+ DryRunAction(
482
+ fn=actor.add_constraint,
483
+ kwargs={
484
+ "columns": ["id"],
485
+ "constraint": ConstraintType.PRIMARY_KEY,
486
+ "constraint_args": None,
487
+ "constraint_name": "model1_pkey",
488
+ "table_name": "model1",
489
+ },
490
+ ),
491
+ ]
492
+
493
+
494
+ def test_multiple_primary_keys(clear_all_database_objects):
495
+ """
496
+ Support models defined with multiple primary keys. This should
497
+ result in a composite constraint, which has different handling internally
498
+ than most other field-constraints that are isolated to the field itself.
499
+
500
+ """
501
+
502
+ class ExampleModel(TableBase):
503
+ value_a: UUID = Field(primary_key=True)
504
+ value_b: UUID = Field(primary_key=True)
505
+
506
+ migrator = DatabaseMemorySerializer()
507
+ db_objects = list(migrator.delegate([ExampleModel]))
508
+ assert db_objects == [
509
+ (
510
+ DBTable(table_name="examplemodel"),
511
+ [],
512
+ ),
513
+ (
514
+ DBColumn(
515
+ table_name="examplemodel",
516
+ column_name="value_a",
517
+ column_type=ColumnType.UUID,
518
+ column_is_list=False,
519
+ nullable=False,
520
+ ),
521
+ [
522
+ DBTable(table_name="examplemodel"),
523
+ ],
524
+ ),
525
+ (
526
+ DBColumn(
527
+ table_name="examplemodel",
528
+ column_name="value_b",
529
+ column_type=ColumnType.UUID,
530
+ column_is_list=False,
531
+ nullable=False,
532
+ ),
533
+ [
534
+ DBTable(table_name="examplemodel"),
535
+ ],
536
+ ),
537
+ (
538
+ DBConstraint(
539
+ table_name="examplemodel",
540
+ constraint_name="examplemodel_pkey",
541
+ columns=frozenset({"value_a", "value_b"}),
542
+ constraint_type=ConstraintType.PRIMARY_KEY,
543
+ foreign_key_constraint=None,
544
+ ),
545
+ [
546
+ DBTable(table_name="examplemodel"),
547
+ DBColumn(
548
+ table_name="examplemodel",
549
+ column_name="value_a",
550
+ column_type=ColumnType.UUID,
551
+ column_is_list=False,
552
+ nullable=False,
553
+ ),
554
+ DBColumn(
555
+ table_name="examplemodel",
556
+ column_name="value_b",
557
+ column_type=ColumnType.UUID,
558
+ column_is_list=False,
559
+ nullable=False,
560
+ ),
561
+ ],
562
+ ),
563
+ ]
564
+
565
+
566
+ def test_enum_column_assignment(clear_all_database_objects):
567
+ """
568
+ Enum values will just yield the current column that they are assigned to even if they
569
+ are assigned to multiple columns. It's up to the full memory serializer to combine them
570
+ so we can properly track how we can migrate existing enum/column pairs to the
571
+ new values.
572
+
573
+ """
574
+
575
+ class CommonEnum(Enum):
576
+ A = "a"
577
+ B = "b"
578
+
579
+ class ExampleModel1(TableBase):
580
+ id: UUID = Field(primary_key=True)
581
+ value: CommonEnum
582
+
583
+ class ExampleModel2(TableBase):
584
+ id: UUID = Field(primary_key=True)
585
+ value: CommonEnum
586
+
587
+ migrator = DatabaseMemorySerializer()
588
+ db_objects = list(migrator.delegate([ExampleModel1, ExampleModel2]))
589
+ assert db_objects == [
590
+ (
591
+ DBTable(table_name="examplemodel1"),
592
+ [],
593
+ ),
594
+ (
595
+ DBColumn(
596
+ table_name="examplemodel1",
597
+ column_name="id",
598
+ column_type=ColumnType.UUID,
599
+ column_is_list=False,
600
+ nullable=False,
601
+ ),
602
+ [
603
+ DBTable(table_name="examplemodel1"),
604
+ ],
605
+ ),
606
+ (
607
+ DBType(
608
+ name="commonenum",
609
+ values=frozenset({"b", "a"}),
610
+ reference_columns=frozenset({("examplemodel1", "value")}),
611
+ ),
612
+ [],
613
+ ),
614
+ (
615
+ DBColumn(
616
+ table_name="examplemodel1",
617
+ column_name="value",
618
+ column_type=DBTypePointer(name="commonenum"),
619
+ column_is_list=False,
620
+ nullable=False,
621
+ ),
622
+ [
623
+ DBType(
624
+ name="commonenum",
625
+ values=frozenset({"b", "a"}),
626
+ reference_columns=frozenset({("examplemodel1", "value")}),
627
+ ),
628
+ DBTable(table_name="examplemodel1"),
629
+ ],
630
+ ),
631
+ (
632
+ DBConstraint(
633
+ table_name="examplemodel1",
634
+ constraint_name="examplemodel1_pkey",
635
+ columns=frozenset({"id"}),
636
+ constraint_type=ConstraintType.PRIMARY_KEY,
637
+ foreign_key_constraint=None,
638
+ check_constraint=None,
639
+ ),
640
+ [
641
+ DBType(
642
+ name="commonenum",
643
+ values=frozenset({"b", "a"}),
644
+ reference_columns=frozenset({("examplemodel1", "value")}),
645
+ ),
646
+ DBTable(table_name="examplemodel1"),
647
+ DBColumn(
648
+ table_name="examplemodel1",
649
+ column_name="id",
650
+ column_type=ColumnType.UUID,
651
+ column_is_list=False,
652
+ nullable=False,
653
+ ),
654
+ DBColumn(
655
+ table_name="examplemodel1",
656
+ column_name="value",
657
+ column_type=DBTypePointer(name="commonenum"),
658
+ column_is_list=False,
659
+ nullable=False,
660
+ ),
661
+ ],
662
+ ),
663
+ (
664
+ DBTable(table_name="examplemodel2"),
665
+ [],
666
+ ),
667
+ (
668
+ DBColumn(
669
+ table_name="examplemodel2",
670
+ column_name="id",
671
+ column_type=ColumnType.UUID,
672
+ column_is_list=False,
673
+ nullable=False,
674
+ ),
675
+ [
676
+ DBTable(table_name="examplemodel2"),
677
+ ],
678
+ ),
679
+ (
680
+ DBType(
681
+ name="commonenum",
682
+ values=frozenset({"b", "a"}),
683
+ reference_columns=frozenset({("examplemodel2", "value")}),
684
+ ),
685
+ [],
686
+ ),
687
+ (
688
+ DBColumn(
689
+ table_name="examplemodel2",
690
+ column_name="value",
691
+ column_type=DBTypePointer(name="commonenum"),
692
+ column_is_list=False,
693
+ nullable=False,
694
+ ),
695
+ [
696
+ DBType(
697
+ name="commonenum",
698
+ values=frozenset({"b", "a"}),
699
+ reference_columns=frozenset({("examplemodel2", "value")}),
700
+ ),
701
+ DBTable(table_name="examplemodel2"),
702
+ ],
703
+ ),
704
+ (
705
+ DBConstraint(
706
+ table_name="examplemodel2",
707
+ constraint_name="examplemodel2_pkey",
708
+ columns=frozenset({"id"}),
709
+ constraint_type=ConstraintType.PRIMARY_KEY,
710
+ foreign_key_constraint=None,
711
+ check_constraint=None,
712
+ ),
713
+ [
714
+ DBType(
715
+ name="commonenum",
716
+ values=frozenset({"b", "a"}),
717
+ reference_columns=frozenset({("examplemodel2", "value")}),
718
+ ),
719
+ DBTable(table_name="examplemodel2"),
720
+ DBColumn(
721
+ table_name="examplemodel2",
722
+ column_name="id",
723
+ column_type=ColumnType.UUID,
724
+ column_is_list=False,
725
+ nullable=False,
726
+ ),
727
+ DBColumn(
728
+ table_name="examplemodel2",
729
+ column_name="value",
730
+ column_type=DBTypePointer(name="commonenum"),
731
+ column_is_list=False,
732
+ nullable=False,
733
+ ),
734
+ ],
735
+ ),
736
+ ]
737
+
738
+
739
+ @pytest.mark.asyncio
740
+ @pytest.mark.parametrize(
741
+ "field_name, annotation, field_info, expected_db_objects",
742
+ [
743
+ # datetime, default no typehinting
744
+ (
745
+ "standard_datetime",
746
+ datetime,
747
+ Field(),
748
+ [
749
+ (
750
+ DBColumn(
751
+ table_name="exampledbmodel",
752
+ column_name="standard_datetime",
753
+ column_type=ColumnType.TIMESTAMP_WITHOUT_TIME_ZONE,
754
+ column_is_list=False,
755
+ nullable=False,
756
+ ),
757
+ [
758
+ DBTable(table_name="exampledbmodel"),
759
+ ],
760
+ ),
761
+ ],
762
+ ),
763
+ # datetime, specified with field arguments
764
+ (
765
+ "standard_datetime",
766
+ datetime,
767
+ Field(postgres_config=PostgresDateTime(timezone=True)),
768
+ [
769
+ (
770
+ DBColumn(
771
+ table_name="exampledbmodel",
772
+ column_name="standard_datetime",
773
+ column_type=ColumnType.TIMESTAMP_WITH_TIME_ZONE,
774
+ column_is_list=False,
775
+ nullable=False,
776
+ ),
777
+ [
778
+ DBTable(table_name="exampledbmodel"),
779
+ ],
780
+ ),
781
+ ],
782
+ ),
783
+ # date
784
+ (
785
+ "standard_date",
786
+ date,
787
+ Field(),
788
+ [
789
+ (
790
+ DBColumn(
791
+ table_name="exampledbmodel",
792
+ column_name="standard_date",
793
+ column_type=ColumnType.DATE,
794
+ column_is_list=False,
795
+ nullable=False,
796
+ ),
797
+ [
798
+ DBTable(table_name="exampledbmodel"),
799
+ ],
800
+ ),
801
+ ],
802
+ ),
803
+ # time, no typehinting
804
+ (
805
+ "standard_time",
806
+ time,
807
+ Field(),
808
+ [
809
+ (
810
+ DBColumn(
811
+ table_name="exampledbmodel",
812
+ column_name="standard_time",
813
+ column_type=ColumnType.TIME_WITHOUT_TIME_ZONE,
814
+ column_is_list=False,
815
+ nullable=False,
816
+ ),
817
+ [
818
+ DBTable(table_name="exampledbmodel"),
819
+ ],
820
+ ),
821
+ ],
822
+ ),
823
+ # time, specified with field arguments
824
+ (
825
+ "standard_time",
826
+ time,
827
+ Field(postgres_config=PostgresTime(timezone=True)),
828
+ [
829
+ (
830
+ DBColumn(
831
+ table_name="exampledbmodel",
832
+ column_name="standard_time",
833
+ column_type=ColumnType.TIME_WITH_TIME_ZONE,
834
+ column_is_list=False,
835
+ nullable=False,
836
+ ),
837
+ [
838
+ DBTable(table_name="exampledbmodel"),
839
+ ],
840
+ ),
841
+ ],
842
+ ),
843
+ # timedelta
844
+ (
845
+ "standard_timedelta",
846
+ timedelta,
847
+ Field(),
848
+ [
849
+ (
850
+ DBColumn(
851
+ table_name="exampledbmodel",
852
+ column_name="standard_timedelta",
853
+ column_type=ColumnType.INTERVAL,
854
+ column_is_list=False,
855
+ nullable=False,
856
+ ),
857
+ [
858
+ DBTable(table_name="exampledbmodel"),
859
+ ],
860
+ ),
861
+ ],
862
+ ),
863
+ ],
864
+ )
865
+ async def test_datetimes(
866
+ field_name: str,
867
+ annotation: type,
868
+ field_info: FieldInfo,
869
+ expected_db_objects: list[tuple[DBObject, list[DBObject | DBObjectPointer]]],
870
+ ):
871
+ ExampleDBModel = create_model( # type: ignore
872
+ "ExampleDBModel",
873
+ __base__=TableBase,
874
+ **{ # type: ignore
875
+ # Requires the ID to be specified for the model to be constructed correctly
876
+ "id": (int, Field(primary_key=True)),
877
+ field_name: (annotation, field_info),
878
+ },
879
+ )
880
+
881
+ migrator = DatabaseMemorySerializer()
882
+ db_objects = list(migrator.delegate([ExampleDBModel]))
883
+
884
+ # Table and primary key are created for each model
885
+ base_db_objects: list[tuple[DBObject, list[DBObject | DBObjectPointer]]] = [
886
+ (
887
+ DBTable(table_name="exampledbmodel"),
888
+ [],
889
+ ),
890
+ (
891
+ DBColumn(
892
+ table_name="exampledbmodel",
893
+ column_name="id",
894
+ column_type=ColumnType.INTEGER,
895
+ column_is_list=False,
896
+ nullable=False,
897
+ ),
898
+ [
899
+ DBTable(table_name="exampledbmodel"),
900
+ ],
901
+ ),
902
+ (
903
+ DBConstraint(
904
+ table_name="exampledbmodel",
905
+ constraint_name="exampledbmodel_pkey",
906
+ columns=frozenset({"id"}),
907
+ constraint_type=ConstraintType.PRIMARY_KEY,
908
+ foreign_key_constraint=None,
909
+ ),
910
+ [
911
+ DBTable(table_name="exampledbmodel"),
912
+ DBColumn(
913
+ table_name="exampledbmodel",
914
+ column_name="id",
915
+ column_type=ColumnType.INTEGER,
916
+ column_is_list=False,
917
+ nullable=False,
918
+ ),
919
+ DBColumn.model_construct(
920
+ table_name="exampledbmodel",
921
+ column_name=field_name,
922
+ column_type=ANY,
923
+ column_is_list=False,
924
+ nullable=False,
925
+ ),
926
+ ],
927
+ ),
928
+ ]
929
+
930
+ compare_db_objects(db_objects, base_db_objects + expected_db_objects)
931
+
932
+
933
+ def test_order_db_objects_sorts_by_table():
934
+ """
935
+ Unless there are some explicit cross-table dependencies, we should group
936
+ table operations together in one code block.
937
+
938
+ """
939
+
940
+ class OldValues(Enum):
941
+ A = "A"
942
+
943
+ class ModelA(TableBase):
944
+ id: int = Field(primary_key=True)
945
+ animal: OldValues
946
+ was_nullable: str | None
947
+
948
+ class ModelB(TableBase):
949
+ id: int = Field(primary_key=True)
950
+ animal: OldValues
951
+ was_nullable: str | None
952
+
953
+ migrator = DatabaseMemorySerializer()
954
+
955
+ db_objects = list(migrator.delegate([ModelA, ModelB]))
956
+ next_ordering = migrator.order_db_objects(db_objects)
957
+
958
+ sorted_actions = sorted(next_ordering.items(), key=lambda x: x[1])
959
+
960
+ table_order = [
961
+ action.table_name
962
+ for action, _ in sorted_actions
963
+ if isinstance(action, (DBTable, DBColumn, DBConstraint))
964
+ ]
965
+
966
+ # Table 3 columns 1 primary constraint
967
+ assert table_order == ["modela"] * 5 + ["modelb"] * 5
968
+
969
+
970
+ @pytest.mark.asyncio
971
+ async def test_generic_field_subclass():
972
+ class OldValues(Enum):
973
+ A = "A"
974
+
975
+ T = TypeVar("T")
976
+
977
+ class GenericSuperclass(Generic[T]):
978
+ value: T
979
+
980
+ class ModelA(TableBase, GenericSuperclass[OldValues]):
981
+ id: int = Field(primary_key=True)
982
+
983
+ migrator = DatabaseMemorySerializer()
984
+
985
+ db_objects = list(migrator.delegate([ModelA]))
986
+ next_ordering = migrator.order_db_objects(db_objects)
987
+
988
+ actor = DatabaseActions()
989
+ actions = await migrator.build_actions(
990
+ actor, [], {}, [obj for obj, _ in db_objects], next_ordering
991
+ )
992
+
993
+ assert actions == [
994
+ DryRunAction(
995
+ fn=actor.add_type,
996
+ kwargs={
997
+ "type_name": "oldvalues",
998
+ "values": [
999
+ "A",
1000
+ ],
1001
+ },
1002
+ ),
1003
+ DryRunComment(
1004
+ text="\nNEW TABLE: modela\n",
1005
+ previous_line=False,
1006
+ ),
1007
+ DryRunAction(
1008
+ fn=actor.add_table,
1009
+ kwargs={
1010
+ "table_name": "modela",
1011
+ },
1012
+ ),
1013
+ DryRunAction(
1014
+ fn=actor.add_column,
1015
+ kwargs={
1016
+ "column_name": "value",
1017
+ "custom_data_type": "oldvalues",
1018
+ "explicit_data_is_list": False,
1019
+ "explicit_data_type": None,
1020
+ "table_name": "modela",
1021
+ },
1022
+ ),
1023
+ DryRunAction(
1024
+ fn=actor.add_not_null,
1025
+ kwargs={
1026
+ "column_name": "value",
1027
+ "table_name": "modela",
1028
+ },
1029
+ ),
1030
+ DryRunAction(
1031
+ fn=actor.add_column,
1032
+ kwargs={
1033
+ "column_name": "id",
1034
+ "custom_data_type": None,
1035
+ "explicit_data_is_list": False,
1036
+ "explicit_data_type": ColumnType.INTEGER,
1037
+ "table_name": "modela",
1038
+ },
1039
+ ),
1040
+ DryRunAction(
1041
+ fn=actor.add_not_null,
1042
+ kwargs={
1043
+ "column_name": "id",
1044
+ "table_name": "modela",
1045
+ },
1046
+ ),
1047
+ DryRunAction(
1048
+ fn=actor.add_constraint,
1049
+ kwargs={
1050
+ "columns": [
1051
+ "id",
1052
+ ],
1053
+ "constraint": ConstraintType.PRIMARY_KEY,
1054
+ "constraint_args": None,
1055
+ "constraint_name": "modela_pkey",
1056
+ "table_name": "modela",
1057
+ },
1058
+ ),
1059
+ ]
1060
+
1061
+
1062
+ @pytest.mark.asyncio
1063
+ async def test_serial_only_on_create():
1064
+ """
1065
+ SERIAL types should only be used during table creation. Test a synthetic
1066
+ migration where we both create an initial SERIAL and migrate from a "db" table
1067
+ schema (that won't have autoincrement set) to a "new" table schema (that will).
1068
+ Nothing should happen to the id column in this case.
1069
+
1070
+ """
1071
+
1072
+ class ModelA(TableBase):
1073
+ id: int | None = Field(default=None, primary_key=True)
1074
+ value: int
1075
+
1076
+ class ModelADB(TableBase):
1077
+ table_name = "modela"
1078
+ id: int | None = Field(primary_key=True)
1079
+ value_b: int
1080
+
1081
+ # Because "default" is omitted, this should be detected as a regular INTEGER
1082
+ # column and not a SERIAL column.
1083
+ id_definition = [field for field in ModelADB.model_fields.values()]
1084
+ assert id_definition[0].autoincrement is False
1085
+
1086
+ migrator = DatabaseMemorySerializer()
1087
+
1088
+ memory_objects = list(migrator.delegate([ModelA]))
1089
+ memory_ordering = migrator.order_db_objects(memory_objects)
1090
+
1091
+ db_objects = list(migrator.delegate([ModelADB]))
1092
+ db_ordering = migrator.order_db_objects(db_objects)
1093
+
1094
+ # At the DBColumn level, these should both be integer objects
1095
+ id_columns = [
1096
+ column
1097
+ for column, _ in memory_objects + db_objects
1098
+ if isinstance(column, DBColumn) and column.column_name == "id"
1099
+ ]
1100
+ assert [column.column_type for column in id_columns] == [
1101
+ ColumnType.INTEGER,
1102
+ ColumnType.INTEGER,
1103
+ ]
1104
+
1105
+ # First, test the creation logic. We expect to see a SERIAL column here.
1106
+ actor = DatabaseActions()
1107
+ actions = await migrator.build_actions(
1108
+ actor, [], {}, [obj for obj, _ in memory_objects], memory_ordering
1109
+ )
1110
+
1111
+ assert [
1112
+ action
1113
+ for action in actions
1114
+ if isinstance(action, DryRunAction) and action.kwargs.get("column_name") == "id"
1115
+ ] == [
1116
+ DryRunAction(
1117
+ fn=actor.add_column,
1118
+ kwargs={
1119
+ "column_name": "id",
1120
+ "custom_data_type": None,
1121
+ "explicit_data_is_list": False,
1122
+ "explicit_data_type": ColumnType.SERIAL,
1123
+ "table_name": "modela",
1124
+ },
1125
+ ),
1126
+ DryRunAction(
1127
+ fn=actor.add_not_null, kwargs={"table_name": "modela", "column_name": "id"}
1128
+ ),
1129
+ ]
1130
+
1131
+ # Now, test the migration logic. We expect to see no changes to the id
1132
+ # column here because integers should logically equal serials for the purposes
1133
+ # of migration differences.
1134
+ actor = DatabaseActions()
1135
+ actions = await migrator.build_actions(
1136
+ actor,
1137
+ [obj for obj, _ in db_objects],
1138
+ db_ordering,
1139
+ [obj for obj, _ in memory_objects],
1140
+ memory_ordering,
1141
+ )
1142
+ assert [
1143
+ action
1144
+ for action in actions
1145
+ if isinstance(action, DryRunAction) and action.kwargs.get("column_name") == "id"
1146
+ ] == []
1147
+
1148
+
1149
+ #
1150
+ # Column type parsing
1151
+ #
1152
+
1153
+
1154
+ def test_parse_enums():
1155
+ class ModelA(TableBase):
1156
+ id: int = Field(primary_key=True)
1157
+
1158
+ database_handler = DatabaseHandler()
1159
+
1160
+ class StrEnumDemo(StrEnum):
1161
+ A = "a"
1162
+ B = "b"
1163
+
1164
+ type_declaration = database_handler.handle_column_type(
1165
+ "test_key",
1166
+ DBFieldInfo(annotation=StrEnumDemo),
1167
+ ModelA,
1168
+ )
1169
+ assert isinstance(type_declaration.custom_type, DBType)
1170
+ assert type_declaration.custom_type.name == "strenumdemo"
1171
+ assert type_declaration.custom_type.values == frozenset(["a", "b"])
1172
+
1173
+ class IntEnumDemo(IntEnum):
1174
+ A = 1
1175
+ B = 2
1176
+
1177
+ with pytest.raises(ValueError, match="string values are supported for enums"):
1178
+ database_handler.handle_column_type(
1179
+ "test_key",
1180
+ DBFieldInfo(annotation=IntEnumDemo),
1181
+ ModelA,
1182
+ )
1183
+
1184
+ class StandardEnumDemo(Enum):
1185
+ A = "a"
1186
+ B = "b"
1187
+
1188
+ type_declaration = database_handler.handle_column_type(
1189
+ "test_key",
1190
+ DBFieldInfo(annotation=StandardEnumDemo),
1191
+ ModelA,
1192
+ )
1193
+ assert isinstance(type_declaration.custom_type, DBType)
1194
+ assert type_declaration.custom_type.name == "standardenumdemo"
1195
+ assert type_declaration.custom_type.values == frozenset(["a", "b"])
1196
+
1197
+
1198
+ def test_all_constraint_types(clear_all_database_objects):
1199
+ """
1200
+ Test that all types of constraints (foreign keys, unique constraints, indexes,
1201
+ and primary keys) are correctly serialized from TableBase schemas.
1202
+ """
1203
+
1204
+ class ParentModel(TableBase):
1205
+ id: int = Field(primary_key=True)
1206
+ name: str = Field(unique=True)
1207
+
1208
+ class ChildModel(TableBase):
1209
+ id: int = Field(primary_key=True)
1210
+ parent_id: int = Field(foreign_key="parentmodel.id")
1211
+ name: str
1212
+ email: str
1213
+ status: str
1214
+
1215
+ table_args = [
1216
+ UniqueConstraint(columns=["name", "email"]),
1217
+ IndexConstraint(columns=["status"]),
1218
+ ]
1219
+
1220
+ migrator = DatabaseMemorySerializer()
1221
+ db_objects = list(migrator.delegate([ParentModel, ChildModel]))
1222
+
1223
+ # Extract all constraints for verification
1224
+ constraints = [obj for obj, _ in db_objects if isinstance(obj, DBConstraint)]
1225
+
1226
+ # Verify ParentModel constraints
1227
+ parent_constraints = [c for c in constraints if c.table_name == "parentmodel"]
1228
+ assert len(parent_constraints) == 2
1229
+
1230
+ # Primary key constraint
1231
+ pk_constraint = next(
1232
+ c for c in parent_constraints if c.constraint_type == ConstraintType.PRIMARY_KEY
1233
+ )
1234
+ assert pk_constraint.columns == frozenset({"id"})
1235
+ assert pk_constraint.constraint_name == "parentmodel_pkey"
1236
+
1237
+ # Unique constraint on name
1238
+ unique_constraint = next(
1239
+ c for c in parent_constraints if c.constraint_type == ConstraintType.UNIQUE
1240
+ )
1241
+ assert unique_constraint.columns == frozenset({"name"})
1242
+ assert unique_constraint.constraint_name == "parentmodel_name_unique"
1243
+
1244
+ # Verify ChildModel constraints
1245
+ child_constraints = [c for c in constraints if c.table_name == "childmodel"]
1246
+ assert len(child_constraints) == 4 # PK, FK, Unique, Index
1247
+
1248
+ # Primary key constraint
1249
+ child_pk = next(
1250
+ c for c in child_constraints if c.constraint_type == ConstraintType.PRIMARY_KEY
1251
+ )
1252
+ assert child_pk.columns == frozenset({"id"})
1253
+ assert child_pk.constraint_name == "childmodel_pkey"
1254
+
1255
+ # Foreign key constraint
1256
+ fk_constraint = next(
1257
+ c for c in child_constraints if c.constraint_type == ConstraintType.FOREIGN_KEY
1258
+ )
1259
+ assert fk_constraint.columns == frozenset({"parent_id"})
1260
+ assert fk_constraint.constraint_name == "childmodel_parent_id_fkey"
1261
+ assert fk_constraint.foreign_key_constraint is not None
1262
+ assert fk_constraint.foreign_key_constraint.target_table == "parentmodel"
1263
+ assert fk_constraint.foreign_key_constraint.target_columns == frozenset({"id"})
1264
+
1265
+ # Composite unique constraint
1266
+ composite_unique = next(
1267
+ c for c in child_constraints if c.constraint_type == ConstraintType.UNIQUE
1268
+ )
1269
+ assert composite_unique.columns == frozenset({"name", "email"})
1270
+ # The order of columns in the constraint name doesn't matter for functionality
1271
+ assert composite_unique.constraint_name in [
1272
+ "childmodel_name_email_unique",
1273
+ "childmodel_email_name_unique",
1274
+ ]
1275
+
1276
+ # Index constraint
1277
+ index_constraint = next(
1278
+ c for c in child_constraints if c.constraint_type == ConstraintType.INDEX
1279
+ )
1280
+ assert index_constraint.columns == frozenset({"status"})
1281
+ assert index_constraint.constraint_name == "childmodel_status_idx"
1282
+
1283
+
1284
+ def test_primary_key_not_null(clear_all_database_objects):
1285
+ """
1286
+ Test that primary key fields are automatically marked as not-null in their
1287
+ intermediary representation, since primary keys cannot be null.
1288
+
1289
+ This includes both explicitly set primary keys and auto-assigned ones.
1290
+ """
1291
+
1292
+ class ExplicitModel(TableBase):
1293
+ id: int = Field(primary_key=True)
1294
+ name: str
1295
+
1296
+ class AutoAssignedModel(TableBase):
1297
+ id: int | None = Field(default=None, primary_key=True)
1298
+ name: str
1299
+
1300
+ migrator = DatabaseMemorySerializer()
1301
+ db_objects = list(migrator.delegate([ExplicitModel, AutoAssignedModel]))
1302
+
1303
+ # Extract the column definitions
1304
+ columns = [obj for obj, _ in db_objects if isinstance(obj, DBColumn)]
1305
+
1306
+ # Find the explicit primary key column
1307
+ explicit_id_column = next(
1308
+ c for c in columns if c.column_name == "id" and c.table_name == "explicitmodel"
1309
+ )
1310
+ assert not explicit_id_column.nullable
1311
+
1312
+ # Find the auto-assigned primary key column
1313
+ auto_id_column = next(
1314
+ c
1315
+ for c in columns
1316
+ if c.column_name == "id" and c.table_name == "autoassignedmodel"
1317
+ )
1318
+ assert not auto_id_column.nullable
1319
+ assert auto_id_column.autoincrement
1320
+
1321
+
1322
+ @pytest.mark.asyncio
1323
+ async def test_foreign_key_table_dependency():
1324
+ """
1325
+ Test that foreign key constraints properly depend on the referenced table being created first.
1326
+ This test verifies that the foreign key constraint is ordered after both tables are created.
1327
+ """
1328
+
1329
+ class TargetModel(TableBase):
1330
+ id: int = Field(primary_key=True)
1331
+ value: str
1332
+
1333
+ class SourceModel(TableBase):
1334
+ id: int = Field(primary_key=True)
1335
+ target_id: int = Field(foreign_key="targetmodel.id")
1336
+
1337
+ migrator = DatabaseMemorySerializer()
1338
+
1339
+ # Make sure Source is parsed before Target so we can make sure our foreign-key
1340
+ # constraint actually re-orders the final objects.
1341
+ db_objects = list(migrator.delegate([SourceModel, TargetModel]))
1342
+ ordering = migrator.order_db_objects(db_objects)
1343
+
1344
+ # Get all objects in their sorted order
1345
+ sorted_objects = sorted(
1346
+ [obj for obj, _ in db_objects], key=lambda obj: ordering[obj]
1347
+ )
1348
+
1349
+ # Find the positions of key objects
1350
+ target_table_pos = next(
1351
+ i
1352
+ for i, obj in enumerate(sorted_objects)
1353
+ if isinstance(obj, DBTable) and obj.table_name == "targetmodel"
1354
+ )
1355
+ source_table_pos = next(
1356
+ i
1357
+ for i, obj in enumerate(sorted_objects)
1358
+ if isinstance(obj, DBTable) and obj.table_name == "sourcemodel"
1359
+ )
1360
+ target_column_pos = next(
1361
+ i
1362
+ for i, obj in enumerate(sorted_objects)
1363
+ if isinstance(obj, DBColumn)
1364
+ and obj.table_name == "targetmodel"
1365
+ and obj.column_name == "id"
1366
+ )
1367
+ target_pk_pos = next(
1368
+ i
1369
+ for i, obj in enumerate(sorted_objects)
1370
+ if isinstance(obj, DBConstraint)
1371
+ and obj.constraint_type == ConstraintType.PRIMARY_KEY
1372
+ and obj.table_name == "targetmodel"
1373
+ )
1374
+ fk_constraint_pos = next(
1375
+ i
1376
+ for i, obj in enumerate(sorted_objects)
1377
+ if isinstance(obj, DBConstraint)
1378
+ and obj.constraint_type == ConstraintType.FOREIGN_KEY
1379
+ and obj.table_name == "sourcemodel"
1380
+ )
1381
+
1382
+ # The foreign key constraint should come after both tables and the target column are created
1383
+ assert target_table_pos < fk_constraint_pos, (
1384
+ "Foreign key constraint should be created after target table"
1385
+ )
1386
+ assert source_table_pos < fk_constraint_pos, (
1387
+ "Foreign key constraint should be created after source table"
1388
+ )
1389
+ assert target_column_pos < fk_constraint_pos, (
1390
+ "Foreign key constraint should be created after target column"
1391
+ )
1392
+ assert target_pk_pos < fk_constraint_pos, (
1393
+ "Foreign key constraint should be created after target primary key"
1394
+ )
1395
+
1396
+ # Verify the actual migration actions
1397
+ actor = DatabaseActions()
1398
+ actions = await migrator.build_actions(
1399
+ actor, [], {}, [obj for obj, _ in db_objects], ordering
1400
+ )
1401
+
1402
+ # Extract the table creation and foreign key constraint actions
1403
+ table_creations = [
1404
+ action
1405
+ for action in actions
1406
+ if isinstance(action, DryRunAction) and action.fn == actor.add_table
1407
+ ]
1408
+ fk_constraints = [
1409
+ action
1410
+ for action in actions
1411
+ if isinstance(action, DryRunAction)
1412
+ and action.fn == actor.add_constraint
1413
+ and action.kwargs.get("constraint") == ConstraintType.FOREIGN_KEY
1414
+ ]
1415
+
1416
+ # Verify that table creations come before foreign key constraints
1417
+ assert len(table_creations) == 2
1418
+ assert len(fk_constraints) == 1
1419
+
1420
+ table_creation_indices = [
1421
+ i for i, action in enumerate(actions) if action in table_creations
1422
+ ]
1423
+ fk_constraint_indices = [
1424
+ i for i, action in enumerate(actions) if action in fk_constraints
1425
+ ]
1426
+
1427
+ assert all(
1428
+ table_idx < fk_idx
1429
+ for table_idx in table_creation_indices
1430
+ for fk_idx in fk_constraint_indices
1431
+ )
1432
+
1433
+
1434
+ def test_foreign_key_actions():
1435
+ """
1436
+ Test that foreign key ON UPDATE/ON DELETE actions are correctly serialized from TableBase schemas.
1437
+ """
1438
+
1439
+ class ParentModel(TableBase):
1440
+ id: int = Field(primary_key=True)
1441
+
1442
+ class ChildModel(TableBase):
1443
+ id: int = Field(primary_key=True)
1444
+ parent_id: int = Field(
1445
+ foreign_key="parentmodel.id",
1446
+ postgres_config=PostgresForeignKey(
1447
+ on_delete="CASCADE",
1448
+ on_update="CASCADE",
1449
+ ),
1450
+ )
1451
+
1452
+ migrator = DatabaseMemorySerializer()
1453
+ db_objects = list(migrator.delegate([ParentModel, ChildModel]))
1454
+
1455
+ # Extract all constraints for verification
1456
+ constraints = [obj for obj, _ in db_objects if isinstance(obj, DBConstraint)]
1457
+
1458
+ # Find the foreign key constraint
1459
+ fk_constraint = next(
1460
+ c for c in constraints if c.constraint_type == ConstraintType.FOREIGN_KEY
1461
+ )
1462
+ assert fk_constraint.foreign_key_constraint is not None
1463
+ assert fk_constraint.foreign_key_constraint.target_table == "parentmodel"
1464
+ assert fk_constraint.foreign_key_constraint.target_columns == frozenset({"id"})
1465
+ assert fk_constraint.foreign_key_constraint.on_delete == "CASCADE"
1466
+ assert fk_constraint.foreign_key_constraint.on_update == "CASCADE"
1467
+
1468
+
1469
+ def test_multiple_primary_keys_foreign_key_error():
1470
+ """
1471
+ Test that when a model has multiple primary keys and foreign key constraints,
1472
+ we get a helpful error message explaining the issue.
1473
+ """
1474
+
1475
+ class User(TableBase):
1476
+ id: int = Field(primary_key=True)
1477
+ tenant_id: int = Field(primary_key=True) # Composite primary key
1478
+ name: str
1479
+
1480
+ class Topic(TableBase):
1481
+ id: str = Field(primary_key=True)
1482
+ tenant_id: int = Field(primary_key=True) # Composite primary key
1483
+ title: str
1484
+
1485
+ class Rec(TableBase):
1486
+ id: int = Field(primary_key=True, default=None)
1487
+ creator_id: int = Field(
1488
+ foreign_key="user.id"
1489
+ ) # This will fail because user is leveraging our synthetic primary key
1490
+ topic_id: str = Field(
1491
+ foreign_key="topic.id"
1492
+ ) # This will fail because topic is leveraging our synthetic primary key
1493
+
1494
+ migrator = DatabaseMemorySerializer()
1495
+
1496
+ with pytest.raises(CompositePrimaryKeyConstraintError) as exc_info:
1497
+ db_objects = list(migrator.delegate([User, Topic, Rec]))
1498
+ migrator.order_db_objects(db_objects)
1499
+
1500
+ # Check that the exception has the expected attributes
1501
+ assert exc_info.value.missing_constraints == [("user", "id")]
1502
+
1503
+
1504
+ def test_multiple_primary_keys_warning():
1505
+ """
1506
+ Test that when a model has multiple primary keys, we get a warning.
1507
+ """
1508
+
1509
+ class ExampleModel(TableBase):
1510
+ value_a: int = Field(primary_key=True)
1511
+ value_b: int = Field(primary_key=True)
1512
+
1513
+ migrator = DatabaseMemorySerializer()
1514
+
1515
+ with warnings.catch_warnings(record=True) as w:
1516
+ warnings.simplefilter("always")
1517
+ list(migrator.delegate([ExampleModel]))
1518
+
1519
+ # Check that a warning was issued
1520
+ assert len(w) == 1
1521
+ assert issubclass(w[0].category, UserWarning)
1522
+ warning_message = str(w[0].message)
1523
+ assert "multiple fields marked as primary_key=True" in warning_message
1524
+ assert "composite primary key constraint" in warning_message
1525
+ assert "Consider using only one primary key field" in warning_message
1526
+
1527
+
1528
+ def test_explicit_type_override(clear_all_database_objects):
1529
+ """
1530
+ Test that explicit_type parameter overrides automatic type inference.
1531
+ """
1532
+
1533
+ class TestModel(TableBase):
1534
+ id: int = Field(primary_key=True)
1535
+ # This should be BIGINT instead of INTEGER due to explicit_type
1536
+ big_number: int = Field(explicit_type=ColumnType.BIGINT)
1537
+ # This should be TEXT instead of VARCHAR due to explicit_type
1538
+ long_text: str = Field(explicit_type=ColumnType.TEXT)
1539
+ # This should be JSONB instead of JSON due to explicit_type
1540
+ data: dict = Field(is_json=True, explicit_type=ColumnType.JSONB)
1541
+ # Normal field without explicit_type for comparison
1542
+ normal_field: str = Field()
1543
+
1544
+ migrator = DatabaseMemorySerializer()
1545
+ db_objects = list(migrator.delegate([TestModel]))
1546
+
1547
+ # Extract column definitions
1548
+ columns = [obj for obj, _ in db_objects if isinstance(obj, DBColumn)]
1549
+
1550
+ # Find each column and verify the type
1551
+ big_number_column = next(c for c in columns if c.column_name == "big_number")
1552
+ assert big_number_column.column_type == ColumnType.BIGINT
1553
+ assert not big_number_column.nullable
1554
+
1555
+ long_text_column = next(c for c in columns if c.column_name == "long_text")
1556
+ assert long_text_column.column_type == ColumnType.TEXT
1557
+ assert not long_text_column.nullable
1558
+
1559
+ data_column = next(c for c in columns if c.column_name == "data")
1560
+ assert data_column.column_type == ColumnType.JSONB
1561
+ assert not data_column.nullable
1562
+
1563
+ # Verify normal field still uses automatic inference
1564
+ normal_field_column = next(c for c in columns if c.column_name == "normal_field")
1565
+ assert normal_field_column.column_type == ColumnType.VARCHAR
1566
+ assert not normal_field_column.nullable
1567
+
1568
+ # Verify the id field uses automatic inference (INTEGER)
1569
+ id_column = next(c for c in columns if c.column_name == "id")
1570
+ assert id_column.column_type == ColumnType.INTEGER
1571
+ assert not id_column.nullable