sqlspec 0.24.1__py3-none-any.whl → 0.26.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of sqlspec might be problematic. Click here for more details.
- sqlspec/_serialization.py +223 -21
- sqlspec/_sql.py +20 -62
- sqlspec/_typing.py +11 -0
- sqlspec/adapters/adbc/config.py +8 -1
- sqlspec/adapters/adbc/data_dictionary.py +290 -0
- sqlspec/adapters/adbc/driver.py +129 -20
- sqlspec/adapters/adbc/type_converter.py +159 -0
- sqlspec/adapters/aiosqlite/config.py +3 -0
- sqlspec/adapters/aiosqlite/data_dictionary.py +117 -0
- sqlspec/adapters/aiosqlite/driver.py +17 -3
- sqlspec/adapters/asyncmy/_types.py +1 -1
- sqlspec/adapters/asyncmy/config.py +11 -8
- sqlspec/adapters/asyncmy/data_dictionary.py +122 -0
- sqlspec/adapters/asyncmy/driver.py +31 -7
- sqlspec/adapters/asyncpg/config.py +3 -0
- sqlspec/adapters/asyncpg/data_dictionary.py +134 -0
- sqlspec/adapters/asyncpg/driver.py +19 -4
- sqlspec/adapters/bigquery/config.py +3 -0
- sqlspec/adapters/bigquery/data_dictionary.py +109 -0
- sqlspec/adapters/bigquery/driver.py +21 -3
- sqlspec/adapters/bigquery/type_converter.py +93 -0
- sqlspec/adapters/duckdb/_types.py +1 -1
- sqlspec/adapters/duckdb/config.py +2 -0
- sqlspec/adapters/duckdb/data_dictionary.py +124 -0
- sqlspec/adapters/duckdb/driver.py +32 -5
- sqlspec/adapters/duckdb/pool.py +1 -1
- sqlspec/adapters/duckdb/type_converter.py +103 -0
- sqlspec/adapters/oracledb/config.py +6 -0
- sqlspec/adapters/oracledb/data_dictionary.py +442 -0
- sqlspec/adapters/oracledb/driver.py +68 -9
- sqlspec/adapters/oracledb/migrations.py +51 -67
- sqlspec/adapters/oracledb/type_converter.py +132 -0
- sqlspec/adapters/psqlpy/config.py +3 -0
- sqlspec/adapters/psqlpy/data_dictionary.py +133 -0
- sqlspec/adapters/psqlpy/driver.py +23 -179
- sqlspec/adapters/psqlpy/type_converter.py +73 -0
- sqlspec/adapters/psycopg/config.py +8 -4
- sqlspec/adapters/psycopg/data_dictionary.py +257 -0
- sqlspec/adapters/psycopg/driver.py +40 -5
- sqlspec/adapters/sqlite/config.py +3 -0
- sqlspec/adapters/sqlite/data_dictionary.py +117 -0
- sqlspec/adapters/sqlite/driver.py +18 -3
- sqlspec/adapters/sqlite/pool.py +13 -4
- sqlspec/base.py +3 -4
- sqlspec/builder/_base.py +130 -48
- sqlspec/builder/_column.py +66 -24
- sqlspec/builder/_ddl.py +91 -41
- sqlspec/builder/_insert.py +40 -58
- sqlspec/builder/_parsing_utils.py +127 -12
- sqlspec/builder/_select.py +147 -2
- sqlspec/builder/_update.py +1 -1
- sqlspec/builder/mixins/_cte_and_set_ops.py +31 -23
- sqlspec/builder/mixins/_delete_operations.py +12 -7
- sqlspec/builder/mixins/_insert_operations.py +50 -36
- sqlspec/builder/mixins/_join_operations.py +15 -30
- sqlspec/builder/mixins/_merge_operations.py +210 -78
- sqlspec/builder/mixins/_order_limit_operations.py +4 -10
- sqlspec/builder/mixins/_pivot_operations.py +1 -0
- sqlspec/builder/mixins/_select_operations.py +44 -22
- sqlspec/builder/mixins/_update_operations.py +30 -37
- sqlspec/builder/mixins/_where_clause.py +52 -70
- sqlspec/cli.py +246 -140
- sqlspec/config.py +33 -19
- sqlspec/core/__init__.py +3 -2
- sqlspec/core/cache.py +298 -352
- sqlspec/core/compiler.py +61 -4
- sqlspec/core/filters.py +246 -213
- sqlspec/core/hashing.py +9 -11
- sqlspec/core/parameters.py +27 -10
- sqlspec/core/statement.py +72 -12
- sqlspec/core/type_conversion.py +234 -0
- sqlspec/driver/__init__.py +6 -3
- sqlspec/driver/_async.py +108 -5
- sqlspec/driver/_common.py +186 -17
- sqlspec/driver/_sync.py +108 -5
- sqlspec/driver/mixins/_result_tools.py +60 -7
- sqlspec/exceptions.py +5 -0
- sqlspec/loader.py +8 -9
- sqlspec/migrations/__init__.py +4 -3
- sqlspec/migrations/base.py +153 -14
- sqlspec/migrations/commands.py +34 -96
- sqlspec/migrations/context.py +145 -0
- sqlspec/migrations/loaders.py +25 -8
- sqlspec/migrations/runner.py +352 -82
- sqlspec/storage/backends/fsspec.py +1 -0
- sqlspec/typing.py +4 -0
- sqlspec/utils/config_resolver.py +153 -0
- sqlspec/utils/serializers.py +50 -2
- {sqlspec-0.24.1.dist-info → sqlspec-0.26.0.dist-info}/METADATA +1 -1
- sqlspec-0.26.0.dist-info/RECORD +157 -0
- sqlspec-0.24.1.dist-info/RECORD +0 -139
- {sqlspec-0.24.1.dist-info → sqlspec-0.26.0.dist-info}/WHEEL +0 -0
- {sqlspec-0.24.1.dist-info → sqlspec-0.26.0.dist-info}/entry_points.txt +0 -0
- {sqlspec-0.24.1.dist-info → sqlspec-0.26.0.dist-info}/licenses/LICENSE +0 -0
- {sqlspec-0.24.1.dist-info → sqlspec-0.26.0.dist-info}/licenses/NOTICE +0 -0
sqlspec/migrations/runner.py
CHANGED
|
@@ -1,28 +1,135 @@
|
|
|
1
1
|
"""Migration execution engine for SQLSpec.
|
|
2
2
|
|
|
3
|
-
This module
|
|
3
|
+
This module provides separate sync and async migration runners with clean separation
|
|
4
|
+
of concerns and proper type safety.
|
|
4
5
|
"""
|
|
5
6
|
|
|
7
|
+
import operator
|
|
6
8
|
import time
|
|
9
|
+
from abc import ABC, abstractmethod
|
|
7
10
|
from pathlib import Path
|
|
8
|
-
from typing import TYPE_CHECKING, Any, Optional, cast
|
|
11
|
+
from typing import TYPE_CHECKING, Any, Literal, Optional, Union, cast, overload
|
|
9
12
|
|
|
10
13
|
from sqlspec.core.statement import SQL
|
|
11
|
-
from sqlspec.migrations.
|
|
14
|
+
from sqlspec.migrations.context import MigrationContext
|
|
12
15
|
from sqlspec.migrations.loaders import get_migration_loader
|
|
13
16
|
from sqlspec.utils.logging import get_logger
|
|
14
|
-
from sqlspec.utils.sync_tools import await_
|
|
17
|
+
from sqlspec.utils.sync_tools import async_, await_
|
|
15
18
|
|
|
16
19
|
if TYPE_CHECKING:
|
|
20
|
+
from collections.abc import Coroutine
|
|
21
|
+
|
|
17
22
|
from sqlspec.driver import AsyncDriverAdapterBase, SyncDriverAdapterBase
|
|
18
23
|
|
|
19
|
-
__all__ = ("AsyncMigrationRunner", "SyncMigrationRunner")
|
|
24
|
+
__all__ = ("AsyncMigrationRunner", "SyncMigrationRunner", "create_migration_runner")
|
|
20
25
|
|
|
21
26
|
logger = get_logger("migrations.runner")
|
|
22
27
|
|
|
23
28
|
|
|
24
|
-
class
|
|
25
|
-
"""
|
|
29
|
+
class BaseMigrationRunner(ABC):
|
|
30
|
+
"""Base migration runner with common functionality shared between sync and async implementations."""
|
|
31
|
+
|
|
32
|
+
def __init__(
|
|
33
|
+
self,
|
|
34
|
+
migrations_path: Path,
|
|
35
|
+
extension_migrations: "Optional[dict[str, Path]]" = None,
|
|
36
|
+
context: "Optional[MigrationContext]" = None,
|
|
37
|
+
extension_configs: "Optional[dict[str, dict[str, Any]]]" = None,
|
|
38
|
+
) -> None:
|
|
39
|
+
"""Initialize the migration runner.
|
|
40
|
+
|
|
41
|
+
Args:
|
|
42
|
+
migrations_path: Path to the directory containing migration files.
|
|
43
|
+
extension_migrations: Optional mapping of extension names to their migration paths.
|
|
44
|
+
context: Optional migration context for Python migrations.
|
|
45
|
+
extension_configs: Optional mapping of extension names to their configurations.
|
|
46
|
+
"""
|
|
47
|
+
self.migrations_path = migrations_path
|
|
48
|
+
self.extension_migrations = extension_migrations or {}
|
|
49
|
+
from sqlspec.loader import SQLFileLoader
|
|
50
|
+
|
|
51
|
+
self.loader = SQLFileLoader()
|
|
52
|
+
self.project_root: Optional[Path] = None
|
|
53
|
+
self.context = context
|
|
54
|
+
self.extension_configs = extension_configs or {}
|
|
55
|
+
|
|
56
|
+
def _extract_version(self, filename: str) -> "Optional[str]":
|
|
57
|
+
"""Extract version from filename.
|
|
58
|
+
|
|
59
|
+
Args:
|
|
60
|
+
filename: The migration filename.
|
|
61
|
+
|
|
62
|
+
Returns:
|
|
63
|
+
The extracted version string or None.
|
|
64
|
+
"""
|
|
65
|
+
# Handle extension-prefixed versions (e.g., "ext_litestar_0001")
|
|
66
|
+
if filename.startswith("ext_"):
|
|
67
|
+
# This is already a prefixed version, return as-is
|
|
68
|
+
return filename
|
|
69
|
+
|
|
70
|
+
# Regular version extraction
|
|
71
|
+
parts = filename.split("_", 1)
|
|
72
|
+
return parts[0].zfill(4) if parts and parts[0].isdigit() else None
|
|
73
|
+
|
|
74
|
+
def _calculate_checksum(self, content: str) -> str:
|
|
75
|
+
"""Calculate MD5 checksum of migration content.
|
|
76
|
+
|
|
77
|
+
Args:
|
|
78
|
+
content: The migration file content.
|
|
79
|
+
|
|
80
|
+
Returns:
|
|
81
|
+
MD5 checksum hex string.
|
|
82
|
+
"""
|
|
83
|
+
import hashlib
|
|
84
|
+
|
|
85
|
+
return hashlib.md5(content.encode()).hexdigest() # noqa: S324
|
|
86
|
+
|
|
87
|
+
@abstractmethod
|
|
88
|
+
def load_migration(self, file_path: Path) -> Union["dict[str, Any]", "Coroutine[Any, Any, dict[str, Any]]"]:
|
|
89
|
+
"""Load a migration file and extract its components.
|
|
90
|
+
|
|
91
|
+
Args:
|
|
92
|
+
file_path: Path to the migration file.
|
|
93
|
+
|
|
94
|
+
Returns:
|
|
95
|
+
Dictionary containing migration metadata and queries.
|
|
96
|
+
For async implementations, returns a coroutine.
|
|
97
|
+
"""
|
|
98
|
+
|
|
99
|
+
def _get_migration_files_sync(self) -> "list[tuple[str, Path]]":
|
|
100
|
+
"""Get all migration files sorted by version.
|
|
101
|
+
|
|
102
|
+
Returns:
|
|
103
|
+
List of tuples containing (version, file_path).
|
|
104
|
+
"""
|
|
105
|
+
|
|
106
|
+
migrations = []
|
|
107
|
+
|
|
108
|
+
# Scan primary migration path
|
|
109
|
+
if self.migrations_path.exists():
|
|
110
|
+
for pattern in ("*.sql", "*.py"):
|
|
111
|
+
for file_path in self.migrations_path.glob(pattern):
|
|
112
|
+
if file_path.name.startswith("."):
|
|
113
|
+
continue
|
|
114
|
+
version = self._extract_version(file_path.name)
|
|
115
|
+
if version:
|
|
116
|
+
migrations.append((version, file_path))
|
|
117
|
+
|
|
118
|
+
# Scan extension migration paths
|
|
119
|
+
for ext_name, ext_path in self.extension_migrations.items():
|
|
120
|
+
if ext_path.exists():
|
|
121
|
+
for pattern in ("*.sql", "*.py"):
|
|
122
|
+
for file_path in ext_path.glob(pattern):
|
|
123
|
+
if file_path.name.startswith("."):
|
|
124
|
+
continue
|
|
125
|
+
# Prefix extension migrations to avoid version conflicts
|
|
126
|
+
version = self._extract_version(file_path.name)
|
|
127
|
+
if version:
|
|
128
|
+
# Use ext_ prefix to distinguish extension migrations
|
|
129
|
+
prefixed_version = f"ext_{ext_name}_{version}"
|
|
130
|
+
migrations.append((prefixed_version, file_path))
|
|
131
|
+
|
|
132
|
+
return sorted(migrations, key=operator.itemgetter(0))
|
|
26
133
|
|
|
27
134
|
def get_migration_files(self) -> "list[tuple[str, Path]]":
|
|
28
135
|
"""Get all migration files sorted by version.
|
|
@@ -32,6 +139,72 @@ class SyncMigrationRunner(BaseMigrationRunner["SyncDriverAdapterBase"]):
|
|
|
32
139
|
"""
|
|
33
140
|
return self._get_migration_files_sync()
|
|
34
141
|
|
|
142
|
+
def _load_migration_metadata_common(self, file_path: Path) -> "dict[str, Any]":
|
|
143
|
+
"""Load common migration metadata that doesn't require async operations.
|
|
144
|
+
|
|
145
|
+
Args:
|
|
146
|
+
file_path: Path to the migration file.
|
|
147
|
+
|
|
148
|
+
Returns:
|
|
149
|
+
Partial migration metadata dictionary.
|
|
150
|
+
"""
|
|
151
|
+
content = file_path.read_text(encoding="utf-8")
|
|
152
|
+
checksum = self._calculate_checksum(content)
|
|
153
|
+
version = self._extract_version(file_path.name)
|
|
154
|
+
description = file_path.stem.split("_", 1)[1] if "_" in file_path.stem else ""
|
|
155
|
+
|
|
156
|
+
return {
|
|
157
|
+
"version": version,
|
|
158
|
+
"description": description,
|
|
159
|
+
"file_path": file_path,
|
|
160
|
+
"checksum": checksum,
|
|
161
|
+
"content": content,
|
|
162
|
+
}
|
|
163
|
+
|
|
164
|
+
def _get_context_for_migration(self, file_path: Path) -> "Optional[MigrationContext]":
|
|
165
|
+
"""Get the appropriate context for a migration file.
|
|
166
|
+
|
|
167
|
+
Args:
|
|
168
|
+
file_path: Path to the migration file.
|
|
169
|
+
|
|
170
|
+
Returns:
|
|
171
|
+
Migration context to use, or None to use default.
|
|
172
|
+
"""
|
|
173
|
+
context_to_use = self.context
|
|
174
|
+
if context_to_use and file_path.name.startswith("ext_"):
|
|
175
|
+
version = self._extract_version(file_path.name)
|
|
176
|
+
if version and version.startswith("ext_"):
|
|
177
|
+
min_extension_version_parts = 3
|
|
178
|
+
parts = version.split("_", 2)
|
|
179
|
+
if len(parts) >= min_extension_version_parts:
|
|
180
|
+
ext_name = parts[1]
|
|
181
|
+
if ext_name in self.extension_configs:
|
|
182
|
+
context_to_use = MigrationContext(
|
|
183
|
+
dialect=self.context.dialect if self.context else None,
|
|
184
|
+
config=self.context.config if self.context else None,
|
|
185
|
+
driver=self.context.driver if self.context else None,
|
|
186
|
+
metadata=self.context.metadata.copy() if self.context and self.context.metadata else {},
|
|
187
|
+
extension_config=self.extension_configs[ext_name],
|
|
188
|
+
)
|
|
189
|
+
|
|
190
|
+
for ext_name, ext_path in self.extension_migrations.items():
|
|
191
|
+
if file_path.parent == ext_path:
|
|
192
|
+
if ext_name in self.extension_configs and self.context:
|
|
193
|
+
context_to_use = MigrationContext(
|
|
194
|
+
config=self.context.config,
|
|
195
|
+
dialect=self.context.dialect,
|
|
196
|
+
driver=self.context.driver,
|
|
197
|
+
metadata=self.context.metadata.copy() if self.context.metadata else {},
|
|
198
|
+
extension_config=self.extension_configs[ext_name],
|
|
199
|
+
)
|
|
200
|
+
break
|
|
201
|
+
|
|
202
|
+
return context_to_use
|
|
203
|
+
|
|
204
|
+
|
|
205
|
+
class SyncMigrationRunner(BaseMigrationRunner):
|
|
206
|
+
"""Synchronous migration runner with pure sync methods."""
|
|
207
|
+
|
|
35
208
|
def load_migration(self, file_path: Path) -> "dict[str, Any]":
|
|
36
209
|
"""Load a migration file and extract its components.
|
|
37
210
|
|
|
@@ -41,7 +214,29 @@ class SyncMigrationRunner(BaseMigrationRunner["SyncDriverAdapterBase"]):
|
|
|
41
214
|
Returns:
|
|
42
215
|
Dictionary containing migration metadata and queries.
|
|
43
216
|
"""
|
|
44
|
-
|
|
217
|
+
# Get common metadata
|
|
218
|
+
metadata = self._load_migration_metadata_common(file_path)
|
|
219
|
+
context_to_use = self._get_context_for_migration(file_path)
|
|
220
|
+
|
|
221
|
+
loader = get_migration_loader(file_path, self.migrations_path, self.project_root, context_to_use)
|
|
222
|
+
loader.validate_migration_file(file_path)
|
|
223
|
+
|
|
224
|
+
has_upgrade, has_downgrade = True, False
|
|
225
|
+
|
|
226
|
+
if file_path.suffix == ".sql":
|
|
227
|
+
version = metadata["version"]
|
|
228
|
+
up_query, down_query = f"migrate-{version}-up", f"migrate-{version}-down"
|
|
229
|
+
self.loader.clear_cache()
|
|
230
|
+
self.loader.load_sql(file_path)
|
|
231
|
+
has_upgrade, has_downgrade = self.loader.has_query(up_query), self.loader.has_query(down_query)
|
|
232
|
+
else:
|
|
233
|
+
try:
|
|
234
|
+
has_downgrade = bool(self._get_migration_sql_sync({"loader": loader, "file_path": file_path}, "down"))
|
|
235
|
+
except Exception:
|
|
236
|
+
has_downgrade = False
|
|
237
|
+
|
|
238
|
+
metadata.update({"has_upgrade": has_upgrade, "has_downgrade": has_downgrade, "loader": loader})
|
|
239
|
+
return metadata
|
|
45
240
|
|
|
46
241
|
def execute_upgrade(
|
|
47
242
|
self, driver: "SyncDriverAdapterBase", migration: "dict[str, Any]"
|
|
@@ -49,13 +244,13 @@ class SyncMigrationRunner(BaseMigrationRunner["SyncDriverAdapterBase"]):
|
|
|
49
244
|
"""Execute an upgrade migration.
|
|
50
245
|
|
|
51
246
|
Args:
|
|
52
|
-
driver: The database driver to use.
|
|
247
|
+
driver: The sync database driver to use.
|
|
53
248
|
migration: Migration metadata dictionary.
|
|
54
249
|
|
|
55
250
|
Returns:
|
|
56
251
|
Tuple of (sql_content, execution_time_ms).
|
|
57
252
|
"""
|
|
58
|
-
upgrade_sql_list = self.
|
|
253
|
+
upgrade_sql_list = self._get_migration_sql_sync(migration, "up")
|
|
59
254
|
if upgrade_sql_list is None:
|
|
60
255
|
return None, 0
|
|
61
256
|
|
|
@@ -73,13 +268,13 @@ class SyncMigrationRunner(BaseMigrationRunner["SyncDriverAdapterBase"]):
|
|
|
73
268
|
"""Execute a downgrade migration.
|
|
74
269
|
|
|
75
270
|
Args:
|
|
76
|
-
driver: The database driver to use.
|
|
271
|
+
driver: The sync database driver to use.
|
|
77
272
|
migration: Migration metadata dictionary.
|
|
78
273
|
|
|
79
274
|
Returns:
|
|
80
275
|
Tuple of (sql_content, execution_time_ms).
|
|
81
276
|
"""
|
|
82
|
-
downgrade_sql_list = self.
|
|
277
|
+
downgrade_sql_list = self._get_migration_sql_sync(migration, "down")
|
|
83
278
|
if downgrade_sql_list is None:
|
|
84
279
|
return None, 0
|
|
85
280
|
|
|
@@ -91,6 +286,50 @@ class SyncMigrationRunner(BaseMigrationRunner["SyncDriverAdapterBase"]):
|
|
|
91
286
|
execution_time = int((time.time() - start_time) * 1000)
|
|
92
287
|
return None, execution_time
|
|
93
288
|
|
|
289
|
+
def _get_migration_sql_sync(self, migration: "dict[str, Any]", direction: str) -> "Optional[list[str]]":
|
|
290
|
+
"""Get migration SQL for given direction (sync version).
|
|
291
|
+
|
|
292
|
+
Args:
|
|
293
|
+
migration: Migration metadata.
|
|
294
|
+
direction: Either 'up' or 'down'.
|
|
295
|
+
|
|
296
|
+
Returns:
|
|
297
|
+
SQL statements for the migration.
|
|
298
|
+
"""
|
|
299
|
+
# If this is being called during migration loading (no has_*grade field yet),
|
|
300
|
+
# don't raise/warn - just proceed to check if the method exists
|
|
301
|
+
if f"has_{direction}grade" in migration and not migration.get(f"has_{direction}grade"):
|
|
302
|
+
if direction == "down":
|
|
303
|
+
logger.warning("Migration %s has no downgrade query", migration.get("version"))
|
|
304
|
+
return None
|
|
305
|
+
msg = f"Migration {migration.get('version')} has no upgrade query"
|
|
306
|
+
raise ValueError(msg)
|
|
307
|
+
|
|
308
|
+
file_path, loader = migration["file_path"], migration["loader"]
|
|
309
|
+
|
|
310
|
+
try:
|
|
311
|
+
method = loader.get_up_sql if direction == "up" else loader.get_down_sql
|
|
312
|
+
# Check if the method is async and handle appropriately
|
|
313
|
+
import inspect
|
|
314
|
+
|
|
315
|
+
if inspect.iscoroutinefunction(method):
|
|
316
|
+
# For async methods, use await_ to run in sync context
|
|
317
|
+
sql_statements = await_(method, raise_sync_error=False)(file_path)
|
|
318
|
+
else:
|
|
319
|
+
# For sync methods, call directly
|
|
320
|
+
sql_statements = method(file_path)
|
|
321
|
+
|
|
322
|
+
except Exception as e:
|
|
323
|
+
if direction == "down":
|
|
324
|
+
logger.warning("Failed to load downgrade for migration %s: %s", migration.get("version"), e)
|
|
325
|
+
return None
|
|
326
|
+
msg = f"Failed to load upgrade for migration {migration.get('version')}: {e}"
|
|
327
|
+
raise ValueError(msg) from e
|
|
328
|
+
else:
|
|
329
|
+
if sql_statements:
|
|
330
|
+
return cast("list[str]", sql_statements)
|
|
331
|
+
return None
|
|
332
|
+
|
|
94
333
|
def load_all_migrations(self) -> "dict[str, SQL]":
|
|
95
334
|
"""Load all migrations into a single namespace for bulk operations.
|
|
96
335
|
|
|
@@ -106,11 +345,11 @@ class SyncMigrationRunner(BaseMigrationRunner["SyncDriverAdapterBase"]):
|
|
|
106
345
|
for query_name in self.loader.list_queries():
|
|
107
346
|
all_queries[query_name] = self.loader.get_sql(query_name)
|
|
108
347
|
else:
|
|
109
|
-
loader = get_migration_loader(file_path, self.migrations_path, self.project_root)
|
|
348
|
+
loader = get_migration_loader(file_path, self.migrations_path, self.project_root, self.context)
|
|
110
349
|
|
|
111
350
|
try:
|
|
112
|
-
up_sql = await_(loader.get_up_sql
|
|
113
|
-
down_sql = await_(loader.get_down_sql
|
|
351
|
+
up_sql = await_(loader.get_up_sql)(file_path)
|
|
352
|
+
down_sql = await_(loader.get_down_sql)(file_path)
|
|
114
353
|
|
|
115
354
|
if up_sql:
|
|
116
355
|
all_queries[f"migrate-{version}-up"] = SQL(up_sql[0])
|
|
@@ -123,14 +362,14 @@ class SyncMigrationRunner(BaseMigrationRunner["SyncDriverAdapterBase"]):
|
|
|
123
362
|
return all_queries
|
|
124
363
|
|
|
125
364
|
|
|
126
|
-
class AsyncMigrationRunner(BaseMigrationRunner
|
|
127
|
-
"""Asynchronous migration
|
|
365
|
+
class AsyncMigrationRunner(BaseMigrationRunner):
|
|
366
|
+
"""Asynchronous migration runner with pure async methods."""
|
|
128
367
|
|
|
129
|
-
async def get_migration_files(self) -> "list[tuple[str, Path]]":
|
|
368
|
+
async def get_migration_files(self) -> "list[tuple[str, Path]]": # type: ignore[override]
|
|
130
369
|
"""Get all migration files sorted by version.
|
|
131
370
|
|
|
132
371
|
Returns:
|
|
133
|
-
List of
|
|
372
|
+
List of (version, path) tuples sorted by version.
|
|
134
373
|
"""
|
|
135
374
|
return self._get_migration_files_sync()
|
|
136
375
|
|
|
@@ -141,82 +380,33 @@ class AsyncMigrationRunner(BaseMigrationRunner["AsyncDriverAdapterBase"]):
|
|
|
141
380
|
file_path: Path to the migration file.
|
|
142
381
|
|
|
143
382
|
Returns:
|
|
144
|
-
Dictionary containing migration metadata.
|
|
383
|
+
Dictionary containing migration metadata and queries.
|
|
145
384
|
"""
|
|
146
|
-
|
|
147
|
-
|
|
148
|
-
|
|
149
|
-
"""Load migration metadata from file (async version).
|
|
150
|
-
|
|
151
|
-
Args:
|
|
152
|
-
file_path: Path to the migration file.
|
|
385
|
+
# Get common metadata
|
|
386
|
+
metadata = self._load_migration_metadata_common(file_path)
|
|
387
|
+
context_to_use = self._get_context_for_migration(file_path)
|
|
153
388
|
|
|
154
|
-
|
|
155
|
-
Migration metadata dictionary.
|
|
156
|
-
"""
|
|
157
|
-
loader = get_migration_loader(file_path, self.migrations_path, self.project_root)
|
|
389
|
+
loader = get_migration_loader(file_path, self.migrations_path, self.project_root, context_to_use)
|
|
158
390
|
loader.validate_migration_file(file_path)
|
|
159
|
-
content = file_path.read_text(encoding="utf-8")
|
|
160
|
-
checksum = self._calculate_checksum(content)
|
|
161
|
-
version = self._extract_version(file_path.name)
|
|
162
|
-
description = file_path.stem.split("_", 1)[1] if "_" in file_path.stem else ""
|
|
163
391
|
|
|
164
392
|
has_upgrade, has_downgrade = True, False
|
|
165
393
|
|
|
166
394
|
if file_path.suffix == ".sql":
|
|
395
|
+
version = metadata["version"]
|
|
167
396
|
up_query, down_query = f"migrate-{version}-up", f"migrate-{version}-down"
|
|
168
397
|
self.loader.clear_cache()
|
|
169
|
-
self.loader.load_sql(file_path)
|
|
398
|
+
await async_(self.loader.load_sql)(file_path)
|
|
170
399
|
has_upgrade, has_downgrade = self.loader.has_query(up_query), self.loader.has_query(down_query)
|
|
171
400
|
else:
|
|
172
401
|
try:
|
|
173
|
-
has_downgrade = bool(
|
|
402
|
+
has_downgrade = bool(
|
|
403
|
+
await self._get_migration_sql_async({"loader": loader, "file_path": file_path}, "down")
|
|
404
|
+
)
|
|
174
405
|
except Exception:
|
|
175
406
|
has_downgrade = False
|
|
176
407
|
|
|
177
|
-
|
|
178
|
-
|
|
179
|
-
"description": description,
|
|
180
|
-
"file_path": file_path,
|
|
181
|
-
"checksum": checksum,
|
|
182
|
-
"has_upgrade": has_upgrade,
|
|
183
|
-
"has_downgrade": has_downgrade,
|
|
184
|
-
"loader": loader,
|
|
185
|
-
}
|
|
186
|
-
|
|
187
|
-
async def _get_migration_sql_async(self, migration: "dict[str, Any]", direction: str) -> "Optional[list[str]]":
|
|
188
|
-
"""Get migration SQL for given direction (async version).
|
|
189
|
-
|
|
190
|
-
Args:
|
|
191
|
-
migration: Migration metadata.
|
|
192
|
-
direction: Either 'up' or 'down'.
|
|
193
|
-
|
|
194
|
-
Returns:
|
|
195
|
-
SQL statements for the migration.
|
|
196
|
-
"""
|
|
197
|
-
if not migration.get(f"has_{direction}grade"):
|
|
198
|
-
if direction == "down":
|
|
199
|
-
logger.warning("Migration %s has no downgrade query", migration["version"])
|
|
200
|
-
return None
|
|
201
|
-
msg = f"Migration {migration['version']} has no upgrade query"
|
|
202
|
-
raise ValueError(msg)
|
|
203
|
-
|
|
204
|
-
file_path, loader = migration["file_path"], migration["loader"]
|
|
205
|
-
|
|
206
|
-
try:
|
|
207
|
-
method = loader.get_up_sql if direction == "up" else loader.get_down_sql
|
|
208
|
-
sql_statements = await method(file_path)
|
|
209
|
-
|
|
210
|
-
except Exception as e:
|
|
211
|
-
if direction == "down":
|
|
212
|
-
logger.warning("Failed to load downgrade for migration %s: %s", migration["version"], e)
|
|
213
|
-
return None
|
|
214
|
-
msg = f"Failed to load upgrade for migration {migration['version']}: {e}"
|
|
215
|
-
raise ValueError(msg) from e
|
|
216
|
-
else:
|
|
217
|
-
if sql_statements:
|
|
218
|
-
return cast("list[str]", sql_statements)
|
|
219
|
-
return None
|
|
408
|
+
metadata.update({"has_upgrade": has_upgrade, "has_downgrade": has_downgrade, "loader": loader})
|
|
409
|
+
return metadata
|
|
220
410
|
|
|
221
411
|
async def execute_upgrade(
|
|
222
412
|
self, driver: "AsyncDriverAdapterBase", migration: "dict[str, Any]"
|
|
@@ -266,6 +456,42 @@ class AsyncMigrationRunner(BaseMigrationRunner["AsyncDriverAdapterBase"]):
|
|
|
266
456
|
execution_time = int((time.time() - start_time) * 1000)
|
|
267
457
|
return None, execution_time
|
|
268
458
|
|
|
459
|
+
async def _get_migration_sql_async(self, migration: "dict[str, Any]", direction: str) -> "Optional[list[str]]":
|
|
460
|
+
"""Get migration SQL for given direction (async version).
|
|
461
|
+
|
|
462
|
+
Args:
|
|
463
|
+
migration: Migration metadata.
|
|
464
|
+
direction: Either 'up' or 'down'.
|
|
465
|
+
|
|
466
|
+
Returns:
|
|
467
|
+
SQL statements for the migration.
|
|
468
|
+
"""
|
|
469
|
+
# If this is being called during migration loading (no has_*grade field yet),
|
|
470
|
+
# don't raise/warn - just proceed to check if the method exists
|
|
471
|
+
if f"has_{direction}grade" in migration and not migration.get(f"has_{direction}grade"):
|
|
472
|
+
if direction == "down":
|
|
473
|
+
logger.warning("Migration %s has no downgrade query", migration.get("version"))
|
|
474
|
+
return None
|
|
475
|
+
msg = f"Migration {migration.get('version')} has no upgrade query"
|
|
476
|
+
raise ValueError(msg)
|
|
477
|
+
|
|
478
|
+
file_path, loader = migration["file_path"], migration["loader"]
|
|
479
|
+
|
|
480
|
+
try:
|
|
481
|
+
method = loader.get_up_sql if direction == "up" else loader.get_down_sql
|
|
482
|
+
sql_statements = await method(file_path)
|
|
483
|
+
|
|
484
|
+
except Exception as e:
|
|
485
|
+
if direction == "down":
|
|
486
|
+
logger.warning("Failed to load downgrade for migration %s: %s", migration.get("version"), e)
|
|
487
|
+
return None
|
|
488
|
+
msg = f"Failed to load upgrade for migration {migration.get('version')}: {e}"
|
|
489
|
+
raise ValueError(msg) from e
|
|
490
|
+
else:
|
|
491
|
+
if sql_statements:
|
|
492
|
+
return cast("list[str]", sql_statements)
|
|
493
|
+
return None
|
|
494
|
+
|
|
269
495
|
async def load_all_migrations(self) -> "dict[str, SQL]":
|
|
270
496
|
"""Load all migrations into a single namespace for bulk operations.
|
|
271
497
|
|
|
@@ -277,11 +503,11 @@ class AsyncMigrationRunner(BaseMigrationRunner["AsyncDriverAdapterBase"]):
|
|
|
277
503
|
|
|
278
504
|
for version, file_path in migrations:
|
|
279
505
|
if file_path.suffix == ".sql":
|
|
280
|
-
self.loader.load_sql(file_path)
|
|
506
|
+
await async_(self.loader.load_sql)(file_path)
|
|
281
507
|
for query_name in self.loader.list_queries():
|
|
282
508
|
all_queries[query_name] = self.loader.get_sql(query_name)
|
|
283
509
|
else:
|
|
284
|
-
loader = get_migration_loader(file_path, self.migrations_path, self.project_root)
|
|
510
|
+
loader = get_migration_loader(file_path, self.migrations_path, self.project_root, self.context)
|
|
285
511
|
|
|
286
512
|
try:
|
|
287
513
|
up_sql = await loader.get_up_sql(file_path)
|
|
@@ -296,3 +522,47 @@ class AsyncMigrationRunner(BaseMigrationRunner["AsyncDriverAdapterBase"]):
|
|
|
296
522
|
logger.debug("Failed to load Python migration %s: %s", file_path, e)
|
|
297
523
|
|
|
298
524
|
return all_queries
|
|
525
|
+
|
|
526
|
+
|
|
527
|
+
@overload
|
|
528
|
+
def create_migration_runner(
|
|
529
|
+
migrations_path: Path,
|
|
530
|
+
extension_migrations: "dict[str, Path]",
|
|
531
|
+
context: "Optional[MigrationContext]",
|
|
532
|
+
extension_configs: "dict[str, Any]",
|
|
533
|
+
is_async: "Literal[False]" = False,
|
|
534
|
+
) -> SyncMigrationRunner: ...
|
|
535
|
+
|
|
536
|
+
|
|
537
|
+
@overload
|
|
538
|
+
def create_migration_runner(
|
|
539
|
+
migrations_path: Path,
|
|
540
|
+
extension_migrations: "dict[str, Path]",
|
|
541
|
+
context: "Optional[MigrationContext]",
|
|
542
|
+
extension_configs: "dict[str, Any]",
|
|
543
|
+
is_async: "Literal[True]",
|
|
544
|
+
) -> AsyncMigrationRunner: ...
|
|
545
|
+
|
|
546
|
+
|
|
547
|
+
def create_migration_runner(
|
|
548
|
+
migrations_path: Path,
|
|
549
|
+
extension_migrations: "dict[str, Path]",
|
|
550
|
+
context: "Optional[MigrationContext]",
|
|
551
|
+
extension_configs: "dict[str, Any]",
|
|
552
|
+
is_async: bool = False,
|
|
553
|
+
) -> "Union[SyncMigrationRunner, AsyncMigrationRunner]":
|
|
554
|
+
"""Factory function to create the appropriate migration runner.
|
|
555
|
+
|
|
556
|
+
Args:
|
|
557
|
+
migrations_path: Path to migrations directory.
|
|
558
|
+
extension_migrations: Extension migration paths.
|
|
559
|
+
context: Migration context.
|
|
560
|
+
extension_configs: Extension configurations.
|
|
561
|
+
is_async: Whether to create async or sync runner.
|
|
562
|
+
|
|
563
|
+
Returns:
|
|
564
|
+
Appropriate migration runner instance.
|
|
565
|
+
"""
|
|
566
|
+
if is_async:
|
|
567
|
+
return AsyncMigrationRunner(migrations_path, extension_migrations, context, extension_configs)
|
|
568
|
+
return SyncMigrationRunner(migrations_path, extension_migrations, context, extension_configs)
|
sqlspec/typing.py
CHANGED
|
@@ -12,8 +12,10 @@ from sqlspec._typing import (
|
|
|
12
12
|
FSSPEC_INSTALLED,
|
|
13
13
|
LITESTAR_INSTALLED,
|
|
14
14
|
MSGSPEC_INSTALLED,
|
|
15
|
+
NUMPY_INSTALLED,
|
|
15
16
|
OBSTORE_INSTALLED,
|
|
16
17
|
OPENTELEMETRY_INSTALLED,
|
|
18
|
+
ORJSON_INSTALLED,
|
|
17
19
|
PGVECTOR_INSTALLED,
|
|
18
20
|
PROMETHEUS_INSTALLED,
|
|
19
21
|
PYARROW_INSTALLED,
|
|
@@ -187,8 +189,10 @@ __all__ = (
|
|
|
187
189
|
"FSSPEC_INSTALLED",
|
|
188
190
|
"LITESTAR_INSTALLED",
|
|
189
191
|
"MSGSPEC_INSTALLED",
|
|
192
|
+
"NUMPY_INSTALLED",
|
|
190
193
|
"OBSTORE_INSTALLED",
|
|
191
194
|
"OPENTELEMETRY_INSTALLED",
|
|
195
|
+
"ORJSON_INSTALLED",
|
|
192
196
|
"PGVECTOR_INSTALLED",
|
|
193
197
|
"PROMETHEUS_INSTALLED",
|
|
194
198
|
"PYARROW_INSTALLED",
|