sqlspec 0.25.0__py3-none-any.whl → 0.27.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of sqlspec might be problematic. Click here for more details.
- sqlspec/__init__.py +7 -15
- sqlspec/_serialization.py +256 -24
- sqlspec/_typing.py +71 -52
- sqlspec/adapters/adbc/_types.py +1 -1
- sqlspec/adapters/adbc/adk/__init__.py +5 -0
- sqlspec/adapters/adbc/adk/store.py +870 -0
- sqlspec/adapters/adbc/config.py +69 -12
- sqlspec/adapters/adbc/data_dictionary.py +340 -0
- sqlspec/adapters/adbc/driver.py +266 -58
- sqlspec/adapters/adbc/litestar/__init__.py +5 -0
- sqlspec/adapters/adbc/litestar/store.py +504 -0
- sqlspec/adapters/adbc/type_converter.py +153 -0
- sqlspec/adapters/aiosqlite/_types.py +1 -1
- sqlspec/adapters/aiosqlite/adk/__init__.py +5 -0
- sqlspec/adapters/aiosqlite/adk/store.py +527 -0
- sqlspec/adapters/aiosqlite/config.py +88 -15
- sqlspec/adapters/aiosqlite/data_dictionary.py +149 -0
- sqlspec/adapters/aiosqlite/driver.py +143 -40
- sqlspec/adapters/aiosqlite/litestar/__init__.py +5 -0
- sqlspec/adapters/aiosqlite/litestar/store.py +281 -0
- sqlspec/adapters/aiosqlite/pool.py +7 -7
- sqlspec/adapters/asyncmy/__init__.py +7 -1
- sqlspec/adapters/asyncmy/_types.py +2 -2
- sqlspec/adapters/asyncmy/adk/__init__.py +5 -0
- sqlspec/adapters/asyncmy/adk/store.py +493 -0
- sqlspec/adapters/asyncmy/config.py +68 -23
- sqlspec/adapters/asyncmy/data_dictionary.py +161 -0
- sqlspec/adapters/asyncmy/driver.py +313 -58
- sqlspec/adapters/asyncmy/litestar/__init__.py +5 -0
- sqlspec/adapters/asyncmy/litestar/store.py +296 -0
- sqlspec/adapters/asyncpg/__init__.py +2 -1
- sqlspec/adapters/asyncpg/_type_handlers.py +71 -0
- sqlspec/adapters/asyncpg/_types.py +11 -7
- sqlspec/adapters/asyncpg/adk/__init__.py +5 -0
- sqlspec/adapters/asyncpg/adk/store.py +450 -0
- sqlspec/adapters/asyncpg/config.py +59 -35
- sqlspec/adapters/asyncpg/data_dictionary.py +173 -0
- sqlspec/adapters/asyncpg/driver.py +170 -25
- sqlspec/adapters/asyncpg/litestar/__init__.py +5 -0
- sqlspec/adapters/asyncpg/litestar/store.py +253 -0
- sqlspec/adapters/bigquery/_types.py +1 -1
- sqlspec/adapters/bigquery/adk/__init__.py +5 -0
- sqlspec/adapters/bigquery/adk/store.py +576 -0
- sqlspec/adapters/bigquery/config.py +27 -10
- sqlspec/adapters/bigquery/data_dictionary.py +149 -0
- sqlspec/adapters/bigquery/driver.py +368 -142
- sqlspec/adapters/bigquery/litestar/__init__.py +5 -0
- sqlspec/adapters/bigquery/litestar/store.py +327 -0
- sqlspec/adapters/bigquery/type_converter.py +125 -0
- sqlspec/adapters/duckdb/_types.py +1 -1
- sqlspec/adapters/duckdb/adk/__init__.py +14 -0
- sqlspec/adapters/duckdb/adk/store.py +553 -0
- sqlspec/adapters/duckdb/config.py +80 -20
- sqlspec/adapters/duckdb/data_dictionary.py +163 -0
- sqlspec/adapters/duckdb/driver.py +167 -45
- sqlspec/adapters/duckdb/litestar/__init__.py +5 -0
- sqlspec/adapters/duckdb/litestar/store.py +332 -0
- sqlspec/adapters/duckdb/pool.py +4 -4
- sqlspec/adapters/duckdb/type_converter.py +133 -0
- sqlspec/adapters/oracledb/_numpy_handlers.py +133 -0
- sqlspec/adapters/oracledb/_types.py +20 -2
- sqlspec/adapters/oracledb/adk/__init__.py +5 -0
- sqlspec/adapters/oracledb/adk/store.py +1745 -0
- sqlspec/adapters/oracledb/config.py +122 -32
- sqlspec/adapters/oracledb/data_dictionary.py +509 -0
- sqlspec/adapters/oracledb/driver.py +353 -91
- sqlspec/adapters/oracledb/litestar/__init__.py +5 -0
- sqlspec/adapters/oracledb/litestar/store.py +767 -0
- sqlspec/adapters/oracledb/migrations.py +348 -73
- sqlspec/adapters/oracledb/type_converter.py +207 -0
- sqlspec/adapters/psqlpy/_type_handlers.py +44 -0
- sqlspec/adapters/psqlpy/_types.py +2 -1
- sqlspec/adapters/psqlpy/adk/__init__.py +5 -0
- sqlspec/adapters/psqlpy/adk/store.py +482 -0
- sqlspec/adapters/psqlpy/config.py +46 -17
- sqlspec/adapters/psqlpy/data_dictionary.py +172 -0
- sqlspec/adapters/psqlpy/driver.py +123 -209
- sqlspec/adapters/psqlpy/litestar/__init__.py +5 -0
- sqlspec/adapters/psqlpy/litestar/store.py +272 -0
- sqlspec/adapters/psqlpy/type_converter.py +102 -0
- sqlspec/adapters/psycopg/_type_handlers.py +80 -0
- sqlspec/adapters/psycopg/_types.py +2 -1
- sqlspec/adapters/psycopg/adk/__init__.py +5 -0
- sqlspec/adapters/psycopg/adk/store.py +944 -0
- sqlspec/adapters/psycopg/config.py +69 -35
- sqlspec/adapters/psycopg/data_dictionary.py +331 -0
- sqlspec/adapters/psycopg/driver.py +238 -81
- sqlspec/adapters/psycopg/litestar/__init__.py +5 -0
- sqlspec/adapters/psycopg/litestar/store.py +554 -0
- sqlspec/adapters/sqlite/__init__.py +2 -1
- sqlspec/adapters/sqlite/_type_handlers.py +86 -0
- sqlspec/adapters/sqlite/_types.py +1 -1
- sqlspec/adapters/sqlite/adk/__init__.py +5 -0
- sqlspec/adapters/sqlite/adk/store.py +572 -0
- sqlspec/adapters/sqlite/config.py +87 -15
- sqlspec/adapters/sqlite/data_dictionary.py +149 -0
- sqlspec/adapters/sqlite/driver.py +137 -54
- sqlspec/adapters/sqlite/litestar/__init__.py +5 -0
- sqlspec/adapters/sqlite/litestar/store.py +318 -0
- sqlspec/adapters/sqlite/pool.py +18 -9
- sqlspec/base.py +45 -26
- sqlspec/builder/__init__.py +73 -4
- sqlspec/builder/_base.py +162 -89
- sqlspec/builder/_column.py +62 -29
- sqlspec/builder/_ddl.py +180 -121
- sqlspec/builder/_delete.py +5 -4
- sqlspec/builder/_dml.py +388 -0
- sqlspec/{_sql.py → builder/_factory.py} +53 -94
- sqlspec/builder/_insert.py +32 -131
- sqlspec/builder/_join.py +375 -0
- sqlspec/builder/_merge.py +446 -11
- sqlspec/builder/_parsing_utils.py +111 -17
- sqlspec/builder/_select.py +1457 -24
- sqlspec/builder/_update.py +11 -42
- sqlspec/cli.py +307 -194
- sqlspec/config.py +252 -67
- sqlspec/core/__init__.py +5 -4
- sqlspec/core/cache.py +17 -17
- sqlspec/core/compiler.py +62 -9
- sqlspec/core/filters.py +37 -37
- sqlspec/core/hashing.py +9 -9
- sqlspec/core/parameters.py +83 -48
- sqlspec/core/result.py +102 -46
- sqlspec/core/splitter.py +16 -17
- sqlspec/core/statement.py +36 -30
- sqlspec/core/type_conversion.py +235 -0
- sqlspec/driver/__init__.py +7 -6
- sqlspec/driver/_async.py +188 -151
- sqlspec/driver/_common.py +285 -80
- sqlspec/driver/_sync.py +188 -152
- sqlspec/driver/mixins/_result_tools.py +20 -236
- sqlspec/driver/mixins/_sql_translator.py +4 -4
- sqlspec/exceptions.py +75 -7
- sqlspec/extensions/adk/__init__.py +53 -0
- sqlspec/extensions/adk/_types.py +51 -0
- sqlspec/extensions/adk/converters.py +172 -0
- sqlspec/extensions/adk/migrations/0001_create_adk_tables.py +144 -0
- sqlspec/extensions/adk/migrations/__init__.py +0 -0
- sqlspec/extensions/adk/service.py +181 -0
- sqlspec/extensions/adk/store.py +536 -0
- sqlspec/extensions/aiosql/adapter.py +73 -53
- sqlspec/extensions/litestar/__init__.py +21 -4
- sqlspec/extensions/litestar/cli.py +54 -10
- sqlspec/extensions/litestar/config.py +59 -266
- sqlspec/extensions/litestar/handlers.py +46 -17
- sqlspec/extensions/litestar/migrations/0001_create_session_table.py +137 -0
- sqlspec/extensions/litestar/migrations/__init__.py +3 -0
- sqlspec/extensions/litestar/plugin.py +324 -223
- sqlspec/extensions/litestar/providers.py +25 -25
- sqlspec/extensions/litestar/store.py +265 -0
- sqlspec/loader.py +30 -49
- sqlspec/migrations/__init__.py +4 -3
- sqlspec/migrations/base.py +302 -39
- sqlspec/migrations/commands.py +611 -144
- sqlspec/migrations/context.py +142 -0
- sqlspec/migrations/fix.py +199 -0
- sqlspec/migrations/loaders.py +68 -23
- sqlspec/migrations/runner.py +543 -107
- sqlspec/migrations/tracker.py +237 -21
- sqlspec/migrations/utils.py +51 -3
- sqlspec/migrations/validation.py +177 -0
- sqlspec/protocols.py +66 -36
- sqlspec/storage/_utils.py +98 -0
- sqlspec/storage/backends/fsspec.py +134 -106
- sqlspec/storage/backends/local.py +78 -51
- sqlspec/storage/backends/obstore.py +278 -162
- sqlspec/storage/registry.py +75 -39
- sqlspec/typing.py +16 -84
- sqlspec/utils/config_resolver.py +153 -0
- sqlspec/utils/correlation.py +4 -5
- sqlspec/utils/data_transformation.py +3 -2
- sqlspec/utils/deprecation.py +9 -8
- sqlspec/utils/fixtures.py +4 -4
- sqlspec/utils/logging.py +46 -6
- sqlspec/utils/module_loader.py +2 -2
- sqlspec/utils/schema.py +288 -0
- sqlspec/utils/serializers.py +50 -2
- sqlspec/utils/sync_tools.py +21 -17
- sqlspec/utils/text.py +1 -2
- sqlspec/utils/type_guards.py +111 -20
- sqlspec/utils/version.py +433 -0
- {sqlspec-0.25.0.dist-info → sqlspec-0.27.0.dist-info}/METADATA +40 -21
- sqlspec-0.27.0.dist-info/RECORD +207 -0
- sqlspec/builder/mixins/__init__.py +0 -55
- sqlspec/builder/mixins/_cte_and_set_ops.py +0 -254
- sqlspec/builder/mixins/_delete_operations.py +0 -50
- sqlspec/builder/mixins/_insert_operations.py +0 -282
- sqlspec/builder/mixins/_join_operations.py +0 -389
- sqlspec/builder/mixins/_merge_operations.py +0 -592
- sqlspec/builder/mixins/_order_limit_operations.py +0 -152
- sqlspec/builder/mixins/_pivot_operations.py +0 -157
- sqlspec/builder/mixins/_select_operations.py +0 -936
- sqlspec/builder/mixins/_update_operations.py +0 -218
- sqlspec/builder/mixins/_where_clause.py +0 -1304
- sqlspec-0.25.0.dist-info/RECORD +0 -139
- sqlspec-0.25.0.dist-info/licenses/NOTICE +0 -29
- {sqlspec-0.25.0.dist-info → sqlspec-0.27.0.dist-info}/WHEEL +0 -0
- {sqlspec-0.25.0.dist-info → sqlspec-0.27.0.dist-info}/entry_points.txt +0 -0
- {sqlspec-0.25.0.dist-info → sqlspec-0.27.0.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,142 @@
|
|
|
1
|
+
"""Migration context for passing runtime information to migrations."""
|
|
2
|
+
|
|
3
|
+
import asyncio
|
|
4
|
+
import inspect
|
|
5
|
+
from dataclasses import dataclass, field
|
|
6
|
+
from typing import TYPE_CHECKING, Any
|
|
7
|
+
|
|
8
|
+
from sqlspec.utils.logging import get_logger
|
|
9
|
+
|
|
10
|
+
if TYPE_CHECKING:
|
|
11
|
+
from sqlspec.driver import AsyncDriverAdapterBase, SyncDriverAdapterBase
|
|
12
|
+
|
|
13
|
+
logger = get_logger("migrations.context")
|
|
14
|
+
|
|
15
|
+
__all__ = ("MigrationContext",)
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
@dataclass
|
|
19
|
+
class MigrationContext:
|
|
20
|
+
"""Context object passed to migration functions.
|
|
21
|
+
|
|
22
|
+
Provides runtime information about the database environment
|
|
23
|
+
to migration functions, allowing them to generate dialect-specific SQL.
|
|
24
|
+
"""
|
|
25
|
+
|
|
26
|
+
config: "Any | None" = None
|
|
27
|
+
"""Database configuration object."""
|
|
28
|
+
dialect: "str | None" = None
|
|
29
|
+
"""Database dialect (e.g., 'postgres', 'mysql', 'sqlite')."""
|
|
30
|
+
metadata: "dict[str, Any] | None" = None
|
|
31
|
+
"""Additional metadata for the migration."""
|
|
32
|
+
extension_config: "dict[str, Any] | None" = None
|
|
33
|
+
"""Extension-specific configuration options."""
|
|
34
|
+
|
|
35
|
+
driver: "SyncDriverAdapterBase | AsyncDriverAdapterBase | None" = None
|
|
36
|
+
"""Database driver instance (available during execution)."""
|
|
37
|
+
|
|
38
|
+
_execution_metadata: "dict[str, Any]" = field(default_factory=dict)
|
|
39
|
+
"""Internal execution metadata for tracking async operations."""
|
|
40
|
+
|
|
41
|
+
def __post_init__(self) -> None:
|
|
42
|
+
"""Initialize metadata and extension config if not provided."""
|
|
43
|
+
if not self.metadata:
|
|
44
|
+
self.metadata = {}
|
|
45
|
+
if not self.extension_config:
|
|
46
|
+
self.extension_config = {}
|
|
47
|
+
|
|
48
|
+
@classmethod
|
|
49
|
+
def from_config(cls, config: Any) -> "MigrationContext":
|
|
50
|
+
"""Create context from database configuration.
|
|
51
|
+
|
|
52
|
+
Args:
|
|
53
|
+
config: Database configuration object.
|
|
54
|
+
|
|
55
|
+
Returns:
|
|
56
|
+
Migration context with dialect information.
|
|
57
|
+
"""
|
|
58
|
+
dialect = None
|
|
59
|
+
try:
|
|
60
|
+
if hasattr(config, "statement_config") and config.statement_config:
|
|
61
|
+
dialect = getattr(config.statement_config, "dialect", None)
|
|
62
|
+
elif hasattr(config, "_create_statement_config") and callable(config._create_statement_config):
|
|
63
|
+
stmt_config = config._create_statement_config()
|
|
64
|
+
dialect = getattr(stmt_config, "dialect", None)
|
|
65
|
+
except Exception:
|
|
66
|
+
logger.debug("Unable to extract dialect from config")
|
|
67
|
+
|
|
68
|
+
return cls(dialect=dialect, config=config)
|
|
69
|
+
|
|
70
|
+
@property
|
|
71
|
+
def is_async_execution(self) -> bool:
|
|
72
|
+
"""Check if migrations are running in an async execution context.
|
|
73
|
+
|
|
74
|
+
Returns:
|
|
75
|
+
True if executing in an async context.
|
|
76
|
+
"""
|
|
77
|
+
try:
|
|
78
|
+
asyncio.current_task()
|
|
79
|
+
except RuntimeError:
|
|
80
|
+
return False
|
|
81
|
+
else:
|
|
82
|
+
return True
|
|
83
|
+
|
|
84
|
+
@property
|
|
85
|
+
def is_async_driver(self) -> bool:
|
|
86
|
+
"""Check if the current driver is async.
|
|
87
|
+
|
|
88
|
+
Returns:
|
|
89
|
+
True if driver supports async operations.
|
|
90
|
+
"""
|
|
91
|
+
if self.driver is None:
|
|
92
|
+
return False
|
|
93
|
+
|
|
94
|
+
execute_method = getattr(self.driver, "execute_script", None)
|
|
95
|
+
return execute_method is not None and inspect.iscoroutinefunction(execute_method)
|
|
96
|
+
|
|
97
|
+
@property
|
|
98
|
+
def execution_mode(self) -> str:
|
|
99
|
+
"""Get the current execution mode.
|
|
100
|
+
|
|
101
|
+
Returns:
|
|
102
|
+
'async' if in async context, 'sync' otherwise.
|
|
103
|
+
"""
|
|
104
|
+
return "async" if self.is_async_execution else "sync"
|
|
105
|
+
|
|
106
|
+
def set_execution_metadata(self, key: str, value: Any) -> None:
|
|
107
|
+
"""Set execution metadata for tracking migration state.
|
|
108
|
+
|
|
109
|
+
Args:
|
|
110
|
+
key: Metadata key.
|
|
111
|
+
value: Metadata value.
|
|
112
|
+
"""
|
|
113
|
+
self._execution_metadata[key] = value
|
|
114
|
+
|
|
115
|
+
def get_execution_metadata(self, key: str, default: Any = None) -> Any:
|
|
116
|
+
"""Get execution metadata.
|
|
117
|
+
|
|
118
|
+
Args:
|
|
119
|
+
key: Metadata key.
|
|
120
|
+
default: Default value if key not found.
|
|
121
|
+
|
|
122
|
+
Returns:
|
|
123
|
+
Metadata value or default.
|
|
124
|
+
"""
|
|
125
|
+
return self._execution_metadata.get(key, default)
|
|
126
|
+
|
|
127
|
+
def validate_async_usage(self, migration_func: Any) -> None:
|
|
128
|
+
"""Validate proper usage of async functions in migration context.
|
|
129
|
+
|
|
130
|
+
Args:
|
|
131
|
+
migration_func: The migration function to validate.
|
|
132
|
+
"""
|
|
133
|
+
if inspect.iscoroutinefunction(migration_func) and not self.is_async_execution and not self.is_async_driver:
|
|
134
|
+
msg = (
|
|
135
|
+
"Async migration function detected but execution context is sync. "
|
|
136
|
+
"Consider using async database configuration or sync migration functions."
|
|
137
|
+
)
|
|
138
|
+
logger.warning(msg)
|
|
139
|
+
|
|
140
|
+
if not inspect.iscoroutinefunction(migration_func) and self.is_async_driver:
|
|
141
|
+
self.set_execution_metadata("mixed_execution", value=True)
|
|
142
|
+
logger.debug("Sync migration function in async driver context - using compatibility mode")
|
|
@@ -0,0 +1,199 @@
|
|
|
1
|
+
"""Migration file fix operations for converting timestamp to sequential versions.
|
|
2
|
+
|
|
3
|
+
This module provides utilities to convert timestamp-format migration files to
|
|
4
|
+
sequential format, supporting the hybrid versioning workflow where development
|
|
5
|
+
uses timestamps and production uses sequential numbers.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
import logging
|
|
9
|
+
import re
|
|
10
|
+
import shutil
|
|
11
|
+
from dataclasses import dataclass
|
|
12
|
+
from datetime import datetime, timezone
|
|
13
|
+
from pathlib import Path
|
|
14
|
+
|
|
15
|
+
__all__ = ("MigrationFixer", "MigrationRename")
|
|
16
|
+
|
|
17
|
+
logger = logging.getLogger(__name__)
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
@dataclass
|
|
21
|
+
class MigrationRename:
|
|
22
|
+
"""Represents a planned migration file rename operation.
|
|
23
|
+
|
|
24
|
+
Attributes:
|
|
25
|
+
old_path: Current file path.
|
|
26
|
+
new_path: Target file path after rename.
|
|
27
|
+
old_version: Current version string.
|
|
28
|
+
new_version: Target version string.
|
|
29
|
+
needs_content_update: Whether file content needs updating.
|
|
30
|
+
True for SQL files that contain query names.
|
|
31
|
+
"""
|
|
32
|
+
|
|
33
|
+
old_path: Path
|
|
34
|
+
new_path: Path
|
|
35
|
+
old_version: str
|
|
36
|
+
new_version: str
|
|
37
|
+
needs_content_update: bool
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
class MigrationFixer:
|
|
41
|
+
"""Handles atomic migration file conversion operations.
|
|
42
|
+
|
|
43
|
+
Provides backup/rollback functionality and manages conversion from
|
|
44
|
+
timestamp-based migration files to sequential format.
|
|
45
|
+
"""
|
|
46
|
+
|
|
47
|
+
def __init__(self, migrations_path: Path) -> None:
|
|
48
|
+
"""Initialize migration fixer.
|
|
49
|
+
|
|
50
|
+
Args:
|
|
51
|
+
migrations_path: Path to migrations directory.
|
|
52
|
+
"""
|
|
53
|
+
self.migrations_path = migrations_path
|
|
54
|
+
self.backup_path: Path | None = None
|
|
55
|
+
|
|
56
|
+
def plan_renames(self, conversion_map: dict[str, str]) -> list[MigrationRename]:
|
|
57
|
+
"""Plan all file rename operations from conversion map.
|
|
58
|
+
|
|
59
|
+
Scans migration directory and builds list of MigrationRename objects
|
|
60
|
+
for all files that need conversion. Validates no target collisions.
|
|
61
|
+
|
|
62
|
+
Args:
|
|
63
|
+
conversion_map: Dictionary mapping old versions to new versions.
|
|
64
|
+
|
|
65
|
+
Returns:
|
|
66
|
+
List of planned rename operations.
|
|
67
|
+
|
|
68
|
+
Raises:
|
|
69
|
+
ValueError: If target file already exists or collision detected.
|
|
70
|
+
"""
|
|
71
|
+
if not conversion_map:
|
|
72
|
+
return []
|
|
73
|
+
|
|
74
|
+
renames: list[MigrationRename] = []
|
|
75
|
+
|
|
76
|
+
for old_version, new_version in conversion_map.items():
|
|
77
|
+
matching_files = list(self.migrations_path.glob(f"{old_version}_*"))
|
|
78
|
+
|
|
79
|
+
for old_path in matching_files:
|
|
80
|
+
suffix = old_path.suffix
|
|
81
|
+
description = old_path.stem.replace(f"{old_version}_", "")
|
|
82
|
+
|
|
83
|
+
new_filename = f"{new_version}_{description}{suffix}"
|
|
84
|
+
new_path = self.migrations_path / new_filename
|
|
85
|
+
|
|
86
|
+
if new_path.exists() and new_path != old_path:
|
|
87
|
+
msg = f"Target file already exists: {new_path}"
|
|
88
|
+
raise ValueError(msg)
|
|
89
|
+
|
|
90
|
+
needs_content_update = suffix == ".sql"
|
|
91
|
+
|
|
92
|
+
renames.append(
|
|
93
|
+
MigrationRename(
|
|
94
|
+
old_path=old_path,
|
|
95
|
+
new_path=new_path,
|
|
96
|
+
old_version=old_version,
|
|
97
|
+
new_version=new_version,
|
|
98
|
+
needs_content_update=needs_content_update,
|
|
99
|
+
)
|
|
100
|
+
)
|
|
101
|
+
|
|
102
|
+
return renames
|
|
103
|
+
|
|
104
|
+
def create_backup(self) -> Path:
|
|
105
|
+
"""Create timestamped backup directory with all migration files.
|
|
106
|
+
|
|
107
|
+
Returns:
|
|
108
|
+
Path to created backup directory.
|
|
109
|
+
|
|
110
|
+
"""
|
|
111
|
+
timestamp = datetime.now(tz=timezone.utc).strftime("%Y%m%d_%H%M%S")
|
|
112
|
+
backup_dir = self.migrations_path / f".backup_{timestamp}"
|
|
113
|
+
|
|
114
|
+
backup_dir.mkdir(parents=True, exist_ok=False)
|
|
115
|
+
|
|
116
|
+
for file_path in self.migrations_path.iterdir():
|
|
117
|
+
if file_path.is_file() and not file_path.name.startswith("."):
|
|
118
|
+
shutil.copy2(file_path, backup_dir / file_path.name)
|
|
119
|
+
|
|
120
|
+
self.backup_path = backup_dir
|
|
121
|
+
return backup_dir
|
|
122
|
+
|
|
123
|
+
def apply_renames(self, renames: "list[MigrationRename]", dry_run: bool = False) -> None:
|
|
124
|
+
"""Execute planned rename operations.
|
|
125
|
+
|
|
126
|
+
Args:
|
|
127
|
+
renames: List of planned rename operations.
|
|
128
|
+
dry_run: If True, log operations without executing.
|
|
129
|
+
|
|
130
|
+
"""
|
|
131
|
+
if not renames:
|
|
132
|
+
return
|
|
133
|
+
|
|
134
|
+
for rename in renames:
|
|
135
|
+
if dry_run:
|
|
136
|
+
continue
|
|
137
|
+
|
|
138
|
+
if rename.needs_content_update:
|
|
139
|
+
self.update_file_content(rename.old_path, rename.old_version, rename.new_version)
|
|
140
|
+
|
|
141
|
+
rename.old_path.rename(rename.new_path)
|
|
142
|
+
|
|
143
|
+
def update_file_content(self, file_path: Path, old_version: str, new_version: str) -> None:
|
|
144
|
+
"""Update SQL query names and version comments in file content.
|
|
145
|
+
|
|
146
|
+
Transforms query names and version metadata from old version to new version:
|
|
147
|
+
-- name: migrate-{old_version}-up → -- name: migrate-{new_version}-up
|
|
148
|
+
-- name: migrate-{old_version}-down → -- name: migrate-{new_version}-down
|
|
149
|
+
-- Version: {old_version} → -- Version: {new_version}
|
|
150
|
+
|
|
151
|
+
Creates version-specific regex patterns to avoid unintended replacements
|
|
152
|
+
of other migrate-* patterns in the file.
|
|
153
|
+
|
|
154
|
+
Args:
|
|
155
|
+
file_path: Path to file to update.
|
|
156
|
+
old_version: Old version string.
|
|
157
|
+
new_version: New version string.
|
|
158
|
+
|
|
159
|
+
"""
|
|
160
|
+
content = file_path.read_text(encoding="utf-8")
|
|
161
|
+
|
|
162
|
+
up_pattern = re.compile(rf"(-- name:\s+migrate-){re.escape(old_version)}(-up)")
|
|
163
|
+
down_pattern = re.compile(rf"(-- name:\s+migrate-){re.escape(old_version)}(-down)")
|
|
164
|
+
version_pattern = re.compile(rf"(-- Version:\s+){re.escape(old_version)}")
|
|
165
|
+
|
|
166
|
+
content = up_pattern.sub(rf"\g<1>{new_version}\g<2>", content)
|
|
167
|
+
content = down_pattern.sub(rf"\g<1>{new_version}\g<2>", content)
|
|
168
|
+
content = version_pattern.sub(rf"\g<1>{new_version}", content)
|
|
169
|
+
|
|
170
|
+
file_path.write_text(content, encoding="utf-8")
|
|
171
|
+
logger.debug("Updated content in %s", file_path.name)
|
|
172
|
+
|
|
173
|
+
def rollback(self) -> None:
|
|
174
|
+
"""Restore migration files from backup.
|
|
175
|
+
|
|
176
|
+
Deletes current migration files and restores from backup directory.
|
|
177
|
+
Only restores if backup exists.
|
|
178
|
+
"""
|
|
179
|
+
if not self.backup_path or not self.backup_path.exists():
|
|
180
|
+
return
|
|
181
|
+
|
|
182
|
+
for file_path in self.migrations_path.iterdir():
|
|
183
|
+
if file_path.is_file() and not file_path.name.startswith("."):
|
|
184
|
+
file_path.unlink()
|
|
185
|
+
|
|
186
|
+
for backup_file in self.backup_path.iterdir():
|
|
187
|
+
if backup_file.is_file():
|
|
188
|
+
shutil.copy2(backup_file, self.migrations_path / backup_file.name)
|
|
189
|
+
|
|
190
|
+
def cleanup(self) -> None:
|
|
191
|
+
"""Remove backup directory after successful conversion.
|
|
192
|
+
|
|
193
|
+
Only removes backup if it exists. Logs warning if no backup found.
|
|
194
|
+
"""
|
|
195
|
+
if not self.backup_path or not self.backup_path.exists():
|
|
196
|
+
return
|
|
197
|
+
|
|
198
|
+
shutil.rmtree(self.backup_path)
|
|
199
|
+
self.backup_path = None
|
sqlspec/migrations/loaders.py
CHANGED
|
@@ -10,7 +10,7 @@ import types
|
|
|
10
10
|
from collections.abc import Iterator
|
|
11
11
|
from contextlib import contextmanager
|
|
12
12
|
from pathlib import Path
|
|
13
|
-
from typing import Any, Final
|
|
13
|
+
from typing import Any, Final
|
|
14
14
|
|
|
15
15
|
from sqlspec.loader import SQLFileLoader as CoreSQLFileLoader
|
|
16
16
|
|
|
@@ -77,13 +77,22 @@ class SQLFileLoader(BaseMigrationLoader):
|
|
|
77
77
|
|
|
78
78
|
__slots__ = ("sql_loader",)
|
|
79
79
|
|
|
80
|
-
def __init__(self) -> None:
|
|
81
|
-
"""Initialize SQL file loader.
|
|
82
|
-
|
|
80
|
+
def __init__(self, sql_loader: "CoreSQLFileLoader | None" = None) -> None:
|
|
81
|
+
"""Initialize SQL file loader.
|
|
82
|
+
|
|
83
|
+
Args:
|
|
84
|
+
sql_loader: Optional shared SQLFileLoader instance to reuse.
|
|
85
|
+
If not provided, creates a new instance.
|
|
86
|
+
"""
|
|
87
|
+
self.sql_loader: CoreSQLFileLoader = sql_loader if sql_loader is not None else CoreSQLFileLoader()
|
|
83
88
|
|
|
84
89
|
async def get_up_sql(self, path: Path) -> list[str]:
|
|
85
90
|
"""Extract the 'up' SQL from a SQL migration file.
|
|
86
91
|
|
|
92
|
+
The SQL file must already be loaded via validate_migration_file()
|
|
93
|
+
before calling this method. This design ensures the file is loaded
|
|
94
|
+
exactly once during the migration process.
|
|
95
|
+
|
|
87
96
|
Args:
|
|
88
97
|
path: Path to SQL migration file.
|
|
89
98
|
|
|
@@ -93,9 +102,6 @@ class SQLFileLoader(BaseMigrationLoader):
|
|
|
93
102
|
Raises:
|
|
94
103
|
MigrationLoadError: If migration file is invalid or missing up query.
|
|
95
104
|
"""
|
|
96
|
-
self.sql_loader.clear_cache()
|
|
97
|
-
self.sql_loader.load_sql(path)
|
|
98
|
-
|
|
99
105
|
version = self._extract_version(path.name)
|
|
100
106
|
up_query = f"migrate-{version}-up"
|
|
101
107
|
|
|
@@ -109,15 +115,16 @@ class SQLFileLoader(BaseMigrationLoader):
|
|
|
109
115
|
async def get_down_sql(self, path: Path) -> list[str]:
|
|
110
116
|
"""Extract the 'down' SQL from a SQL migration file.
|
|
111
117
|
|
|
118
|
+
The SQL file must already be loaded via validate_migration_file()
|
|
119
|
+
before calling this method. This design ensures the file is loaded
|
|
120
|
+
exactly once during the migration process.
|
|
121
|
+
|
|
112
122
|
Args:
|
|
113
123
|
path: Path to SQL migration file.
|
|
114
124
|
|
|
115
125
|
Returns:
|
|
116
126
|
List containing single SQL statement for downgrade, or empty list.
|
|
117
127
|
"""
|
|
118
|
-
self.sql_loader.clear_cache()
|
|
119
|
-
self.sql_loader.load_sql(path)
|
|
120
|
-
|
|
121
128
|
version = self._extract_version(path.name)
|
|
122
129
|
down_query = f"migrate-{version}-down"
|
|
123
130
|
|
|
@@ -141,7 +148,6 @@ class SQLFileLoader(BaseMigrationLoader):
|
|
|
141
148
|
msg = f"Invalid migration filename: {path.name}"
|
|
142
149
|
raise MigrationLoadError(msg)
|
|
143
150
|
|
|
144
|
-
self.sql_loader.clear_cache()
|
|
145
151
|
self.sql_loader.load_sql(path)
|
|
146
152
|
up_query = f"migrate-{version}-up"
|
|
147
153
|
if not self.sql_loader.has_query(up_query):
|
|
@@ -151,30 +157,49 @@ class SQLFileLoader(BaseMigrationLoader):
|
|
|
151
157
|
def _extract_version(self, filename: str) -> str:
|
|
152
158
|
"""Extract version from filename.
|
|
153
159
|
|
|
160
|
+
Supports sequential (0001), timestamp (20251011120000), and extension-prefixed
|
|
161
|
+
(ext_litestar_0001) version formats.
|
|
162
|
+
|
|
154
163
|
Args:
|
|
155
164
|
filename: Migration filename to parse.
|
|
156
165
|
|
|
157
166
|
Returns:
|
|
158
|
-
|
|
167
|
+
Version string or empty string if invalid.
|
|
159
168
|
"""
|
|
160
|
-
|
|
161
|
-
|
|
169
|
+
extension_version_parts = 3
|
|
170
|
+
timestamp_min_length = 4
|
|
171
|
+
|
|
172
|
+
name_without_ext = filename.rsplit(".", 1)[0]
|
|
173
|
+
|
|
174
|
+
if name_without_ext.startswith("ext_"):
|
|
175
|
+
parts = name_without_ext.split("_", 3)
|
|
176
|
+
if len(parts) >= extension_version_parts:
|
|
177
|
+
return f"{parts[0]}_{parts[1]}_{parts[2]}"
|
|
178
|
+
return ""
|
|
179
|
+
|
|
180
|
+
parts = name_without_ext.split("_", 1)
|
|
181
|
+
if parts and parts[0].isdigit():
|
|
182
|
+
return parts[0] if len(parts[0]) > timestamp_min_length else parts[0].zfill(4)
|
|
183
|
+
|
|
184
|
+
return ""
|
|
162
185
|
|
|
163
186
|
|
|
164
187
|
class PythonFileLoader(BaseMigrationLoader):
|
|
165
188
|
"""Loader for Python migration files."""
|
|
166
189
|
|
|
167
|
-
__slots__ = ("migrations_dir", "project_root")
|
|
190
|
+
__slots__ = ("context", "migrations_dir", "project_root")
|
|
168
191
|
|
|
169
|
-
def __init__(self, migrations_dir: Path, project_root: "
|
|
192
|
+
def __init__(self, migrations_dir: Path, project_root: "Path | None" = None, context: "Any | None" = None) -> None:
|
|
170
193
|
"""Initialize Python file loader.
|
|
171
194
|
|
|
172
195
|
Args:
|
|
173
196
|
migrations_dir: Directory containing migration files.
|
|
174
197
|
project_root: Optional project root directory for imports.
|
|
198
|
+
context: Optional migration context to pass to functions.
|
|
175
199
|
"""
|
|
176
200
|
self.migrations_dir = migrations_dir
|
|
177
201
|
self.project_root = project_root if project_root is not None else self._find_project_root(migrations_dir)
|
|
202
|
+
self.context = context
|
|
178
203
|
|
|
179
204
|
async def get_up_sql(self, path: Path) -> list[str]:
|
|
180
205
|
"""Load Python migration and execute upgrade function.
|
|
@@ -208,10 +233,16 @@ class PythonFileLoader(BaseMigrationLoader):
|
|
|
208
233
|
msg = f"'{func_name}' is not callable in {path}"
|
|
209
234
|
raise MigrationLoadError(msg)
|
|
210
235
|
|
|
236
|
+
# Check if function accepts context parameter
|
|
237
|
+
sig = inspect.signature(upgrade_func)
|
|
238
|
+
accepts_context = "context" in sig.parameters or len(sig.parameters) > 0
|
|
239
|
+
|
|
211
240
|
if inspect.iscoroutinefunction(upgrade_func):
|
|
212
|
-
sql_result =
|
|
241
|
+
sql_result = (
|
|
242
|
+
await upgrade_func(self.context) if accepts_context and self.context else await upgrade_func()
|
|
243
|
+
)
|
|
213
244
|
else:
|
|
214
|
-
sql_result = upgrade_func()
|
|
245
|
+
sql_result = upgrade_func(self.context) if accepts_context and self.context else upgrade_func()
|
|
215
246
|
|
|
216
247
|
return self._normalize_and_validate_sql(sql_result, path)
|
|
217
248
|
|
|
@@ -239,10 +270,16 @@ class PythonFileLoader(BaseMigrationLoader):
|
|
|
239
270
|
if not callable(downgrade_func):
|
|
240
271
|
return []
|
|
241
272
|
|
|
273
|
+
# Check if function accepts context parameter
|
|
274
|
+
sig = inspect.signature(downgrade_func)
|
|
275
|
+
accepts_context = "context" in sig.parameters or len(sig.parameters) > 0
|
|
276
|
+
|
|
242
277
|
if inspect.iscoroutinefunction(downgrade_func):
|
|
243
|
-
sql_result =
|
|
278
|
+
sql_result = (
|
|
279
|
+
await downgrade_func(self.context) if accepts_context and self.context else await downgrade_func()
|
|
280
|
+
)
|
|
244
281
|
else:
|
|
245
|
-
sql_result = downgrade_func()
|
|
282
|
+
sql_result = downgrade_func(self.context) if accepts_context and self.context else downgrade_func()
|
|
246
283
|
|
|
247
284
|
return self._normalize_and_validate_sql(sql_result, path)
|
|
248
285
|
|
|
@@ -380,7 +417,11 @@ class PythonFileLoader(BaseMigrationLoader):
|
|
|
380
417
|
|
|
381
418
|
|
|
382
419
|
def get_migration_loader(
|
|
383
|
-
file_path: Path,
|
|
420
|
+
file_path: Path,
|
|
421
|
+
migrations_dir: Path,
|
|
422
|
+
project_root: "Path | None" = None,
|
|
423
|
+
context: "Any | None" = None,
|
|
424
|
+
sql_loader: "CoreSQLFileLoader | None" = None,
|
|
384
425
|
) -> BaseMigrationLoader:
|
|
385
426
|
"""Factory function to get appropriate loader for migration file.
|
|
386
427
|
|
|
@@ -388,6 +429,10 @@ def get_migration_loader(
|
|
|
388
429
|
file_path: Path to the migration file.
|
|
389
430
|
migrations_dir: Directory containing migration files.
|
|
390
431
|
project_root: Optional project root directory for Python imports.
|
|
432
|
+
context: Optional migration context to pass to Python migrations.
|
|
433
|
+
sql_loader: Optional shared SQLFileLoader instance for SQL migrations.
|
|
434
|
+
When provided, SQL files are loaded using this shared instance,
|
|
435
|
+
avoiding redundant file parsing.
|
|
391
436
|
|
|
392
437
|
Returns:
|
|
393
438
|
Appropriate loader instance for the file type.
|
|
@@ -398,8 +443,8 @@ def get_migration_loader(
|
|
|
398
443
|
suffix = file_path.suffix
|
|
399
444
|
|
|
400
445
|
if suffix == ".py":
|
|
401
|
-
return PythonFileLoader(migrations_dir, project_root)
|
|
446
|
+
return PythonFileLoader(migrations_dir, project_root, context)
|
|
402
447
|
if suffix == ".sql":
|
|
403
|
-
return SQLFileLoader()
|
|
448
|
+
return SQLFileLoader(sql_loader)
|
|
404
449
|
msg = f"Unsupported migration file type: {suffix}"
|
|
405
450
|
raise MigrationLoadError(msg)
|