sqlspec 0.17.1__py3-none-any.whl → 0.19.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of sqlspec might be problematic. Click here for more details.
- sqlspec/__init__.py +1 -1
- sqlspec/_sql.py +54 -159
- sqlspec/adapters/adbc/config.py +24 -30
- sqlspec/adapters/adbc/driver.py +42 -61
- sqlspec/adapters/aiosqlite/config.py +5 -10
- sqlspec/adapters/aiosqlite/driver.py +9 -25
- sqlspec/adapters/aiosqlite/pool.py +43 -35
- sqlspec/adapters/asyncmy/config.py +10 -7
- sqlspec/adapters/asyncmy/driver.py +18 -39
- sqlspec/adapters/asyncpg/config.py +4 -0
- sqlspec/adapters/asyncpg/driver.py +32 -79
- sqlspec/adapters/bigquery/config.py +12 -65
- sqlspec/adapters/bigquery/driver.py +39 -133
- sqlspec/adapters/duckdb/config.py +11 -15
- sqlspec/adapters/duckdb/driver.py +61 -85
- sqlspec/adapters/duckdb/pool.py +2 -5
- sqlspec/adapters/oracledb/_types.py +8 -1
- sqlspec/adapters/oracledb/config.py +55 -38
- sqlspec/adapters/oracledb/driver.py +35 -92
- sqlspec/adapters/oracledb/migrations.py +257 -0
- sqlspec/adapters/psqlpy/config.py +13 -9
- sqlspec/adapters/psqlpy/driver.py +28 -103
- sqlspec/adapters/psycopg/config.py +9 -5
- sqlspec/adapters/psycopg/driver.py +107 -175
- sqlspec/adapters/sqlite/config.py +7 -5
- sqlspec/adapters/sqlite/driver.py +37 -73
- sqlspec/adapters/sqlite/pool.py +3 -12
- sqlspec/base.py +19 -22
- sqlspec/builder/__init__.py +1 -1
- sqlspec/builder/_base.py +34 -20
- sqlspec/builder/_ddl.py +407 -183
- sqlspec/builder/_insert.py +1 -1
- sqlspec/builder/mixins/_insert_operations.py +26 -6
- sqlspec/builder/mixins/_merge_operations.py +1 -1
- sqlspec/builder/mixins/_select_operations.py +1 -5
- sqlspec/cli.py +281 -33
- sqlspec/config.py +183 -14
- sqlspec/core/__init__.py +89 -14
- sqlspec/core/cache.py +57 -104
- sqlspec/core/compiler.py +57 -112
- sqlspec/core/filters.py +1 -21
- sqlspec/core/hashing.py +13 -47
- sqlspec/core/parameters.py +272 -261
- sqlspec/core/result.py +12 -27
- sqlspec/core/splitter.py +17 -21
- sqlspec/core/statement.py +150 -159
- sqlspec/driver/_async.py +2 -15
- sqlspec/driver/_common.py +16 -95
- sqlspec/driver/_sync.py +2 -15
- sqlspec/driver/mixins/_result_tools.py +8 -29
- sqlspec/driver/mixins/_sql_translator.py +6 -8
- sqlspec/exceptions.py +1 -2
- sqlspec/extensions/litestar/plugin.py +15 -8
- sqlspec/loader.py +43 -115
- sqlspec/migrations/__init__.py +1 -1
- sqlspec/migrations/base.py +34 -45
- sqlspec/migrations/commands.py +34 -15
- sqlspec/migrations/loaders.py +1 -1
- sqlspec/migrations/runner.py +104 -19
- sqlspec/migrations/tracker.py +49 -2
- sqlspec/protocols.py +3 -6
- sqlspec/storage/__init__.py +4 -4
- sqlspec/storage/backends/fsspec.py +5 -6
- sqlspec/storage/backends/obstore.py +7 -8
- sqlspec/storage/registry.py +3 -3
- sqlspec/utils/__init__.py +2 -2
- sqlspec/utils/logging.py +6 -10
- sqlspec/utils/sync_tools.py +27 -4
- sqlspec/utils/text.py +6 -1
- {sqlspec-0.17.1.dist-info → sqlspec-0.19.0.dist-info}/METADATA +1 -1
- sqlspec-0.19.0.dist-info/RECORD +138 -0
- sqlspec/builder/_ddl_utils.py +0 -103
- sqlspec-0.17.1.dist-info/RECORD +0 -138
- {sqlspec-0.17.1.dist-info → sqlspec-0.19.0.dist-info}/WHEEL +0 -0
- {sqlspec-0.17.1.dist-info → sqlspec-0.19.0.dist-info}/entry_points.txt +0 -0
- {sqlspec-0.17.1.dist-info → sqlspec-0.19.0.dist-info}/licenses/LICENSE +0 -0
- {sqlspec-0.17.1.dist-info → sqlspec-0.19.0.dist-info}/licenses/NOTICE +0 -0
sqlspec/loader.py
CHANGED
|
@@ -1,17 +1,15 @@
|
|
|
1
|
-
"""SQL file loader
|
|
1
|
+
"""SQL file loader for managing SQL statements from files.
|
|
2
2
|
|
|
3
|
-
|
|
3
|
+
Provides functionality to load, cache, and manage SQL statements
|
|
4
4
|
from files using aiosql-style named queries.
|
|
5
5
|
"""
|
|
6
6
|
|
|
7
7
|
import hashlib
|
|
8
8
|
import re
|
|
9
9
|
import time
|
|
10
|
-
from dataclasses import dataclass, field
|
|
11
10
|
from datetime import datetime, timezone
|
|
12
|
-
from difflib import get_close_matches
|
|
13
11
|
from pathlib import Path
|
|
14
|
-
from typing import Any, Optional, Union
|
|
12
|
+
from typing import TYPE_CHECKING, Any, Final, Optional, Union
|
|
15
13
|
|
|
16
14
|
from sqlspec.core.cache import CacheKey, get_cache_config, get_default_cache
|
|
17
15
|
from sqlspec.core.statement import SQL
|
|
@@ -21,11 +19,13 @@ from sqlspec.exceptions import (
|
|
|
21
19
|
SQLFileParseError,
|
|
22
20
|
StorageOperationFailedError,
|
|
23
21
|
)
|
|
24
|
-
from sqlspec.storage import storage_registry
|
|
25
|
-
from sqlspec.storage.registry import StorageRegistry
|
|
22
|
+
from sqlspec.storage.registry import storage_registry as default_storage_registry
|
|
26
23
|
from sqlspec.utils.correlation import CorrelationContext
|
|
27
24
|
from sqlspec.utils.logging import get_logger
|
|
28
25
|
|
|
26
|
+
if TYPE_CHECKING:
|
|
27
|
+
from sqlspec.storage.registry import StorageRegistry
|
|
28
|
+
|
|
29
29
|
__all__ = ("CachedSQLFile", "NamedStatement", "SQLFile", "SQLFileLoader")
|
|
30
30
|
|
|
31
31
|
logger = get_logger("loader")
|
|
@@ -38,48 +38,8 @@ TRIM_SPECIAL_CHARS = re.compile(r"[^\w.-]")
|
|
|
38
38
|
# Matches: -- dialect: dialect_name (optional dialect specification)
|
|
39
39
|
DIALECT_PATTERN = re.compile(r"^\s*--\s*dialect\s*:\s*(?P<dialect>[a-zA-Z0-9_]+)\s*$", re.IGNORECASE | re.MULTILINE)
|
|
40
40
|
|
|
41
|
-
# Supported SQL dialects (based on SQLGlot's available dialects)
|
|
42
|
-
SUPPORTED_DIALECTS = {
|
|
43
|
-
# Core databases
|
|
44
|
-
"sqlite",
|
|
45
|
-
"postgresql",
|
|
46
|
-
"postgres",
|
|
47
|
-
"mysql",
|
|
48
|
-
"oracle",
|
|
49
|
-
"mssql",
|
|
50
|
-
"tsql",
|
|
51
|
-
# Cloud platforms
|
|
52
|
-
"bigquery",
|
|
53
|
-
"snowflake",
|
|
54
|
-
"redshift",
|
|
55
|
-
"athena",
|
|
56
|
-
"fabric",
|
|
57
|
-
# Analytics engines
|
|
58
|
-
"clickhouse",
|
|
59
|
-
"duckdb",
|
|
60
|
-
"databricks",
|
|
61
|
-
"spark",
|
|
62
|
-
"spark2",
|
|
63
|
-
"trino",
|
|
64
|
-
"presto",
|
|
65
|
-
# Specialized
|
|
66
|
-
"hive",
|
|
67
|
-
"drill",
|
|
68
|
-
"druid",
|
|
69
|
-
"materialize",
|
|
70
|
-
"teradata",
|
|
71
|
-
"dremio",
|
|
72
|
-
"doris",
|
|
73
|
-
"risingwave",
|
|
74
|
-
"singlestore",
|
|
75
|
-
"starrocks",
|
|
76
|
-
"tableau",
|
|
77
|
-
"exasol",
|
|
78
|
-
"dune",
|
|
79
|
-
}
|
|
80
41
|
|
|
81
|
-
|
|
82
|
-
DIALECT_ALIASES = {
|
|
42
|
+
DIALECT_ALIASES: Final = {
|
|
83
43
|
"postgresql": "postgres",
|
|
84
44
|
"pg": "postgres",
|
|
85
45
|
"pgplsql": "postgres",
|
|
@@ -88,7 +48,7 @@ DIALECT_ALIASES = {
|
|
|
88
48
|
"tsql": "mssql",
|
|
89
49
|
}
|
|
90
50
|
|
|
91
|
-
MIN_QUERY_PARTS = 3
|
|
51
|
+
MIN_QUERY_PARTS: Final = 3
|
|
92
52
|
|
|
93
53
|
|
|
94
54
|
def _normalize_query_name(name: str) -> str:
|
|
@@ -129,19 +89,6 @@ def _normalize_dialect_for_sqlglot(dialect: str) -> str:
|
|
|
129
89
|
return DIALECT_ALIASES.get(normalized, normalized)
|
|
130
90
|
|
|
131
91
|
|
|
132
|
-
def _get_dialect_suggestions(invalid_dialect: str) -> "list[str]":
|
|
133
|
-
"""Get dialect suggestions using fuzzy matching.
|
|
134
|
-
|
|
135
|
-
Args:
|
|
136
|
-
invalid_dialect: Invalid dialect name that was provided
|
|
137
|
-
|
|
138
|
-
Returns:
|
|
139
|
-
List of suggested dialect names (up to 3 suggestions)
|
|
140
|
-
"""
|
|
141
|
-
|
|
142
|
-
return get_close_matches(invalid_dialect, SUPPORTED_DIALECTS, n=3, cutoff=0.6)
|
|
143
|
-
|
|
144
|
-
|
|
145
92
|
class NamedStatement:
|
|
146
93
|
"""Represents a parsed SQL statement with metadata.
|
|
147
94
|
|
|
@@ -159,7 +106,6 @@ class NamedStatement:
|
|
|
159
106
|
self.start_line = start_line
|
|
160
107
|
|
|
161
108
|
|
|
162
|
-
@dataclass
|
|
163
109
|
class SQLFile:
|
|
164
110
|
"""Represents a loaded SQL file with metadata.
|
|
165
111
|
|
|
@@ -167,28 +113,32 @@ class SQLFile:
|
|
|
167
113
|
timestamps, and content hash.
|
|
168
114
|
"""
|
|
169
115
|
|
|
170
|
-
content
|
|
171
|
-
"""The raw SQL content from the file."""
|
|
116
|
+
__slots__ = ("checksum", "content", "loaded_at", "metadata", "path")
|
|
172
117
|
|
|
173
|
-
|
|
174
|
-
|
|
118
|
+
def __init__(
|
|
119
|
+
self,
|
|
120
|
+
content: str,
|
|
121
|
+
path: str,
|
|
122
|
+
metadata: "Optional[dict[str, Any]]" = None,
|
|
123
|
+
loaded_at: "Optional[datetime]" = None,
|
|
124
|
+
) -> None:
|
|
125
|
+
"""Initialize SQLFile.
|
|
175
126
|
|
|
176
|
-
|
|
177
|
-
|
|
178
|
-
|
|
179
|
-
|
|
180
|
-
|
|
181
|
-
|
|
182
|
-
|
|
183
|
-
|
|
184
|
-
|
|
185
|
-
|
|
186
|
-
"""Calculate checksum after initialization."""
|
|
127
|
+
Args:
|
|
128
|
+
content: The raw SQL content from the file.
|
|
129
|
+
path: Path where the SQL file was loaded from.
|
|
130
|
+
metadata: Optional metadata associated with the SQL file.
|
|
131
|
+
loaded_at: Timestamp when the file was loaded.
|
|
132
|
+
"""
|
|
133
|
+
self.content = content
|
|
134
|
+
self.path = path
|
|
135
|
+
self.metadata = metadata or {}
|
|
136
|
+
self.loaded_at = loaded_at or datetime.now(timezone.utc)
|
|
187
137
|
self.checksum = hashlib.md5(self.content.encode(), usedforsecurity=False).hexdigest()
|
|
188
138
|
|
|
189
139
|
|
|
190
140
|
class CachedSQLFile:
|
|
191
|
-
"""Cached SQL file with parsed statements
|
|
141
|
+
"""Cached SQL file with parsed statements.
|
|
192
142
|
|
|
193
143
|
Stored in the file cache to avoid re-parsing SQL files when their
|
|
194
144
|
content hasn't changed.
|
|
@@ -205,17 +155,19 @@ class CachedSQLFile:
|
|
|
205
155
|
"""
|
|
206
156
|
self.sql_file = sql_file
|
|
207
157
|
self.parsed_statements = parsed_statements
|
|
208
|
-
self.statement_names =
|
|
158
|
+
self.statement_names = tuple(parsed_statements.keys())
|
|
209
159
|
|
|
210
160
|
|
|
211
161
|
class SQLFileLoader:
|
|
212
162
|
"""Loads and parses SQL files with aiosql-style named queries.
|
|
213
163
|
|
|
214
|
-
|
|
215
|
-
|
|
164
|
+
Loads SQL files containing named queries (using -- name: syntax)
|
|
165
|
+
and retrieves them by name.
|
|
216
166
|
"""
|
|
217
167
|
|
|
218
|
-
|
|
168
|
+
__slots__ = ("_files", "_queries", "_query_to_file", "encoding", "storage_registry")
|
|
169
|
+
|
|
170
|
+
def __init__(self, *, encoding: str = "utf-8", storage_registry: "Optional[StorageRegistry]" = None) -> None:
|
|
219
171
|
"""Initialize the SQL file loader.
|
|
220
172
|
|
|
221
173
|
Args:
|
|
@@ -223,7 +175,8 @@ class SQLFileLoader:
|
|
|
223
175
|
storage_registry: Storage registry for handling file URIs.
|
|
224
176
|
"""
|
|
225
177
|
self.encoding = encoding
|
|
226
|
-
|
|
178
|
+
|
|
179
|
+
self.storage_registry = storage_registry or default_storage_registry
|
|
227
180
|
self._queries: dict[str, NamedStatement] = {}
|
|
228
181
|
self._files: dict[str, SQLFile] = {}
|
|
229
182
|
self._query_to_file: dict[str, str] = {}
|
|
@@ -309,7 +262,6 @@ class SQLFileLoader:
|
|
|
309
262
|
except KeyError as e:
|
|
310
263
|
raise SQLFileNotFoundError(path_str) from e
|
|
311
264
|
except MissingDependencyError:
|
|
312
|
-
# Fall back to standard file reading when no storage backend is available
|
|
313
265
|
try:
|
|
314
266
|
return path.read_text(encoding=self.encoding) # type: ignore[union-attr]
|
|
315
267
|
except FileNotFoundError as e:
|
|
@@ -350,7 +302,6 @@ class SQLFileLoader:
|
|
|
350
302
|
or invalid dialect names are specified
|
|
351
303
|
"""
|
|
352
304
|
statements: dict[str, NamedStatement] = {}
|
|
353
|
-
content.splitlines()
|
|
354
305
|
|
|
355
306
|
name_matches = list(QUERY_NAME_PATTERN.finditer(content))
|
|
356
307
|
if not name_matches:
|
|
@@ -379,20 +330,7 @@ class SQLFileLoader:
|
|
|
379
330
|
if dialect_match:
|
|
380
331
|
declared_dialect = dialect_match.group("dialect").lower()
|
|
381
332
|
|
|
382
|
-
|
|
383
|
-
|
|
384
|
-
if normalized_dialect not in SUPPORTED_DIALECTS:
|
|
385
|
-
suggestions = _get_dialect_suggestions(normalized_dialect)
|
|
386
|
-
warning_msg = f"Unknown dialect '{declared_dialect}' at line {statement_start_line + 1}"
|
|
387
|
-
if suggestions:
|
|
388
|
-
warning_msg += f". Did you mean: {', '.join(suggestions)}?"
|
|
389
|
-
warning_msg += (
|
|
390
|
-
f". Supported dialects: {', '.join(sorted(SUPPORTED_DIALECTS))}. Using dialect as-is."
|
|
391
|
-
)
|
|
392
|
-
logger.warning(warning_msg)
|
|
393
|
-
dialect = declared_dialect.lower()
|
|
394
|
-
else:
|
|
395
|
-
dialect = normalized_dialect
|
|
333
|
+
dialect = _normalize_dialect(declared_dialect)
|
|
396
334
|
remaining_lines = section_lines[1:]
|
|
397
335
|
statement_sql = "\n".join(remaining_lines)
|
|
398
336
|
|
|
@@ -473,7 +411,7 @@ class SQLFileLoader:
|
|
|
473
411
|
raise
|
|
474
412
|
|
|
475
413
|
def _load_directory(self, dir_path: Path) -> int:
|
|
476
|
-
"""Load all SQL files from a directory
|
|
414
|
+
"""Load all SQL files from a directory."""
|
|
477
415
|
sql_files = list(dir_path.rglob("*.sql"))
|
|
478
416
|
if not sql_files:
|
|
479
417
|
return 0
|
|
@@ -486,7 +424,7 @@ class SQLFileLoader:
|
|
|
486
424
|
return len(sql_files)
|
|
487
425
|
|
|
488
426
|
def _load_single_file(self, file_path: Union[str, Path], namespace: Optional[str]) -> None:
|
|
489
|
-
"""Load a single SQL file with optional namespace
|
|
427
|
+
"""Load a single SQL file with optional namespace.
|
|
490
428
|
|
|
491
429
|
Args:
|
|
492
430
|
file_path: Path to the SQL file.
|
|
@@ -543,7 +481,7 @@ class SQLFileLoader:
|
|
|
543
481
|
unified_cache.put(cache_key, cached_file_data)
|
|
544
482
|
|
|
545
483
|
def _load_file_without_cache(self, file_path: Union[str, Path], namespace: Optional[str]) -> None:
|
|
546
|
-
"""Load a single SQL file without
|
|
484
|
+
"""Load a single SQL file without using cache.
|
|
547
485
|
|
|
548
486
|
Args:
|
|
549
487
|
file_path: Path to the SQL file.
|
|
@@ -580,7 +518,7 @@ class SQLFileLoader:
|
|
|
580
518
|
Raises:
|
|
581
519
|
ValueError: If query name already exists.
|
|
582
520
|
"""
|
|
583
|
-
|
|
521
|
+
|
|
584
522
|
normalized_name = _normalize_query_name(name)
|
|
585
523
|
|
|
586
524
|
if normalized_name in self._queries:
|
|
@@ -589,17 +527,7 @@ class SQLFileLoader:
|
|
|
589
527
|
raise ValueError(msg)
|
|
590
528
|
|
|
591
529
|
if dialect is not None:
|
|
592
|
-
|
|
593
|
-
if normalized_dialect not in SUPPORTED_DIALECTS:
|
|
594
|
-
suggestions = _get_dialect_suggestions(normalized_dialect)
|
|
595
|
-
warning_msg = f"Unknown dialect '{dialect}'"
|
|
596
|
-
if suggestions:
|
|
597
|
-
warning_msg += f". Did you mean: {', '.join(suggestions)}?"
|
|
598
|
-
warning_msg += f". Supported dialects: {', '.join(sorted(SUPPORTED_DIALECTS))}. Using dialect as-is."
|
|
599
|
-
logger.warning(warning_msg)
|
|
600
|
-
dialect = dialect.lower()
|
|
601
|
-
else:
|
|
602
|
-
dialect = normalized_dialect
|
|
530
|
+
dialect = _normalize_dialect(dialect)
|
|
603
531
|
|
|
604
532
|
statement = NamedStatement(name=normalized_name, sql=sql.strip(), dialect=dialect, start_line=0)
|
|
605
533
|
self._queries[normalized_name] = statement
|
sqlspec/migrations/__init__.py
CHANGED
|
@@ -1,7 +1,7 @@
|
|
|
1
1
|
"""SQLSpec Migration Tool.
|
|
2
2
|
|
|
3
3
|
A native migration system for SQLSpec that leverages the SQLFileLoader
|
|
4
|
-
and driver
|
|
4
|
+
and driver system for database versioning.
|
|
5
5
|
"""
|
|
6
6
|
|
|
7
7
|
from sqlspec.migrations.commands import AsyncMigrationCommands, MigrationCommands, SyncMigrationCommands
|
sqlspec/migrations/base.py
CHANGED
|
@@ -3,18 +3,19 @@
|
|
|
3
3
|
This module provides abstract base classes for migration components.
|
|
4
4
|
"""
|
|
5
5
|
|
|
6
|
+
import hashlib
|
|
6
7
|
import operator
|
|
7
8
|
from abc import ABC, abstractmethod
|
|
8
9
|
from pathlib import Path
|
|
9
|
-
from typing import Any, Generic, Optional, TypeVar
|
|
10
|
+
from typing import Any, Generic, Optional, TypeVar, cast
|
|
10
11
|
|
|
11
12
|
from sqlspec._sql import sql
|
|
13
|
+
from sqlspec.builder import Delete, Insert, Select
|
|
12
14
|
from sqlspec.builder._ddl import CreateTable
|
|
13
|
-
from sqlspec.core.statement import SQL
|
|
14
15
|
from sqlspec.loader import SQLFileLoader
|
|
15
16
|
from sqlspec.migrations.loaders import get_migration_loader
|
|
16
17
|
from sqlspec.utils.logging import get_logger
|
|
17
|
-
from sqlspec.utils.sync_tools import
|
|
18
|
+
from sqlspec.utils.sync_tools import await_
|
|
18
19
|
|
|
19
20
|
__all__ = ("BaseMigrationCommands", "BaseMigrationRunner", "BaseMigrationTracker")
|
|
20
21
|
|
|
@@ -28,6 +29,8 @@ ConfigT = TypeVar("ConfigT")
|
|
|
28
29
|
class BaseMigrationTracker(ABC, Generic[DriverT]):
|
|
29
30
|
"""Base class for migration version tracking."""
|
|
30
31
|
|
|
32
|
+
__slots__ = ("version_table",)
|
|
33
|
+
|
|
31
34
|
def __init__(self, version_table_name: str = "ddl_migrations") -> None:
|
|
32
35
|
"""Initialize the migration tracker.
|
|
33
36
|
|
|
@@ -36,54 +39,43 @@ class BaseMigrationTracker(ABC, Generic[DriverT]):
|
|
|
36
39
|
"""
|
|
37
40
|
self.version_table = version_table_name
|
|
38
41
|
|
|
39
|
-
def _get_create_table_sql(self) ->
|
|
40
|
-
"""Get SQL for creating the tracking table.
|
|
42
|
+
def _get_create_table_sql(self) -> CreateTable:
|
|
43
|
+
"""Get SQL builder for creating the tracking table.
|
|
41
44
|
|
|
42
45
|
Returns:
|
|
43
|
-
SQL object for table creation.
|
|
46
|
+
SQL builder object for table creation.
|
|
44
47
|
"""
|
|
45
|
-
builder = CreateTable(self.version_table)
|
|
46
|
-
if not hasattr(builder, "_columns"):
|
|
47
|
-
builder._columns = []
|
|
48
|
-
if not hasattr(builder, "_constraints"):
|
|
49
|
-
builder._constraints = []
|
|
50
|
-
if not hasattr(builder, "_table_options"):
|
|
51
|
-
builder._table_options = {}
|
|
52
|
-
|
|
53
48
|
return (
|
|
54
|
-
|
|
49
|
+
sql.create_table(self.version_table)
|
|
50
|
+
.if_not_exists()
|
|
55
51
|
.column("version_num", "VARCHAR(32)", primary_key=True)
|
|
56
52
|
.column("description", "TEXT")
|
|
57
|
-
.column("applied_at", "TIMESTAMP",
|
|
53
|
+
.column("applied_at", "TIMESTAMP", default="CURRENT_TIMESTAMP", not_null=True)
|
|
58
54
|
.column("execution_time_ms", "INTEGER")
|
|
59
55
|
.column("checksum", "VARCHAR(64)")
|
|
60
56
|
.column("applied_by", "VARCHAR(255)")
|
|
61
|
-
)
|
|
57
|
+
)
|
|
62
58
|
|
|
63
|
-
def _get_current_version_sql(self) ->
|
|
64
|
-
"""Get SQL for retrieving current version.
|
|
59
|
+
def _get_current_version_sql(self) -> Select:
|
|
60
|
+
"""Get SQL builder for retrieving current version.
|
|
65
61
|
|
|
66
62
|
Returns:
|
|
67
|
-
SQL object for version query.
|
|
63
|
+
SQL builder object for version query.
|
|
68
64
|
"""
|
|
65
|
+
return sql.select("version_num").from_(self.version_table).order_by("version_num DESC").limit(1)
|
|
69
66
|
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
).to_statement()
|
|
73
|
-
|
|
74
|
-
def _get_applied_migrations_sql(self) -> SQL:
|
|
75
|
-
"""Get SQL for retrieving all applied migrations.
|
|
67
|
+
def _get_applied_migrations_sql(self) -> Select:
|
|
68
|
+
"""Get SQL builder for retrieving all applied migrations.
|
|
76
69
|
|
|
77
70
|
Returns:
|
|
78
|
-
SQL object for migrations query.
|
|
71
|
+
SQL builder object for migrations query.
|
|
79
72
|
"""
|
|
80
|
-
|
|
81
|
-
return (sql.select("*").from_(self.version_table).order_by("version_num")).to_statement()
|
|
73
|
+
return sql.select("*").from_(self.version_table).order_by("version_num")
|
|
82
74
|
|
|
83
75
|
def _get_record_migration_sql(
|
|
84
76
|
self, version: str, description: str, execution_time_ms: int, checksum: str, applied_by: str
|
|
85
|
-
) ->
|
|
86
|
-
"""Get SQL for recording a migration.
|
|
77
|
+
) -> Insert:
|
|
78
|
+
"""Get SQL builder for recording a migration.
|
|
87
79
|
|
|
88
80
|
Args:
|
|
89
81
|
version: Version number of the migration.
|
|
@@ -93,26 +85,24 @@ class BaseMigrationTracker(ABC, Generic[DriverT]):
|
|
|
93
85
|
applied_by: User who applied the migration.
|
|
94
86
|
|
|
95
87
|
Returns:
|
|
96
|
-
SQL object for insert.
|
|
88
|
+
SQL builder object for insert.
|
|
97
89
|
"""
|
|
98
|
-
|
|
99
90
|
return (
|
|
100
91
|
sql.insert(self.version_table)
|
|
101
92
|
.columns("version_num", "description", "execution_time_ms", "checksum", "applied_by")
|
|
102
93
|
.values(version, description, execution_time_ms, checksum, applied_by)
|
|
103
|
-
)
|
|
94
|
+
)
|
|
104
95
|
|
|
105
|
-
def _get_remove_migration_sql(self, version: str) ->
|
|
106
|
-
"""Get SQL for removing a migration record.
|
|
96
|
+
def _get_remove_migration_sql(self, version: str) -> Delete:
|
|
97
|
+
"""Get SQL builder for removing a migration record.
|
|
107
98
|
|
|
108
99
|
Args:
|
|
109
100
|
version: Version number to remove.
|
|
110
101
|
|
|
111
102
|
Returns:
|
|
112
|
-
SQL object for delete.
|
|
103
|
+
SQL builder object for delete.
|
|
113
104
|
"""
|
|
114
|
-
|
|
115
|
-
return (sql.delete().from_(self.version_table).where(sql.version_num == version)).to_statement()
|
|
105
|
+
return sql.delete().from_(self.version_table).where(sql.version_num == version)
|
|
116
106
|
|
|
117
107
|
@abstractmethod
|
|
118
108
|
def ensure_tracking_table(self, driver: DriverT) -> Any:
|
|
@@ -176,7 +166,6 @@ class BaseMigrationRunner(ABC, Generic[DriverT]):
|
|
|
176
166
|
Returns:
|
|
177
167
|
MD5 checksum hex string.
|
|
178
168
|
"""
|
|
179
|
-
import hashlib
|
|
180
169
|
|
|
181
170
|
return hashlib.md5(content.encode()).hexdigest() # noqa: S324
|
|
182
171
|
|
|
@@ -226,7 +215,7 @@ class BaseMigrationRunner(ABC, Generic[DriverT]):
|
|
|
226
215
|
has_upgrade, has_downgrade = self.loader.has_query(up_query), self.loader.has_query(down_query)
|
|
227
216
|
else:
|
|
228
217
|
try:
|
|
229
|
-
has_downgrade = bool(
|
|
218
|
+
has_downgrade = bool(await_(loader.get_down_sql, raise_sync_error=False)(file_path))
|
|
230
219
|
except Exception:
|
|
231
220
|
has_downgrade = False
|
|
232
221
|
|
|
@@ -240,7 +229,7 @@ class BaseMigrationRunner(ABC, Generic[DriverT]):
|
|
|
240
229
|
"loader": loader,
|
|
241
230
|
}
|
|
242
231
|
|
|
243
|
-
def _get_migration_sql(self, migration: "dict[str, Any]", direction: str) -> Optional[
|
|
232
|
+
def _get_migration_sql(self, migration: "dict[str, Any]", direction: str) -> "Optional[list[str]]":
|
|
244
233
|
"""Get migration SQL for given direction.
|
|
245
234
|
|
|
246
235
|
Args:
|
|
@@ -261,7 +250,7 @@ class BaseMigrationRunner(ABC, Generic[DriverT]):
|
|
|
261
250
|
|
|
262
251
|
try:
|
|
263
252
|
method = loader.get_up_sql if direction == "up" else loader.get_down_sql
|
|
264
|
-
sql_statements =
|
|
253
|
+
sql_statements = await_(method, raise_sync_error=False)(file_path)
|
|
265
254
|
|
|
266
255
|
except Exception as e:
|
|
267
256
|
if direction == "down":
|
|
@@ -271,7 +260,7 @@ class BaseMigrationRunner(ABC, Generic[DriverT]):
|
|
|
271
260
|
raise ValueError(msg) from e
|
|
272
261
|
else:
|
|
273
262
|
if sql_statements:
|
|
274
|
-
return
|
|
263
|
+
return cast("list[str]", sql_statements)
|
|
275
264
|
return None
|
|
276
265
|
|
|
277
266
|
@abstractmethod
|
|
@@ -312,7 +301,7 @@ class BaseMigrationCommands(ABC, Generic[ConfigT, DriverT]):
|
|
|
312
301
|
self.config = config
|
|
313
302
|
migration_config = getattr(self.config, "migration_config", {}) or {}
|
|
314
303
|
|
|
315
|
-
self.version_table = migration_config.get("version_table_name", "
|
|
304
|
+
self.version_table = migration_config.get("version_table_name", "ddl_migrations")
|
|
316
305
|
self.migrations_path = Path(migration_config.get("script_location", "migrations"))
|
|
317
306
|
self.project_root = Path(migration_config["project_root"]) if "project_root" in migration_config else None
|
|
318
307
|
|
sqlspec/migrations/commands.py
CHANGED
|
@@ -3,7 +3,7 @@
|
|
|
3
3
|
This module provides the main command interface for database migrations.
|
|
4
4
|
"""
|
|
5
5
|
|
|
6
|
-
from typing import TYPE_CHECKING, Any, Union, cast
|
|
6
|
+
from typing import TYPE_CHECKING, Any, Optional, Union, cast
|
|
7
7
|
|
|
8
8
|
from rich.console import Console
|
|
9
9
|
from rich.table import Table
|
|
@@ -11,7 +11,6 @@ from rich.table import Table
|
|
|
11
11
|
from sqlspec._sql import sql
|
|
12
12
|
from sqlspec.migrations.base import BaseMigrationCommands
|
|
13
13
|
from sqlspec.migrations.runner import AsyncMigrationRunner, SyncMigrationRunner
|
|
14
|
-
from sqlspec.migrations.tracker import AsyncMigrationTracker, SyncMigrationTracker
|
|
15
14
|
from sqlspec.migrations.utils import create_migration_file
|
|
16
15
|
from sqlspec.utils.logging import get_logger
|
|
17
16
|
from sqlspec.utils.sync_tools import await_
|
|
@@ -26,7 +25,7 @@ console = Console()
|
|
|
26
25
|
|
|
27
26
|
|
|
28
27
|
class SyncMigrationCommands(BaseMigrationCommands["SyncConfigT", Any]):
|
|
29
|
-
"""
|
|
28
|
+
"""Synchronous migration commands."""
|
|
30
29
|
|
|
31
30
|
def __init__(self, config: "SyncConfigT") -> None:
|
|
32
31
|
"""Initialize migration commands.
|
|
@@ -35,7 +34,7 @@ class SyncMigrationCommands(BaseMigrationCommands["SyncConfigT", Any]):
|
|
|
35
34
|
config: The SQLSpec configuration.
|
|
36
35
|
"""
|
|
37
36
|
super().__init__(config)
|
|
38
|
-
self.tracker =
|
|
37
|
+
self.tracker = config.migration_tracker_type(self.version_table)
|
|
39
38
|
self.runner = SyncMigrationRunner(self.migrations_path)
|
|
40
39
|
|
|
41
40
|
def init(self, directory: str, package: bool = True) -> None:
|
|
@@ -47,11 +46,14 @@ class SyncMigrationCommands(BaseMigrationCommands["SyncConfigT", Any]):
|
|
|
47
46
|
"""
|
|
48
47
|
self.init_directory(directory, package)
|
|
49
48
|
|
|
50
|
-
def current(self, verbose: bool = False) ->
|
|
49
|
+
def current(self, verbose: bool = False) -> "Optional[str]":
|
|
51
50
|
"""Show current migration version.
|
|
52
51
|
|
|
53
52
|
Args:
|
|
54
53
|
verbose: Whether to show detailed migration history.
|
|
54
|
+
|
|
55
|
+
Returns:
|
|
56
|
+
The current migration version or None if no migrations applied.
|
|
55
57
|
"""
|
|
56
58
|
with self.config.provide_session() as driver:
|
|
57
59
|
self.tracker.ensure_tracking_table(driver)
|
|
@@ -59,7 +61,7 @@ class SyncMigrationCommands(BaseMigrationCommands["SyncConfigT", Any]):
|
|
|
59
61
|
current = self.tracker.get_current_version(driver)
|
|
60
62
|
if not current:
|
|
61
63
|
console.print("[yellow]No migrations applied yet[/]")
|
|
62
|
-
return
|
|
64
|
+
return None
|
|
63
65
|
|
|
64
66
|
console.print(f"[green]Current version:[/] {current}")
|
|
65
67
|
|
|
@@ -84,6 +86,8 @@ class SyncMigrationCommands(BaseMigrationCommands["SyncConfigT", Any]):
|
|
|
84
86
|
|
|
85
87
|
console.print(table)
|
|
86
88
|
|
|
89
|
+
return cast("Optional[str]", current)
|
|
90
|
+
|
|
87
91
|
def upgrade(self, revision: str = "head") -> None:
|
|
88
92
|
"""Upgrade to a target revision.
|
|
89
93
|
|
|
@@ -137,6 +141,8 @@ class SyncMigrationCommands(BaseMigrationCommands["SyncConfigT", Any]):
|
|
|
137
141
|
to_revert = []
|
|
138
142
|
if revision == "-1":
|
|
139
143
|
to_revert = [applied[-1]]
|
|
144
|
+
elif revision == "base":
|
|
145
|
+
to_revert = list(reversed(applied))
|
|
140
146
|
else:
|
|
141
147
|
for migration in reversed(applied):
|
|
142
148
|
if migration["version_num"] > revision:
|
|
@@ -195,7 +201,7 @@ class SyncMigrationCommands(BaseMigrationCommands["SyncConfigT", Any]):
|
|
|
195
201
|
|
|
196
202
|
|
|
197
203
|
class AsyncMigrationCommands(BaseMigrationCommands["AsyncConfigT", Any]):
|
|
198
|
-
"""
|
|
204
|
+
"""Asynchronous migration commands."""
|
|
199
205
|
|
|
200
206
|
def __init__(self, sqlspec_config: "AsyncConfigT") -> None:
|
|
201
207
|
"""Initialize migration commands.
|
|
@@ -204,7 +210,7 @@ class AsyncMigrationCommands(BaseMigrationCommands["AsyncConfigT", Any]):
|
|
|
204
210
|
sqlspec_config: The SQLSpec configuration.
|
|
205
211
|
"""
|
|
206
212
|
super().__init__(sqlspec_config)
|
|
207
|
-
self.tracker =
|
|
213
|
+
self.tracker = sqlspec_config.migration_tracker_type(self.version_table)
|
|
208
214
|
self.runner = AsyncMigrationRunner(self.migrations_path)
|
|
209
215
|
|
|
210
216
|
async def init(self, directory: str, package: bool = True) -> None:
|
|
@@ -216,11 +222,14 @@ class AsyncMigrationCommands(BaseMigrationCommands["AsyncConfigT", Any]):
|
|
|
216
222
|
"""
|
|
217
223
|
self.init_directory(directory, package)
|
|
218
224
|
|
|
219
|
-
async def current(self, verbose: bool = False) ->
|
|
225
|
+
async def current(self, verbose: bool = False) -> "Optional[str]":
|
|
220
226
|
"""Show current migration version.
|
|
221
227
|
|
|
222
228
|
Args:
|
|
223
229
|
verbose: Whether to show detailed migration history.
|
|
230
|
+
|
|
231
|
+
Returns:
|
|
232
|
+
The current migration version or None if no migrations applied.
|
|
224
233
|
"""
|
|
225
234
|
async with self.config.provide_session() as driver:
|
|
226
235
|
await self.tracker.ensure_tracking_table(driver)
|
|
@@ -228,7 +237,7 @@ class AsyncMigrationCommands(BaseMigrationCommands["AsyncConfigT", Any]):
|
|
|
228
237
|
current = await self.tracker.get_current_version(driver)
|
|
229
238
|
if not current:
|
|
230
239
|
console.print("[yellow]No migrations applied yet[/]")
|
|
231
|
-
return
|
|
240
|
+
return None
|
|
232
241
|
|
|
233
242
|
console.print(f"[green]Current version:[/] {current}")
|
|
234
243
|
if verbose:
|
|
@@ -249,6 +258,8 @@ class AsyncMigrationCommands(BaseMigrationCommands["AsyncConfigT", Any]):
|
|
|
249
258
|
)
|
|
250
259
|
console.print(table)
|
|
251
260
|
|
|
261
|
+
return cast("Optional[str]", current)
|
|
262
|
+
|
|
252
263
|
async def upgrade(self, revision: str = "head") -> None:
|
|
253
264
|
"""Upgrade to a target revision.
|
|
254
265
|
|
|
@@ -297,6 +308,8 @@ class AsyncMigrationCommands(BaseMigrationCommands["AsyncConfigT", Any]):
|
|
|
297
308
|
to_revert = []
|
|
298
309
|
if revision == "-1":
|
|
299
310
|
to_revert = [applied[-1]]
|
|
311
|
+
elif revision == "base":
|
|
312
|
+
to_revert = list(reversed(applied))
|
|
300
313
|
else:
|
|
301
314
|
for migration in reversed(applied):
|
|
302
315
|
if migration["version_num"] > revision:
|
|
@@ -382,20 +395,26 @@ class MigrationCommands:
|
|
|
382
395
|
package: Whether to create __init__.py file.
|
|
383
396
|
"""
|
|
384
397
|
if self._is_async:
|
|
385
|
-
await_(cast("AsyncMigrationCommands[Any]", self._impl).init
|
|
398
|
+
await_(cast("AsyncMigrationCommands[Any]", self._impl).init, raise_sync_error=False)(
|
|
399
|
+
directory, package=package
|
|
400
|
+
)
|
|
386
401
|
else:
|
|
387
402
|
cast("SyncMigrationCommands[Any]", self._impl).init(directory, package=package)
|
|
388
403
|
|
|
389
|
-
def current(self, verbose: bool = False) ->
|
|
404
|
+
def current(self, verbose: bool = False) -> "Optional[str]":
|
|
390
405
|
"""Show current migration version.
|
|
391
406
|
|
|
392
407
|
Args:
|
|
393
408
|
verbose: Whether to show detailed migration history.
|
|
409
|
+
|
|
410
|
+
Returns:
|
|
411
|
+
The current migration version or None if no migrations applied.
|
|
394
412
|
"""
|
|
395
413
|
if self._is_async:
|
|
396
|
-
await_(cast("AsyncMigrationCommands[Any]", self._impl).current, raise_sync_error=False)(
|
|
397
|
-
|
|
398
|
-
|
|
414
|
+
return await_(cast("AsyncMigrationCommands[Any]", self._impl).current, raise_sync_error=False)(
|
|
415
|
+
verbose=verbose
|
|
416
|
+
)
|
|
417
|
+
return cast("SyncMigrationCommands[Any]", self._impl).current(verbose=verbose)
|
|
399
418
|
|
|
400
419
|
def upgrade(self, revision: str = "head") -> None:
|
|
401
420
|
"""Upgrade to a target revision.
|
sqlspec/migrations/loaders.py
CHANGED