sqlspec 0.13.1__py3-none-any.whl → 0.14.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of sqlspec might be problematic. Click here for more details.
- sqlspec/__init__.py +39 -1
- sqlspec/__main__.py +12 -0
- sqlspec/adapters/adbc/config.py +16 -40
- sqlspec/adapters/adbc/driver.py +43 -16
- sqlspec/adapters/adbc/transformers.py +108 -0
- sqlspec/adapters/aiosqlite/config.py +2 -20
- sqlspec/adapters/aiosqlite/driver.py +36 -18
- sqlspec/adapters/asyncmy/config.py +2 -33
- sqlspec/adapters/asyncmy/driver.py +23 -16
- sqlspec/adapters/asyncpg/config.py +5 -39
- sqlspec/adapters/asyncpg/driver.py +41 -18
- sqlspec/adapters/bigquery/config.py +2 -43
- sqlspec/adapters/bigquery/driver.py +26 -14
- sqlspec/adapters/duckdb/config.py +2 -49
- sqlspec/adapters/duckdb/driver.py +35 -16
- sqlspec/adapters/oracledb/config.py +4 -83
- sqlspec/adapters/oracledb/driver.py +54 -27
- sqlspec/adapters/psqlpy/config.py +2 -55
- sqlspec/adapters/psqlpy/driver.py +28 -8
- sqlspec/adapters/psycopg/config.py +4 -73
- sqlspec/adapters/psycopg/driver.py +69 -24
- sqlspec/adapters/sqlite/config.py +3 -21
- sqlspec/adapters/sqlite/driver.py +50 -26
- sqlspec/cli.py +248 -0
- sqlspec/config.py +18 -20
- sqlspec/driver/_async.py +28 -10
- sqlspec/driver/_common.py +5 -4
- sqlspec/driver/_sync.py +28 -10
- sqlspec/driver/mixins/__init__.py +6 -0
- sqlspec/driver/mixins/_cache.py +114 -0
- sqlspec/driver/mixins/_pipeline.py +0 -4
- sqlspec/{service/base.py → driver/mixins/_query_tools.py} +86 -421
- sqlspec/driver/mixins/_result_utils.py +0 -2
- sqlspec/driver/mixins/_sql_translator.py +0 -2
- sqlspec/driver/mixins/_storage.py +4 -18
- sqlspec/driver/mixins/_type_coercion.py +0 -2
- sqlspec/driver/parameters.py +4 -4
- sqlspec/extensions/aiosql/adapter.py +4 -4
- sqlspec/extensions/litestar/__init__.py +2 -1
- sqlspec/extensions/litestar/cli.py +48 -0
- sqlspec/extensions/litestar/plugin.py +3 -0
- sqlspec/loader.py +1 -1
- sqlspec/migrations/__init__.py +23 -0
- sqlspec/migrations/base.py +390 -0
- sqlspec/migrations/commands.py +525 -0
- sqlspec/migrations/runner.py +215 -0
- sqlspec/migrations/tracker.py +153 -0
- sqlspec/migrations/utils.py +89 -0
- sqlspec/protocols.py +37 -3
- sqlspec/statement/builder/__init__.py +8 -8
- sqlspec/statement/builder/{column.py → _column.py} +82 -52
- sqlspec/statement/builder/{ddl.py → _ddl.py} +5 -5
- sqlspec/statement/builder/_ddl_utils.py +1 -1
- sqlspec/statement/builder/{delete.py → _delete.py} +1 -1
- sqlspec/statement/builder/{insert.py → _insert.py} +1 -1
- sqlspec/statement/builder/{merge.py → _merge.py} +1 -1
- sqlspec/statement/builder/_parsing_utils.py +5 -3
- sqlspec/statement/builder/{select.py → _select.py} +59 -61
- sqlspec/statement/builder/{update.py → _update.py} +2 -2
- sqlspec/statement/builder/mixins/__init__.py +24 -30
- sqlspec/statement/builder/mixins/{_set_ops.py → _cte_and_set_ops.py} +86 -2
- sqlspec/statement/builder/mixins/{_delete_from.py → _delete_operations.py} +2 -0
- sqlspec/statement/builder/mixins/{_insert_values.py → _insert_operations.py} +70 -1
- sqlspec/statement/builder/mixins/{_merge_clauses.py → _merge_operations.py} +2 -0
- sqlspec/statement/builder/mixins/_order_limit_operations.py +123 -0
- sqlspec/statement/builder/mixins/{_pivot.py → _pivot_operations.py} +71 -2
- sqlspec/statement/builder/mixins/_select_operations.py +612 -0
- sqlspec/statement/builder/mixins/{_update_set.py → _update_operations.py} +73 -2
- sqlspec/statement/builder/mixins/_where_clause.py +536 -0
- sqlspec/statement/cache.py +50 -0
- sqlspec/statement/filters.py +37 -8
- sqlspec/statement/parameters.py +143 -54
- sqlspec/statement/pipelines/__init__.py +1 -1
- sqlspec/statement/pipelines/context.py +4 -10
- sqlspec/statement/pipelines/transformers/_expression_simplifier.py +3 -3
- sqlspec/statement/pipelines/validators/_parameter_style.py +22 -22
- sqlspec/statement/pipelines/validators/_performance.py +1 -5
- sqlspec/statement/sql.py +246 -176
- sqlspec/utils/__init__.py +2 -1
- sqlspec/utils/statement_hashing.py +203 -0
- sqlspec/utils/type_guards.py +32 -0
- {sqlspec-0.13.1.dist-info → sqlspec-0.14.1.dist-info}/METADATA +1 -1
- sqlspec-0.14.1.dist-info/RECORD +145 -0
- sqlspec-0.14.1.dist-info/entry_points.txt +2 -0
- sqlspec/service/__init__.py +0 -4
- sqlspec/service/_util.py +0 -147
- sqlspec/service/pagination.py +0 -26
- sqlspec/statement/builder/mixins/_aggregate_functions.py +0 -250
- sqlspec/statement/builder/mixins/_case_builder.py +0 -91
- sqlspec/statement/builder/mixins/_common_table_expr.py +0 -90
- sqlspec/statement/builder/mixins/_from.py +0 -63
- sqlspec/statement/builder/mixins/_group_by.py +0 -118
- sqlspec/statement/builder/mixins/_having.py +0 -35
- sqlspec/statement/builder/mixins/_insert_from_select.py +0 -47
- sqlspec/statement/builder/mixins/_insert_into.py +0 -36
- sqlspec/statement/builder/mixins/_limit_offset.py +0 -53
- sqlspec/statement/builder/mixins/_order_by.py +0 -46
- sqlspec/statement/builder/mixins/_returning.py +0 -37
- sqlspec/statement/builder/mixins/_select_columns.py +0 -61
- sqlspec/statement/builder/mixins/_unpivot.py +0 -77
- sqlspec/statement/builder/mixins/_update_from.py +0 -55
- sqlspec/statement/builder/mixins/_update_table.py +0 -29
- sqlspec/statement/builder/mixins/_where.py +0 -401
- sqlspec/statement/builder/mixins/_window_functions.py +0 -86
- sqlspec/statement/parameter_manager.py +0 -220
- sqlspec/statement/sql_compiler.py +0 -140
- sqlspec-0.13.1.dist-info/RECORD +0 -150
- /sqlspec/statement/builder/{base.py → _base.py} +0 -0
- /sqlspec/statement/builder/mixins/{_join.py → _join_operations.py} +0 -0
- {sqlspec-0.13.1.dist-info → sqlspec-0.14.1.dist-info}/WHEEL +0 -0
- {sqlspec-0.13.1.dist-info → sqlspec-0.14.1.dist-info}/licenses/LICENSE +0 -0
- {sqlspec-0.13.1.dist-info → sqlspec-0.14.1.dist-info}/licenses/NOTICE +0 -0
|
@@ -10,8 +10,6 @@ __all__ = ("SQLTranslatorMixin",)
|
|
|
10
10
|
class SQLTranslatorMixin:
|
|
11
11
|
"""Mixin for drivers supporting SQL translation."""
|
|
12
12
|
|
|
13
|
-
__slots__ = ()
|
|
14
|
-
|
|
15
13
|
def convert_to_dialect(self, statement: "Statement", to_dialect: DialectType = None, pretty: bool = True) -> str:
|
|
16
14
|
parsed_expression: exp.Expression
|
|
17
15
|
if statement is not None and isinstance(statement, SQL):
|
|
@@ -44,8 +44,6 @@ WINDOWS_PATH_MIN_LENGTH = 3
|
|
|
44
44
|
class StorageMixinBase(ABC):
|
|
45
45
|
"""Base class with common storage functionality."""
|
|
46
46
|
|
|
47
|
-
__slots__ = ()
|
|
48
|
-
|
|
49
47
|
config: Any
|
|
50
48
|
_connection: Any
|
|
51
49
|
dialect: "DialectType"
|
|
@@ -144,8 +142,6 @@ class StorageMixinBase(ABC):
|
|
|
144
142
|
class SyncStorageMixin(StorageMixinBase):
|
|
145
143
|
"""Unified storage operations for synchronous drivers."""
|
|
146
144
|
|
|
147
|
-
__slots__ = ()
|
|
148
|
-
|
|
149
145
|
def ingest_arrow_table(self, table: "ArrowTable", table_name: str, mode: str = "create", **options: Any) -> int:
|
|
150
146
|
"""Ingest an Arrow table into the database.
|
|
151
147
|
|
|
@@ -211,7 +207,7 @@ class SyncStorageMixin(StorageMixinBase):
|
|
|
211
207
|
# disable parameter validation entirely to allow transformer-added parameters
|
|
212
208
|
if params is None and _config and _config.enable_transformations:
|
|
213
209
|
# Disable validation entirely for transformer-generated parameters
|
|
214
|
-
_config = replace(_config,
|
|
210
|
+
_config = replace(_config, enable_validation=False)
|
|
215
211
|
|
|
216
212
|
# Only pass params if it's not None to avoid adding None as a parameter
|
|
217
213
|
if params is not None:
|
|
@@ -294,12 +290,8 @@ class SyncStorageMixin(StorageMixinBase):
|
|
|
294
290
|
_config = self.config
|
|
295
291
|
if _config and not _config.dialect:
|
|
296
292
|
_config = replace(_config, dialect=self.dialect)
|
|
297
|
-
if _config and _config.enable_transformations:
|
|
298
|
-
_config = replace(_config, enable_transformations=False)
|
|
299
293
|
|
|
300
|
-
sql = (
|
|
301
|
-
SQL(statement, parameters=params, config=_config) if params is not None else SQL(statement, config=_config)
|
|
302
|
-
)
|
|
294
|
+
sql = SQL(statement, *params, config=_config) if params else SQL(statement, config=_config)
|
|
303
295
|
for filter_ in filters:
|
|
304
296
|
sql = sql.filter(filter_)
|
|
305
297
|
|
|
@@ -576,8 +568,6 @@ class SyncStorageMixin(StorageMixinBase):
|
|
|
576
568
|
class AsyncStorageMixin(StorageMixinBase):
|
|
577
569
|
"""Unified storage operations for asynchronous drivers."""
|
|
578
570
|
|
|
579
|
-
__slots__ = ()
|
|
580
|
-
|
|
581
571
|
async def ingest_arrow_table(
|
|
582
572
|
self, table: "ArrowTable", table_name: str, mode: str = "create", **options: Any
|
|
583
573
|
) -> int:
|
|
@@ -650,7 +640,7 @@ class AsyncStorageMixin(StorageMixinBase):
|
|
|
650
640
|
# disable parameter validation entirely to allow transformer-added parameters
|
|
651
641
|
if params is None and _config and _config.enable_transformations:
|
|
652
642
|
# Disable validation entirely for transformer-generated parameters
|
|
653
|
-
_config = replace(_config,
|
|
643
|
+
_config = replace(_config, enable_validation=False)
|
|
654
644
|
|
|
655
645
|
# Only pass params if it's not None to avoid adding None as a parameter
|
|
656
646
|
if params is not None:
|
|
@@ -703,12 +693,8 @@ class AsyncStorageMixin(StorageMixinBase):
|
|
|
703
693
|
_config = self.config
|
|
704
694
|
if _config and not _config.dialect:
|
|
705
695
|
_config = replace(_config, dialect=self.dialect)
|
|
706
|
-
if _config and _config.enable_transformations:
|
|
707
|
-
_config = replace(_config, enable_transformations=False)
|
|
708
696
|
|
|
709
|
-
sql = (
|
|
710
|
-
SQL(statement, parameters=params, config=_config) if params is not None else SQL(statement, config=_config)
|
|
711
|
-
)
|
|
697
|
+
sql = SQL(statement, *params, config=_config) if params else SQL(statement, config=_config)
|
|
712
698
|
for filter_ in filters:
|
|
713
699
|
sql = sql.filter(filter_)
|
|
714
700
|
|
|
@@ -22,8 +22,6 @@ class TypeCoercionMixin:
|
|
|
22
22
|
and convert values to database-specific types.
|
|
23
23
|
"""
|
|
24
24
|
|
|
25
|
-
__slots__ = ()
|
|
26
|
-
|
|
27
25
|
def _process_parameters(self, parameters: "SQLParameterType") -> "SQLParameterType":
|
|
28
26
|
"""Process parameters, extracting values from TypedParameter objects.
|
|
29
27
|
|
sqlspec/driver/parameters.py
CHANGED
|
@@ -13,8 +13,8 @@ if TYPE_CHECKING:
|
|
|
13
13
|
from sqlspec.typing import StatementParameters
|
|
14
14
|
|
|
15
15
|
__all__ = (
|
|
16
|
+
"convert_parameter_sequence",
|
|
16
17
|
"convert_parameters_to_positional",
|
|
17
|
-
"normalize_parameter_sequence",
|
|
18
18
|
"process_execute_many_parameters",
|
|
19
19
|
"separate_filters_and_parameters",
|
|
20
20
|
"should_use_transaction",
|
|
@@ -62,19 +62,19 @@ def process_execute_many_parameters(
|
|
|
62
62
|
param_sequence = param_values[0] if param_values else None
|
|
63
63
|
|
|
64
64
|
# Normalize the parameter sequence
|
|
65
|
-
param_sequence =
|
|
65
|
+
param_sequence = convert_parameter_sequence(param_sequence)
|
|
66
66
|
|
|
67
67
|
return filters, param_sequence
|
|
68
68
|
|
|
69
69
|
|
|
70
|
-
def
|
|
70
|
+
def convert_parameter_sequence(params: Any) -> Optional[list[Any]]:
|
|
71
71
|
"""Normalize a parameter sequence to a list format.
|
|
72
72
|
|
|
73
73
|
Args:
|
|
74
74
|
params: Parameter sequence in various formats
|
|
75
75
|
|
|
76
76
|
Returns:
|
|
77
|
-
|
|
77
|
+
converted list of parameters or None
|
|
78
78
|
"""
|
|
79
79
|
if params is None:
|
|
80
80
|
return None
|
|
@@ -38,7 +38,7 @@ def _normalize_dialect(dialect: "Union[str, Any, None]") -> str:
|
|
|
38
38
|
dialect: Original dialect name (can be str, Dialect, type[Dialect], or None)
|
|
39
39
|
|
|
40
40
|
Returns:
|
|
41
|
-
|
|
41
|
+
converted dialect name
|
|
42
42
|
"""
|
|
43
43
|
if dialect is None:
|
|
44
44
|
return "sql"
|
|
@@ -84,9 +84,9 @@ class _AiosqlAdapterBase:
|
|
|
84
84
|
|
|
85
85
|
def _create_sql_object(self, sql: str, parameters: "Any" = None) -> SQL:
|
|
86
86
|
"""Create SQL object with proper configuration."""
|
|
87
|
-
config = SQLConfig(
|
|
88
|
-
|
|
89
|
-
return SQL(sql, parameters, config=config, dialect=
|
|
87
|
+
config = SQLConfig(enable_validation=False)
|
|
88
|
+
converted_dialect = _normalize_dialect(self.driver.dialect)
|
|
89
|
+
return SQL(sql, parameters, config=config, dialect=converted_dialect)
|
|
90
90
|
|
|
91
91
|
|
|
92
92
|
class AiosqlSyncAdapter(_AiosqlAdapterBase):
|
|
@@ -1,5 +1,6 @@
|
|
|
1
1
|
from sqlspec.extensions.litestar import handlers, providers
|
|
2
|
+
from sqlspec.extensions.litestar.cli import database_group
|
|
2
3
|
from sqlspec.extensions.litestar.config import DatabaseConfig
|
|
3
4
|
from sqlspec.extensions.litestar.plugin import SQLSpec
|
|
4
5
|
|
|
5
|
-
__all__ = ("DatabaseConfig", "SQLSpec", "handlers", "providers")
|
|
6
|
+
__all__ = ("DatabaseConfig", "SQLSpec", "database_group", "handlers", "providers")
|
|
@@ -0,0 +1,48 @@
|
|
|
1
|
+
"""Litestar CLI integration for SQLSpec migrations."""
|
|
2
|
+
|
|
3
|
+
from contextlib import suppress
|
|
4
|
+
from typing import TYPE_CHECKING
|
|
5
|
+
|
|
6
|
+
from litestar.cli._utils import LitestarGroup
|
|
7
|
+
|
|
8
|
+
from sqlspec.cli import add_migration_commands
|
|
9
|
+
|
|
10
|
+
try:
|
|
11
|
+
import rich_click as click
|
|
12
|
+
except ImportError:
|
|
13
|
+
import click # type: ignore[no-redef]
|
|
14
|
+
|
|
15
|
+
if TYPE_CHECKING:
|
|
16
|
+
from litestar import Litestar
|
|
17
|
+
|
|
18
|
+
from sqlspec.extensions.litestar.plugin import SQLSpec
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
def get_database_migration_plugin(app: "Litestar") -> "SQLSpec":
|
|
22
|
+
"""Retrieve the SQLSpec plugin from the Litestar application's plugins.
|
|
23
|
+
|
|
24
|
+
Args:
|
|
25
|
+
app: The Litestar application
|
|
26
|
+
|
|
27
|
+
Returns:
|
|
28
|
+
The SQLSpec plugin
|
|
29
|
+
|
|
30
|
+
Raises:
|
|
31
|
+
ImproperConfigurationError: If the SQLSpec plugin is not found
|
|
32
|
+
"""
|
|
33
|
+
from sqlspec.exceptions import ImproperConfigurationError
|
|
34
|
+
from sqlspec.extensions.litestar.plugin import SQLSpec
|
|
35
|
+
|
|
36
|
+
with suppress(KeyError):
|
|
37
|
+
return app.plugins.get(SQLSpec)
|
|
38
|
+
msg = "Failed to initialize database migrations. The required SQLSpec plugin is missing."
|
|
39
|
+
raise ImproperConfigurationError(msg)
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
@click.group(cls=LitestarGroup, name="database")
|
|
43
|
+
def database_group(ctx: "click.Context") -> None:
|
|
44
|
+
"""Manage SQLSpec database components."""
|
|
45
|
+
ctx.obj = {"app": ctx.obj, "configs": get_database_migration_plugin(ctx.obj.app).config}
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
add_migration_commands(database_group)
|
|
@@ -47,6 +47,9 @@ class SQLSpec(InitPluginProtocol, SQLSpecBase):
|
|
|
47
47
|
|
|
48
48
|
def on_cli_init(self, cli: "Group") -> None:
|
|
49
49
|
"""Configure the CLI for use with SQLSpec."""
|
|
50
|
+
from sqlspec.extensions.litestar.cli import database_group
|
|
51
|
+
|
|
52
|
+
cli.add_command(database_group)
|
|
50
53
|
|
|
51
54
|
def on_app_init(self, app_config: "AppConfig") -> "AppConfig":
|
|
52
55
|
"""Configure application for use with SQLSpec.
|
sqlspec/loader.py
CHANGED
|
@@ -40,7 +40,7 @@ def _normalize_query_name(name: str) -> str:
|
|
|
40
40
|
name: Raw query name from SQL file
|
|
41
41
|
|
|
42
42
|
Returns:
|
|
43
|
-
|
|
43
|
+
converted query name suitable as Python identifier
|
|
44
44
|
"""
|
|
45
45
|
# Strip trailing non-alphanumeric characters (excluding underscore) and replace hyphens
|
|
46
46
|
return TRIM_TRAILING_SPECIAL_CHARS.sub("", name).replace("-", "_")
|
|
@@ -0,0 +1,23 @@
|
|
|
1
|
+
"""SQLSpec Migration Tool.
|
|
2
|
+
|
|
3
|
+
A native migration system for SQLSpec that leverages the SQLFileLoader
|
|
4
|
+
and driver architecture for database versioning.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
from sqlspec.migrations.commands import AsyncMigrationCommands, MigrationCommands, SyncMigrationCommands
|
|
8
|
+
from sqlspec.migrations.runner import AsyncMigrationRunner, SyncMigrationRunner
|
|
9
|
+
from sqlspec.migrations.tracker import AsyncMigrationTracker, SyncMigrationTracker
|
|
10
|
+
from sqlspec.migrations.utils import create_migration_file, drop_all, get_author
|
|
11
|
+
|
|
12
|
+
__all__ = (
|
|
13
|
+
"AsyncMigrationCommands",
|
|
14
|
+
"AsyncMigrationRunner",
|
|
15
|
+
"AsyncMigrationTracker",
|
|
16
|
+
"MigrationCommands",
|
|
17
|
+
"SyncMigrationCommands",
|
|
18
|
+
"SyncMigrationRunner",
|
|
19
|
+
"SyncMigrationTracker",
|
|
20
|
+
"create_migration_file",
|
|
21
|
+
"drop_all",
|
|
22
|
+
"get_author",
|
|
23
|
+
)
|
|
@@ -0,0 +1,390 @@
|
|
|
1
|
+
"""Base classes for SQLSpec migrations.
|
|
2
|
+
|
|
3
|
+
This module provides abstract base classes for migration components.
|
|
4
|
+
"""
|
|
5
|
+
|
|
6
|
+
from abc import ABC, abstractmethod
|
|
7
|
+
from pathlib import Path
|
|
8
|
+
from typing import Any, Generic, Optional, TypeVar
|
|
9
|
+
|
|
10
|
+
from sqlspec.loader import SQLFileLoader
|
|
11
|
+
from sqlspec.statement.sql import SQL
|
|
12
|
+
from sqlspec.utils.logging import get_logger
|
|
13
|
+
|
|
14
|
+
__all__ = ("BaseMigrationCommands", "BaseMigrationRunner", "BaseMigrationTracker")
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
logger = get_logger("migrations.base")
|
|
18
|
+
|
|
19
|
+
# Type variables for generic driver and config types
|
|
20
|
+
DriverT = TypeVar("DriverT")
|
|
21
|
+
ConfigT = TypeVar("ConfigT")
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
class BaseMigrationTracker(ABC, Generic[DriverT]):
|
|
25
|
+
"""Base class for migration version tracking."""
|
|
26
|
+
|
|
27
|
+
def __init__(self, version_table_name: str = "ddl_migrations") -> None:
|
|
28
|
+
"""Initialize the migration tracker.
|
|
29
|
+
|
|
30
|
+
Args:
|
|
31
|
+
version_table_name: Name of the table to track migrations.
|
|
32
|
+
"""
|
|
33
|
+
self.version_table = version_table_name
|
|
34
|
+
|
|
35
|
+
def _get_create_table_sql(self) -> SQL:
|
|
36
|
+
"""Get SQL for creating the tracking table.
|
|
37
|
+
|
|
38
|
+
Returns:
|
|
39
|
+
SQL object for table creation.
|
|
40
|
+
"""
|
|
41
|
+
return SQL(
|
|
42
|
+
f"""
|
|
43
|
+
CREATE TABLE IF NOT EXISTS {self.version_table} (
|
|
44
|
+
version_num VARCHAR(32) PRIMARY KEY,
|
|
45
|
+
description TEXT,
|
|
46
|
+
applied_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
|
47
|
+
execution_time_ms INTEGER,
|
|
48
|
+
checksum VARCHAR(64),
|
|
49
|
+
applied_by VARCHAR(255)
|
|
50
|
+
)
|
|
51
|
+
"""
|
|
52
|
+
)
|
|
53
|
+
|
|
54
|
+
def _get_current_version_sql(self) -> SQL:
|
|
55
|
+
"""Get SQL for retrieving current version.
|
|
56
|
+
|
|
57
|
+
Returns:
|
|
58
|
+
SQL object for version query.
|
|
59
|
+
"""
|
|
60
|
+
return SQL(f"SELECT version_num FROM {self.version_table} ORDER BY version_num DESC LIMIT 1")
|
|
61
|
+
|
|
62
|
+
def _get_applied_migrations_sql(self) -> SQL:
|
|
63
|
+
"""Get SQL for retrieving all applied migrations.
|
|
64
|
+
|
|
65
|
+
Returns:
|
|
66
|
+
SQL object for migrations query.
|
|
67
|
+
"""
|
|
68
|
+
return SQL(f"SELECT * FROM {self.version_table} ORDER BY version_num")
|
|
69
|
+
|
|
70
|
+
def _get_record_migration_sql(
|
|
71
|
+
self, version: str, description: str, execution_time_ms: int, checksum: str, applied_by: str
|
|
72
|
+
) -> SQL:
|
|
73
|
+
"""Get SQL for recording a migration.
|
|
74
|
+
|
|
75
|
+
Args:
|
|
76
|
+
version: Version number of the migration.
|
|
77
|
+
description: Description of the migration.
|
|
78
|
+
execution_time_ms: Execution time in milliseconds.
|
|
79
|
+
checksum: MD5 checksum of the migration content.
|
|
80
|
+
applied_by: User who applied the migration.
|
|
81
|
+
|
|
82
|
+
Returns:
|
|
83
|
+
SQL object for insert.
|
|
84
|
+
"""
|
|
85
|
+
return SQL(
|
|
86
|
+
f"INSERT INTO {self.version_table} (version_num, description, execution_time_ms, checksum, applied_by) VALUES (?, ?, ?, ?, ?)",
|
|
87
|
+
version,
|
|
88
|
+
description,
|
|
89
|
+
execution_time_ms,
|
|
90
|
+
checksum,
|
|
91
|
+
applied_by,
|
|
92
|
+
)
|
|
93
|
+
|
|
94
|
+
def _get_remove_migration_sql(self, version: str) -> SQL:
|
|
95
|
+
"""Get SQL for removing a migration record.
|
|
96
|
+
|
|
97
|
+
Args:
|
|
98
|
+
version: Version number to remove.
|
|
99
|
+
|
|
100
|
+
Returns:
|
|
101
|
+
SQL object for delete.
|
|
102
|
+
"""
|
|
103
|
+
return SQL(f"DELETE FROM {self.version_table} WHERE version_num = ?", version)
|
|
104
|
+
|
|
105
|
+
@abstractmethod
|
|
106
|
+
def ensure_tracking_table(self, driver: DriverT) -> Any:
|
|
107
|
+
"""Create the migration tracking table if it doesn't exist."""
|
|
108
|
+
...
|
|
109
|
+
|
|
110
|
+
@abstractmethod
|
|
111
|
+
def get_current_version(self, driver: DriverT) -> Any:
|
|
112
|
+
"""Get the latest applied migration version."""
|
|
113
|
+
...
|
|
114
|
+
|
|
115
|
+
@abstractmethod
|
|
116
|
+
def get_applied_migrations(self, driver: DriverT) -> Any:
|
|
117
|
+
"""Get all applied migrations in order."""
|
|
118
|
+
...
|
|
119
|
+
|
|
120
|
+
@abstractmethod
|
|
121
|
+
def record_migration(
|
|
122
|
+
self, driver: DriverT, version: str, description: str, execution_time_ms: int, checksum: str
|
|
123
|
+
) -> Any:
|
|
124
|
+
"""Record a successfully applied migration."""
|
|
125
|
+
...
|
|
126
|
+
|
|
127
|
+
@abstractmethod
|
|
128
|
+
def remove_migration(self, driver: DriverT, version: str) -> Any:
|
|
129
|
+
"""Remove a migration record."""
|
|
130
|
+
...
|
|
131
|
+
|
|
132
|
+
|
|
133
|
+
class BaseMigrationRunner(ABC, Generic[DriverT]):
|
|
134
|
+
"""Base class for migration execution."""
|
|
135
|
+
|
|
136
|
+
def __init__(self, migrations_path: Path) -> None:
|
|
137
|
+
"""Initialize the migration runner.
|
|
138
|
+
|
|
139
|
+
Args:
|
|
140
|
+
migrations_path: Path to the directory containing migration files.
|
|
141
|
+
"""
|
|
142
|
+
self.migrations_path = migrations_path
|
|
143
|
+
self.loader = SQLFileLoader()
|
|
144
|
+
|
|
145
|
+
def _extract_version(self, filename: str) -> Optional[str]:
|
|
146
|
+
"""Extract version from filename (e.g., '0001_initial.sql' -> '0001').
|
|
147
|
+
|
|
148
|
+
Args:
|
|
149
|
+
filename: The migration filename.
|
|
150
|
+
|
|
151
|
+
Returns:
|
|
152
|
+
The extracted version string or None.
|
|
153
|
+
"""
|
|
154
|
+
parts = filename.split("_", 1)
|
|
155
|
+
if parts and parts[0].isdigit():
|
|
156
|
+
return parts[0].zfill(4)
|
|
157
|
+
return None
|
|
158
|
+
|
|
159
|
+
def _calculate_checksum(self, content: str) -> str:
|
|
160
|
+
"""Calculate MD5 checksum of migration content.
|
|
161
|
+
|
|
162
|
+
Args:
|
|
163
|
+
content: The migration file content.
|
|
164
|
+
|
|
165
|
+
Returns:
|
|
166
|
+
The MD5 checksum hex string.
|
|
167
|
+
"""
|
|
168
|
+
import hashlib
|
|
169
|
+
|
|
170
|
+
return hashlib.md5(content.encode()).hexdigest() # noqa: S324
|
|
171
|
+
|
|
172
|
+
def _get_migration_files_sync(self) -> "list[tuple[str, Path]]":
|
|
173
|
+
"""Get all migration files sorted by version (sync version).
|
|
174
|
+
|
|
175
|
+
Returns:
|
|
176
|
+
List of tuples containing (version, file_path).
|
|
177
|
+
"""
|
|
178
|
+
if not self.migrations_path.exists():
|
|
179
|
+
return []
|
|
180
|
+
|
|
181
|
+
migrations = []
|
|
182
|
+
for file_path in self.migrations_path.glob("*.sql"):
|
|
183
|
+
if file_path.name.startswith("."):
|
|
184
|
+
continue
|
|
185
|
+
version = self._extract_version(file_path.name)
|
|
186
|
+
if version:
|
|
187
|
+
migrations.append((version, file_path))
|
|
188
|
+
|
|
189
|
+
return sorted(migrations, key=lambda x: x[0])
|
|
190
|
+
|
|
191
|
+
def _load_migration_metadata(self, file_path: Path) -> "dict[str, Any]":
|
|
192
|
+
"""Load migration metadata from file.
|
|
193
|
+
|
|
194
|
+
Args:
|
|
195
|
+
file_path: Path to the migration file.
|
|
196
|
+
|
|
197
|
+
Returns:
|
|
198
|
+
Dictionary containing migration metadata.
|
|
199
|
+
"""
|
|
200
|
+
self.loader.clear_cache()
|
|
201
|
+
self.loader.load_sql(file_path)
|
|
202
|
+
|
|
203
|
+
# Read raw content for checksum
|
|
204
|
+
content = file_path.read_text()
|
|
205
|
+
checksum = self._calculate_checksum(content)
|
|
206
|
+
|
|
207
|
+
# Extract metadata
|
|
208
|
+
version = self._extract_version(file_path.name)
|
|
209
|
+
description = file_path.stem.split("_", 1)[1] if "_" in file_path.stem else ""
|
|
210
|
+
|
|
211
|
+
# Query names use versioned pattern
|
|
212
|
+
up_query = f"migrate-{version}-up"
|
|
213
|
+
down_query = f"migrate-{version}-down"
|
|
214
|
+
|
|
215
|
+
return {
|
|
216
|
+
"version": version,
|
|
217
|
+
"description": description,
|
|
218
|
+
"file_path": file_path,
|
|
219
|
+
"checksum": checksum,
|
|
220
|
+
"up_query": up_query,
|
|
221
|
+
"down_query": down_query,
|
|
222
|
+
"has_upgrade": self.loader.has_query(up_query),
|
|
223
|
+
"has_downgrade": self.loader.has_query(down_query),
|
|
224
|
+
}
|
|
225
|
+
|
|
226
|
+
def _get_migration_sql(self, migration: "dict[str, Any]", direction: str) -> Optional[SQL]:
|
|
227
|
+
"""Get migration SQL for given direction.
|
|
228
|
+
|
|
229
|
+
Args:
|
|
230
|
+
migration: Migration metadata.
|
|
231
|
+
direction: Either 'up' or 'down'.
|
|
232
|
+
|
|
233
|
+
Returns:
|
|
234
|
+
SQL object for the migration.
|
|
235
|
+
"""
|
|
236
|
+
query_key = f"{direction}_query"
|
|
237
|
+
has_key = f"has_{direction}grade"
|
|
238
|
+
|
|
239
|
+
if not migration.get(has_key):
|
|
240
|
+
if direction == "down":
|
|
241
|
+
logger.warning("Migration %s has no downgrade query", migration["version"])
|
|
242
|
+
return None
|
|
243
|
+
msg = f"Migration {migration['version']} has no upgrade query"
|
|
244
|
+
raise ValueError(msg)
|
|
245
|
+
|
|
246
|
+
return self.loader.get_sql(migration[query_key])
|
|
247
|
+
|
|
248
|
+
@abstractmethod
|
|
249
|
+
def get_migration_files(self) -> Any:
|
|
250
|
+
"""Get all migration files sorted by version."""
|
|
251
|
+
...
|
|
252
|
+
|
|
253
|
+
@abstractmethod
|
|
254
|
+
def load_migration(self, file_path: Path) -> Any:
|
|
255
|
+
"""Load a migration file and extract its components."""
|
|
256
|
+
...
|
|
257
|
+
|
|
258
|
+
@abstractmethod
|
|
259
|
+
def execute_upgrade(self, driver: DriverT, migration: "dict[str, Any]") -> Any:
|
|
260
|
+
"""Execute an upgrade migration."""
|
|
261
|
+
...
|
|
262
|
+
|
|
263
|
+
@abstractmethod
|
|
264
|
+
def execute_downgrade(self, driver: DriverT, migration: "dict[str, Any]") -> Any:
|
|
265
|
+
"""Execute a downgrade migration."""
|
|
266
|
+
...
|
|
267
|
+
|
|
268
|
+
@abstractmethod
|
|
269
|
+
def load_all_migrations(self) -> Any:
|
|
270
|
+
"""Load all migrations into a single namespace for bulk operations."""
|
|
271
|
+
...
|
|
272
|
+
|
|
273
|
+
|
|
274
|
+
class BaseMigrationCommands(ABC, Generic[ConfigT, DriverT]):
|
|
275
|
+
"""Base class for migration commands."""
|
|
276
|
+
|
|
277
|
+
def __init__(self, config: ConfigT) -> None:
|
|
278
|
+
"""Initialize migration commands.
|
|
279
|
+
|
|
280
|
+
Args:
|
|
281
|
+
config: The SQLSpec configuration.
|
|
282
|
+
"""
|
|
283
|
+
self.config = config
|
|
284
|
+
|
|
285
|
+
# Get migration settings from config
|
|
286
|
+
migration_config = getattr(self.config, "migration_config", {})
|
|
287
|
+
if migration_config is None:
|
|
288
|
+
migration_config = {}
|
|
289
|
+
|
|
290
|
+
self.version_table = migration_config.get("version_table_name", "sqlspec_migrations")
|
|
291
|
+
self.migrations_path = Path(migration_config.get("script_location", "migrations"))
|
|
292
|
+
|
|
293
|
+
def _get_init_readme_content(self) -> str:
|
|
294
|
+
"""Get the README content for migration directory initialization.
|
|
295
|
+
|
|
296
|
+
Returns:
|
|
297
|
+
The README markdown content.
|
|
298
|
+
"""
|
|
299
|
+
return """# SQLSpec Migrations
|
|
300
|
+
|
|
301
|
+
This directory contains database migration files.
|
|
302
|
+
|
|
303
|
+
## File Format
|
|
304
|
+
|
|
305
|
+
Migration files use SQLFileLoader's named query syntax with versioned names:
|
|
306
|
+
|
|
307
|
+
```sql
|
|
308
|
+
-- name: migrate-0001-up
|
|
309
|
+
CREATE TABLE example (
|
|
310
|
+
id INTEGER PRIMARY KEY,
|
|
311
|
+
name TEXT NOT NULL
|
|
312
|
+
);
|
|
313
|
+
|
|
314
|
+
-- name: migrate-0001-down
|
|
315
|
+
DROP TABLE example;
|
|
316
|
+
```
|
|
317
|
+
|
|
318
|
+
## Naming Conventions
|
|
319
|
+
|
|
320
|
+
### File Names
|
|
321
|
+
|
|
322
|
+
Format: `{version}_{description}.sql`
|
|
323
|
+
|
|
324
|
+
- Version: Zero-padded 4-digit number (0001, 0002, etc.)
|
|
325
|
+
- Description: Brief description using underscores
|
|
326
|
+
- Example: `0001_create_users_table.sql`
|
|
327
|
+
|
|
328
|
+
### Query Names
|
|
329
|
+
|
|
330
|
+
- Upgrade: `migrate-{version}-up`
|
|
331
|
+
- Downgrade: `migrate-{version}-down`
|
|
332
|
+
|
|
333
|
+
This naming ensures proper sorting and avoids conflicts when loading multiple files.
|
|
334
|
+
"""
|
|
335
|
+
|
|
336
|
+
def init_directory(self, directory: str, package: bool = True) -> None:
|
|
337
|
+
"""Initialize migration directory structure (sync implementation).
|
|
338
|
+
|
|
339
|
+
Args:
|
|
340
|
+
directory: Directory to initialize migrations in.
|
|
341
|
+
package: Whether to create __init__.py file.
|
|
342
|
+
"""
|
|
343
|
+
from rich.console import Console
|
|
344
|
+
|
|
345
|
+
console = Console()
|
|
346
|
+
|
|
347
|
+
migrations_dir = Path(directory)
|
|
348
|
+
migrations_dir.mkdir(parents=True, exist_ok=True)
|
|
349
|
+
|
|
350
|
+
if package:
|
|
351
|
+
(migrations_dir / "__init__.py").touch()
|
|
352
|
+
|
|
353
|
+
# Create README
|
|
354
|
+
readme = migrations_dir / "README.md"
|
|
355
|
+
readme.write_text(self._get_init_readme_content())
|
|
356
|
+
|
|
357
|
+
# Create .gitkeep for empty directory
|
|
358
|
+
(migrations_dir / ".gitkeep").touch()
|
|
359
|
+
|
|
360
|
+
console.print(f"[green]Initialized migrations in {directory}[/]")
|
|
361
|
+
|
|
362
|
+
@abstractmethod
|
|
363
|
+
def init(self, directory: str, package: bool = True) -> Any:
|
|
364
|
+
"""Initialize migration directory structure."""
|
|
365
|
+
...
|
|
366
|
+
|
|
367
|
+
@abstractmethod
|
|
368
|
+
def current(self, verbose: bool = False) -> Any:
|
|
369
|
+
"""Show current migration version."""
|
|
370
|
+
...
|
|
371
|
+
|
|
372
|
+
@abstractmethod
|
|
373
|
+
def upgrade(self, revision: str = "head") -> Any:
|
|
374
|
+
"""Upgrade to a target revision."""
|
|
375
|
+
...
|
|
376
|
+
|
|
377
|
+
@abstractmethod
|
|
378
|
+
def downgrade(self, revision: str = "-1") -> Any:
|
|
379
|
+
"""Downgrade to a target revision."""
|
|
380
|
+
...
|
|
381
|
+
|
|
382
|
+
@abstractmethod
|
|
383
|
+
def stamp(self, revision: str) -> Any:
|
|
384
|
+
"""Mark database as being at a specific revision without running migrations."""
|
|
385
|
+
...
|
|
386
|
+
|
|
387
|
+
@abstractmethod
|
|
388
|
+
def revision(self, message: str) -> Any:
|
|
389
|
+
"""Create a new migration file."""
|
|
390
|
+
...
|