sqlspec 0.24.1__py3-none-any.whl → 0.26.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of sqlspec might be problematic. Click here for more details.
- sqlspec/_serialization.py +223 -21
- sqlspec/_sql.py +20 -62
- sqlspec/_typing.py +11 -0
- sqlspec/adapters/adbc/config.py +8 -1
- sqlspec/adapters/adbc/data_dictionary.py +290 -0
- sqlspec/adapters/adbc/driver.py +129 -20
- sqlspec/adapters/adbc/type_converter.py +159 -0
- sqlspec/adapters/aiosqlite/config.py +3 -0
- sqlspec/adapters/aiosqlite/data_dictionary.py +117 -0
- sqlspec/adapters/aiosqlite/driver.py +17 -3
- sqlspec/adapters/asyncmy/_types.py +1 -1
- sqlspec/adapters/asyncmy/config.py +11 -8
- sqlspec/adapters/asyncmy/data_dictionary.py +122 -0
- sqlspec/adapters/asyncmy/driver.py +31 -7
- sqlspec/adapters/asyncpg/config.py +3 -0
- sqlspec/adapters/asyncpg/data_dictionary.py +134 -0
- sqlspec/adapters/asyncpg/driver.py +19 -4
- sqlspec/adapters/bigquery/config.py +3 -0
- sqlspec/adapters/bigquery/data_dictionary.py +109 -0
- sqlspec/adapters/bigquery/driver.py +21 -3
- sqlspec/adapters/bigquery/type_converter.py +93 -0
- sqlspec/adapters/duckdb/_types.py +1 -1
- sqlspec/adapters/duckdb/config.py +2 -0
- sqlspec/adapters/duckdb/data_dictionary.py +124 -0
- sqlspec/adapters/duckdb/driver.py +32 -5
- sqlspec/adapters/duckdb/pool.py +1 -1
- sqlspec/adapters/duckdb/type_converter.py +103 -0
- sqlspec/adapters/oracledb/config.py +6 -0
- sqlspec/adapters/oracledb/data_dictionary.py +442 -0
- sqlspec/adapters/oracledb/driver.py +68 -9
- sqlspec/adapters/oracledb/migrations.py +51 -67
- sqlspec/adapters/oracledb/type_converter.py +132 -0
- sqlspec/adapters/psqlpy/config.py +3 -0
- sqlspec/adapters/psqlpy/data_dictionary.py +133 -0
- sqlspec/adapters/psqlpy/driver.py +23 -179
- sqlspec/adapters/psqlpy/type_converter.py +73 -0
- sqlspec/adapters/psycopg/config.py +8 -4
- sqlspec/adapters/psycopg/data_dictionary.py +257 -0
- sqlspec/adapters/psycopg/driver.py +40 -5
- sqlspec/adapters/sqlite/config.py +3 -0
- sqlspec/adapters/sqlite/data_dictionary.py +117 -0
- sqlspec/adapters/sqlite/driver.py +18 -3
- sqlspec/adapters/sqlite/pool.py +13 -4
- sqlspec/base.py +3 -4
- sqlspec/builder/_base.py +130 -48
- sqlspec/builder/_column.py +66 -24
- sqlspec/builder/_ddl.py +91 -41
- sqlspec/builder/_insert.py +40 -58
- sqlspec/builder/_parsing_utils.py +127 -12
- sqlspec/builder/_select.py +147 -2
- sqlspec/builder/_update.py +1 -1
- sqlspec/builder/mixins/_cte_and_set_ops.py +31 -23
- sqlspec/builder/mixins/_delete_operations.py +12 -7
- sqlspec/builder/mixins/_insert_operations.py +50 -36
- sqlspec/builder/mixins/_join_operations.py +15 -30
- sqlspec/builder/mixins/_merge_operations.py +210 -78
- sqlspec/builder/mixins/_order_limit_operations.py +4 -10
- sqlspec/builder/mixins/_pivot_operations.py +1 -0
- sqlspec/builder/mixins/_select_operations.py +44 -22
- sqlspec/builder/mixins/_update_operations.py +30 -37
- sqlspec/builder/mixins/_where_clause.py +52 -70
- sqlspec/cli.py +246 -140
- sqlspec/config.py +33 -19
- sqlspec/core/__init__.py +3 -2
- sqlspec/core/cache.py +298 -352
- sqlspec/core/compiler.py +61 -4
- sqlspec/core/filters.py +246 -213
- sqlspec/core/hashing.py +9 -11
- sqlspec/core/parameters.py +27 -10
- sqlspec/core/statement.py +72 -12
- sqlspec/core/type_conversion.py +234 -0
- sqlspec/driver/__init__.py +6 -3
- sqlspec/driver/_async.py +108 -5
- sqlspec/driver/_common.py +186 -17
- sqlspec/driver/_sync.py +108 -5
- sqlspec/driver/mixins/_result_tools.py +60 -7
- sqlspec/exceptions.py +5 -0
- sqlspec/loader.py +8 -9
- sqlspec/migrations/__init__.py +4 -3
- sqlspec/migrations/base.py +153 -14
- sqlspec/migrations/commands.py +34 -96
- sqlspec/migrations/context.py +145 -0
- sqlspec/migrations/loaders.py +25 -8
- sqlspec/migrations/runner.py +352 -82
- sqlspec/storage/backends/fsspec.py +1 -0
- sqlspec/typing.py +4 -0
- sqlspec/utils/config_resolver.py +153 -0
- sqlspec/utils/serializers.py +50 -2
- {sqlspec-0.24.1.dist-info → sqlspec-0.26.0.dist-info}/METADATA +1 -1
- sqlspec-0.26.0.dist-info/RECORD +157 -0
- sqlspec-0.24.1.dist-info/RECORD +0 -139
- {sqlspec-0.24.1.dist-info → sqlspec-0.26.0.dist-info}/WHEEL +0 -0
- {sqlspec-0.24.1.dist-info → sqlspec-0.26.0.dist-info}/entry_points.txt +0 -0
- {sqlspec-0.24.1.dist-info → sqlspec-0.26.0.dist-info}/licenses/LICENSE +0 -0
- {sqlspec-0.24.1.dist-info → sqlspec-0.26.0.dist-info}/licenses/NOTICE +0 -0
sqlspec/migrations/base.py
CHANGED
|
@@ -15,6 +15,7 @@ from sqlspec.builder._ddl import CreateTable
|
|
|
15
15
|
from sqlspec.loader import SQLFileLoader
|
|
16
16
|
from sqlspec.migrations.loaders import get_migration_loader
|
|
17
17
|
from sqlspec.utils.logging import get_logger
|
|
18
|
+
from sqlspec.utils.module_loader import module_to_os_path
|
|
18
19
|
from sqlspec.utils.sync_tools import await_
|
|
19
20
|
|
|
20
21
|
__all__ = ("BaseMigrationCommands", "BaseMigrationRunner", "BaseMigrationTracker")
|
|
@@ -135,15 +136,29 @@ class BaseMigrationTracker(ABC, Generic[DriverT]):
|
|
|
135
136
|
class BaseMigrationRunner(ABC, Generic[DriverT]):
|
|
136
137
|
"""Base class for migration execution."""
|
|
137
138
|
|
|
138
|
-
|
|
139
|
+
extension_configs: "dict[str, dict[str, Any]]"
|
|
140
|
+
|
|
141
|
+
def __init__(
|
|
142
|
+
self,
|
|
143
|
+
migrations_path: Path,
|
|
144
|
+
extension_migrations: "Optional[dict[str, Path]]" = None,
|
|
145
|
+
context: "Optional[Any]" = None,
|
|
146
|
+
extension_configs: "Optional[dict[str, dict[str, Any]]]" = None,
|
|
147
|
+
) -> None:
|
|
139
148
|
"""Initialize the migration runner.
|
|
140
149
|
|
|
141
150
|
Args:
|
|
142
151
|
migrations_path: Path to the directory containing migration files.
|
|
152
|
+
extension_migrations: Optional mapping of extension names to their migration paths.
|
|
153
|
+
context: Optional migration context for Python migrations.
|
|
154
|
+
extension_configs: Optional mapping of extension names to their configurations.
|
|
143
155
|
"""
|
|
144
156
|
self.migrations_path = migrations_path
|
|
157
|
+
self.extension_migrations = extension_migrations or {}
|
|
145
158
|
self.loader = SQLFileLoader()
|
|
146
159
|
self.project_root: Optional[Path] = None
|
|
160
|
+
self.context = context
|
|
161
|
+
self.extension_configs = extension_configs or {}
|
|
147
162
|
|
|
148
163
|
def _extract_version(self, filename: str) -> Optional[str]:
|
|
149
164
|
"""Extract version from filename.
|
|
@@ -154,6 +169,12 @@ class BaseMigrationRunner(ABC, Generic[DriverT]):
|
|
|
154
169
|
Returns:
|
|
155
170
|
The extracted version string or None.
|
|
156
171
|
"""
|
|
172
|
+
# Handle extension-prefixed versions (e.g., "ext_litestar_0001")
|
|
173
|
+
if filename.startswith("ext_"):
|
|
174
|
+
# This is already a prefixed version, return as-is
|
|
175
|
+
return filename
|
|
176
|
+
|
|
177
|
+
# Regular version extraction
|
|
157
178
|
parts = filename.split("_", 1)
|
|
158
179
|
return parts[0].zfill(4) if parts and parts[0].isdigit() else None
|
|
159
180
|
|
|
@@ -175,17 +196,31 @@ class BaseMigrationRunner(ABC, Generic[DriverT]):
|
|
|
175
196
|
Returns:
|
|
176
197
|
List of tuples containing (version, file_path).
|
|
177
198
|
"""
|
|
178
|
-
if not self.migrations_path.exists():
|
|
179
|
-
return []
|
|
180
|
-
|
|
181
199
|
migrations = []
|
|
182
|
-
|
|
183
|
-
|
|
184
|
-
|
|
185
|
-
|
|
186
|
-
|
|
187
|
-
|
|
188
|
-
|
|
200
|
+
|
|
201
|
+
# Scan primary migration path
|
|
202
|
+
if self.migrations_path.exists():
|
|
203
|
+
for pattern in ("*.sql", "*.py"):
|
|
204
|
+
for file_path in self.migrations_path.glob(pattern):
|
|
205
|
+
if file_path.name.startswith("."):
|
|
206
|
+
continue
|
|
207
|
+
version = self._extract_version(file_path.name)
|
|
208
|
+
if version:
|
|
209
|
+
migrations.append((version, file_path))
|
|
210
|
+
|
|
211
|
+
# Scan extension migration paths
|
|
212
|
+
for ext_name, ext_path in self.extension_migrations.items():
|
|
213
|
+
if ext_path.exists():
|
|
214
|
+
for pattern in ("*.sql", "*.py"):
|
|
215
|
+
for file_path in ext_path.glob(pattern):
|
|
216
|
+
if file_path.name.startswith("."):
|
|
217
|
+
continue
|
|
218
|
+
# Prefix extension migrations to avoid version conflicts
|
|
219
|
+
version = self._extract_version(file_path.name)
|
|
220
|
+
if version:
|
|
221
|
+
# Use ext_ prefix to distinguish extension migrations
|
|
222
|
+
prefixed_version = f"ext_{ext_name}_{version}"
|
|
223
|
+
migrations.append((prefixed_version, file_path))
|
|
189
224
|
|
|
190
225
|
return sorted(migrations, key=operator.itemgetter(0))
|
|
191
226
|
|
|
@@ -199,7 +234,45 @@ class BaseMigrationRunner(ABC, Generic[DriverT]):
|
|
|
199
234
|
Migration metadata dictionary.
|
|
200
235
|
"""
|
|
201
236
|
|
|
202
|
-
|
|
237
|
+
# Check if this is an extension migration and update context accordingly
|
|
238
|
+
context_to_use = self.context
|
|
239
|
+
if context_to_use and file_path.name.startswith("ext_"):
|
|
240
|
+
# Try to extract extension name from the version
|
|
241
|
+
version = self._extract_version(file_path.name)
|
|
242
|
+
if version and version.startswith("ext_"):
|
|
243
|
+
# Parse extension name from version like "ext_litestar_0001"
|
|
244
|
+
min_extension_version_parts = 3
|
|
245
|
+
parts = version.split("_", 2)
|
|
246
|
+
if len(parts) >= min_extension_version_parts:
|
|
247
|
+
ext_name = parts[1]
|
|
248
|
+
if ext_name in self.extension_configs:
|
|
249
|
+
# Create a new context with the extension config
|
|
250
|
+
from sqlspec.migrations.context import MigrationContext
|
|
251
|
+
|
|
252
|
+
context_to_use = MigrationContext(
|
|
253
|
+
dialect=self.context.dialect if self.context else None,
|
|
254
|
+
config=self.context.config if self.context else None,
|
|
255
|
+
driver=self.context.driver if self.context else None,
|
|
256
|
+
metadata=self.context.metadata.copy() if self.context and self.context.metadata else {},
|
|
257
|
+
extension_config=self.extension_configs[ext_name],
|
|
258
|
+
)
|
|
259
|
+
|
|
260
|
+
# For extension migrations, check by path
|
|
261
|
+
for ext_name, ext_path in self.extension_migrations.items():
|
|
262
|
+
if file_path.parent == ext_path:
|
|
263
|
+
if ext_name in self.extension_configs and self.context:
|
|
264
|
+
from sqlspec.migrations.context import MigrationContext
|
|
265
|
+
|
|
266
|
+
context_to_use = MigrationContext(
|
|
267
|
+
dialect=self.context.dialect,
|
|
268
|
+
config=self.context.config,
|
|
269
|
+
driver=self.context.driver,
|
|
270
|
+
metadata=self.context.metadata.copy() if self.context.metadata else {},
|
|
271
|
+
extension_config=self.extension_configs[ext_name],
|
|
272
|
+
)
|
|
273
|
+
break
|
|
274
|
+
|
|
275
|
+
loader = get_migration_loader(file_path, self.migrations_path, self.project_root, context_to_use)
|
|
203
276
|
loader.validate_migration_file(file_path)
|
|
204
277
|
content = file_path.read_text(encoding="utf-8")
|
|
205
278
|
checksum = self._calculate_checksum(content)
|
|
@@ -292,6 +365,8 @@ class BaseMigrationRunner(ABC, Generic[DriverT]):
|
|
|
292
365
|
class BaseMigrationCommands(ABC, Generic[ConfigT, DriverT]):
|
|
293
366
|
"""Base class for migration commands."""
|
|
294
367
|
|
|
368
|
+
extension_configs: "dict[str, dict[str, Any]]"
|
|
369
|
+
|
|
295
370
|
def __init__(self, config: ConfigT) -> None:
|
|
296
371
|
"""Initialize migration commands.
|
|
297
372
|
|
|
@@ -304,6 +379,72 @@ class BaseMigrationCommands(ABC, Generic[ConfigT, DriverT]):
|
|
|
304
379
|
self.version_table = migration_config.get("version_table_name", "ddl_migrations")
|
|
305
380
|
self.migrations_path = Path(migration_config.get("script_location", "migrations"))
|
|
306
381
|
self.project_root = Path(migration_config["project_root"]) if "project_root" in migration_config else None
|
|
382
|
+
self.include_extensions = migration_config.get("include_extensions", [])
|
|
383
|
+
self.extension_configs = self._parse_extension_configs()
|
|
384
|
+
|
|
385
|
+
def _parse_extension_configs(self) -> "dict[str, dict[str, Any]]":
|
|
386
|
+
"""Parse extension configurations from include_extensions.
|
|
387
|
+
|
|
388
|
+
Supports both string format (extension name) and dict format
|
|
389
|
+
(extension name with configuration).
|
|
390
|
+
|
|
391
|
+
Returns:
|
|
392
|
+
Dictionary mapping extension names to their configurations.
|
|
393
|
+
"""
|
|
394
|
+
configs = {}
|
|
395
|
+
|
|
396
|
+
for ext_config in self.include_extensions:
|
|
397
|
+
if isinstance(ext_config, str):
|
|
398
|
+
# Simple string format: just the extension name
|
|
399
|
+
ext_name = ext_config
|
|
400
|
+
ext_options = {}
|
|
401
|
+
elif isinstance(ext_config, dict):
|
|
402
|
+
# Dict format: {"name": "litestar", "session_table": "custom_sessions"}
|
|
403
|
+
ext_name_raw = ext_config.get("name")
|
|
404
|
+
if not ext_name_raw:
|
|
405
|
+
logger.warning("Extension configuration missing 'name' field: %s", ext_config)
|
|
406
|
+
continue
|
|
407
|
+
# Assert for type narrowing: ext_name_raw is guaranteed to be str here
|
|
408
|
+
assert isinstance(ext_name_raw, str)
|
|
409
|
+
ext_name = ext_name_raw
|
|
410
|
+
ext_options = {k: v for k, v in ext_config.items() if k != "name"}
|
|
411
|
+
else:
|
|
412
|
+
logger.warning("Invalid extension configuration format: %s", ext_config)
|
|
413
|
+
continue
|
|
414
|
+
|
|
415
|
+
# Apply default configurations for known extensions
|
|
416
|
+
if ext_name == "litestar" and "session_table" not in ext_options:
|
|
417
|
+
ext_options["session_table"] = "litestar_sessions"
|
|
418
|
+
|
|
419
|
+
configs[ext_name] = ext_options
|
|
420
|
+
|
|
421
|
+
return configs
|
|
422
|
+
|
|
423
|
+
def _discover_extension_migrations(self) -> "dict[str, Path]":
|
|
424
|
+
"""Discover migration paths for configured extensions.
|
|
425
|
+
|
|
426
|
+
Returns:
|
|
427
|
+
Dictionary mapping extension names to their migration paths.
|
|
428
|
+
"""
|
|
429
|
+
|
|
430
|
+
extension_migrations = {}
|
|
431
|
+
|
|
432
|
+
for ext_name in self.extension_configs:
|
|
433
|
+
module_name = "sqlspec.extensions.litestar" if ext_name == "litestar" else f"sqlspec.extensions.{ext_name}"
|
|
434
|
+
|
|
435
|
+
try:
|
|
436
|
+
module_path = module_to_os_path(module_name)
|
|
437
|
+
migrations_dir = module_path / "migrations"
|
|
438
|
+
|
|
439
|
+
if migrations_dir.exists():
|
|
440
|
+
extension_migrations[ext_name] = migrations_dir
|
|
441
|
+
logger.debug("Found migrations for extension %s at %s", ext_name, migrations_dir)
|
|
442
|
+
else:
|
|
443
|
+
logger.warning("No migrations directory found for extension %s", ext_name)
|
|
444
|
+
except TypeError:
|
|
445
|
+
logger.warning("Extension %s not found", ext_name)
|
|
446
|
+
|
|
447
|
+
return extension_migrations
|
|
307
448
|
|
|
308
449
|
def _get_init_readme_content(self) -> str:
|
|
309
450
|
"""Get README content for migration directory initialization.
|
|
@@ -368,8 +509,6 @@ This naming ensures proper sorting and avoids conflicts when loading multiple fi
|
|
|
368
509
|
readme = migrations_dir / "README.md"
|
|
369
510
|
readme.write_text(self._get_init_readme_content())
|
|
370
511
|
|
|
371
|
-
(migrations_dir / ".gitkeep").touch()
|
|
372
|
-
|
|
373
512
|
console.print(f"[green]Initialized migrations in {directory}[/]")
|
|
374
513
|
|
|
375
514
|
@abstractmethod
|
sqlspec/migrations/commands.py
CHANGED
|
@@ -10,15 +10,15 @@ from rich.table import Table
|
|
|
10
10
|
|
|
11
11
|
from sqlspec._sql import sql
|
|
12
12
|
from sqlspec.migrations.base import BaseMigrationCommands
|
|
13
|
+
from sqlspec.migrations.context import MigrationContext
|
|
13
14
|
from sqlspec.migrations.runner import AsyncMigrationRunner, SyncMigrationRunner
|
|
14
15
|
from sqlspec.migrations.utils import create_migration_file
|
|
15
16
|
from sqlspec.utils.logging import get_logger
|
|
16
|
-
from sqlspec.utils.sync_tools import await_
|
|
17
17
|
|
|
18
18
|
if TYPE_CHECKING:
|
|
19
19
|
from sqlspec.config import AsyncConfigT, SyncConfigT
|
|
20
20
|
|
|
21
|
-
__all__ = ("AsyncMigrationCommands", "
|
|
21
|
+
__all__ = ("AsyncMigrationCommands", "SyncMigrationCommands", "create_migration_commands")
|
|
22
22
|
|
|
23
23
|
logger = get_logger("migrations.commands")
|
|
24
24
|
console = Console()
|
|
@@ -35,7 +35,14 @@ class SyncMigrationCommands(BaseMigrationCommands["SyncConfigT", Any]):
|
|
|
35
35
|
"""
|
|
36
36
|
super().__init__(config)
|
|
37
37
|
self.tracker = config.migration_tracker_type(self.version_table)
|
|
38
|
-
|
|
38
|
+
|
|
39
|
+
# Create context with extension configurations
|
|
40
|
+
context = MigrationContext.from_config(config)
|
|
41
|
+
context.extension_config = self.extension_configs
|
|
42
|
+
|
|
43
|
+
self.runner = SyncMigrationRunner(
|
|
44
|
+
self.migrations_path, self._discover_extension_migrations(), context, self.extension_configs
|
|
45
|
+
)
|
|
39
46
|
|
|
40
47
|
def init(self, directory: str, package: bool = True) -> None:
|
|
41
48
|
"""Initialize migration directory structure.
|
|
@@ -203,15 +210,22 @@ class SyncMigrationCommands(BaseMigrationCommands["SyncConfigT", Any]):
|
|
|
203
210
|
class AsyncMigrationCommands(BaseMigrationCommands["AsyncConfigT", Any]):
|
|
204
211
|
"""Asynchronous migration commands."""
|
|
205
212
|
|
|
206
|
-
def __init__(self,
|
|
213
|
+
def __init__(self, config: "AsyncConfigT") -> None:
|
|
207
214
|
"""Initialize migration commands.
|
|
208
215
|
|
|
209
216
|
Args:
|
|
210
|
-
|
|
217
|
+
config: The SQLSpec configuration.
|
|
211
218
|
"""
|
|
212
|
-
super().__init__(
|
|
213
|
-
self.tracker =
|
|
214
|
-
|
|
219
|
+
super().__init__(config)
|
|
220
|
+
self.tracker = config.migration_tracker_type(self.version_table)
|
|
221
|
+
|
|
222
|
+
# Create context with extension configurations
|
|
223
|
+
context = MigrationContext.from_config(config)
|
|
224
|
+
context.extension_config = self.extension_configs
|
|
225
|
+
|
|
226
|
+
self.runner = AsyncMigrationRunner(
|
|
227
|
+
self.migrations_path, self._discover_extension_migrations(), context, self.extension_configs
|
|
228
|
+
)
|
|
215
229
|
|
|
216
230
|
async def init(self, directory: str, package: bool = True) -> None:
|
|
217
231
|
"""Initialize migration directory structure.
|
|
@@ -370,93 +384,17 @@ class AsyncMigrationCommands(BaseMigrationCommands["AsyncConfigT", Any]):
|
|
|
370
384
|
console.print(f"[green]Created migration:[/] {file_path}")
|
|
371
385
|
|
|
372
386
|
|
|
373
|
-
|
|
374
|
-
"
|
|
375
|
-
|
|
376
|
-
|
|
377
|
-
"""Initialize migration commands with sync/async implementation.
|
|
378
|
-
|
|
379
|
-
Args:
|
|
380
|
-
config: The SQLSpec configuration.
|
|
381
|
-
"""
|
|
382
|
-
if config.is_async:
|
|
383
|
-
self._impl: Union[AsyncMigrationCommands[Any], SyncMigrationCommands[Any]] = AsyncMigrationCommands(
|
|
384
|
-
cast("AsyncConfigT", config)
|
|
385
|
-
)
|
|
386
|
-
else:
|
|
387
|
-
self._impl = SyncMigrationCommands(cast("SyncConfigT", config))
|
|
388
|
-
self._is_async = config.is_async
|
|
389
|
-
|
|
390
|
-
def init(self, directory: str, package: bool = True) -> None:
|
|
391
|
-
"""Initialize migration directory structure.
|
|
387
|
+
def create_migration_commands(
|
|
388
|
+
config: "Union[SyncConfigT, AsyncConfigT]",
|
|
389
|
+
) -> "Union[SyncMigrationCommands[Any], AsyncMigrationCommands[Any]]":
|
|
390
|
+
"""Factory function to create the appropriate migration commands.
|
|
392
391
|
|
|
393
|
-
|
|
394
|
-
|
|
395
|
-
package: Whether to create __init__.py file.
|
|
396
|
-
"""
|
|
397
|
-
if self._is_async:
|
|
398
|
-
await_(cast("AsyncMigrationCommands[Any]", self._impl).init, raise_sync_error=False)(
|
|
399
|
-
directory, package=package
|
|
400
|
-
)
|
|
401
|
-
else:
|
|
402
|
-
cast("SyncMigrationCommands[Any]", self._impl).init(directory, package=package)
|
|
403
|
-
|
|
404
|
-
def current(self, verbose: bool = False) -> "Optional[str]":
|
|
405
|
-
"""Show current migration version.
|
|
406
|
-
|
|
407
|
-
Args:
|
|
408
|
-
verbose: Whether to show detailed migration history.
|
|
409
|
-
|
|
410
|
-
Returns:
|
|
411
|
-
The current migration version or None if no migrations applied.
|
|
412
|
-
"""
|
|
413
|
-
if self._is_async:
|
|
414
|
-
return await_(cast("AsyncMigrationCommands[Any]", self._impl).current, raise_sync_error=False)(
|
|
415
|
-
verbose=verbose
|
|
416
|
-
)
|
|
417
|
-
return cast("SyncMigrationCommands[Any]", self._impl).current(verbose=verbose)
|
|
392
|
+
Args:
|
|
393
|
+
config: The SQLSpec configuration.
|
|
418
394
|
|
|
419
|
-
|
|
420
|
-
|
|
421
|
-
|
|
422
|
-
|
|
423
|
-
|
|
424
|
-
|
|
425
|
-
if self._is_async:
|
|
426
|
-
await_(cast("AsyncMigrationCommands[Any]", self._impl).upgrade, raise_sync_error=False)(revision=revision)
|
|
427
|
-
else:
|
|
428
|
-
cast("SyncMigrationCommands[Any]", self._impl).upgrade(revision=revision)
|
|
429
|
-
|
|
430
|
-
def downgrade(self, revision: str = "-1") -> None:
|
|
431
|
-
"""Downgrade to a target revision.
|
|
432
|
-
|
|
433
|
-
Args:
|
|
434
|
-
revision: Target revision or "-1" for one step back.
|
|
435
|
-
"""
|
|
436
|
-
if self._is_async:
|
|
437
|
-
await_(cast("AsyncMigrationCommands[Any]", self._impl).downgrade, raise_sync_error=False)(revision=revision)
|
|
438
|
-
else:
|
|
439
|
-
cast("SyncMigrationCommands[Any]", self._impl).downgrade(revision=revision)
|
|
440
|
-
|
|
441
|
-
def stamp(self, revision: str) -> None:
|
|
442
|
-
"""Mark database as being at a specific revision without running migrations.
|
|
443
|
-
|
|
444
|
-
Args:
|
|
445
|
-
revision: The revision to stamp.
|
|
446
|
-
"""
|
|
447
|
-
if self._is_async:
|
|
448
|
-
await_(cast("AsyncMigrationCommands[Any]", self._impl).stamp, raise_sync_error=False)(revision)
|
|
449
|
-
else:
|
|
450
|
-
cast("SyncMigrationCommands[Any]", self._impl).stamp(revision)
|
|
451
|
-
|
|
452
|
-
def revision(self, message: str, file_type: str = "sql") -> None:
|
|
453
|
-
"""Create a new migration file.
|
|
454
|
-
|
|
455
|
-
Args:
|
|
456
|
-
message: Description for the migration.
|
|
457
|
-
file_type: Type of migration file to create ('sql' or 'py').
|
|
458
|
-
"""
|
|
459
|
-
if self._is_async:
|
|
460
|
-
await_(cast("AsyncMigrationCommands[Any]", self._impl).revision, raise_sync_error=False)(message, file_type)
|
|
461
|
-
else:
|
|
462
|
-
cast("SyncMigrationCommands[Any]", self._impl).revision(message, file_type)
|
|
395
|
+
Returns:
|
|
396
|
+
Appropriate migration commands instance.
|
|
397
|
+
"""
|
|
398
|
+
if config.is_async:
|
|
399
|
+
return AsyncMigrationCommands(cast("AsyncConfigT", config))
|
|
400
|
+
return SyncMigrationCommands(cast("SyncConfigT", config))
|
|
@@ -0,0 +1,145 @@
|
|
|
1
|
+
"""Migration context for passing runtime information to migrations."""
|
|
2
|
+
|
|
3
|
+
import asyncio
|
|
4
|
+
import inspect
|
|
5
|
+
from dataclasses import dataclass, field
|
|
6
|
+
from typing import TYPE_CHECKING, Any, Optional, Union
|
|
7
|
+
|
|
8
|
+
from sqlspec.utils.logging import get_logger
|
|
9
|
+
|
|
10
|
+
if TYPE_CHECKING:
|
|
11
|
+
from sqlspec.driver import AsyncDriverAdapterBase, SyncDriverAdapterBase
|
|
12
|
+
|
|
13
|
+
logger = get_logger("migrations.context")
|
|
14
|
+
|
|
15
|
+
__all__ = ("MigrationContext",)
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
@dataclass
|
|
19
|
+
class MigrationContext:
|
|
20
|
+
"""Context object passed to migration functions.
|
|
21
|
+
|
|
22
|
+
Provides runtime information about the database environment
|
|
23
|
+
to migration functions, allowing them to generate dialect-specific SQL.
|
|
24
|
+
"""
|
|
25
|
+
|
|
26
|
+
config: "Optional[Any]" = None
|
|
27
|
+
"""Database configuration object."""
|
|
28
|
+
dialect: "Optional[str]" = None
|
|
29
|
+
"""Database dialect (e.g., 'postgres', 'mysql', 'sqlite')."""
|
|
30
|
+
metadata: "Optional[dict[str, Any]]" = None
|
|
31
|
+
"""Additional metadata for the migration."""
|
|
32
|
+
extension_config: "Optional[dict[str, Any]]" = None
|
|
33
|
+
"""Extension-specific configuration options."""
|
|
34
|
+
|
|
35
|
+
driver: "Optional[Union[SyncDriverAdapterBase, AsyncDriverAdapterBase]]" = None
|
|
36
|
+
"""Database driver instance (available during execution)."""
|
|
37
|
+
|
|
38
|
+
_execution_metadata: "dict[str, Any]" = field(default_factory=dict)
|
|
39
|
+
"""Internal execution metadata for tracking async operations."""
|
|
40
|
+
|
|
41
|
+
def __post_init__(self) -> None:
|
|
42
|
+
"""Initialize metadata and extension config if not provided."""
|
|
43
|
+
if not self.metadata:
|
|
44
|
+
self.metadata = {}
|
|
45
|
+
if not self.extension_config:
|
|
46
|
+
self.extension_config = {}
|
|
47
|
+
|
|
48
|
+
@classmethod
|
|
49
|
+
def from_config(cls, config: Any) -> "MigrationContext":
|
|
50
|
+
"""Create context from database configuration.
|
|
51
|
+
|
|
52
|
+
Args:
|
|
53
|
+
config: Database configuration object.
|
|
54
|
+
|
|
55
|
+
Returns:
|
|
56
|
+
Migration context with dialect information.
|
|
57
|
+
"""
|
|
58
|
+
dialect = None
|
|
59
|
+
try:
|
|
60
|
+
if hasattr(config, "statement_config") and config.statement_config:
|
|
61
|
+
dialect = getattr(config.statement_config, "dialect", None)
|
|
62
|
+
elif hasattr(config, "_create_statement_config") and callable(config._create_statement_config):
|
|
63
|
+
stmt_config = config._create_statement_config()
|
|
64
|
+
dialect = getattr(stmt_config, "dialect", None)
|
|
65
|
+
except Exception:
|
|
66
|
+
logger.debug("Unable to extract dialect from config")
|
|
67
|
+
|
|
68
|
+
return cls(dialect=dialect, config=config)
|
|
69
|
+
|
|
70
|
+
@property
|
|
71
|
+
def is_async_execution(self) -> bool:
|
|
72
|
+
"""Check if migrations are running in an async execution context.
|
|
73
|
+
|
|
74
|
+
Returns:
|
|
75
|
+
True if executing in an async context.
|
|
76
|
+
"""
|
|
77
|
+
try:
|
|
78
|
+
asyncio.current_task()
|
|
79
|
+
except RuntimeError:
|
|
80
|
+
return False
|
|
81
|
+
else:
|
|
82
|
+
return True
|
|
83
|
+
|
|
84
|
+
@property
|
|
85
|
+
def is_async_driver(self) -> bool:
|
|
86
|
+
"""Check if the current driver is async.
|
|
87
|
+
|
|
88
|
+
Returns:
|
|
89
|
+
True if driver supports async operations.
|
|
90
|
+
"""
|
|
91
|
+
if self.driver is None:
|
|
92
|
+
return False
|
|
93
|
+
|
|
94
|
+
execute_method = getattr(self.driver, "execute_script", None)
|
|
95
|
+
return execute_method is not None and inspect.iscoroutinefunction(execute_method)
|
|
96
|
+
|
|
97
|
+
@property
|
|
98
|
+
def execution_mode(self) -> str:
|
|
99
|
+
"""Get the current execution mode.
|
|
100
|
+
|
|
101
|
+
Returns:
|
|
102
|
+
'async' if in async context, 'sync' otherwise.
|
|
103
|
+
"""
|
|
104
|
+
return "async" if self.is_async_execution else "sync"
|
|
105
|
+
|
|
106
|
+
def set_execution_metadata(self, key: str, value: Any) -> None:
|
|
107
|
+
"""Set execution metadata for tracking migration state.
|
|
108
|
+
|
|
109
|
+
Args:
|
|
110
|
+
key: Metadata key.
|
|
111
|
+
value: Metadata value.
|
|
112
|
+
"""
|
|
113
|
+
self._execution_metadata[key] = value
|
|
114
|
+
|
|
115
|
+
def get_execution_metadata(self, key: str, default: Any = None) -> Any:
|
|
116
|
+
"""Get execution metadata.
|
|
117
|
+
|
|
118
|
+
Args:
|
|
119
|
+
key: Metadata key.
|
|
120
|
+
default: Default value if key not found.
|
|
121
|
+
|
|
122
|
+
Returns:
|
|
123
|
+
Metadata value or default.
|
|
124
|
+
"""
|
|
125
|
+
return self._execution_metadata.get(key, default)
|
|
126
|
+
|
|
127
|
+
def validate_async_usage(self, migration_func: Any) -> None:
|
|
128
|
+
"""Validate proper usage of async functions in migration context.
|
|
129
|
+
|
|
130
|
+
Args:
|
|
131
|
+
migration_func: The migration function to validate.
|
|
132
|
+
|
|
133
|
+
Raises:
|
|
134
|
+
RuntimeError: If async function is used inappropriately.
|
|
135
|
+
"""
|
|
136
|
+
if inspect.iscoroutinefunction(migration_func) and not self.is_async_execution and not self.is_async_driver:
|
|
137
|
+
msg = (
|
|
138
|
+
"Async migration function detected but execution context is sync. "
|
|
139
|
+
"Consider using async database configuration or sync migration functions."
|
|
140
|
+
)
|
|
141
|
+
logger.warning(msg)
|
|
142
|
+
|
|
143
|
+
if not inspect.iscoroutinefunction(migration_func) and self.is_async_driver:
|
|
144
|
+
self.set_execution_metadata("mixed_execution", value=True)
|
|
145
|
+
logger.debug("Sync migration function in async driver context - using compatibility mode")
|
sqlspec/migrations/loaders.py
CHANGED
|
@@ -164,17 +164,21 @@ class SQLFileLoader(BaseMigrationLoader):
|
|
|
164
164
|
class PythonFileLoader(BaseMigrationLoader):
|
|
165
165
|
"""Loader for Python migration files."""
|
|
166
166
|
|
|
167
|
-
__slots__ = ("migrations_dir", "project_root")
|
|
167
|
+
__slots__ = ("context", "migrations_dir", "project_root")
|
|
168
168
|
|
|
169
|
-
def __init__(
|
|
169
|
+
def __init__(
|
|
170
|
+
self, migrations_dir: Path, project_root: "Optional[Path]" = None, context: "Optional[Any]" = None
|
|
171
|
+
) -> None:
|
|
170
172
|
"""Initialize Python file loader.
|
|
171
173
|
|
|
172
174
|
Args:
|
|
173
175
|
migrations_dir: Directory containing migration files.
|
|
174
176
|
project_root: Optional project root directory for imports.
|
|
177
|
+
context: Optional migration context to pass to functions.
|
|
175
178
|
"""
|
|
176
179
|
self.migrations_dir = migrations_dir
|
|
177
180
|
self.project_root = project_root if project_root is not None else self._find_project_root(migrations_dir)
|
|
181
|
+
self.context = context
|
|
178
182
|
|
|
179
183
|
async def get_up_sql(self, path: Path) -> list[str]:
|
|
180
184
|
"""Load Python migration and execute upgrade function.
|
|
@@ -208,10 +212,16 @@ class PythonFileLoader(BaseMigrationLoader):
|
|
|
208
212
|
msg = f"'{func_name}' is not callable in {path}"
|
|
209
213
|
raise MigrationLoadError(msg)
|
|
210
214
|
|
|
215
|
+
# Check if function accepts context parameter
|
|
216
|
+
sig = inspect.signature(upgrade_func)
|
|
217
|
+
accepts_context = "context" in sig.parameters or len(sig.parameters) > 0
|
|
218
|
+
|
|
211
219
|
if inspect.iscoroutinefunction(upgrade_func):
|
|
212
|
-
sql_result =
|
|
220
|
+
sql_result = (
|
|
221
|
+
await upgrade_func(self.context) if accepts_context and self.context else await upgrade_func()
|
|
222
|
+
)
|
|
213
223
|
else:
|
|
214
|
-
sql_result = upgrade_func()
|
|
224
|
+
sql_result = upgrade_func(self.context) if accepts_context and self.context else upgrade_func()
|
|
215
225
|
|
|
216
226
|
return self._normalize_and_validate_sql(sql_result, path)
|
|
217
227
|
|
|
@@ -239,10 +249,16 @@ class PythonFileLoader(BaseMigrationLoader):
|
|
|
239
249
|
if not callable(downgrade_func):
|
|
240
250
|
return []
|
|
241
251
|
|
|
252
|
+
# Check if function accepts context parameter
|
|
253
|
+
sig = inspect.signature(downgrade_func)
|
|
254
|
+
accepts_context = "context" in sig.parameters or len(sig.parameters) > 0
|
|
255
|
+
|
|
242
256
|
if inspect.iscoroutinefunction(downgrade_func):
|
|
243
|
-
sql_result =
|
|
257
|
+
sql_result = (
|
|
258
|
+
await downgrade_func(self.context) if accepts_context and self.context else await downgrade_func()
|
|
259
|
+
)
|
|
244
260
|
else:
|
|
245
|
-
sql_result = downgrade_func()
|
|
261
|
+
sql_result = downgrade_func(self.context) if accepts_context and self.context else downgrade_func()
|
|
246
262
|
|
|
247
263
|
return self._normalize_and_validate_sql(sql_result, path)
|
|
248
264
|
|
|
@@ -380,7 +396,7 @@ class PythonFileLoader(BaseMigrationLoader):
|
|
|
380
396
|
|
|
381
397
|
|
|
382
398
|
def get_migration_loader(
|
|
383
|
-
file_path: Path, migrations_dir: Path, project_root: "Optional[Path]" = None
|
|
399
|
+
file_path: Path, migrations_dir: Path, project_root: "Optional[Path]" = None, context: "Optional[Any]" = None
|
|
384
400
|
) -> BaseMigrationLoader:
|
|
385
401
|
"""Factory function to get appropriate loader for migration file.
|
|
386
402
|
|
|
@@ -388,6 +404,7 @@ def get_migration_loader(
|
|
|
388
404
|
file_path: Path to the migration file.
|
|
389
405
|
migrations_dir: Directory containing migration files.
|
|
390
406
|
project_root: Optional project root directory for Python imports.
|
|
407
|
+
context: Optional migration context to pass to Python migrations.
|
|
391
408
|
|
|
392
409
|
Returns:
|
|
393
410
|
Appropriate loader instance for the file type.
|
|
@@ -398,7 +415,7 @@ def get_migration_loader(
|
|
|
398
415
|
suffix = file_path.suffix
|
|
399
416
|
|
|
400
417
|
if suffix == ".py":
|
|
401
|
-
return PythonFileLoader(migrations_dir, project_root)
|
|
418
|
+
return PythonFileLoader(migrations_dir, project_root, context)
|
|
402
419
|
if suffix == ".sql":
|
|
403
420
|
return SQLFileLoader()
|
|
404
421
|
msg = f"Unsupported migration file type: {suffix}"
|