sqlspec 0.17.1__py3-none-any.whl → 0.18.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of sqlspec might be problematic. Click here for more details.
- sqlspec/__init__.py +1 -1
- sqlspec/_sql.py +54 -159
- sqlspec/adapters/adbc/config.py +24 -30
- sqlspec/adapters/adbc/driver.py +42 -61
- sqlspec/adapters/aiosqlite/config.py +5 -10
- sqlspec/adapters/aiosqlite/driver.py +9 -25
- sqlspec/adapters/aiosqlite/pool.py +43 -35
- sqlspec/adapters/asyncmy/config.py +10 -7
- sqlspec/adapters/asyncmy/driver.py +18 -39
- sqlspec/adapters/asyncpg/config.py +4 -0
- sqlspec/adapters/asyncpg/driver.py +32 -79
- sqlspec/adapters/bigquery/config.py +12 -65
- sqlspec/adapters/bigquery/driver.py +39 -133
- sqlspec/adapters/duckdb/config.py +11 -15
- sqlspec/adapters/duckdb/driver.py +61 -85
- sqlspec/adapters/duckdb/pool.py +2 -5
- sqlspec/adapters/oracledb/_types.py +8 -1
- sqlspec/adapters/oracledb/config.py +55 -38
- sqlspec/adapters/oracledb/driver.py +35 -92
- sqlspec/adapters/oracledb/migrations.py +257 -0
- sqlspec/adapters/psqlpy/config.py +13 -9
- sqlspec/adapters/psqlpy/driver.py +28 -103
- sqlspec/adapters/psycopg/config.py +9 -5
- sqlspec/adapters/psycopg/driver.py +107 -175
- sqlspec/adapters/sqlite/config.py +7 -5
- sqlspec/adapters/sqlite/driver.py +37 -73
- sqlspec/adapters/sqlite/pool.py +3 -12
- sqlspec/base.py +1 -8
- sqlspec/builder/__init__.py +1 -1
- sqlspec/builder/_base.py +34 -20
- sqlspec/builder/_ddl.py +407 -183
- sqlspec/builder/_insert.py +1 -1
- sqlspec/builder/mixins/_insert_operations.py +26 -6
- sqlspec/builder/mixins/_merge_operations.py +1 -1
- sqlspec/builder/mixins/_select_operations.py +1 -5
- sqlspec/config.py +32 -13
- sqlspec/core/__init__.py +89 -14
- sqlspec/core/cache.py +57 -104
- sqlspec/core/compiler.py +57 -112
- sqlspec/core/filters.py +1 -21
- sqlspec/core/hashing.py +13 -47
- sqlspec/core/parameters.py +272 -261
- sqlspec/core/result.py +12 -27
- sqlspec/core/splitter.py +17 -21
- sqlspec/core/statement.py +150 -159
- sqlspec/driver/_async.py +2 -15
- sqlspec/driver/_common.py +16 -95
- sqlspec/driver/_sync.py +2 -15
- sqlspec/driver/mixins/_result_tools.py +8 -29
- sqlspec/driver/mixins/_sql_translator.py +6 -8
- sqlspec/exceptions.py +1 -2
- sqlspec/loader.py +43 -115
- sqlspec/migrations/__init__.py +1 -1
- sqlspec/migrations/base.py +34 -45
- sqlspec/migrations/commands.py +34 -15
- sqlspec/migrations/loaders.py +1 -1
- sqlspec/migrations/runner.py +104 -19
- sqlspec/migrations/tracker.py +49 -2
- sqlspec/protocols.py +3 -6
- sqlspec/storage/__init__.py +4 -4
- sqlspec/storage/backends/fsspec.py +5 -6
- sqlspec/storage/backends/obstore.py +7 -8
- sqlspec/storage/registry.py +3 -3
- sqlspec/utils/__init__.py +2 -2
- sqlspec/utils/logging.py +6 -10
- sqlspec/utils/sync_tools.py +27 -4
- sqlspec/utils/text.py +6 -1
- {sqlspec-0.17.1.dist-info → sqlspec-0.18.0.dist-info}/METADATA +1 -1
- sqlspec-0.18.0.dist-info/RECORD +138 -0
- sqlspec/builder/_ddl_utils.py +0 -103
- sqlspec-0.17.1.dist-info/RECORD +0 -138
- {sqlspec-0.17.1.dist-info → sqlspec-0.18.0.dist-info}/WHEEL +0 -0
- {sqlspec-0.17.1.dist-info → sqlspec-0.18.0.dist-info}/entry_points.txt +0 -0
- {sqlspec-0.17.1.dist-info → sqlspec-0.18.0.dist-info}/licenses/LICENSE +0 -0
- {sqlspec-0.17.1.dist-info → sqlspec-0.18.0.dist-info}/licenses/NOTICE +0 -0
sqlspec/migrations/__init__.py
CHANGED
|
@@ -1,7 +1,7 @@
|
|
|
1
1
|
"""SQLSpec Migration Tool.
|
|
2
2
|
|
|
3
3
|
A native migration system for SQLSpec that leverages the SQLFileLoader
|
|
4
|
-
and driver
|
|
4
|
+
and driver system for database versioning.
|
|
5
5
|
"""
|
|
6
6
|
|
|
7
7
|
from sqlspec.migrations.commands import AsyncMigrationCommands, MigrationCommands, SyncMigrationCommands
|
sqlspec/migrations/base.py
CHANGED
|
@@ -3,18 +3,19 @@
|
|
|
3
3
|
This module provides abstract base classes for migration components.
|
|
4
4
|
"""
|
|
5
5
|
|
|
6
|
+
import hashlib
|
|
6
7
|
import operator
|
|
7
8
|
from abc import ABC, abstractmethod
|
|
8
9
|
from pathlib import Path
|
|
9
|
-
from typing import Any, Generic, Optional, TypeVar
|
|
10
|
+
from typing import Any, Generic, Optional, TypeVar, cast
|
|
10
11
|
|
|
11
12
|
from sqlspec._sql import sql
|
|
13
|
+
from sqlspec.builder import Delete, Insert, Select
|
|
12
14
|
from sqlspec.builder._ddl import CreateTable
|
|
13
|
-
from sqlspec.core.statement import SQL
|
|
14
15
|
from sqlspec.loader import SQLFileLoader
|
|
15
16
|
from sqlspec.migrations.loaders import get_migration_loader
|
|
16
17
|
from sqlspec.utils.logging import get_logger
|
|
17
|
-
from sqlspec.utils.sync_tools import
|
|
18
|
+
from sqlspec.utils.sync_tools import await_
|
|
18
19
|
|
|
19
20
|
__all__ = ("BaseMigrationCommands", "BaseMigrationRunner", "BaseMigrationTracker")
|
|
20
21
|
|
|
@@ -28,6 +29,8 @@ ConfigT = TypeVar("ConfigT")
|
|
|
28
29
|
class BaseMigrationTracker(ABC, Generic[DriverT]):
|
|
29
30
|
"""Base class for migration version tracking."""
|
|
30
31
|
|
|
32
|
+
__slots__ = ("version_table",)
|
|
33
|
+
|
|
31
34
|
def __init__(self, version_table_name: str = "ddl_migrations") -> None:
|
|
32
35
|
"""Initialize the migration tracker.
|
|
33
36
|
|
|
@@ -36,54 +39,43 @@ class BaseMigrationTracker(ABC, Generic[DriverT]):
|
|
|
36
39
|
"""
|
|
37
40
|
self.version_table = version_table_name
|
|
38
41
|
|
|
39
|
-
def _get_create_table_sql(self) ->
|
|
40
|
-
"""Get SQL for creating the tracking table.
|
|
42
|
+
def _get_create_table_sql(self) -> CreateTable:
|
|
43
|
+
"""Get SQL builder for creating the tracking table.
|
|
41
44
|
|
|
42
45
|
Returns:
|
|
43
|
-
SQL object for table creation.
|
|
46
|
+
SQL builder object for table creation.
|
|
44
47
|
"""
|
|
45
|
-
builder = CreateTable(self.version_table)
|
|
46
|
-
if not hasattr(builder, "_columns"):
|
|
47
|
-
builder._columns = []
|
|
48
|
-
if not hasattr(builder, "_constraints"):
|
|
49
|
-
builder._constraints = []
|
|
50
|
-
if not hasattr(builder, "_table_options"):
|
|
51
|
-
builder._table_options = {}
|
|
52
|
-
|
|
53
48
|
return (
|
|
54
|
-
|
|
49
|
+
sql.create_table(self.version_table)
|
|
50
|
+
.if_not_exists()
|
|
55
51
|
.column("version_num", "VARCHAR(32)", primary_key=True)
|
|
56
52
|
.column("description", "TEXT")
|
|
57
|
-
.column("applied_at", "TIMESTAMP",
|
|
53
|
+
.column("applied_at", "TIMESTAMP", default="CURRENT_TIMESTAMP", not_null=True)
|
|
58
54
|
.column("execution_time_ms", "INTEGER")
|
|
59
55
|
.column("checksum", "VARCHAR(64)")
|
|
60
56
|
.column("applied_by", "VARCHAR(255)")
|
|
61
|
-
)
|
|
57
|
+
)
|
|
62
58
|
|
|
63
|
-
def _get_current_version_sql(self) ->
|
|
64
|
-
"""Get SQL for retrieving current version.
|
|
59
|
+
def _get_current_version_sql(self) -> Select:
|
|
60
|
+
"""Get SQL builder for retrieving current version.
|
|
65
61
|
|
|
66
62
|
Returns:
|
|
67
|
-
SQL object for version query.
|
|
63
|
+
SQL builder object for version query.
|
|
68
64
|
"""
|
|
65
|
+
return sql.select("version_num").from_(self.version_table).order_by("version_num DESC").limit(1)
|
|
69
66
|
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
).to_statement()
|
|
73
|
-
|
|
74
|
-
def _get_applied_migrations_sql(self) -> SQL:
|
|
75
|
-
"""Get SQL for retrieving all applied migrations.
|
|
67
|
+
def _get_applied_migrations_sql(self) -> Select:
|
|
68
|
+
"""Get SQL builder for retrieving all applied migrations.
|
|
76
69
|
|
|
77
70
|
Returns:
|
|
78
|
-
SQL object for migrations query.
|
|
71
|
+
SQL builder object for migrations query.
|
|
79
72
|
"""
|
|
80
|
-
|
|
81
|
-
return (sql.select("*").from_(self.version_table).order_by("version_num")).to_statement()
|
|
73
|
+
return sql.select("*").from_(self.version_table).order_by("version_num")
|
|
82
74
|
|
|
83
75
|
def _get_record_migration_sql(
|
|
84
76
|
self, version: str, description: str, execution_time_ms: int, checksum: str, applied_by: str
|
|
85
|
-
) ->
|
|
86
|
-
"""Get SQL for recording a migration.
|
|
77
|
+
) -> Insert:
|
|
78
|
+
"""Get SQL builder for recording a migration.
|
|
87
79
|
|
|
88
80
|
Args:
|
|
89
81
|
version: Version number of the migration.
|
|
@@ -93,26 +85,24 @@ class BaseMigrationTracker(ABC, Generic[DriverT]):
|
|
|
93
85
|
applied_by: User who applied the migration.
|
|
94
86
|
|
|
95
87
|
Returns:
|
|
96
|
-
SQL object for insert.
|
|
88
|
+
SQL builder object for insert.
|
|
97
89
|
"""
|
|
98
|
-
|
|
99
90
|
return (
|
|
100
91
|
sql.insert(self.version_table)
|
|
101
92
|
.columns("version_num", "description", "execution_time_ms", "checksum", "applied_by")
|
|
102
93
|
.values(version, description, execution_time_ms, checksum, applied_by)
|
|
103
|
-
)
|
|
94
|
+
)
|
|
104
95
|
|
|
105
|
-
def _get_remove_migration_sql(self, version: str) ->
|
|
106
|
-
"""Get SQL for removing a migration record.
|
|
96
|
+
def _get_remove_migration_sql(self, version: str) -> Delete:
|
|
97
|
+
"""Get SQL builder for removing a migration record.
|
|
107
98
|
|
|
108
99
|
Args:
|
|
109
100
|
version: Version number to remove.
|
|
110
101
|
|
|
111
102
|
Returns:
|
|
112
|
-
SQL object for delete.
|
|
103
|
+
SQL builder object for delete.
|
|
113
104
|
"""
|
|
114
|
-
|
|
115
|
-
return (sql.delete().from_(self.version_table).where(sql.version_num == version)).to_statement()
|
|
105
|
+
return sql.delete().from_(self.version_table).where(sql.version_num == version)
|
|
116
106
|
|
|
117
107
|
@abstractmethod
|
|
118
108
|
def ensure_tracking_table(self, driver: DriverT) -> Any:
|
|
@@ -176,7 +166,6 @@ class BaseMigrationRunner(ABC, Generic[DriverT]):
|
|
|
176
166
|
Returns:
|
|
177
167
|
MD5 checksum hex string.
|
|
178
168
|
"""
|
|
179
|
-
import hashlib
|
|
180
169
|
|
|
181
170
|
return hashlib.md5(content.encode()).hexdigest() # noqa: S324
|
|
182
171
|
|
|
@@ -226,7 +215,7 @@ class BaseMigrationRunner(ABC, Generic[DriverT]):
|
|
|
226
215
|
has_upgrade, has_downgrade = self.loader.has_query(up_query), self.loader.has_query(down_query)
|
|
227
216
|
else:
|
|
228
217
|
try:
|
|
229
|
-
has_downgrade = bool(
|
|
218
|
+
has_downgrade = bool(await_(loader.get_down_sql, raise_sync_error=False)(file_path))
|
|
230
219
|
except Exception:
|
|
231
220
|
has_downgrade = False
|
|
232
221
|
|
|
@@ -240,7 +229,7 @@ class BaseMigrationRunner(ABC, Generic[DriverT]):
|
|
|
240
229
|
"loader": loader,
|
|
241
230
|
}
|
|
242
231
|
|
|
243
|
-
def _get_migration_sql(self, migration: "dict[str, Any]", direction: str) -> Optional[
|
|
232
|
+
def _get_migration_sql(self, migration: "dict[str, Any]", direction: str) -> "Optional[list[str]]":
|
|
244
233
|
"""Get migration SQL for given direction.
|
|
245
234
|
|
|
246
235
|
Args:
|
|
@@ -261,7 +250,7 @@ class BaseMigrationRunner(ABC, Generic[DriverT]):
|
|
|
261
250
|
|
|
262
251
|
try:
|
|
263
252
|
method = loader.get_up_sql if direction == "up" else loader.get_down_sql
|
|
264
|
-
sql_statements =
|
|
253
|
+
sql_statements = await_(method, raise_sync_error=False)(file_path)
|
|
265
254
|
|
|
266
255
|
except Exception as e:
|
|
267
256
|
if direction == "down":
|
|
@@ -271,7 +260,7 @@ class BaseMigrationRunner(ABC, Generic[DriverT]):
|
|
|
271
260
|
raise ValueError(msg) from e
|
|
272
261
|
else:
|
|
273
262
|
if sql_statements:
|
|
274
|
-
return
|
|
263
|
+
return cast("list[str]", sql_statements)
|
|
275
264
|
return None
|
|
276
265
|
|
|
277
266
|
@abstractmethod
|
|
@@ -312,7 +301,7 @@ class BaseMigrationCommands(ABC, Generic[ConfigT, DriverT]):
|
|
|
312
301
|
self.config = config
|
|
313
302
|
migration_config = getattr(self.config, "migration_config", {}) or {}
|
|
314
303
|
|
|
315
|
-
self.version_table = migration_config.get("version_table_name", "
|
|
304
|
+
self.version_table = migration_config.get("version_table_name", "ddl_migrations")
|
|
316
305
|
self.migrations_path = Path(migration_config.get("script_location", "migrations"))
|
|
317
306
|
self.project_root = Path(migration_config["project_root"]) if "project_root" in migration_config else None
|
|
318
307
|
|
sqlspec/migrations/commands.py
CHANGED
|
@@ -3,7 +3,7 @@
|
|
|
3
3
|
This module provides the main command interface for database migrations.
|
|
4
4
|
"""
|
|
5
5
|
|
|
6
|
-
from typing import TYPE_CHECKING, Any, Union, cast
|
|
6
|
+
from typing import TYPE_CHECKING, Any, Optional, Union, cast
|
|
7
7
|
|
|
8
8
|
from rich.console import Console
|
|
9
9
|
from rich.table import Table
|
|
@@ -11,7 +11,6 @@ from rich.table import Table
|
|
|
11
11
|
from sqlspec._sql import sql
|
|
12
12
|
from sqlspec.migrations.base import BaseMigrationCommands
|
|
13
13
|
from sqlspec.migrations.runner import AsyncMigrationRunner, SyncMigrationRunner
|
|
14
|
-
from sqlspec.migrations.tracker import AsyncMigrationTracker, SyncMigrationTracker
|
|
15
14
|
from sqlspec.migrations.utils import create_migration_file
|
|
16
15
|
from sqlspec.utils.logging import get_logger
|
|
17
16
|
from sqlspec.utils.sync_tools import await_
|
|
@@ -26,7 +25,7 @@ console = Console()
|
|
|
26
25
|
|
|
27
26
|
|
|
28
27
|
class SyncMigrationCommands(BaseMigrationCommands["SyncConfigT", Any]):
|
|
29
|
-
"""
|
|
28
|
+
"""Synchronous migration commands."""
|
|
30
29
|
|
|
31
30
|
def __init__(self, config: "SyncConfigT") -> None:
|
|
32
31
|
"""Initialize migration commands.
|
|
@@ -35,7 +34,7 @@ class SyncMigrationCommands(BaseMigrationCommands["SyncConfigT", Any]):
|
|
|
35
34
|
config: The SQLSpec configuration.
|
|
36
35
|
"""
|
|
37
36
|
super().__init__(config)
|
|
38
|
-
self.tracker =
|
|
37
|
+
self.tracker = config.migration_tracker_type(self.version_table)
|
|
39
38
|
self.runner = SyncMigrationRunner(self.migrations_path)
|
|
40
39
|
|
|
41
40
|
def init(self, directory: str, package: bool = True) -> None:
|
|
@@ -47,11 +46,14 @@ class SyncMigrationCommands(BaseMigrationCommands["SyncConfigT", Any]):
|
|
|
47
46
|
"""
|
|
48
47
|
self.init_directory(directory, package)
|
|
49
48
|
|
|
50
|
-
def current(self, verbose: bool = False) ->
|
|
49
|
+
def current(self, verbose: bool = False) -> "Optional[str]":
|
|
51
50
|
"""Show current migration version.
|
|
52
51
|
|
|
53
52
|
Args:
|
|
54
53
|
verbose: Whether to show detailed migration history.
|
|
54
|
+
|
|
55
|
+
Returns:
|
|
56
|
+
The current migration version or None if no migrations applied.
|
|
55
57
|
"""
|
|
56
58
|
with self.config.provide_session() as driver:
|
|
57
59
|
self.tracker.ensure_tracking_table(driver)
|
|
@@ -59,7 +61,7 @@ class SyncMigrationCommands(BaseMigrationCommands["SyncConfigT", Any]):
|
|
|
59
61
|
current = self.tracker.get_current_version(driver)
|
|
60
62
|
if not current:
|
|
61
63
|
console.print("[yellow]No migrations applied yet[/]")
|
|
62
|
-
return
|
|
64
|
+
return None
|
|
63
65
|
|
|
64
66
|
console.print(f"[green]Current version:[/] {current}")
|
|
65
67
|
|
|
@@ -84,6 +86,8 @@ class SyncMigrationCommands(BaseMigrationCommands["SyncConfigT", Any]):
|
|
|
84
86
|
|
|
85
87
|
console.print(table)
|
|
86
88
|
|
|
89
|
+
return cast("Optional[str]", current)
|
|
90
|
+
|
|
87
91
|
def upgrade(self, revision: str = "head") -> None:
|
|
88
92
|
"""Upgrade to a target revision.
|
|
89
93
|
|
|
@@ -137,6 +141,8 @@ class SyncMigrationCommands(BaseMigrationCommands["SyncConfigT", Any]):
|
|
|
137
141
|
to_revert = []
|
|
138
142
|
if revision == "-1":
|
|
139
143
|
to_revert = [applied[-1]]
|
|
144
|
+
elif revision == "base":
|
|
145
|
+
to_revert = list(reversed(applied))
|
|
140
146
|
else:
|
|
141
147
|
for migration in reversed(applied):
|
|
142
148
|
if migration["version_num"] > revision:
|
|
@@ -195,7 +201,7 @@ class SyncMigrationCommands(BaseMigrationCommands["SyncConfigT", Any]):
|
|
|
195
201
|
|
|
196
202
|
|
|
197
203
|
class AsyncMigrationCommands(BaseMigrationCommands["AsyncConfigT", Any]):
|
|
198
|
-
"""
|
|
204
|
+
"""Asynchronous migration commands."""
|
|
199
205
|
|
|
200
206
|
def __init__(self, sqlspec_config: "AsyncConfigT") -> None:
|
|
201
207
|
"""Initialize migration commands.
|
|
@@ -204,7 +210,7 @@ class AsyncMigrationCommands(BaseMigrationCommands["AsyncConfigT", Any]):
|
|
|
204
210
|
sqlspec_config: The SQLSpec configuration.
|
|
205
211
|
"""
|
|
206
212
|
super().__init__(sqlspec_config)
|
|
207
|
-
self.tracker =
|
|
213
|
+
self.tracker = sqlspec_config.migration_tracker_type(self.version_table)
|
|
208
214
|
self.runner = AsyncMigrationRunner(self.migrations_path)
|
|
209
215
|
|
|
210
216
|
async def init(self, directory: str, package: bool = True) -> None:
|
|
@@ -216,11 +222,14 @@ class AsyncMigrationCommands(BaseMigrationCommands["AsyncConfigT", Any]):
|
|
|
216
222
|
"""
|
|
217
223
|
self.init_directory(directory, package)
|
|
218
224
|
|
|
219
|
-
async def current(self, verbose: bool = False) ->
|
|
225
|
+
async def current(self, verbose: bool = False) -> "Optional[str]":
|
|
220
226
|
"""Show current migration version.
|
|
221
227
|
|
|
222
228
|
Args:
|
|
223
229
|
verbose: Whether to show detailed migration history.
|
|
230
|
+
|
|
231
|
+
Returns:
|
|
232
|
+
The current migration version or None if no migrations applied.
|
|
224
233
|
"""
|
|
225
234
|
async with self.config.provide_session() as driver:
|
|
226
235
|
await self.tracker.ensure_tracking_table(driver)
|
|
@@ -228,7 +237,7 @@ class AsyncMigrationCommands(BaseMigrationCommands["AsyncConfigT", Any]):
|
|
|
228
237
|
current = await self.tracker.get_current_version(driver)
|
|
229
238
|
if not current:
|
|
230
239
|
console.print("[yellow]No migrations applied yet[/]")
|
|
231
|
-
return
|
|
240
|
+
return None
|
|
232
241
|
|
|
233
242
|
console.print(f"[green]Current version:[/] {current}")
|
|
234
243
|
if verbose:
|
|
@@ -249,6 +258,8 @@ class AsyncMigrationCommands(BaseMigrationCommands["AsyncConfigT", Any]):
|
|
|
249
258
|
)
|
|
250
259
|
console.print(table)
|
|
251
260
|
|
|
261
|
+
return cast("Optional[str]", current)
|
|
262
|
+
|
|
252
263
|
async def upgrade(self, revision: str = "head") -> None:
|
|
253
264
|
"""Upgrade to a target revision.
|
|
254
265
|
|
|
@@ -297,6 +308,8 @@ class AsyncMigrationCommands(BaseMigrationCommands["AsyncConfigT", Any]):
|
|
|
297
308
|
to_revert = []
|
|
298
309
|
if revision == "-1":
|
|
299
310
|
to_revert = [applied[-1]]
|
|
311
|
+
elif revision == "base":
|
|
312
|
+
to_revert = list(reversed(applied))
|
|
300
313
|
else:
|
|
301
314
|
for migration in reversed(applied):
|
|
302
315
|
if migration["version_num"] > revision:
|
|
@@ -382,20 +395,26 @@ class MigrationCommands:
|
|
|
382
395
|
package: Whether to create __init__.py file.
|
|
383
396
|
"""
|
|
384
397
|
if self._is_async:
|
|
385
|
-
await_(cast("AsyncMigrationCommands[Any]", self._impl).init
|
|
398
|
+
await_(cast("AsyncMigrationCommands[Any]", self._impl).init, raise_sync_error=False)(
|
|
399
|
+
directory, package=package
|
|
400
|
+
)
|
|
386
401
|
else:
|
|
387
402
|
cast("SyncMigrationCommands[Any]", self._impl).init(directory, package=package)
|
|
388
403
|
|
|
389
|
-
def current(self, verbose: bool = False) ->
|
|
404
|
+
def current(self, verbose: bool = False) -> "Optional[str]":
|
|
390
405
|
"""Show current migration version.
|
|
391
406
|
|
|
392
407
|
Args:
|
|
393
408
|
verbose: Whether to show detailed migration history.
|
|
409
|
+
|
|
410
|
+
Returns:
|
|
411
|
+
The current migration version or None if no migrations applied.
|
|
394
412
|
"""
|
|
395
413
|
if self._is_async:
|
|
396
|
-
await_(cast("AsyncMigrationCommands[Any]", self._impl).current, raise_sync_error=False)(
|
|
397
|
-
|
|
398
|
-
|
|
414
|
+
return await_(cast("AsyncMigrationCommands[Any]", self._impl).current, raise_sync_error=False)(
|
|
415
|
+
verbose=verbose
|
|
416
|
+
)
|
|
417
|
+
return cast("SyncMigrationCommands[Any]", self._impl).current(verbose=verbose)
|
|
399
418
|
|
|
400
419
|
def upgrade(self, revision: str = "head") -> None:
|
|
401
420
|
"""Upgrade to a target revision.
|
sqlspec/migrations/loaders.py
CHANGED
sqlspec/migrations/runner.py
CHANGED
|
@@ -5,13 +5,13 @@ This module handles migration file loading and execution using SQLFileLoader.
|
|
|
5
5
|
|
|
6
6
|
import time
|
|
7
7
|
from pathlib import Path
|
|
8
|
-
from typing import TYPE_CHECKING, Any, Optional
|
|
8
|
+
from typing import TYPE_CHECKING, Any, Optional, cast
|
|
9
9
|
|
|
10
10
|
from sqlspec.core.statement import SQL
|
|
11
11
|
from sqlspec.migrations.base import BaseMigrationRunner
|
|
12
12
|
from sqlspec.migrations.loaders import get_migration_loader
|
|
13
13
|
from sqlspec.utils.logging import get_logger
|
|
14
|
-
from sqlspec.utils.sync_tools import
|
|
14
|
+
from sqlspec.utils.sync_tools import await_
|
|
15
15
|
|
|
16
16
|
if TYPE_CHECKING:
|
|
17
17
|
from sqlspec.driver import AsyncDriverAdapterBase, SyncDriverAdapterBase
|
|
@@ -22,7 +22,7 @@ logger = get_logger("migrations.runner")
|
|
|
22
22
|
|
|
23
23
|
|
|
24
24
|
class SyncMigrationRunner(BaseMigrationRunner["SyncDriverAdapterBase"]):
|
|
25
|
-
"""
|
|
25
|
+
"""Synchronous migration executor."""
|
|
26
26
|
|
|
27
27
|
def get_migration_files(self) -> "list[tuple[str, Path]]":
|
|
28
28
|
"""Get all migration files sorted by version.
|
|
@@ -55,12 +55,15 @@ class SyncMigrationRunner(BaseMigrationRunner["SyncDriverAdapterBase"]):
|
|
|
55
55
|
Returns:
|
|
56
56
|
Tuple of (sql_content, execution_time_ms).
|
|
57
57
|
"""
|
|
58
|
-
|
|
59
|
-
if
|
|
58
|
+
upgrade_sql_list = self._get_migration_sql(migration, "up")
|
|
59
|
+
if upgrade_sql_list is None:
|
|
60
60
|
return None, 0
|
|
61
61
|
|
|
62
62
|
start_time = time.time()
|
|
63
|
-
|
|
63
|
+
|
|
64
|
+
for sql_statement in upgrade_sql_list:
|
|
65
|
+
if sql_statement.strip():
|
|
66
|
+
driver.execute_script(sql_statement)
|
|
64
67
|
execution_time = int((time.time() - start_time) * 1000)
|
|
65
68
|
return None, execution_time
|
|
66
69
|
|
|
@@ -76,12 +79,15 @@ class SyncMigrationRunner(BaseMigrationRunner["SyncDriverAdapterBase"]):
|
|
|
76
79
|
Returns:
|
|
77
80
|
Tuple of (sql_content, execution_time_ms).
|
|
78
81
|
"""
|
|
79
|
-
|
|
80
|
-
if
|
|
82
|
+
downgrade_sql_list = self._get_migration_sql(migration, "down")
|
|
83
|
+
if downgrade_sql_list is None:
|
|
81
84
|
return None, 0
|
|
82
85
|
|
|
83
86
|
start_time = time.time()
|
|
84
|
-
|
|
87
|
+
|
|
88
|
+
for sql_statement in downgrade_sql_list:
|
|
89
|
+
if sql_statement.strip():
|
|
90
|
+
driver.execute_script(sql_statement)
|
|
85
91
|
execution_time = int((time.time() - start_time) * 1000)
|
|
86
92
|
return None, execution_time
|
|
87
93
|
|
|
@@ -103,8 +109,8 @@ class SyncMigrationRunner(BaseMigrationRunner["SyncDriverAdapterBase"]):
|
|
|
103
109
|
loader = get_migration_loader(file_path, self.migrations_path, self.project_root)
|
|
104
110
|
|
|
105
111
|
try:
|
|
106
|
-
up_sql =
|
|
107
|
-
down_sql =
|
|
112
|
+
up_sql = await_(loader.get_up_sql, raise_sync_error=False)(file_path)
|
|
113
|
+
down_sql = await_(loader.get_down_sql, raise_sync_error=False)(file_path)
|
|
108
114
|
|
|
109
115
|
if up_sql:
|
|
110
116
|
all_queries[f"migrate-{version}-up"] = SQL(up_sql[0])
|
|
@@ -118,7 +124,7 @@ class SyncMigrationRunner(BaseMigrationRunner["SyncDriverAdapterBase"]):
|
|
|
118
124
|
|
|
119
125
|
|
|
120
126
|
class AsyncMigrationRunner(BaseMigrationRunner["AsyncDriverAdapterBase"]):
|
|
121
|
-
"""
|
|
127
|
+
"""Asynchronous migration executor."""
|
|
122
128
|
|
|
123
129
|
async def get_migration_files(self) -> "list[tuple[str, Path]]":
|
|
124
130
|
"""Get all migration files sorted by version.
|
|
@@ -137,7 +143,80 @@ class AsyncMigrationRunner(BaseMigrationRunner["AsyncDriverAdapterBase"]):
|
|
|
137
143
|
Returns:
|
|
138
144
|
Dictionary containing migration metadata.
|
|
139
145
|
"""
|
|
140
|
-
return self.
|
|
146
|
+
return await self._load_migration_metadata_async(file_path)
|
|
147
|
+
|
|
148
|
+
async def _load_migration_metadata_async(self, file_path: Path) -> "dict[str, Any]":
|
|
149
|
+
"""Load migration metadata from file (async version).
|
|
150
|
+
|
|
151
|
+
Args:
|
|
152
|
+
file_path: Path to the migration file.
|
|
153
|
+
|
|
154
|
+
Returns:
|
|
155
|
+
Migration metadata dictionary.
|
|
156
|
+
"""
|
|
157
|
+
loader = get_migration_loader(file_path, self.migrations_path, self.project_root)
|
|
158
|
+
loader.validate_migration_file(file_path)
|
|
159
|
+
content = file_path.read_text(encoding="utf-8")
|
|
160
|
+
checksum = self._calculate_checksum(content)
|
|
161
|
+
version = self._extract_version(file_path.name)
|
|
162
|
+
description = file_path.stem.split("_", 1)[1] if "_" in file_path.stem else ""
|
|
163
|
+
|
|
164
|
+
has_upgrade, has_downgrade = True, False
|
|
165
|
+
|
|
166
|
+
if file_path.suffix == ".sql":
|
|
167
|
+
up_query, down_query = f"migrate-{version}-up", f"migrate-{version}-down"
|
|
168
|
+
self.loader.clear_cache()
|
|
169
|
+
self.loader.load_sql(file_path)
|
|
170
|
+
has_upgrade, has_downgrade = self.loader.has_query(up_query), self.loader.has_query(down_query)
|
|
171
|
+
else:
|
|
172
|
+
try:
|
|
173
|
+
has_downgrade = bool(await loader.get_down_sql(file_path))
|
|
174
|
+
except Exception:
|
|
175
|
+
has_downgrade = False
|
|
176
|
+
|
|
177
|
+
return {
|
|
178
|
+
"version": version,
|
|
179
|
+
"description": description,
|
|
180
|
+
"file_path": file_path,
|
|
181
|
+
"checksum": checksum,
|
|
182
|
+
"has_upgrade": has_upgrade,
|
|
183
|
+
"has_downgrade": has_downgrade,
|
|
184
|
+
"loader": loader,
|
|
185
|
+
}
|
|
186
|
+
|
|
187
|
+
async def _get_migration_sql_async(self, migration: "dict[str, Any]", direction: str) -> "Optional[list[str]]":
|
|
188
|
+
"""Get migration SQL for given direction (async version).
|
|
189
|
+
|
|
190
|
+
Args:
|
|
191
|
+
migration: Migration metadata.
|
|
192
|
+
direction: Either 'up' or 'down'.
|
|
193
|
+
|
|
194
|
+
Returns:
|
|
195
|
+
SQL statements for the migration.
|
|
196
|
+
"""
|
|
197
|
+
if not migration.get(f"has_{direction}grade"):
|
|
198
|
+
if direction == "down":
|
|
199
|
+
logger.warning("Migration %s has no downgrade query", migration["version"])
|
|
200
|
+
return None
|
|
201
|
+
msg = f"Migration {migration['version']} has no upgrade query"
|
|
202
|
+
raise ValueError(msg)
|
|
203
|
+
|
|
204
|
+
file_path, loader = migration["file_path"], migration["loader"]
|
|
205
|
+
|
|
206
|
+
try:
|
|
207
|
+
method = loader.get_up_sql if direction == "up" else loader.get_down_sql
|
|
208
|
+
sql_statements = await method(file_path)
|
|
209
|
+
|
|
210
|
+
except Exception as e:
|
|
211
|
+
if direction == "down":
|
|
212
|
+
logger.warning("Failed to load downgrade for migration %s: %s", migration["version"], e)
|
|
213
|
+
return None
|
|
214
|
+
msg = f"Failed to load upgrade for migration {migration['version']}: {e}"
|
|
215
|
+
raise ValueError(msg) from e
|
|
216
|
+
else:
|
|
217
|
+
if sql_statements:
|
|
218
|
+
return cast("list[str]", sql_statements)
|
|
219
|
+
return None
|
|
141
220
|
|
|
142
221
|
async def execute_upgrade(
|
|
143
222
|
self, driver: "AsyncDriverAdapterBase", migration: "dict[str, Any]"
|
|
@@ -151,12 +230,15 @@ class AsyncMigrationRunner(BaseMigrationRunner["AsyncDriverAdapterBase"]):
|
|
|
151
230
|
Returns:
|
|
152
231
|
Tuple of (sql_content, execution_time_ms).
|
|
153
232
|
"""
|
|
154
|
-
|
|
155
|
-
if
|
|
233
|
+
upgrade_sql_list = await self._get_migration_sql_async(migration, "up")
|
|
234
|
+
if upgrade_sql_list is None:
|
|
156
235
|
return None, 0
|
|
157
236
|
|
|
158
237
|
start_time = time.time()
|
|
159
|
-
|
|
238
|
+
|
|
239
|
+
for sql_statement in upgrade_sql_list:
|
|
240
|
+
if sql_statement.strip():
|
|
241
|
+
await driver.execute_script(sql_statement)
|
|
160
242
|
execution_time = int((time.time() - start_time) * 1000)
|
|
161
243
|
return None, execution_time
|
|
162
244
|
|
|
@@ -172,12 +254,15 @@ class AsyncMigrationRunner(BaseMigrationRunner["AsyncDriverAdapterBase"]):
|
|
|
172
254
|
Returns:
|
|
173
255
|
Tuple of (sql_content, execution_time_ms).
|
|
174
256
|
"""
|
|
175
|
-
|
|
176
|
-
if
|
|
257
|
+
downgrade_sql_list = await self._get_migration_sql_async(migration, "down")
|
|
258
|
+
if downgrade_sql_list is None:
|
|
177
259
|
return None, 0
|
|
178
260
|
|
|
179
261
|
start_time = time.time()
|
|
180
|
-
|
|
262
|
+
|
|
263
|
+
for sql_statement in downgrade_sql_list:
|
|
264
|
+
if sql_statement.strip():
|
|
265
|
+
await driver.execute_script(sql_statement)
|
|
181
266
|
execution_time = int((time.time() - start_time) * 1000)
|
|
182
267
|
return None, execution_time
|
|
183
268
|
|