iceaxe 0.7.1__cp313-cp313-macosx_11_0_arm64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of iceaxe might be problematic. Click here for more details.
- iceaxe/__init__.py +20 -0
- iceaxe/__tests__/__init__.py +0 -0
- iceaxe/__tests__/benchmarks/__init__.py +0 -0
- iceaxe/__tests__/benchmarks/test_bulk_insert.py +45 -0
- iceaxe/__tests__/benchmarks/test_select.py +114 -0
- iceaxe/__tests__/conf_models.py +133 -0
- iceaxe/__tests__/conftest.py +204 -0
- iceaxe/__tests__/docker_helpers.py +208 -0
- iceaxe/__tests__/helpers.py +268 -0
- iceaxe/__tests__/migrations/__init__.py +0 -0
- iceaxe/__tests__/migrations/conftest.py +36 -0
- iceaxe/__tests__/migrations/test_action_sorter.py +237 -0
- iceaxe/__tests__/migrations/test_generator.py +140 -0
- iceaxe/__tests__/migrations/test_generics.py +91 -0
- iceaxe/__tests__/mountaineer/__init__.py +0 -0
- iceaxe/__tests__/mountaineer/dependencies/__init__.py +0 -0
- iceaxe/__tests__/mountaineer/dependencies/test_core.py +76 -0
- iceaxe/__tests__/schemas/__init__.py +0 -0
- iceaxe/__tests__/schemas/test_actions.py +1264 -0
- iceaxe/__tests__/schemas/test_cli.py +25 -0
- iceaxe/__tests__/schemas/test_db_memory_serializer.py +1525 -0
- iceaxe/__tests__/schemas/test_db_serializer.py +398 -0
- iceaxe/__tests__/schemas/test_db_stubs.py +190 -0
- iceaxe/__tests__/test_alias.py +83 -0
- iceaxe/__tests__/test_base.py +52 -0
- iceaxe/__tests__/test_comparison.py +383 -0
- iceaxe/__tests__/test_field.py +11 -0
- iceaxe/__tests__/test_helpers.py +9 -0
- iceaxe/__tests__/test_modifications.py +151 -0
- iceaxe/__tests__/test_queries.py +605 -0
- iceaxe/__tests__/test_queries_str.py +173 -0
- iceaxe/__tests__/test_session.py +1511 -0
- iceaxe/__tests__/test_text_search.py +287 -0
- iceaxe/alias_values.py +67 -0
- iceaxe/base.py +350 -0
- iceaxe/comparison.py +560 -0
- iceaxe/field.py +250 -0
- iceaxe/functions.py +906 -0
- iceaxe/generics.py +140 -0
- iceaxe/io.py +107 -0
- iceaxe/logging.py +91 -0
- iceaxe/migrations/__init__.py +5 -0
- iceaxe/migrations/action_sorter.py +98 -0
- iceaxe/migrations/cli.py +228 -0
- iceaxe/migrations/client_io.py +62 -0
- iceaxe/migrations/generator.py +404 -0
- iceaxe/migrations/migration.py +86 -0
- iceaxe/migrations/migrator.py +101 -0
- iceaxe/modifications.py +176 -0
- iceaxe/mountaineer/__init__.py +10 -0
- iceaxe/mountaineer/cli.py +74 -0
- iceaxe/mountaineer/config.py +46 -0
- iceaxe/mountaineer/dependencies/__init__.py +6 -0
- iceaxe/mountaineer/dependencies/core.py +67 -0
- iceaxe/postgres.py +133 -0
- iceaxe/py.typed +0 -0
- iceaxe/queries.py +1455 -0
- iceaxe/queries_str.py +294 -0
- iceaxe/schemas/__init__.py +0 -0
- iceaxe/schemas/actions.py +864 -0
- iceaxe/schemas/cli.py +30 -0
- iceaxe/schemas/db_memory_serializer.py +705 -0
- iceaxe/schemas/db_serializer.py +346 -0
- iceaxe/schemas/db_stubs.py +525 -0
- iceaxe/session.py +860 -0
- iceaxe/session_optimized.c +12035 -0
- iceaxe/session_optimized.cpython-313-darwin.so +0 -0
- iceaxe/session_optimized.pyx +212 -0
- iceaxe/sql_types.py +148 -0
- iceaxe/typing.py +73 -0
- iceaxe-0.7.1.dist-info/METADATA +261 -0
- iceaxe-0.7.1.dist-info/RECORD +75 -0
- iceaxe-0.7.1.dist-info/WHEEL +6 -0
- iceaxe-0.7.1.dist-info/licenses/LICENSE +21 -0
- iceaxe-0.7.1.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,525 @@
|
|
|
1
|
+
import re
|
|
2
|
+
from abc import abstractmethod
|
|
3
|
+
from dataclasses import dataclass
|
|
4
|
+
from typing import Generic, Self, TypeVar, Union, cast
|
|
5
|
+
|
|
6
|
+
from pydantic import BaseModel, Field, model_validator
|
|
7
|
+
|
|
8
|
+
from iceaxe.schemas.actions import (
|
|
9
|
+
CheckConstraint,
|
|
10
|
+
ColumnType,
|
|
11
|
+
ConstraintType,
|
|
12
|
+
DatabaseActions,
|
|
13
|
+
ForeignKeyConstraint,
|
|
14
|
+
)
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
@dataclass
|
|
18
|
+
class ConstraintPointerInfo:
|
|
19
|
+
"""Information parsed from a constraint pointer representation."""
|
|
20
|
+
|
|
21
|
+
table_name: str
|
|
22
|
+
column_names: list[str]
|
|
23
|
+
constraint_type: str
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
T = TypeVar("T", bound="DBObject")
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
class DBObject(BaseModel, Generic[T]):
|
|
30
|
+
"""
|
|
31
|
+
A subclass for all models that are intended to store an in-memory representation
|
|
32
|
+
of a database object that we can perform diff support against.
|
|
33
|
+
|
|
34
|
+
Our Generic[T] here is a bit of a hack to allow us to properly typehint the expected
|
|
35
|
+
API contract of child implementations. `Self` in pyright results in fixing the API
|
|
36
|
+
contract to the base class DBObject whereas we want it to adjust to the child class.
|
|
37
|
+
|
|
38
|
+
"""
|
|
39
|
+
|
|
40
|
+
model_config = {
|
|
41
|
+
"frozen": True,
|
|
42
|
+
}
|
|
43
|
+
|
|
44
|
+
@abstractmethod
|
|
45
|
+
def representation(self) -> str:
|
|
46
|
+
"""
|
|
47
|
+
The representation should be unique in global namespace, used to de-duplicate
|
|
48
|
+
objects across multiple migration revisions.
|
|
49
|
+
|
|
50
|
+
"""
|
|
51
|
+
pass
|
|
52
|
+
|
|
53
|
+
@abstractmethod
|
|
54
|
+
async def create(self, actor: DatabaseActions):
|
|
55
|
+
pass
|
|
56
|
+
|
|
57
|
+
@abstractmethod
|
|
58
|
+
async def migrate(self, previous: T, actor: DatabaseActions):
|
|
59
|
+
pass
|
|
60
|
+
|
|
61
|
+
@abstractmethod
|
|
62
|
+
async def destroy(self, actor: DatabaseActions):
|
|
63
|
+
pass
|
|
64
|
+
|
|
65
|
+
def merge(self, other: T) -> T:
|
|
66
|
+
"""
|
|
67
|
+
If there is another object with the same .reference() as this object
|
|
68
|
+
this function is in charge of merging the two objects. By default
|
|
69
|
+
we will just use an equality check to ensure that the objects are the
|
|
70
|
+
same and return the current object.
|
|
71
|
+
|
|
72
|
+
If clients override this function, ensure that the result is the same regardless
|
|
73
|
+
of the order that the merge is called in. Callers make no guarantee about the
|
|
74
|
+
resolution order.
|
|
75
|
+
|
|
76
|
+
"""
|
|
77
|
+
if self != other:
|
|
78
|
+
raise ValueError(
|
|
79
|
+
f"Conflicting definitions for {self.representation()}\n{self} != {other}"
|
|
80
|
+
)
|
|
81
|
+
return cast(T, self)
|
|
82
|
+
|
|
83
|
+
|
|
84
|
+
class DBObjectPointer(BaseModel):
|
|
85
|
+
"""
|
|
86
|
+
A pointer to an object that was already created elsewhere. Used only for DAG comparisons. Make sure
|
|
87
|
+
the representation mirrors the root object string - otherwise comparison
|
|
88
|
+
won't work properly.
|
|
89
|
+
|
|
90
|
+
We typically use pointers in cases where we want to reference an object that should
|
|
91
|
+
already be created, and the change in the child value shouldn't auto-update the parent.
|
|
92
|
+
Since by default we use direct model-equality to determine whether we create a migration
|
|
93
|
+
stage, nesting a full DBObject within a parent object would otherwise cause the parent
|
|
94
|
+
to update.
|
|
95
|
+
|
|
96
|
+
"""
|
|
97
|
+
|
|
98
|
+
model_config = {
|
|
99
|
+
"frozen": True,
|
|
100
|
+
}
|
|
101
|
+
|
|
102
|
+
@abstractmethod
|
|
103
|
+
def representation(self) -> str:
|
|
104
|
+
pass
|
|
105
|
+
|
|
106
|
+
def parse_constraint_pointer(self) -> ConstraintPointerInfo | None:
|
|
107
|
+
"""
|
|
108
|
+
Parse a constraint pointer representation into its components.
|
|
109
|
+
|
|
110
|
+
Returns:
|
|
111
|
+
ConstraintPointerInfo | None: Parsed constraint information or None if not a constraint pointer
|
|
112
|
+
|
|
113
|
+
Examples:
|
|
114
|
+
"table.['column'].PRIMARY KEY" -> ConstraintPointerInfo("table", ["column"], "PRIMARY KEY")
|
|
115
|
+
"table.['col1', 'col2'].UNIQUE" -> ConstraintPointerInfo("table", ["col1", "col2"], "UNIQUE")
|
|
116
|
+
"""
|
|
117
|
+
representation = self.representation()
|
|
118
|
+
|
|
119
|
+
# Pattern to match: table_name.[column_list].constraint_type
|
|
120
|
+
# where column_list can be ['col'] or ['col1', 'col2', ...]
|
|
121
|
+
# The table_name can contain dots (for schema.table), so we need to be more careful
|
|
122
|
+
# We look for the pattern .[...]. to identify where the column list starts
|
|
123
|
+
pattern = r"^(.+)\.(\[.*?\])\.(.+)$"
|
|
124
|
+
match = re.match(pattern, representation)
|
|
125
|
+
|
|
126
|
+
if not match:
|
|
127
|
+
return None
|
|
128
|
+
|
|
129
|
+
table_name, columns_part, constraint_type = match.groups()
|
|
130
|
+
|
|
131
|
+
# Validate that the column list contains properly quoted column names or is empty
|
|
132
|
+
# Remove brackets and check the content
|
|
133
|
+
columns_str = columns_part.strip("[]")
|
|
134
|
+
if not columns_str:
|
|
135
|
+
# Empty column list is valid
|
|
136
|
+
return ConstraintPointerInfo(table_name, [], constraint_type)
|
|
137
|
+
|
|
138
|
+
# Split by comma and validate each column name is properly quoted
|
|
139
|
+
columns = []
|
|
140
|
+
for col in columns_str.split(","):
|
|
141
|
+
col = col.strip()
|
|
142
|
+
# Check if the column is properly quoted (single or double quotes)
|
|
143
|
+
if (col.startswith("'") and col.endswith("'")) or (
|
|
144
|
+
col.startswith('"') and col.endswith('"')
|
|
145
|
+
):
|
|
146
|
+
# Remove quotes and add to list
|
|
147
|
+
col_name = col[1:-1]
|
|
148
|
+
if col_name: # Don't add empty column names
|
|
149
|
+
columns.append(col_name)
|
|
150
|
+
else:
|
|
151
|
+
# Column is not properly quoted, this is not a valid constraint pointer
|
|
152
|
+
return None
|
|
153
|
+
|
|
154
|
+
return ConstraintPointerInfo(table_name, columns, constraint_type)
|
|
155
|
+
|
|
156
|
+
def get_table_name(self) -> str | None:
|
|
157
|
+
"""
|
|
158
|
+
Extract the table name from the pointer representation.
|
|
159
|
+
|
|
160
|
+
Returns:
|
|
161
|
+
str | None: The table name if it can be parsed, None otherwise
|
|
162
|
+
"""
|
|
163
|
+
# Try constraint pointer format first
|
|
164
|
+
parsed = self.parse_constraint_pointer()
|
|
165
|
+
if parsed is not None:
|
|
166
|
+
return parsed.table_name
|
|
167
|
+
|
|
168
|
+
# Try simple table.column format
|
|
169
|
+
representation = self.representation()
|
|
170
|
+
if not representation:
|
|
171
|
+
return None
|
|
172
|
+
|
|
173
|
+
parts = representation.split(".")
|
|
174
|
+
if len(parts) >= 2:
|
|
175
|
+
# For schema.table.column format, take all parts except the last one
|
|
176
|
+
return ".".join(parts[:-1])
|
|
177
|
+
elif len(parts) == 1:
|
|
178
|
+
# Just a table name
|
|
179
|
+
return parts[0]
|
|
180
|
+
else:
|
|
181
|
+
return None
|
|
182
|
+
|
|
183
|
+
def get_column_names(self) -> list[str]:
|
|
184
|
+
"""
|
|
185
|
+
Extract column names from the pointer representation.
|
|
186
|
+
|
|
187
|
+
Returns:
|
|
188
|
+
list[str]: List of column names if they can be parsed, empty list otherwise
|
|
189
|
+
"""
|
|
190
|
+
# Try constraint pointer format first
|
|
191
|
+
parsed = self.parse_constraint_pointer()
|
|
192
|
+
if parsed is not None:
|
|
193
|
+
return parsed.column_names
|
|
194
|
+
|
|
195
|
+
# Try simple table.column format
|
|
196
|
+
representation = self.representation()
|
|
197
|
+
if not representation:
|
|
198
|
+
return []
|
|
199
|
+
|
|
200
|
+
parts = representation.split(".")
|
|
201
|
+
if len(parts) >= 2:
|
|
202
|
+
# For schema.table.column format, take the last part as the column name
|
|
203
|
+
return [parts[-1]]
|
|
204
|
+
else:
|
|
205
|
+
# Just a table name, no columns
|
|
206
|
+
return []
|
|
207
|
+
|
|
208
|
+
|
|
209
|
+
class DBTable(DBObject["DBTable"]):
|
|
210
|
+
table_name: str
|
|
211
|
+
|
|
212
|
+
def representation(self):
|
|
213
|
+
return self.table_name
|
|
214
|
+
|
|
215
|
+
async def create(self, actor: DatabaseActions):
|
|
216
|
+
actor.add_comment(f"\nNEW TABLE: {self.table_name}\n")
|
|
217
|
+
await actor.add_table(self.table_name)
|
|
218
|
+
|
|
219
|
+
async def migrate(self, previous: Self, actor: DatabaseActions):
|
|
220
|
+
raise NotImplementedError
|
|
221
|
+
|
|
222
|
+
async def destroy(self, actor: DatabaseActions):
|
|
223
|
+
await actor.drop_table(self.table_name)
|
|
224
|
+
|
|
225
|
+
|
|
226
|
+
class DBColumnBase(BaseModel):
|
|
227
|
+
table_name: str
|
|
228
|
+
column_name: str
|
|
229
|
+
|
|
230
|
+
def representation(self):
|
|
231
|
+
return f"{self.table_name}.{self.column_name}"
|
|
232
|
+
|
|
233
|
+
|
|
234
|
+
class DBColumnPointer(DBColumnBase, DBObjectPointer):
|
|
235
|
+
pass
|
|
236
|
+
|
|
237
|
+
|
|
238
|
+
class DBColumn(DBColumnBase, DBObject["DBColumn"]):
|
|
239
|
+
# Use a type pointer here to avoid full equality checks
|
|
240
|
+
# of the values; if the pointer is the same, we can avoid
|
|
241
|
+
# updating the column type during a migration.
|
|
242
|
+
column_type: Union["DBTypePointer", ColumnType]
|
|
243
|
+
column_is_list: bool
|
|
244
|
+
|
|
245
|
+
nullable: bool
|
|
246
|
+
|
|
247
|
+
autoincrement: bool = False
|
|
248
|
+
|
|
249
|
+
async def create(self, actor: DatabaseActions):
|
|
250
|
+
# The only time SERIAL types are allowed is during creation for autoincrementing
|
|
251
|
+
# integer columns
|
|
252
|
+
explicit_data_type: ColumnType | None = None
|
|
253
|
+
if isinstance(self.column_type, ColumnType):
|
|
254
|
+
if self.column_type == ColumnType.INTEGER and self.autoincrement:
|
|
255
|
+
explicit_data_type = ColumnType.SERIAL
|
|
256
|
+
else:
|
|
257
|
+
explicit_data_type = self.column_type
|
|
258
|
+
|
|
259
|
+
await actor.add_column(
|
|
260
|
+
self.table_name,
|
|
261
|
+
self.column_name,
|
|
262
|
+
explicit_data_type=explicit_data_type,
|
|
263
|
+
explicit_data_is_list=self.column_is_list,
|
|
264
|
+
custom_data_type=(
|
|
265
|
+
self.column_type.representation()
|
|
266
|
+
if isinstance(self.column_type, DBTypePointer)
|
|
267
|
+
else None
|
|
268
|
+
),
|
|
269
|
+
)
|
|
270
|
+
|
|
271
|
+
if not self.nullable:
|
|
272
|
+
await actor.add_not_null(self.table_name, self.column_name)
|
|
273
|
+
|
|
274
|
+
async def destroy(self, actor: DatabaseActions):
|
|
275
|
+
# Destorying the column means we'll also drop constraints associated with it
|
|
276
|
+
# like not-null.
|
|
277
|
+
await actor.drop_column(self.table_name, self.column_name)
|
|
278
|
+
|
|
279
|
+
async def migrate(self, previous: Self, actor: DatabaseActions):
|
|
280
|
+
if (
|
|
281
|
+
self.column_type != previous.column_type
|
|
282
|
+
or self.column_is_list != previous.column_is_list
|
|
283
|
+
):
|
|
284
|
+
await actor.modify_column_type(
|
|
285
|
+
self.table_name,
|
|
286
|
+
self.column_name,
|
|
287
|
+
explicit_data_type=(
|
|
288
|
+
self.column_type
|
|
289
|
+
if isinstance(self.column_type, ColumnType)
|
|
290
|
+
else None
|
|
291
|
+
),
|
|
292
|
+
explicit_data_is_list=self.column_is_list,
|
|
293
|
+
custom_data_type=(
|
|
294
|
+
self.column_type.name
|
|
295
|
+
if isinstance(self.column_type, DBTypePointer)
|
|
296
|
+
else None
|
|
297
|
+
),
|
|
298
|
+
autocast=True,
|
|
299
|
+
)
|
|
300
|
+
actor.add_comment(
|
|
301
|
+
"TODO: Perform a migration of values across types", previous_line=True
|
|
302
|
+
)
|
|
303
|
+
|
|
304
|
+
if not self.nullable and previous.nullable:
|
|
305
|
+
await actor.add_not_null(self.table_name, self.column_name)
|
|
306
|
+
if self.nullable and not previous.nullable:
|
|
307
|
+
await actor.drop_not_null(self.table_name, self.column_name)
|
|
308
|
+
|
|
309
|
+
|
|
310
|
+
class DBConstraint(DBObject["DBConstraint"]):
|
|
311
|
+
table_name: str
|
|
312
|
+
constraint_name: str = Field(exclude=True)
|
|
313
|
+
columns: frozenset[str]
|
|
314
|
+
|
|
315
|
+
constraint_type: ConstraintType
|
|
316
|
+
|
|
317
|
+
foreign_key_constraint: ForeignKeyConstraint | None = None
|
|
318
|
+
check_constraint: CheckConstraint | None = None
|
|
319
|
+
|
|
320
|
+
@model_validator(mode="after")
|
|
321
|
+
def validate_constraint_type(self):
|
|
322
|
+
if (
|
|
323
|
+
self.constraint_type == ConstraintType.FOREIGN_KEY
|
|
324
|
+
and self.foreign_key_constraint is None
|
|
325
|
+
):
|
|
326
|
+
raise ValueError("Foreign key constraints require a ForeignKeyConstraint")
|
|
327
|
+
if (
|
|
328
|
+
self.constraint_type != ConstraintType.FOREIGN_KEY
|
|
329
|
+
and self.foreign_key_constraint is not None
|
|
330
|
+
):
|
|
331
|
+
raise ValueError(
|
|
332
|
+
"Only foreign key constraints require a ForeignKeyConstraint"
|
|
333
|
+
)
|
|
334
|
+
return self
|
|
335
|
+
|
|
336
|
+
def representation(self) -> str:
|
|
337
|
+
# Different construction methods sort the constraint parameters in different ways
|
|
338
|
+
# We rely on sorting these parameters to ensure that the representation matches
|
|
339
|
+
# across these different construction methods
|
|
340
|
+
return f"{self.table_name}.{sorted(self.columns)}.{self.constraint_type}"
|
|
341
|
+
|
|
342
|
+
@classmethod
|
|
343
|
+
def new_constraint_name(
|
|
344
|
+
cls,
|
|
345
|
+
table_name: str,
|
|
346
|
+
columns: list[str],
|
|
347
|
+
constraint_type: ConstraintType,
|
|
348
|
+
):
|
|
349
|
+
elements = [table_name]
|
|
350
|
+
if constraint_type == ConstraintType.PRIMARY_KEY:
|
|
351
|
+
elements.append("pkey")
|
|
352
|
+
elif constraint_type == ConstraintType.FOREIGN_KEY:
|
|
353
|
+
elements += sorted(columns)
|
|
354
|
+
elements.append("fkey")
|
|
355
|
+
elif constraint_type == ConstraintType.UNIQUE:
|
|
356
|
+
elements += sorted(columns)
|
|
357
|
+
elements.append("unique")
|
|
358
|
+
elif constraint_type == ConstraintType.INDEX:
|
|
359
|
+
elements += sorted(columns)
|
|
360
|
+
elements.append("idx")
|
|
361
|
+
else:
|
|
362
|
+
elements += sorted(columns)
|
|
363
|
+
elements.append("key")
|
|
364
|
+
return "_".join(elements)
|
|
365
|
+
|
|
366
|
+
async def create(self, actor: DatabaseActions):
|
|
367
|
+
if self.constraint_type == ConstraintType.FOREIGN_KEY:
|
|
368
|
+
assert self.foreign_key_constraint is not None
|
|
369
|
+
await actor.add_constraint(
|
|
370
|
+
self.table_name,
|
|
371
|
+
constraint=self.constraint_type,
|
|
372
|
+
constraint_name=self.constraint_name,
|
|
373
|
+
constraint_args=self.foreign_key_constraint,
|
|
374
|
+
columns=list(self.columns),
|
|
375
|
+
)
|
|
376
|
+
elif self.constraint_type == ConstraintType.CHECK:
|
|
377
|
+
assert self.check_constraint is not None
|
|
378
|
+
await actor.add_constraint(
|
|
379
|
+
self.table_name,
|
|
380
|
+
constraint=self.constraint_type,
|
|
381
|
+
constraint_name=self.constraint_name,
|
|
382
|
+
constraint_args=self.check_constraint,
|
|
383
|
+
columns=list(self.columns),
|
|
384
|
+
)
|
|
385
|
+
elif self.constraint_type == ConstraintType.INDEX:
|
|
386
|
+
await actor.add_index(
|
|
387
|
+
self.table_name,
|
|
388
|
+
columns=list(self.columns),
|
|
389
|
+
index_name=self.constraint_name,
|
|
390
|
+
)
|
|
391
|
+
else:
|
|
392
|
+
await actor.add_constraint(
|
|
393
|
+
self.table_name,
|
|
394
|
+
constraint=self.constraint_type,
|
|
395
|
+
constraint_name=self.constraint_name,
|
|
396
|
+
columns=list(self.columns),
|
|
397
|
+
)
|
|
398
|
+
|
|
399
|
+
async def destroy(self, actor: DatabaseActions):
|
|
400
|
+
if self.constraint_type == ConstraintType.INDEX:
|
|
401
|
+
await actor.drop_index(
|
|
402
|
+
self.table_name,
|
|
403
|
+
index_name=self.constraint_name,
|
|
404
|
+
)
|
|
405
|
+
else:
|
|
406
|
+
await actor.drop_constraint(
|
|
407
|
+
self.table_name,
|
|
408
|
+
constraint_name=self.constraint_name,
|
|
409
|
+
)
|
|
410
|
+
|
|
411
|
+
async def migrate(self, previous: Self, actor: DatabaseActions):
|
|
412
|
+
if self.constraint_type != previous.constraint_type:
|
|
413
|
+
raise NotImplementedError
|
|
414
|
+
|
|
415
|
+
# Since we allow some flexibility in column ordering, and that affects
|
|
416
|
+
# the actual constarint name, it's possible that this function is being called
|
|
417
|
+
# with a previous example that is actually the same - but fails the equality check.
|
|
418
|
+
# We re-do a proper comparison here to ensure that we don't do unnecessary work.
|
|
419
|
+
has_changed = False
|
|
420
|
+
|
|
421
|
+
self_dict = self.model_dump()
|
|
422
|
+
previous_dict = previous.model_dump()
|
|
423
|
+
|
|
424
|
+
for key in self_dict.keys():
|
|
425
|
+
previous_value = self_dict[key]
|
|
426
|
+
current_value = previous_dict[key]
|
|
427
|
+
if previous_value != current_value:
|
|
428
|
+
has_changed = True
|
|
429
|
+
break
|
|
430
|
+
|
|
431
|
+
if has_changed:
|
|
432
|
+
await self.destroy(actor)
|
|
433
|
+
await self.create(actor)
|
|
434
|
+
|
|
435
|
+
|
|
436
|
+
class DBTypeBase(BaseModel):
|
|
437
|
+
name: str
|
|
438
|
+
|
|
439
|
+
def representation(self):
|
|
440
|
+
# Type definitions are global by nature
|
|
441
|
+
return self.name
|
|
442
|
+
|
|
443
|
+
|
|
444
|
+
class DBTypePointer(DBTypeBase, DBObjectPointer):
|
|
445
|
+
pass
|
|
446
|
+
|
|
447
|
+
|
|
448
|
+
class DBType(DBTypeBase, DBObject["DBType"]):
|
|
449
|
+
values: frozenset[str]
|
|
450
|
+
|
|
451
|
+
# Captures the columns that use this type value, (table_name, column_name)
|
|
452
|
+
# so we can migrate them properly to new types. Type dropping in Postgres
|
|
453
|
+
# isn't supported.
|
|
454
|
+
reference_columns: frozenset[tuple[str, str]]
|
|
455
|
+
|
|
456
|
+
async def create(self, actor: DatabaseActions):
|
|
457
|
+
await actor.add_type(self.name, sorted(list(self.values)))
|
|
458
|
+
|
|
459
|
+
async def destroy(self, actor: DatabaseActions):
|
|
460
|
+
await actor.drop_type(self.name)
|
|
461
|
+
|
|
462
|
+
async def migrate(self, previous: Self, actor: DatabaseActions):
|
|
463
|
+
previous_values = {value for value in previous.values}
|
|
464
|
+
next_values = {value for value in self.values}
|
|
465
|
+
|
|
466
|
+
# We need to update the enum with the new values
|
|
467
|
+
new_values = set(next_values) - set(previous_values)
|
|
468
|
+
deleted_values = set(previous_values) - set(next_values)
|
|
469
|
+
|
|
470
|
+
if new_values:
|
|
471
|
+
await actor.add_type_values(
|
|
472
|
+
self.name,
|
|
473
|
+
sorted(new_values),
|
|
474
|
+
)
|
|
475
|
+
|
|
476
|
+
if deleted_values:
|
|
477
|
+
await actor.drop_type_values(
|
|
478
|
+
self.name,
|
|
479
|
+
sorted(deleted_values),
|
|
480
|
+
list(self.reference_columns),
|
|
481
|
+
)
|
|
482
|
+
|
|
483
|
+
def merge(self, other: "DBType") -> "DBType":
|
|
484
|
+
# We should only be merged with other types that are basically the same
|
|
485
|
+
# but might have different reference columns since they might be produced by
|
|
486
|
+
# different parts of the pipeline.
|
|
487
|
+
if self.name != other.name or self.values != other.values:
|
|
488
|
+
raise ValueError(
|
|
489
|
+
"Cannot merge types with different core values: {self.name}({self.values}) != {other.name}({other.values})"
|
|
490
|
+
)
|
|
491
|
+
|
|
492
|
+
return DBType(
|
|
493
|
+
name=self.name,
|
|
494
|
+
values=self.values,
|
|
495
|
+
reference_columns=self.reference_columns | other.reference_columns,
|
|
496
|
+
)
|
|
497
|
+
|
|
498
|
+
|
|
499
|
+
class DBConstraintPointer(DBObjectPointer):
|
|
500
|
+
"""
|
|
501
|
+
A pointer to a constraint that will be created. Used for dependency tracking
|
|
502
|
+
without needing to know the full constraint definition.
|
|
503
|
+
"""
|
|
504
|
+
|
|
505
|
+
table_name: str
|
|
506
|
+
columns: frozenset[str]
|
|
507
|
+
constraint_type: ConstraintType
|
|
508
|
+
|
|
509
|
+
def representation(self) -> str:
|
|
510
|
+
# Match the representation of DBConstraint
|
|
511
|
+
return f"{self.table_name}.{sorted(self.columns)}.{self.constraint_type}"
|
|
512
|
+
|
|
513
|
+
|
|
514
|
+
class DBPointerOr(DBObjectPointer):
|
|
515
|
+
"""
|
|
516
|
+
A pointer that represents an OR relationship between multiple pointers.
|
|
517
|
+
When resolving dependencies, any of the provided pointers being present
|
|
518
|
+
will satisfy the dependency.
|
|
519
|
+
"""
|
|
520
|
+
|
|
521
|
+
pointers: tuple[DBObjectPointer, ...]
|
|
522
|
+
|
|
523
|
+
def representation(self) -> str:
|
|
524
|
+
# Sort the representations to ensure consistent ordering
|
|
525
|
+
return "OR(" + ",".join(sorted(p.representation() for p in self.pointers)) + ")"
|