iceaxe 0.8.3__cp313-cp313-macosx_11_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of iceaxe might be problematic. Click here for more details.

Files changed (75) hide show
  1. iceaxe/__init__.py +20 -0
  2. iceaxe/__tests__/__init__.py +0 -0
  3. iceaxe/__tests__/benchmarks/__init__.py +0 -0
  4. iceaxe/__tests__/benchmarks/test_bulk_insert.py +45 -0
  5. iceaxe/__tests__/benchmarks/test_select.py +114 -0
  6. iceaxe/__tests__/conf_models.py +133 -0
  7. iceaxe/__tests__/conftest.py +204 -0
  8. iceaxe/__tests__/docker_helpers.py +208 -0
  9. iceaxe/__tests__/helpers.py +268 -0
  10. iceaxe/__tests__/migrations/__init__.py +0 -0
  11. iceaxe/__tests__/migrations/conftest.py +36 -0
  12. iceaxe/__tests__/migrations/test_action_sorter.py +237 -0
  13. iceaxe/__tests__/migrations/test_generator.py +140 -0
  14. iceaxe/__tests__/migrations/test_generics.py +91 -0
  15. iceaxe/__tests__/mountaineer/__init__.py +0 -0
  16. iceaxe/__tests__/mountaineer/dependencies/__init__.py +0 -0
  17. iceaxe/__tests__/mountaineer/dependencies/test_core.py +76 -0
  18. iceaxe/__tests__/schemas/__init__.py +0 -0
  19. iceaxe/__tests__/schemas/test_actions.py +1265 -0
  20. iceaxe/__tests__/schemas/test_cli.py +25 -0
  21. iceaxe/__tests__/schemas/test_db_memory_serializer.py +1571 -0
  22. iceaxe/__tests__/schemas/test_db_serializer.py +435 -0
  23. iceaxe/__tests__/schemas/test_db_stubs.py +190 -0
  24. iceaxe/__tests__/test_alias.py +83 -0
  25. iceaxe/__tests__/test_base.py +52 -0
  26. iceaxe/__tests__/test_comparison.py +383 -0
  27. iceaxe/__tests__/test_field.py +11 -0
  28. iceaxe/__tests__/test_helpers.py +9 -0
  29. iceaxe/__tests__/test_modifications.py +151 -0
  30. iceaxe/__tests__/test_queries.py +764 -0
  31. iceaxe/__tests__/test_queries_str.py +173 -0
  32. iceaxe/__tests__/test_session.py +1511 -0
  33. iceaxe/__tests__/test_text_search.py +287 -0
  34. iceaxe/alias_values.py +67 -0
  35. iceaxe/base.py +351 -0
  36. iceaxe/comparison.py +560 -0
  37. iceaxe/field.py +263 -0
  38. iceaxe/functions.py +1432 -0
  39. iceaxe/generics.py +140 -0
  40. iceaxe/io.py +107 -0
  41. iceaxe/logging.py +91 -0
  42. iceaxe/migrations/__init__.py +5 -0
  43. iceaxe/migrations/action_sorter.py +98 -0
  44. iceaxe/migrations/cli.py +228 -0
  45. iceaxe/migrations/client_io.py +62 -0
  46. iceaxe/migrations/generator.py +404 -0
  47. iceaxe/migrations/migration.py +86 -0
  48. iceaxe/migrations/migrator.py +101 -0
  49. iceaxe/modifications.py +176 -0
  50. iceaxe/mountaineer/__init__.py +10 -0
  51. iceaxe/mountaineer/cli.py +74 -0
  52. iceaxe/mountaineer/config.py +46 -0
  53. iceaxe/mountaineer/dependencies/__init__.py +6 -0
  54. iceaxe/mountaineer/dependencies/core.py +67 -0
  55. iceaxe/postgres.py +133 -0
  56. iceaxe/py.typed +0 -0
  57. iceaxe/queries.py +1459 -0
  58. iceaxe/queries_str.py +294 -0
  59. iceaxe/schemas/__init__.py +0 -0
  60. iceaxe/schemas/actions.py +864 -0
  61. iceaxe/schemas/cli.py +30 -0
  62. iceaxe/schemas/db_memory_serializer.py +711 -0
  63. iceaxe/schemas/db_serializer.py +347 -0
  64. iceaxe/schemas/db_stubs.py +529 -0
  65. iceaxe/session.py +860 -0
  66. iceaxe/session_optimized.c +12207 -0
  67. iceaxe/session_optimized.cpython-313-darwin.so +0 -0
  68. iceaxe/session_optimized.pyx +212 -0
  69. iceaxe/sql_types.py +149 -0
  70. iceaxe/typing.py +73 -0
  71. iceaxe-0.8.3.dist-info/METADATA +262 -0
  72. iceaxe-0.8.3.dist-info/RECORD +75 -0
  73. iceaxe-0.8.3.dist-info/WHEEL +6 -0
  74. iceaxe-0.8.3.dist-info/licenses/LICENSE +21 -0
  75. iceaxe-0.8.3.dist-info/top_level.txt +1 -0
@@ -0,0 +1,711 @@
1
+ import warnings
2
+ from dataclasses import dataclass
3
+ from datetime import date, datetime, time, timedelta
4
+ from inspect import isgenerator
5
+ from typing import Any, Generator, Sequence, Type, TypeVar, Union
6
+ from uuid import UUID
7
+
8
+ from pydantic_core import PydanticUndefined
9
+
10
+ from iceaxe.base import (
11
+ DBFieldInfo,
12
+ IndexConstraint,
13
+ TableBase,
14
+ UniqueConstraint,
15
+ )
16
+ from iceaxe.generics import (
17
+ get_typevar_mapping,
18
+ has_null_type,
19
+ is_type_compatible,
20
+ remove_null_type,
21
+ )
22
+ from iceaxe.migrations.action_sorter import ActionTopologicalSorter
23
+ from iceaxe.postgres import (
24
+ PostgresDateTime,
25
+ PostgresForeignKey,
26
+ PostgresTime,
27
+ )
28
+ from iceaxe.schemas.actions import (
29
+ CheckConstraint,
30
+ ColumnType,
31
+ ConstraintType,
32
+ DatabaseActions,
33
+ ForeignKeyConstraint,
34
+ )
35
+ from iceaxe.schemas.db_stubs import (
36
+ DBColumn,
37
+ DBColumnPointer,
38
+ DBConstraint,
39
+ DBConstraintPointer,
40
+ DBObject,
41
+ DBObjectPointer,
42
+ DBPointerOr,
43
+ DBTable,
44
+ DBType,
45
+ DBTypePointer,
46
+ )
47
+ from iceaxe.sql_types import enum_to_name
48
+ from iceaxe.typing import (
49
+ ALL_ENUM_TYPES,
50
+ DATE_TYPES,
51
+ JSON_WRAPPER_FALLBACK,
52
+ PRIMITIVE_WRAPPER_TYPES,
53
+ )
54
+
55
+ NodeYieldType = Union[DBObject, DBObjectPointer, "NodeDefinition"]
56
+
57
+
58
+ class CompositePrimaryKeyConstraintError(ValueError):
59
+ """
60
+ Raised when foreign key constraints cannot be resolved due to composite primary keys.
61
+
62
+ This occurs when a table has multiple fields marked as primary_key=True, creating
63
+ a composite primary key constraint, but foreign key constraints expect individual
64
+ primary key constraints on the target columns.
65
+
66
+ """
67
+
68
+ def __init__(self, missing_constraints: list[tuple[str, str]], base_message: str):
69
+ self.missing_constraints = missing_constraints
70
+ self.base_message = base_message
71
+
72
+ # Construct the detailed error message
73
+ error_msg = base_message
74
+
75
+ if missing_constraints:
76
+ error_msg += "\n\nThis error commonly occurs when you have multiple fields marked as primary_key=True in your model."
77
+ error_msg += "\nIceaxe creates a single composite primary key constraint, but foreign key constraints"
78
+ error_msg += (
79
+ "\nexpect individual primary key constraints on the target columns."
80
+ )
81
+ error_msg += "\n\nFor a detailed explanation of why this happens and how to fix it, see:"
82
+ error_msg += "\nhttps://mountaineer.sh/iceaxe/guides/relationships#composite-primary-keys-and-foreign-key-constraints"
83
+ error_msg += "\n\nTo fix this issue, choose one of these approaches:"
84
+ error_msg += "\n\nRecommended: Modify the current table"
85
+ error_msg += (
86
+ "\n - Keep only one field as primary_key=True (e.g., just 'id')"
87
+ )
88
+ error_msg += "\n - Add a UniqueConstraint if you need uniqueness across multiple fields"
89
+ error_msg += "\n - This is usually the better design pattern"
90
+
91
+ # Show specific table/column combinations that are missing
92
+ error_msg += "\n\nCurrently missing individual primary key constraints:"
93
+ for table_name, column_name in missing_constraints:
94
+ error_msg += f"\n - Table '{table_name}' needs a primary key on column '{column_name}'"
95
+
96
+ super().__init__(error_msg)
97
+
98
+
99
+ @dataclass
100
+ class NodeDefinition:
101
+ node: DBObject
102
+ dependencies: list[DBObject | DBObjectPointer]
103
+ force_no_dependencies: bool
104
+
105
+
106
+ class DatabaseMemorySerializer:
107
+ """
108
+ Serialize the in-memory database representations into a format that can be
109
+ compared to the database definitions on disk.
110
+
111
+ """
112
+
113
+ def __init__(self):
114
+ # Construct the directed acyclic graph of the in-memory database objects
115
+ # that indicate what order items should be fulfilled in
116
+ self.db_dag = []
117
+
118
+ self.database_handler = DatabaseHandler()
119
+
120
+ def delegate(self, tables: list[Type[TableBase]]):
121
+ """
122
+ Find the most specific relevant handler. For instance, if a subclass
123
+ is a registered handler, we should use that instead of the superclass
124
+ If multiple are found we throw, since we can't determine which one to use
125
+ for the resolution.
126
+
127
+ """
128
+ yield from self.database_handler.convert(tables)
129
+
130
+ def order_db_objects(
131
+ self,
132
+ db_objects: Sequence[tuple[DBObject, Sequence[DBObject | DBObjectPointer]]],
133
+ ):
134
+ """
135
+ Resolve the order that the database objects should be created or modified
136
+ by normalizing pointers/full objects and performing a sort of their defined
137
+ DAG dependencies in the migration graph.
138
+
139
+ """
140
+ # First, go through and create a representative object for each of
141
+ # the representation names
142
+ db_objects_by_name: dict[str, DBObject] = {}
143
+ for db_object, _ in db_objects:
144
+ # Only perform this mapping for objects that are not pointers
145
+ if isinstance(db_object, DBObjectPointer):
146
+ continue
147
+
148
+ # If the object is already in the dictionary, try to merge the two
149
+ # different values. Otherwise this indicates that there is a conflicting
150
+ # name with a different definition which we don't allow
151
+ if db_object.representation() in db_objects_by_name:
152
+ current_obj = db_objects_by_name[db_object.representation()]
153
+ db_objects_by_name[db_object.representation()] = current_obj.merge(
154
+ db_object
155
+ )
156
+ else:
157
+ db_objects_by_name[db_object.representation()] = db_object
158
+
159
+ # Make sure all the pointers can be resolved by full objects
160
+ # Otherwise we want a verbose error that gives more context
161
+ for _, dependencies in db_objects:
162
+ for dep in dependencies:
163
+ if isinstance(dep, DBObjectPointer):
164
+ if isinstance(dep, DBPointerOr):
165
+ # For OR pointers, at least one of the pointers must be resolvable
166
+ if not any(
167
+ pointer.representation() in db_objects_by_name
168
+ for pointer in dep.pointers
169
+ ):
170
+ # Create a more helpful error message for common cases
171
+ missing_pointers = [
172
+ p.representation() for p in dep.pointers
173
+ ]
174
+ error_msg = f"None of the OR pointers {missing_pointers} found in the defined database objects"
175
+
176
+ # Check if this is the common case of multiple primary keys causing foreign key issues
177
+ primary_key_pointers = []
178
+ for p in dep.pointers:
179
+ parsed = p.parse_constraint_pointer()
180
+ if parsed and parsed.constraint_type == "PRIMARY KEY":
181
+ primary_key_pointers.append(p)
182
+
183
+ if primary_key_pointers:
184
+ # Extract table and column info from the primary key pointers
185
+ primary_key_info: list[tuple[str, str]] = []
186
+ for pointer in primary_key_pointers:
187
+ table_name = pointer.get_table_name()
188
+ column_names = pointer.get_column_names()
189
+
190
+ if table_name and column_names:
191
+ for column_name in column_names:
192
+ primary_key_info.append(
193
+ (table_name, column_name)
194
+ )
195
+
196
+ if primary_key_info:
197
+ raise CompositePrimaryKeyConstraintError(
198
+ primary_key_info, error_msg
199
+ )
200
+ raise ValueError(error_msg)
201
+ elif dep.representation() not in db_objects_by_name:
202
+ raise ValueError(
203
+ f"Pointer {dep.representation()} not found in the defined database objects"
204
+ )
205
+
206
+ # Map the potentially different objects to the same object
207
+ graph_edges = {}
208
+ for obj, dependencies in db_objects:
209
+ resolved_deps = []
210
+ for dep in dependencies:
211
+ if isinstance(dep, DBObjectPointer):
212
+ if isinstance(dep, DBPointerOr):
213
+ # Add all resolvable pointers as dependencies
214
+ resolved_deps.extend(
215
+ db_objects_by_name[pointer.representation()]
216
+ for pointer in dep.pointers
217
+ if pointer.representation() in db_objects_by_name
218
+ )
219
+ else:
220
+ resolved_deps.append(db_objects_by_name[dep.representation()])
221
+ else:
222
+ resolved_deps.append(dep)
223
+
224
+ if isinstance(obj, DBObjectPointer):
225
+ continue
226
+
227
+ graph_edges[db_objects_by_name[obj.representation()]] = resolved_deps
228
+
229
+ # Construct the directed acyclic graph
230
+ ts = ActionTopologicalSorter(graph_edges)
231
+ return {obj: i for i, obj in enumerate(ts.sort())}
232
+
233
+ async def build_actions(
234
+ self,
235
+ actor: DatabaseActions,
236
+ previous: list[DBObject],
237
+ previous_ordering: dict[DBObject, int],
238
+ next: list[DBObject],
239
+ next_ordering: dict[DBObject, int],
240
+ ):
241
+ # Arrange each object by their representation so we can determine
242
+ # the state of each
243
+ previous_by_name = {obj.representation(): obj for obj in previous}
244
+ next_by_name = {obj.representation(): obj for obj in next}
245
+
246
+ previous_ordering_by_name = {
247
+ obj.representation(): order for obj, order in previous_ordering.items()
248
+ }
249
+ next_ordering_by_name = {
250
+ obj.representation(): order for obj, order in next_ordering.items()
251
+ }
252
+
253
+ # Verification that the ordering dictionaries align with the objects
254
+ for ordering, objects in [
255
+ (previous_ordering_by_name, previous_by_name),
256
+ (next_ordering_by_name, next_by_name),
257
+ ]:
258
+ if set(ordering.keys()) != set(objects.keys()):
259
+ unique_keys = (set(ordering.keys()) - set(objects.keys())) | (
260
+ set(objects.keys()) - set(ordering.keys())
261
+ )
262
+ raise ValueError(
263
+ f"Ordering dictionary keys must be the same as the objects in the list: {unique_keys}"
264
+ )
265
+
266
+ # Sort the objects by the order that they should be created in. Only create one object
267
+ # for each representation value, in case we were passed duplicate objects.
268
+ previous = sorted(
269
+ previous_by_name.values(),
270
+ key=lambda obj: previous_ordering_by_name[obj.representation()],
271
+ )
272
+ next = sorted(
273
+ next_by_name.values(),
274
+ key=lambda obj: next_ordering_by_name[obj.representation()],
275
+ )
276
+
277
+ for next_obj in next:
278
+ previous_obj = previous_by_name.get(next_obj.representation())
279
+
280
+ if previous_obj is None and next_obj is not None:
281
+ await next_obj.create(actor)
282
+ elif previous_obj is not None and next_obj is not None:
283
+ # Only migrate if they're actually different
284
+ if previous_obj != next_obj:
285
+ await next_obj.migrate(previous_obj, actor)
286
+
287
+ # For all of the items that were in the previous state but not in the
288
+ # next state, we should delete them
289
+ to_delete = [
290
+ previous_obj
291
+ for previous_obj in previous
292
+ if previous_obj.representation() not in next_by_name
293
+ ]
294
+ # We use the reversed representation to destroy objects with more dependencies
295
+ # before the dependencies themselves
296
+ to_delete.reverse()
297
+ for previous_obj in to_delete:
298
+ await previous_obj.destroy(actor)
299
+
300
+ return actor.dry_run_actions
301
+
302
+
303
+ class TypeDeclarationResponse(DBObject):
304
+ # Not really a db object, but we need to fulfill the yield contract
305
+ # They'll be filtered out later
306
+ primitive_type: ColumnType | None = None
307
+ custom_type: DBType | None = None
308
+ is_list: bool = False
309
+
310
+ def representation(self) -> str:
311
+ raise NotImplementedError()
312
+
313
+ def create(self, actor: DatabaseActions):
314
+ raise NotImplementedError()
315
+
316
+ def destroy(self, actor: DatabaseActions):
317
+ raise NotImplementedError()
318
+
319
+ def migrate(self, previous, actor: DatabaseActions):
320
+ raise NotImplementedError()
321
+
322
+
323
+ class DatabaseHandler:
324
+ def __init__(self):
325
+ self.python_to_sql = {
326
+ int: ColumnType.INTEGER,
327
+ float: ColumnType.DOUBLE_PRECISION,
328
+ str: ColumnType.VARCHAR,
329
+ bool: ColumnType.BOOLEAN,
330
+ bytes: ColumnType.BYTEA,
331
+ UUID: ColumnType.UUID,
332
+ Any: ColumnType.JSON,
333
+ }
334
+
335
+ def convert(self, tables: list[Type[TableBase]]):
336
+ for model in sorted(tables, key=lambda model: model.get_table_name()):
337
+ for node in self.convert_table(model):
338
+ yield (node.node, node.dependencies)
339
+
340
+ def convert_table(self, table: Type[TableBase]):
341
+ # Handle the table itself
342
+ table_nodes = self._yield_nodes(DBTable(table_name=table.get_table_name()))
343
+ yield from table_nodes
344
+
345
+ # Handle the columns
346
+ all_column_nodes: list[NodeDefinition] = []
347
+ for field_name, field in table.get_client_fields().items():
348
+ column_nodes = self._yield_nodes(
349
+ self.convert_column(field_name, field, table), dependencies=table_nodes
350
+ )
351
+ yield from column_nodes
352
+ all_column_nodes += column_nodes
353
+
354
+ # Handle field-level constraints
355
+ yield from self._yield_nodes(
356
+ self.handle_single_constraints(field_name, field, table),
357
+ dependencies=column_nodes,
358
+ )
359
+
360
+ # Primary keys must be handled after the columns are created, since multiple
361
+ # columns can be primary keys but only one constraint can be created
362
+ primary_keys = [
363
+ (key, info) for key, info in table.model_fields.items() if info.primary_key
364
+ ]
365
+ yield from self._yield_nodes(
366
+ self.handle_primary_keys(primary_keys, table), dependencies=all_column_nodes
367
+ )
368
+
369
+ if table.table_args != PydanticUndefined:
370
+ for constraint in table.table_args:
371
+ yield from self._yield_nodes(
372
+ self.handle_multiple_constraints(constraint, table),
373
+ dependencies=all_column_nodes,
374
+ )
375
+
376
+ def convert_column(self, key: str, info: DBFieldInfo, table: Type[TableBase]):
377
+ if info.annotation is None:
378
+ raise ValueError(f"Annotation must be provided for {table.__name__}.{key}")
379
+
380
+ # Primary keys should never be nullable, regardless of their type annotation
381
+ is_nullable = not info.primary_key and has_null_type(info.annotation)
382
+
383
+ # If we need to create enums or other db-backed types, we need to do that before
384
+ # the column itself
385
+ db_annotation = self.handle_column_type(key, info, table)
386
+ column_type: DBTypePointer | ColumnType
387
+ column_dependencies: list[NodeDefinition] = []
388
+ if db_annotation.custom_type:
389
+ dependencies = self._yield_nodes(
390
+ db_annotation.custom_type, force_no_dependencies=True
391
+ )
392
+ column_dependencies += dependencies
393
+ yield from dependencies
394
+
395
+ column_type = DBTypePointer(name=db_annotation.custom_type.name)
396
+ elif db_annotation.primitive_type:
397
+ column_type = db_annotation.primitive_type
398
+ else:
399
+ raise ValueError("Column type must be provided")
400
+
401
+ # We need to create the column itself once types have been created
402
+ yield from self._yield_nodes(
403
+ DBColumn(
404
+ table_name=table.get_table_name(),
405
+ column_name=key,
406
+ column_type=column_type,
407
+ column_is_list=db_annotation.is_list,
408
+ nullable=is_nullable,
409
+ autoincrement=info.autoincrement,
410
+ ),
411
+ dependencies=column_dependencies,
412
+ )
413
+
414
+ def handle_column_type(self, key: str, info: DBFieldInfo, table: Type[TableBase]):
415
+ if info.annotation is None:
416
+ raise ValueError(f"Annotation must be provided for {table.__name__}.{key}")
417
+
418
+ # If explicit_type is provided, use it directly as the preferred type
419
+ if info.explicit_type is not None:
420
+ return TypeDeclarationResponse(
421
+ primitive_type=info.explicit_type,
422
+ )
423
+
424
+ annotation = remove_null_type(info.annotation)
425
+
426
+ # Resolve the type of the column, if generic
427
+ if isinstance(annotation, TypeVar):
428
+ typevar_map = get_typevar_mapping(table)
429
+ annotation = typevar_map[annotation]
430
+
431
+ # Should be prioritized in terms of MRO; StrEnums should be processed
432
+ # before the str types
433
+ if is_type_compatible(annotation, ALL_ENUM_TYPES):
434
+ # We only support string values for enums because postgres enums are defined
435
+ # as name-based types
436
+ for value in annotation: # type: ignore
437
+ if not isinstance(value.value, str):
438
+ raise ValueError(
439
+ f"Only string values are supported for enums, received: {value.value} (enum: {annotation})"
440
+ )
441
+
442
+ return TypeDeclarationResponse(
443
+ custom_type=DBType(
444
+ name=enum_to_name(annotation), # type: ignore
445
+ values=frozenset([value.value for value in annotation]), # type: ignore
446
+ reference_columns=frozenset({(table.get_table_name(), key)}),
447
+ ),
448
+ )
449
+ elif is_type_compatible(annotation, PRIMITIVE_WRAPPER_TYPES):
450
+ for primitive, json_type in self.python_to_sql.items():
451
+ if annotation == primitive or annotation == list[primitive]: # type: ignore
452
+ return TypeDeclarationResponse(
453
+ primitive_type=json_type,
454
+ is_list=(annotation == list[primitive]), # type: ignore
455
+ )
456
+ elif is_type_compatible(annotation, DATE_TYPES):
457
+ if is_type_compatible(annotation, datetime): # type: ignore
458
+ if isinstance(info.postgres_config, PostgresDateTime):
459
+ return TypeDeclarationResponse(
460
+ primitive_type=(
461
+ ColumnType.TIMESTAMP_WITH_TIME_ZONE
462
+ if info.postgres_config.timezone
463
+ else ColumnType.TIMESTAMP_WITHOUT_TIME_ZONE
464
+ )
465
+ )
466
+ # Assume no timezone if not specified
467
+ return TypeDeclarationResponse(
468
+ primitive_type=ColumnType.TIMESTAMP_WITHOUT_TIME_ZONE,
469
+ )
470
+ elif is_type_compatible(annotation, date): # type: ignore
471
+ return TypeDeclarationResponse(
472
+ primitive_type=ColumnType.DATE,
473
+ )
474
+ elif is_type_compatible(annotation, time): # type: ignore
475
+ if isinstance(info.postgres_config, PostgresTime):
476
+ return TypeDeclarationResponse(
477
+ primitive_type=(
478
+ ColumnType.TIME_WITH_TIME_ZONE
479
+ if info.postgres_config.timezone
480
+ else ColumnType.TIME_WITHOUT_TIME_ZONE
481
+ ),
482
+ )
483
+ return TypeDeclarationResponse(
484
+ primitive_type=ColumnType.TIME_WITHOUT_TIME_ZONE,
485
+ )
486
+ elif is_type_compatible(annotation, timedelta): # type: ignore
487
+ return TypeDeclarationResponse(
488
+ primitive_type=ColumnType.INTERVAL,
489
+ )
490
+ else:
491
+ raise ValueError(f"Unsupported date type: {annotation}")
492
+ elif is_type_compatible(annotation, JSON_WRAPPER_FALLBACK):
493
+ if info.is_json:
494
+ return TypeDeclarationResponse(
495
+ primitive_type=ColumnType.JSON,
496
+ )
497
+ else:
498
+ raise ValueError(
499
+ f"JSON fields must have Field(is_json=True) specified: {annotation}\n"
500
+ f"Column: {table.__name__}.{key}"
501
+ )
502
+
503
+ raise ValueError(f"Unsupported column type: {annotation}")
504
+
505
+ def handle_single_constraints(
506
+ self, key: str, info: DBFieldInfo, table: Type[TableBase]
507
+ ):
508
+ def _build_constraint(
509
+ constraint_type: ConstraintType,
510
+ *,
511
+ foreign_key_constraint: ForeignKeyConstraint | None = None,
512
+ check_constraint: CheckConstraint | None = None,
513
+ ):
514
+ return DBConstraint(
515
+ table_name=table.get_table_name(),
516
+ constraint_type=constraint_type,
517
+ columns=frozenset([key]),
518
+ constraint_name=DBConstraint.new_constraint_name(
519
+ table.get_table_name(),
520
+ [key],
521
+ constraint_type,
522
+ ),
523
+ foreign_key_constraint=foreign_key_constraint,
524
+ check_constraint=check_constraint,
525
+ )
526
+
527
+ if info.unique:
528
+ yield from self._yield_nodes(_build_constraint(ConstraintType.UNIQUE))
529
+
530
+ if info.foreign_key:
531
+ target_table, target_column = info.foreign_key.rsplit(".", 1)
532
+ # Extract PostgreSQL-specific foreign key options if configured
533
+ on_delete = "NO ACTION"
534
+ on_update = "NO ACTION"
535
+ if isinstance(info.postgres_config, PostgresForeignKey):
536
+ on_delete = info.postgres_config.on_delete
537
+ on_update = info.postgres_config.on_update
538
+
539
+ yield from self._yield_nodes(
540
+ _build_constraint(
541
+ ConstraintType.FOREIGN_KEY,
542
+ foreign_key_constraint=ForeignKeyConstraint(
543
+ target_table=target_table,
544
+ target_columns=frozenset({target_column}),
545
+ on_delete=on_delete,
546
+ on_update=on_update,
547
+ ),
548
+ ),
549
+ dependencies=[
550
+ # Additional dependencies to ensure the target table/column is created first
551
+ DBTable(table_name=target_table),
552
+ DBColumnPointer(
553
+ table_name=target_table,
554
+ column_name=target_column,
555
+ ),
556
+ # Ensure the primary key constraint exists before the foreign key
557
+ # constraint. Postgres also accepts a unique constraint on the same.
558
+ DBPointerOr(
559
+ pointers=tuple(
560
+ [
561
+ DBConstraintPointer(
562
+ table_name=target_table,
563
+ columns=frozenset([target_column]),
564
+ constraint_type=constraint_type,
565
+ )
566
+ for constraint_type in [
567
+ ConstraintType.PRIMARY_KEY,
568
+ ConstraintType.UNIQUE,
569
+ ]
570
+ ]
571
+ ),
572
+ ),
573
+ ],
574
+ )
575
+
576
+ if info.index:
577
+ yield from self._yield_nodes(_build_constraint(ConstraintType.INDEX))
578
+
579
+ if info.check_expression:
580
+ yield from self._yield_nodes(
581
+ _build_constraint(
582
+ ConstraintType.CHECK,
583
+ check_constraint=CheckConstraint(
584
+ check_condition=info.check_expression,
585
+ ),
586
+ )
587
+ )
588
+
589
+ def handle_multiple_constraints(
590
+ self, constraint: UniqueConstraint | IndexConstraint, table: Type[TableBase]
591
+ ):
592
+ columns: list[str]
593
+ constraint_type: ConstraintType
594
+
595
+ if isinstance(constraint, UniqueConstraint):
596
+ constraint_type = ConstraintType.UNIQUE
597
+ columns = constraint.columns
598
+ elif isinstance(constraint, IndexConstraint):
599
+ constraint_type = ConstraintType.INDEX
600
+ columns = constraint.columns
601
+ else:
602
+ raise ValueError(f"Unsupported constraint type: {constraint}")
603
+
604
+ yield from self._yield_nodes(
605
+ DBConstraint(
606
+ table_name=table.get_table_name(),
607
+ constraint_type=constraint_type,
608
+ columns=frozenset(columns),
609
+ constraint_name=DBConstraint.new_constraint_name(
610
+ table.get_table_name(),
611
+ columns,
612
+ constraint_type,
613
+ ),
614
+ )
615
+ )
616
+
617
+ def handle_primary_keys(
618
+ self, keys: list[tuple[str, DBFieldInfo]], table: Type[TableBase]
619
+ ):
620
+ if not keys:
621
+ return
622
+
623
+ # Warn users about potential issues with multiple primary keys
624
+ if len(keys) > 1:
625
+ column_names = [key for key, _ in keys]
626
+ warnings.warn(
627
+ f"Table '{table.get_table_name()}' has multiple fields marked as primary_key=True: {column_names}. "
628
+ f"This creates a composite primary key constraint, which may cause issues with foreign key "
629
+ f"constraints that expect individual primary keys on target columns. "
630
+ f"Consider using only one primary key field and adding UniqueConstraint for uniqueness instead.",
631
+ UserWarning,
632
+ stacklevel=3,
633
+ )
634
+
635
+ columns = [key for key, _ in keys]
636
+ yield from self._yield_nodes(
637
+ DBConstraint(
638
+ table_name=table.get_table_name(),
639
+ constraint_type=ConstraintType.PRIMARY_KEY,
640
+ columns=frozenset(columns),
641
+ constraint_name=DBConstraint.new_constraint_name(
642
+ table.get_table_name(),
643
+ columns,
644
+ ConstraintType.PRIMARY_KEY,
645
+ ),
646
+ )
647
+ )
648
+
649
+ def _yield_nodes(
650
+ self,
651
+ child: NodeYieldType | Generator[NodeYieldType, None, None],
652
+ dependencies: Sequence[NodeYieldType] | None = None,
653
+ force_no_dependencies: bool = False,
654
+ ) -> list[NodeDefinition]:
655
+ """
656
+ Given potentially nested nodes, merge them into a flat list of nodes
657
+ with dependencies.
658
+
659
+ :param force_no_dependencies: If specified, we will never merge this node
660
+ with any upstream dependencies.
661
+ """
662
+
663
+ def _format_dependencies(dependencies: Sequence[NodeYieldType]):
664
+ all_dependencies: list[DBObject | DBObjectPointer] = []
665
+
666
+ for value in dependencies:
667
+ if isinstance(value, (DBObject, DBObjectPointer)):
668
+ all_dependencies.append(value)
669
+ elif isinstance(value, NodeDefinition):
670
+ all_dependencies.append(value.node)
671
+ all_dependencies += value.dependencies
672
+ else:
673
+ raise ValueError(f"Unsupported dependency type: {value}")
674
+
675
+ # Sorting isn't required for the DAG but is useful for testing determinism
676
+ return sorted(
677
+ set(all_dependencies),
678
+ key=lambda x: x.representation(),
679
+ )
680
+
681
+ results: list[NodeDefinition] = []
682
+
683
+ if isinstance(child, DBObject):
684
+ # No dependencies list is provided, let's yield a new one
685
+ results.append(
686
+ NodeDefinition(
687
+ node=child,
688
+ dependencies=_format_dependencies(dependencies or []),
689
+ force_no_dependencies=force_no_dependencies,
690
+ )
691
+ )
692
+ elif isinstance(child, NodeDefinition):
693
+ all_dependencies: list[NodeYieldType] = []
694
+ if not child.force_no_dependencies:
695
+ all_dependencies += dependencies or []
696
+ all_dependencies += child.dependencies
697
+
698
+ results.append(
699
+ NodeDefinition(
700
+ node=child.node,
701
+ dependencies=_format_dependencies(all_dependencies),
702
+ force_no_dependencies=force_no_dependencies,
703
+ )
704
+ )
705
+ elif isgenerator(child):
706
+ for node in child:
707
+ results += self._yield_nodes(node, dependencies)
708
+ else:
709
+ raise ValueError(f"Unsupported node type: {child}")
710
+
711
+ return results