iceaxe 0.7.1__cp313-cp313-macosx_11_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of iceaxe might be problematic. Click here for more details.

Files changed (75) hide show
  1. iceaxe/__init__.py +20 -0
  2. iceaxe/__tests__/__init__.py +0 -0
  3. iceaxe/__tests__/benchmarks/__init__.py +0 -0
  4. iceaxe/__tests__/benchmarks/test_bulk_insert.py +45 -0
  5. iceaxe/__tests__/benchmarks/test_select.py +114 -0
  6. iceaxe/__tests__/conf_models.py +133 -0
  7. iceaxe/__tests__/conftest.py +204 -0
  8. iceaxe/__tests__/docker_helpers.py +208 -0
  9. iceaxe/__tests__/helpers.py +268 -0
  10. iceaxe/__tests__/migrations/__init__.py +0 -0
  11. iceaxe/__tests__/migrations/conftest.py +36 -0
  12. iceaxe/__tests__/migrations/test_action_sorter.py +237 -0
  13. iceaxe/__tests__/migrations/test_generator.py +140 -0
  14. iceaxe/__tests__/migrations/test_generics.py +91 -0
  15. iceaxe/__tests__/mountaineer/__init__.py +0 -0
  16. iceaxe/__tests__/mountaineer/dependencies/__init__.py +0 -0
  17. iceaxe/__tests__/mountaineer/dependencies/test_core.py +76 -0
  18. iceaxe/__tests__/schemas/__init__.py +0 -0
  19. iceaxe/__tests__/schemas/test_actions.py +1264 -0
  20. iceaxe/__tests__/schemas/test_cli.py +25 -0
  21. iceaxe/__tests__/schemas/test_db_memory_serializer.py +1525 -0
  22. iceaxe/__tests__/schemas/test_db_serializer.py +398 -0
  23. iceaxe/__tests__/schemas/test_db_stubs.py +190 -0
  24. iceaxe/__tests__/test_alias.py +83 -0
  25. iceaxe/__tests__/test_base.py +52 -0
  26. iceaxe/__tests__/test_comparison.py +383 -0
  27. iceaxe/__tests__/test_field.py +11 -0
  28. iceaxe/__tests__/test_helpers.py +9 -0
  29. iceaxe/__tests__/test_modifications.py +151 -0
  30. iceaxe/__tests__/test_queries.py +605 -0
  31. iceaxe/__tests__/test_queries_str.py +173 -0
  32. iceaxe/__tests__/test_session.py +1511 -0
  33. iceaxe/__tests__/test_text_search.py +287 -0
  34. iceaxe/alias_values.py +67 -0
  35. iceaxe/base.py +350 -0
  36. iceaxe/comparison.py +560 -0
  37. iceaxe/field.py +250 -0
  38. iceaxe/functions.py +906 -0
  39. iceaxe/generics.py +140 -0
  40. iceaxe/io.py +107 -0
  41. iceaxe/logging.py +91 -0
  42. iceaxe/migrations/__init__.py +5 -0
  43. iceaxe/migrations/action_sorter.py +98 -0
  44. iceaxe/migrations/cli.py +228 -0
  45. iceaxe/migrations/client_io.py +62 -0
  46. iceaxe/migrations/generator.py +404 -0
  47. iceaxe/migrations/migration.py +86 -0
  48. iceaxe/migrations/migrator.py +101 -0
  49. iceaxe/modifications.py +176 -0
  50. iceaxe/mountaineer/__init__.py +10 -0
  51. iceaxe/mountaineer/cli.py +74 -0
  52. iceaxe/mountaineer/config.py +46 -0
  53. iceaxe/mountaineer/dependencies/__init__.py +6 -0
  54. iceaxe/mountaineer/dependencies/core.py +67 -0
  55. iceaxe/postgres.py +133 -0
  56. iceaxe/py.typed +0 -0
  57. iceaxe/queries.py +1455 -0
  58. iceaxe/queries_str.py +294 -0
  59. iceaxe/schemas/__init__.py +0 -0
  60. iceaxe/schemas/actions.py +864 -0
  61. iceaxe/schemas/cli.py +30 -0
  62. iceaxe/schemas/db_memory_serializer.py +705 -0
  63. iceaxe/schemas/db_serializer.py +346 -0
  64. iceaxe/schemas/db_stubs.py +525 -0
  65. iceaxe/session.py +860 -0
  66. iceaxe/session_optimized.c +12035 -0
  67. iceaxe/session_optimized.cpython-313-darwin.so +0 -0
  68. iceaxe/session_optimized.pyx +212 -0
  69. iceaxe/sql_types.py +148 -0
  70. iceaxe/typing.py +73 -0
  71. iceaxe-0.7.1.dist-info/METADATA +261 -0
  72. iceaxe-0.7.1.dist-info/RECORD +75 -0
  73. iceaxe-0.7.1.dist-info/WHEEL +6 -0
  74. iceaxe-0.7.1.dist-info/licenses/LICENSE +21 -0
  75. iceaxe-0.7.1.dist-info/top_level.txt +1 -0
@@ -0,0 +1,864 @@
1
+ from dataclasses import dataclass
2
+ from inspect import Parameter, signature
3
+ from re import fullmatch as re_fullmatch
4
+ from typing import Any, Callable, Literal, overload
5
+
6
+ from pydantic import BaseModel
7
+
8
+ from iceaxe.logging import LOGGER
9
+ from iceaxe.postgres import ForeignKeyModifications
10
+ from iceaxe.queries_str import QueryIdentifier
11
+ from iceaxe.session import DBConnection
12
+ from iceaxe.sql_types import ColumnType, ConstraintType
13
+
14
+
15
+ class ForeignKeyConstraint(BaseModel):
16
+ target_table: str
17
+ target_columns: frozenset[str]
18
+ on_delete: ForeignKeyModifications = "NO ACTION"
19
+ on_update: ForeignKeyModifications = "NO ACTION"
20
+
21
+ model_config = {
22
+ "frozen": True,
23
+ }
24
+
25
+
26
+ class CheckConstraint(BaseModel):
27
+ check_condition: str
28
+
29
+ model_config = {
30
+ "frozen": True,
31
+ }
32
+
33
+
34
+ class ExcludeConstraint(BaseModel):
35
+ exclude_operator: str
36
+
37
+ model_config = {
38
+ "frozen": True,
39
+ }
40
+
41
+
42
+ @dataclass
43
+ class DryRunAction:
44
+ fn: Callable
45
+ kwargs: dict[str, Any]
46
+
47
+
48
+ @dataclass
49
+ class DryRunComment:
50
+ text: str
51
+ previous_line: bool = False
52
+
53
+
54
+ def assert_is_safe_sql_identifier(identifier: str):
55
+ """
56
+ Check if the provided identifier is a safe SQL identifier. Since our code
57
+ pulls these directly from the definitions, there shouldn't
58
+ be any issues with SQL injection, but it's good to be safe.
59
+
60
+ """
61
+ is_valid = re_fullmatch(r"^[A-Za-z_][A-Za-z0-9_]*$", identifier) is not None
62
+ if not is_valid:
63
+ raise ValueError(f"{identifier} is not a valid SQL identifier.")
64
+
65
+
66
+ def format_sql_values(values: list[str]):
67
+ """
68
+ Safely formats string values for SQL insertion by escaping single quotes.
69
+
70
+ """
71
+ escaped_values = [
72
+ value.replace("'", "''") for value in values
73
+ ] # Escaping single quotes in SQL
74
+ formatted_values = ", ".join(f"'{value}'" for value in escaped_values)
75
+ return formatted_values
76
+
77
+
78
+ class DatabaseActions:
79
+ """
80
+ Track the actions that need to be executed to the database. Provides
81
+ a shallow, typed ORM on top of the raw SQL commands that we'll execute
82
+ through asyncpg.
83
+
84
+ This class manually builds up the SQL strings that will be executed against
85
+ postgres. We intentionally avoid using the ORM or variable-insertion modes
86
+ here because most table-schema operations don't permit parameters to
87
+ specify top-level SQL syntax. To keep things consistent, we'll use the
88
+ same SQL string interpolation for all operations.
89
+
90
+ """
91
+
92
+ dry_run: bool
93
+ """
94
+ If True, the actions will be recorded but not executed. This is used
95
+ internally within Iceaxe to generate a typehinted list of actions that will
96
+ be inserted into the migration files without actually running the logic.
97
+
98
+ """
99
+
100
+ dry_run_actions: list[DryRunAction | DryRunComment]
101
+ """
102
+ A list of actions that will be executed. Each arg/kwarg passed to our action
103
+ functions during the dryrun will be recorded here.
104
+
105
+ """
106
+
107
+ prod_sqls: list[str]
108
+ """
109
+ A list of SQL strings that will be executed against the database. This is
110
+ only populated when dry_run is False.
111
+
112
+ """
113
+
114
+ def __init__(
115
+ self,
116
+ dry_run: bool = True,
117
+ db_connection: DBConnection | None = None,
118
+ ):
119
+ self.dry_run = dry_run
120
+
121
+ if not dry_run:
122
+ if db_connection is None:
123
+ raise ValueError(
124
+ "Must provide a db_connection when not in dry run mode."
125
+ )
126
+
127
+ self.dry_run_actions: list[DryRunAction | DryRunComment] = []
128
+ self.db_connection = db_connection
129
+ self.prod_sqls: list[str] = []
130
+
131
+ async def add_table(self, table_name: str):
132
+ """
133
+ Create a new table in the database.
134
+
135
+ """
136
+ assert_is_safe_sql_identifier(table_name)
137
+ table = QueryIdentifier(table_name)
138
+
139
+ await self._record_signature(
140
+ self.add_table,
141
+ dict(table_name=table_name),
142
+ f"""
143
+ CREATE TABLE {table} ();
144
+ """,
145
+ )
146
+
147
+ async def drop_table(self, table_name: str):
148
+ """
149
+ Delete a table and all its contents from the database. This is
150
+ a destructive action, all data in the table will be lost.
151
+
152
+ """
153
+ assert_is_safe_sql_identifier(table_name)
154
+ table = QueryIdentifier(table_name)
155
+
156
+ await self._record_signature(
157
+ self.drop_table,
158
+ dict(table_name=table_name),
159
+ f"""
160
+ DROP TABLE {table}
161
+ """,
162
+ )
163
+
164
+ async def add_column(
165
+ self,
166
+ table_name: str,
167
+ column_name: str,
168
+ explicit_data_type: ColumnType | None = None,
169
+ explicit_data_is_list: bool = False,
170
+ custom_data_type: str | None = None,
171
+ ):
172
+ """
173
+ Add a new column to a table.
174
+
175
+ :param table_name: The name of the table to add the column to.
176
+ :param column_name: The name of the column to add.
177
+ :param explicit_data_type: The explicit data type of the column.
178
+ :param explicit_data_is_list: Whether the explicit data type is a list.
179
+ :param custom_data_type: A custom data type for the column, like an enum
180
+ that's registered in Postgres.
181
+
182
+ """
183
+
184
+ if not explicit_data_type and not custom_data_type:
185
+ raise ValueError(
186
+ "Must provide either an explicit data type or a custom data type."
187
+ )
188
+ if explicit_data_type and custom_data_type:
189
+ raise ValueError(
190
+ "Cannot provide both an explicit data type and a custom data type."
191
+ )
192
+
193
+ assert_is_safe_sql_identifier(table_name)
194
+ assert_is_safe_sql_identifier(column_name)
195
+
196
+ # We only need to check the custom data type, since we know
197
+ # the explicit data types come from the enum and are safe.
198
+ if custom_data_type:
199
+ assert_is_safe_sql_identifier(custom_data_type)
200
+
201
+ table = QueryIdentifier(table_name)
202
+ column = QueryIdentifier(column_name)
203
+
204
+ column_type = self._get_column_type(
205
+ explicit_data_type=explicit_data_type,
206
+ explicit_data_is_list=explicit_data_is_list,
207
+ custom_data_type=custom_data_type,
208
+ )
209
+
210
+ await self._record_signature(
211
+ self.add_column,
212
+ dict(
213
+ table_name=table_name,
214
+ column_name=column_name,
215
+ explicit_data_type=explicit_data_type,
216
+ explicit_data_is_list=explicit_data_is_list,
217
+ custom_data_type=custom_data_type,
218
+ ),
219
+ f"""
220
+ ALTER TABLE {table}
221
+ ADD COLUMN {column} {column_type}
222
+ """,
223
+ )
224
+
225
+ async def drop_column(self, table_name: str, column_name: str):
226
+ """
227
+ Remove a column. This is a destructive action, all data in the column
228
+ will be lost.
229
+
230
+ """
231
+ assert_is_safe_sql_identifier(table_name)
232
+ assert_is_safe_sql_identifier(column_name)
233
+
234
+ table = QueryIdentifier(table_name)
235
+ column = QueryIdentifier(column_name)
236
+
237
+ await self._record_signature(
238
+ self.drop_column,
239
+ dict(table_name=table_name, column_name=column_name),
240
+ f"""
241
+ ALTER TABLE {table}
242
+ DROP COLUMN {column}
243
+ """,
244
+ )
245
+
246
+ async def rename_column(
247
+ self, table_name: str, old_column_name: str, new_column_name: str
248
+ ):
249
+ """
250
+ Rename a column in a table.
251
+
252
+ """
253
+ assert_is_safe_sql_identifier(table_name)
254
+ assert_is_safe_sql_identifier(old_column_name)
255
+ assert_is_safe_sql_identifier(new_column_name)
256
+
257
+ table = QueryIdentifier(table_name)
258
+ old_column = QueryIdentifier(old_column_name)
259
+ new_column = QueryIdentifier(new_column_name)
260
+
261
+ await self._record_signature(
262
+ self.rename_column,
263
+ dict(
264
+ table_name=table_name,
265
+ old_column_name=old_column_name,
266
+ new_column_name=new_column_name,
267
+ ),
268
+ f"""
269
+ ALTER TABLE {table}
270
+ RENAME COLUMN {old_column} TO {new_column}
271
+ """,
272
+ )
273
+
274
+ async def modify_column_type(
275
+ self,
276
+ table_name: str,
277
+ column_name: str,
278
+ explicit_data_type: ColumnType | None = None,
279
+ explicit_data_is_list: bool = False,
280
+ custom_data_type: str | None = None,
281
+ autocast: bool = False,
282
+ ):
283
+ """
284
+ Modify the data type of a column. This does not inherently perform any data migrations
285
+ of the column data types. It simply alters the table schema.
286
+
287
+ :param table_name: The name of the table containing the column
288
+ :param column_name: The name of the column to modify
289
+ :param explicit_data_type: The new data type for the column
290
+ :param explicit_data_is_list: Whether the column should be an array type
291
+ :param custom_data_type: A custom SQL type string (mutually exclusive with explicit_data_type)
292
+ :param autocast: If True, automatically add a USING clause to cast existing data to the new type.
293
+ Auto-generated migrations set this to True by default. Supports most common
294
+ PostgreSQL type conversions including:
295
+ - String to numeric (VARCHAR/TEXT → INTEGER/BIGINT/SMALLINT/REAL)
296
+ - String to boolean (VARCHAR/TEXT → BOOLEAN)
297
+ - String to date/time (VARCHAR/TEXT → DATE/TIMESTAMP/TIME)
298
+ - String to specialized types (VARCHAR/TEXT → UUID/JSON/JSONB)
299
+ - Scalar to array types (INTEGER → INTEGER[])
300
+ - Custom enum conversions (VARCHAR/TEXT → custom enum)
301
+ - Compatible numeric conversions (INTEGER → BIGINT)
302
+
303
+ When autocast=False, PostgreSQL will only allow the type change if it's
304
+ compatible without explicit casting, which may fail for many conversions.
305
+
306
+ Example:
307
+ # Auto-generated migration (autocast=True by default)
308
+ await actor.modify_column_type(
309
+ "products", "price", ColumnType.INTEGER, autocast=True
310
+ )
311
+
312
+ # Manual migration with custom control
313
+ await actor.modify_column_type(
314
+ "products", "price", ColumnType.INTEGER, autocast=False
315
+ )
316
+ # Then handle data conversion manually if needed
317
+
318
+ """
319
+ if not explicit_data_type and not custom_data_type:
320
+ raise ValueError(
321
+ "Must provide either an explicit data type or a custom data type."
322
+ )
323
+ if explicit_data_type and custom_data_type:
324
+ raise ValueError(
325
+ "Cannot provide both an explicit data type and a custom data type."
326
+ )
327
+
328
+ assert_is_safe_sql_identifier(table_name)
329
+ assert_is_safe_sql_identifier(column_name)
330
+
331
+ # We only need to check the custom data type, since we know
332
+ # the explicit data types come from the enum and are safe.
333
+ if custom_data_type:
334
+ assert_is_safe_sql_identifier(custom_data_type)
335
+
336
+ table = QueryIdentifier(table_name)
337
+ column = QueryIdentifier(column_name)
338
+
339
+ column_type = self._get_column_type(
340
+ explicit_data_type=explicit_data_type,
341
+ explicit_data_is_list=explicit_data_is_list,
342
+ custom_data_type=custom_data_type,
343
+ )
344
+
345
+ # Build the SQL with optional USING clause for autocast
346
+ sql = f"ALTER TABLE {table}\nALTER COLUMN {column} TYPE {column_type}"
347
+
348
+ if autocast:
349
+ # Add USING clause to cast the column to the new type
350
+ cast_expression = self._get_autocast_expression(
351
+ column_name=str(column),
352
+ target_type=column_type,
353
+ explicit_data_type=explicit_data_type,
354
+ explicit_data_is_list=explicit_data_is_list,
355
+ custom_data_type=custom_data_type,
356
+ )
357
+ sql += f"\nUSING {cast_expression}"
358
+
359
+ await self._record_signature(
360
+ self.modify_column_type,
361
+ dict(
362
+ table_name=table_name,
363
+ column_name=column_name,
364
+ explicit_data_type=explicit_data_type,
365
+ explicit_data_is_list=explicit_data_is_list,
366
+ custom_data_type=custom_data_type,
367
+ autocast=autocast,
368
+ ),
369
+ sql,
370
+ )
371
+
372
+ def _get_autocast_expression(
373
+ self,
374
+ column_name: str,
375
+ target_type: str,
376
+ explicit_data_type: ColumnType | None = None,
377
+ explicit_data_is_list: bool = False,
378
+ custom_data_type: str | None = None,
379
+ ) -> str:
380
+ """
381
+ Generate an appropriate USING expression for casting a column to a new type.
382
+ This handles common type conversions that PostgreSQL can perform.
383
+ """
384
+ # For array types, we need to handle them specially
385
+ if explicit_data_is_list:
386
+ # For converting scalar to array, we need to wrap the value in an array
387
+ base_type = (
388
+ explicit_data_type.value if explicit_data_type else custom_data_type
389
+ )
390
+ return f"ARRAY[{column_name}::{base_type}]"
391
+
392
+ # For custom types (like enums), use text as intermediate
393
+ if custom_data_type:
394
+ return f"{column_name}::text::{custom_data_type}"
395
+
396
+ # For explicit data types, handle special cases
397
+ if explicit_data_type:
398
+ # Handle common conversions that might need special treatment
399
+ if explicit_data_type in [
400
+ ColumnType.INTEGER,
401
+ ColumnType.BIGINT,
402
+ ColumnType.SMALLINT,
403
+ ]:
404
+ # For numeric types, try direct cast first, but this will fail if source is non-numeric string
405
+ return f"{column_name}::{explicit_data_type.value}"
406
+ elif explicit_data_type == ColumnType.BOOLEAN:
407
+ # Boolean conversion can be tricky, use a more flexible approach
408
+ return f"{column_name}::boolean"
409
+ elif explicit_data_type in [
410
+ ColumnType.DATE,
411
+ ColumnType.TIMESTAMP_WITHOUT_TIME_ZONE,
412
+ ColumnType.TIME_WITHOUT_TIME_ZONE,
413
+ ]:
414
+ # Date/time conversions
415
+ return f"{column_name}::{explicit_data_type.value}"
416
+ elif explicit_data_type in [ColumnType.JSON, ColumnType.JSONB]:
417
+ # JSON conversions - usually from text
418
+ return f"{column_name}::{explicit_data_type.value}"
419
+ else:
420
+ # For most other types, a direct cast should work
421
+ return f"{column_name}::{explicit_data_type.value}"
422
+
423
+ # Fallback to direct cast
424
+ return f"{column_name}::{target_type}"
425
+
426
+ @overload
427
+ async def add_constraint(
428
+ self,
429
+ table_name: str,
430
+ columns: list[str],
431
+ constraint: Literal[ConstraintType.FOREIGN_KEY],
432
+ constraint_name: str,
433
+ constraint_args: ForeignKeyConstraint,
434
+ ): ...
435
+
436
+ @overload
437
+ async def add_constraint(
438
+ self,
439
+ table_name: str,
440
+ columns: list[str],
441
+ constraint: Literal[ConstraintType.PRIMARY_KEY]
442
+ | Literal[ConstraintType.UNIQUE],
443
+ constraint_name: str,
444
+ constraint_args: None = None,
445
+ ): ...
446
+
447
+ @overload
448
+ async def add_constraint(
449
+ self,
450
+ table_name: str,
451
+ columns: list[str],
452
+ constraint: Literal[ConstraintType.CHECK],
453
+ constraint_name: str,
454
+ constraint_args: CheckConstraint,
455
+ ): ...
456
+
457
+ async def add_constraint(
458
+ self,
459
+ table_name: str,
460
+ columns: list[str],
461
+ constraint: ConstraintType,
462
+ constraint_name: str,
463
+ constraint_args: BaseModel | None = None,
464
+ ):
465
+ """
466
+ Adds a constraint to a table. This main entrypoint is used
467
+ for all constraint types.
468
+
469
+ :param table_name: The name of the table to add the constraint to.
470
+ :param columns: The columns to link as part of the constraint.
471
+ :param constraint: The type of constraint to add.
472
+ :param constraint_name: The name of the constraint.
473
+ :param constraint_args: The configuration parameters for the particular constraint
474
+ type, if relevant.
475
+
476
+ """
477
+ assert_is_safe_sql_identifier(table_name)
478
+ for column_name in columns:
479
+ assert_is_safe_sql_identifier(column_name)
480
+
481
+ table = QueryIdentifier(table_name)
482
+ columns_formatted = ", ".join(str(QueryIdentifier(col)) for col in columns)
483
+ sql = f"ALTER TABLE {table} ADD CONSTRAINT {constraint_name} "
484
+
485
+ if constraint == ConstraintType.PRIMARY_KEY:
486
+ sql += f"PRIMARY KEY ({columns_formatted})"
487
+ elif constraint == ConstraintType.FOREIGN_KEY:
488
+ if not isinstance(constraint_args, ForeignKeyConstraint):
489
+ raise ValueError(
490
+ f"Constraint type FOREIGN_KEY must have ForeignKeyConstraint args, received: {constraint_args}"
491
+ )
492
+
493
+ assert_is_safe_sql_identifier(constraint_args.target_table)
494
+ for column_name in constraint_args.target_columns:
495
+ assert_is_safe_sql_identifier(column_name)
496
+
497
+ target_table = QueryIdentifier(constraint_args.target_table)
498
+ ref_cols_formatted = ", ".join(
499
+ str(QueryIdentifier(col)) for col in constraint_args.target_columns
500
+ )
501
+ sql += f"FOREIGN KEY ({columns_formatted}) REFERENCES {target_table} ({ref_cols_formatted})"
502
+ if constraint_args.on_delete != "NO ACTION":
503
+ sql += f" ON DELETE {constraint_args.on_delete}"
504
+ if constraint_args.on_update != "NO ACTION":
505
+ sql += f" ON UPDATE {constraint_args.on_update}"
506
+ elif constraint == ConstraintType.UNIQUE:
507
+ sql += f"UNIQUE ({columns_formatted})"
508
+ elif constraint == ConstraintType.CHECK:
509
+ if not isinstance(constraint_args, CheckConstraint):
510
+ raise ValueError(
511
+ f"Constraint type CHECK must have CheckConstraint args, received: {constraint_args}"
512
+ )
513
+ sql += f"CHECK ({constraint_args.check_condition})"
514
+ else:
515
+ raise ValueError("Unsupported constraint type")
516
+
517
+ sql += ";"
518
+ await self._record_signature(
519
+ self.add_constraint,
520
+ dict(
521
+ table_name=table_name,
522
+ columns=columns,
523
+ constraint=constraint,
524
+ constraint_name=constraint_name,
525
+ constraint_args=constraint_args,
526
+ ),
527
+ sql,
528
+ )
529
+
530
+ async def drop_constraint(
531
+ self,
532
+ table_name: str,
533
+ constraint_name: str,
534
+ ):
535
+ """
536
+ Deletes a constraint from a table.
537
+
538
+ """
539
+ assert_is_safe_sql_identifier(table_name)
540
+ assert_is_safe_sql_identifier(constraint_name)
541
+
542
+ table = QueryIdentifier(table_name)
543
+ constraint = QueryIdentifier(constraint_name)
544
+
545
+ await self._record_signature(
546
+ self.drop_constraint,
547
+ dict(
548
+ table_name=table_name,
549
+ constraint_name=constraint_name,
550
+ ),
551
+ f"""
552
+ ALTER TABLE {table}
553
+ DROP CONSTRAINT {constraint}
554
+ """,
555
+ )
556
+
557
+ async def add_index(
558
+ self,
559
+ table_name: str,
560
+ columns: list[str],
561
+ index_name: str,
562
+ ):
563
+ """
564
+ Adds a new index to a table. Since this requires building up the augmentary data structures
565
+ for more efficient search operations, this migration action can take some
566
+ time on large tables.
567
+
568
+ """
569
+ assert_is_safe_sql_identifier(table_name)
570
+ for column_name in columns:
571
+ assert_is_safe_sql_identifier(column_name)
572
+
573
+ table = QueryIdentifier(table_name)
574
+ columns_formatted = ", ".join(str(QueryIdentifier(col)) for col in columns)
575
+ sql = f"CREATE INDEX {index_name} ON {table} ({columns_formatted});"
576
+ await self._record_signature(
577
+ self.add_index,
578
+ dict(
579
+ table_name=table_name,
580
+ columns=columns,
581
+ index_name=index_name,
582
+ ),
583
+ sql,
584
+ )
585
+
586
+ async def drop_index(
587
+ self,
588
+ table_name: str,
589
+ index_name: str,
590
+ ):
591
+ """
592
+ Deletes an index from a table.
593
+
594
+ """
595
+ assert_is_safe_sql_identifier(table_name)
596
+ assert_is_safe_sql_identifier(index_name)
597
+
598
+ index = QueryIdentifier(index_name)
599
+
600
+ sql = f"DROP INDEX {index};"
601
+ await self._record_signature(
602
+ self.drop_index,
603
+ dict(
604
+ table_name=table_name,
605
+ index_name=index_name,
606
+ ),
607
+ sql,
608
+ )
609
+
610
+ async def add_not_null(self, table_name: str, column_name: str):
611
+ """
612
+ Requires data inserted into a column to be non-null.
613
+
614
+ """
615
+ assert_is_safe_sql_identifier(table_name)
616
+ assert_is_safe_sql_identifier(column_name)
617
+
618
+ table = QueryIdentifier(table_name)
619
+ column = QueryIdentifier(column_name)
620
+
621
+ await self._record_signature(
622
+ self.add_not_null,
623
+ dict(table_name=table_name, column_name=column_name),
624
+ f"""
625
+ ALTER TABLE {table}
626
+ ALTER COLUMN {column}
627
+ SET NOT NULL
628
+ """,
629
+ )
630
+
631
+ async def drop_not_null(self, table_name: str, column_name: str):
632
+ """
633
+ Removes the non-null constraint from a column, which allows new values
634
+ to be inserted as NULL.
635
+
636
+ """
637
+ assert_is_safe_sql_identifier(table_name)
638
+ assert_is_safe_sql_identifier(column_name)
639
+
640
+ table = QueryIdentifier(table_name)
641
+ column = QueryIdentifier(column_name)
642
+
643
+ await self._record_signature(
644
+ self.drop_not_null,
645
+ dict(table_name=table_name, column_name=column_name),
646
+ f"""
647
+ ALTER TABLE {table}
648
+ ALTER COLUMN {column}
649
+ DROP NOT NULL
650
+ """,
651
+ )
652
+
653
+ async def add_type(self, type_name: str, values: list[str]):
654
+ """
655
+ Create a new enum type with the given initial values.
656
+
657
+ """
658
+ assert_is_safe_sql_identifier(type_name)
659
+
660
+ type_identifier = QueryIdentifier(type_name)
661
+ formatted_values = format_sql_values(values)
662
+ await self._record_signature(
663
+ self.add_type,
664
+ dict(type_name=type_name, values=values),
665
+ f"""
666
+ CREATE TYPE {type_identifier} AS ENUM ({formatted_values})
667
+ """,
668
+ )
669
+
670
+ async def add_type_values(self, type_name: str, values: list[str]):
671
+ """
672
+ Modifies the enum members of an existing type to add new values.
673
+
674
+ """
675
+ assert_is_safe_sql_identifier(type_name)
676
+ type_identifier = QueryIdentifier(type_name)
677
+
678
+ sql_commands: list[str] = []
679
+ for value in values:
680
+ # Use the same escape functionality as we use for lists, since
681
+ # there's only one object it won't add any commas
682
+ formatted_value = format_sql_values([value])
683
+ sql_commands.append(
684
+ f"""
685
+ ALTER TYPE {type_identifier} ADD VALUE {formatted_value};
686
+ """
687
+ )
688
+
689
+ await self._record_signature(
690
+ self.add_type_values,
691
+ dict(type_name=type_name, values=values),
692
+ sql_commands,
693
+ )
694
+
695
+ async def drop_type_values(
696
+ self,
697
+ type_name: str,
698
+ values: list[str],
699
+ target_columns: list[tuple[str, str]],
700
+ ):
701
+ """
702
+ Deletes enum members from an existing type.
703
+
704
+ This will only succeed at runtime if you have no table rows that
705
+ currently reference the outdated enum values.
706
+
707
+ Note that dropping values from an existing type isn't natively supported by Postgres. We work
708
+ around this limitation by specifying the "target_columns" that reference the
709
+ enum type that we want to drop, so we can effectively create a new type.
710
+
711
+ :param type_name: The name of the enum type to drop values from.
712
+ :param values: The values to drop from the enum type.
713
+ :param target_columns: Specified tuples of (table_name, column_name) pairs that
714
+ should be migrated to the new enum value.
715
+
716
+ """
717
+ assert_is_safe_sql_identifier(type_name)
718
+ for table_name, column_name in target_columns:
719
+ assert_is_safe_sql_identifier(table_name)
720
+ assert_is_safe_sql_identifier(column_name)
721
+
722
+ type_identifier = QueryIdentifier(type_name)
723
+ old_type_identifier = QueryIdentifier(f"{type_name}_old")
724
+ values_to_remove = format_sql_values(values)
725
+
726
+ column_modifications = ";\n".join(
727
+ [
728
+ (
729
+ # The "USING" param is required for enum migration
730
+ f"EXECUTE 'ALTER TABLE {QueryIdentifier(table_name)} ALTER COLUMN {QueryIdentifier(column_name)} TYPE {type_identifier}"
731
+ f" USING {QueryIdentifier(column_name)}::text::{type_identifier}'"
732
+ )
733
+ for table_name, column_name in target_columns
734
+ ]
735
+ )
736
+ if column_modifications:
737
+ column_modifications += ";"
738
+
739
+ await self._record_signature(
740
+ self.drop_type_values,
741
+ dict(type_name=type_name, values=values, target_columns=target_columns),
742
+ f"""
743
+ DO $$
744
+ DECLARE
745
+ vals text;
746
+ BEGIN
747
+ -- Move the current enum to a temporary type
748
+ EXECUTE 'ALTER TYPE {type_identifier} RENAME TO {old_type_identifier}';
749
+
750
+ -- Retrieve all current enum values except those to be excluded
751
+ SELECT string_agg('''' || unnest || '''', ', ' ORDER BY unnest) INTO vals
752
+ FROM unnest(enum_range(NULL::{old_type_identifier})) AS unnest
753
+ WHERE unnest NOT IN ({values_to_remove});
754
+
755
+ -- Create and populate our new type with the desired changes
756
+ EXECUTE format('CREATE TYPE {type_identifier} AS ENUM (%s)', vals);
757
+
758
+ -- Switch over affected columns to the new type
759
+ {column_modifications}
760
+
761
+ -- Drop the old type
762
+ EXECUTE 'DROP TYPE {old_type_identifier}';
763
+ END $$;
764
+ """,
765
+ )
766
+
767
+ async def drop_type(self, type_name: str):
768
+ """
769
+ Deletes an enum type from the database.
770
+
771
+ """
772
+ assert_is_safe_sql_identifier(type_name)
773
+ type_identifier = QueryIdentifier(type_name)
774
+
775
+ await self._record_signature(
776
+ self.drop_type,
777
+ dict(type_name=type_name),
778
+ f"""
779
+ DROP TYPE {type_identifier}
780
+ """,
781
+ )
782
+
783
+ def _get_column_type(
784
+ self,
785
+ explicit_data_type: ColumnType | None = None,
786
+ explicit_data_is_list: bool = False,
787
+ custom_data_type: str | None = None,
788
+ ) -> str:
789
+ if explicit_data_type:
790
+ return f"{explicit_data_type}{'[]' if explicit_data_is_list else ''}"
791
+ elif custom_data_type:
792
+ return custom_data_type
793
+ else:
794
+ raise ValueError(
795
+ "Must provide either an explicit data type or a custom data type."
796
+ )
797
+
798
+ async def _record_signature(
799
+ self,
800
+ action: Callable,
801
+ kwargs: dict[str, Any],
802
+ sql: str | list[str],
803
+ ):
804
+ """
805
+ If we are doing a dry-run through the migration, only record the method
806
+ signature that was provided. Otherwise if we're actually executing the
807
+ migration, record the SQL that was generated.
808
+
809
+ """
810
+ # Validate that the kwargs can populate all of the action signature arguments
811
+ # that are not optional, and that we don't provide any kwargs that aren't specified
812
+ # in the action signature
813
+ # Get the signature of the action
814
+ sig = signature(action)
815
+ parameters = sig.parameters
816
+
817
+ # Check for required arguments not supplied
818
+ missing_args = [
819
+ name
820
+ for name, param in parameters.items()
821
+ if param.default is Parameter.empty and name not in kwargs
822
+ ]
823
+ if missing_args:
824
+ raise ValueError(f"Missing required arguments: {missing_args}")
825
+
826
+ # Check for extraneous arguments in kwargs
827
+ extraneous_args = [key for key in kwargs if key not in parameters]
828
+ if extraneous_args:
829
+ raise ValueError(f"Extraneous arguments provided: {extraneous_args}")
830
+
831
+ if self.dry_run:
832
+ self.dry_run_actions.append(
833
+ DryRunAction(
834
+ fn=action,
835
+ kwargs=kwargs,
836
+ )
837
+ )
838
+ else:
839
+ if self.db_connection is None:
840
+ raise ValueError("Cannot execute migration without a database session")
841
+
842
+ sql_list = [sql] if isinstance(sql, str) else sql
843
+ for sql_query in sql_list:
844
+ LOGGER.debug(f"Executing migration SQL: {sql_query}")
845
+
846
+ self.prod_sqls.append(sql_query)
847
+
848
+ try:
849
+ await self.db_connection.conn.execute(sql_query)
850
+ except Exception as e:
851
+ # Default errors typically don't include context on the failing SQL
852
+ LOGGER.error(f"Error executing migration SQL: {sql_query}")
853
+ raise e
854
+
855
+ def add_comment(self, text: str, previous_line: bool = False):
856
+ """
857
+ Only used in dry-run mode to record a code-based comment that should
858
+ be added to the migration file.
859
+
860
+ """
861
+ if self.dry_run:
862
+ self.dry_run_actions.append(
863
+ DryRunComment(text=text, previous_line=previous_line)
864
+ )