sqlspec 0.17.1__py3-none-any.whl → 0.19.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of sqlspec might be problematic. Click here for more details.

Files changed (77) hide show
  1. sqlspec/__init__.py +1 -1
  2. sqlspec/_sql.py +54 -159
  3. sqlspec/adapters/adbc/config.py +24 -30
  4. sqlspec/adapters/adbc/driver.py +42 -61
  5. sqlspec/adapters/aiosqlite/config.py +5 -10
  6. sqlspec/adapters/aiosqlite/driver.py +9 -25
  7. sqlspec/adapters/aiosqlite/pool.py +43 -35
  8. sqlspec/adapters/asyncmy/config.py +10 -7
  9. sqlspec/adapters/asyncmy/driver.py +18 -39
  10. sqlspec/adapters/asyncpg/config.py +4 -0
  11. sqlspec/adapters/asyncpg/driver.py +32 -79
  12. sqlspec/adapters/bigquery/config.py +12 -65
  13. sqlspec/adapters/bigquery/driver.py +39 -133
  14. sqlspec/adapters/duckdb/config.py +11 -15
  15. sqlspec/adapters/duckdb/driver.py +61 -85
  16. sqlspec/adapters/duckdb/pool.py +2 -5
  17. sqlspec/adapters/oracledb/_types.py +8 -1
  18. sqlspec/adapters/oracledb/config.py +55 -38
  19. sqlspec/adapters/oracledb/driver.py +35 -92
  20. sqlspec/adapters/oracledb/migrations.py +257 -0
  21. sqlspec/adapters/psqlpy/config.py +13 -9
  22. sqlspec/adapters/psqlpy/driver.py +28 -103
  23. sqlspec/adapters/psycopg/config.py +9 -5
  24. sqlspec/adapters/psycopg/driver.py +107 -175
  25. sqlspec/adapters/sqlite/config.py +7 -5
  26. sqlspec/adapters/sqlite/driver.py +37 -73
  27. sqlspec/adapters/sqlite/pool.py +3 -12
  28. sqlspec/base.py +19 -22
  29. sqlspec/builder/__init__.py +1 -1
  30. sqlspec/builder/_base.py +34 -20
  31. sqlspec/builder/_ddl.py +407 -183
  32. sqlspec/builder/_insert.py +1 -1
  33. sqlspec/builder/mixins/_insert_operations.py +26 -6
  34. sqlspec/builder/mixins/_merge_operations.py +1 -1
  35. sqlspec/builder/mixins/_select_operations.py +1 -5
  36. sqlspec/cli.py +281 -33
  37. sqlspec/config.py +183 -14
  38. sqlspec/core/__init__.py +89 -14
  39. sqlspec/core/cache.py +57 -104
  40. sqlspec/core/compiler.py +57 -112
  41. sqlspec/core/filters.py +1 -21
  42. sqlspec/core/hashing.py +13 -47
  43. sqlspec/core/parameters.py +272 -261
  44. sqlspec/core/result.py +12 -27
  45. sqlspec/core/splitter.py +17 -21
  46. sqlspec/core/statement.py +150 -159
  47. sqlspec/driver/_async.py +2 -15
  48. sqlspec/driver/_common.py +16 -95
  49. sqlspec/driver/_sync.py +2 -15
  50. sqlspec/driver/mixins/_result_tools.py +8 -29
  51. sqlspec/driver/mixins/_sql_translator.py +6 -8
  52. sqlspec/exceptions.py +1 -2
  53. sqlspec/extensions/litestar/plugin.py +15 -8
  54. sqlspec/loader.py +43 -115
  55. sqlspec/migrations/__init__.py +1 -1
  56. sqlspec/migrations/base.py +34 -45
  57. sqlspec/migrations/commands.py +34 -15
  58. sqlspec/migrations/loaders.py +1 -1
  59. sqlspec/migrations/runner.py +104 -19
  60. sqlspec/migrations/tracker.py +49 -2
  61. sqlspec/protocols.py +3 -6
  62. sqlspec/storage/__init__.py +4 -4
  63. sqlspec/storage/backends/fsspec.py +5 -6
  64. sqlspec/storage/backends/obstore.py +7 -8
  65. sqlspec/storage/registry.py +3 -3
  66. sqlspec/utils/__init__.py +2 -2
  67. sqlspec/utils/logging.py +6 -10
  68. sqlspec/utils/sync_tools.py +27 -4
  69. sqlspec/utils/text.py +6 -1
  70. {sqlspec-0.17.1.dist-info → sqlspec-0.19.0.dist-info}/METADATA +1 -1
  71. sqlspec-0.19.0.dist-info/RECORD +138 -0
  72. sqlspec/builder/_ddl_utils.py +0 -103
  73. sqlspec-0.17.1.dist-info/RECORD +0 -138
  74. {sqlspec-0.17.1.dist-info → sqlspec-0.19.0.dist-info}/WHEEL +0 -0
  75. {sqlspec-0.17.1.dist-info → sqlspec-0.19.0.dist-info}/entry_points.txt +0 -0
  76. {sqlspec-0.17.1.dist-info → sqlspec-0.19.0.dist-info}/licenses/LICENSE +0 -0
  77. {sqlspec-0.17.1.dist-info → sqlspec-0.19.0.dist-info}/licenses/NOTICE +0 -0
sqlspec/loader.py CHANGED
@@ -1,17 +1,15 @@
1
- """SQL file loader module for managing SQL statements from files.
1
+ """SQL file loader for managing SQL statements from files.
2
2
 
3
- This module provides functionality to load, cache, and manage SQL statements
3
+ Provides functionality to load, cache, and manage SQL statements
4
4
  from files using aiosql-style named queries.
5
5
  """
6
6
 
7
7
  import hashlib
8
8
  import re
9
9
  import time
10
- from dataclasses import dataclass, field
11
10
  from datetime import datetime, timezone
12
- from difflib import get_close_matches
13
11
  from pathlib import Path
14
- from typing import Any, Optional, Union
12
+ from typing import TYPE_CHECKING, Any, Final, Optional, Union
15
13
 
16
14
  from sqlspec.core.cache import CacheKey, get_cache_config, get_default_cache
17
15
  from sqlspec.core.statement import SQL
@@ -21,11 +19,13 @@ from sqlspec.exceptions import (
21
19
  SQLFileParseError,
22
20
  StorageOperationFailedError,
23
21
  )
24
- from sqlspec.storage import storage_registry
25
- from sqlspec.storage.registry import StorageRegistry
22
+ from sqlspec.storage.registry import storage_registry as default_storage_registry
26
23
  from sqlspec.utils.correlation import CorrelationContext
27
24
  from sqlspec.utils.logging import get_logger
28
25
 
26
+ if TYPE_CHECKING:
27
+ from sqlspec.storage.registry import StorageRegistry
28
+
29
29
  __all__ = ("CachedSQLFile", "NamedStatement", "SQLFile", "SQLFileLoader")
30
30
 
31
31
  logger = get_logger("loader")
@@ -38,48 +38,8 @@ TRIM_SPECIAL_CHARS = re.compile(r"[^\w.-]")
38
38
  # Matches: -- dialect: dialect_name (optional dialect specification)
39
39
  DIALECT_PATTERN = re.compile(r"^\s*--\s*dialect\s*:\s*(?P<dialect>[a-zA-Z0-9_]+)\s*$", re.IGNORECASE | re.MULTILINE)
40
40
 
41
- # Supported SQL dialects (based on SQLGlot's available dialects)
42
- SUPPORTED_DIALECTS = {
43
- # Core databases
44
- "sqlite",
45
- "postgresql",
46
- "postgres",
47
- "mysql",
48
- "oracle",
49
- "mssql",
50
- "tsql",
51
- # Cloud platforms
52
- "bigquery",
53
- "snowflake",
54
- "redshift",
55
- "athena",
56
- "fabric",
57
- # Analytics engines
58
- "clickhouse",
59
- "duckdb",
60
- "databricks",
61
- "spark",
62
- "spark2",
63
- "trino",
64
- "presto",
65
- # Specialized
66
- "hive",
67
- "drill",
68
- "druid",
69
- "materialize",
70
- "teradata",
71
- "dremio",
72
- "doris",
73
- "risingwave",
74
- "singlestore",
75
- "starrocks",
76
- "tableau",
77
- "exasol",
78
- "dune",
79
- }
80
41
 
81
- # Dialect aliases for common variants
82
- DIALECT_ALIASES = {
42
+ DIALECT_ALIASES: Final = {
83
43
  "postgresql": "postgres",
84
44
  "pg": "postgres",
85
45
  "pgplsql": "postgres",
@@ -88,7 +48,7 @@ DIALECT_ALIASES = {
88
48
  "tsql": "mssql",
89
49
  }
90
50
 
91
- MIN_QUERY_PARTS = 3
51
+ MIN_QUERY_PARTS: Final = 3
92
52
 
93
53
 
94
54
  def _normalize_query_name(name: str) -> str:
@@ -129,19 +89,6 @@ def _normalize_dialect_for_sqlglot(dialect: str) -> str:
129
89
  return DIALECT_ALIASES.get(normalized, normalized)
130
90
 
131
91
 
132
- def _get_dialect_suggestions(invalid_dialect: str) -> "list[str]":
133
- """Get dialect suggestions using fuzzy matching.
134
-
135
- Args:
136
- invalid_dialect: Invalid dialect name that was provided
137
-
138
- Returns:
139
- List of suggested dialect names (up to 3 suggestions)
140
- """
141
-
142
- return get_close_matches(invalid_dialect, SUPPORTED_DIALECTS, n=3, cutoff=0.6)
143
-
144
-
145
92
  class NamedStatement:
146
93
  """Represents a parsed SQL statement with metadata.
147
94
 
@@ -159,7 +106,6 @@ class NamedStatement:
159
106
  self.start_line = start_line
160
107
 
161
108
 
162
- @dataclass
163
109
  class SQLFile:
164
110
  """Represents a loaded SQL file with metadata.
165
111
 
@@ -167,28 +113,32 @@ class SQLFile:
167
113
  timestamps, and content hash.
168
114
  """
169
115
 
170
- content: str
171
- """The raw SQL content from the file."""
116
+ __slots__ = ("checksum", "content", "loaded_at", "metadata", "path")
172
117
 
173
- path: str
174
- """Path where the SQL file was loaded from."""
118
+ def __init__(
119
+ self,
120
+ content: str,
121
+ path: str,
122
+ metadata: "Optional[dict[str, Any]]" = None,
123
+ loaded_at: "Optional[datetime]" = None,
124
+ ) -> None:
125
+ """Initialize SQLFile.
175
126
 
176
- metadata: "dict[str, Any]" = field(default_factory=dict)
177
- """Optional metadata associated with the SQL file."""
178
-
179
- checksum: str = field(init=False)
180
- """MD5 checksum of the SQL content for cache invalidation."""
181
-
182
- loaded_at: datetime = field(default_factory=lambda: datetime.now(timezone.utc))
183
- """Timestamp when the file was loaded."""
184
-
185
- def __post_init__(self) -> None:
186
- """Calculate checksum after initialization."""
127
+ Args:
128
+ content: The raw SQL content from the file.
129
+ path: Path where the SQL file was loaded from.
130
+ metadata: Optional metadata associated with the SQL file.
131
+ loaded_at: Timestamp when the file was loaded.
132
+ """
133
+ self.content = content
134
+ self.path = path
135
+ self.metadata = metadata or {}
136
+ self.loaded_at = loaded_at or datetime.now(timezone.utc)
187
137
  self.checksum = hashlib.md5(self.content.encode(), usedforsecurity=False).hexdigest()
188
138
 
189
139
 
190
140
  class CachedSQLFile:
191
- """Cached SQL file with parsed statements for efficient reloading.
141
+ """Cached SQL file with parsed statements.
192
142
 
193
143
  Stored in the file cache to avoid re-parsing SQL files when their
194
144
  content hasn't changed.
@@ -205,17 +155,19 @@ class CachedSQLFile:
205
155
  """
206
156
  self.sql_file = sql_file
207
157
  self.parsed_statements = parsed_statements
208
- self.statement_names = list(parsed_statements.keys())
158
+ self.statement_names = tuple(parsed_statements.keys())
209
159
 
210
160
 
211
161
  class SQLFileLoader:
212
162
  """Loads and parses SQL files with aiosql-style named queries.
213
163
 
214
- Provides functionality to load SQL files containing named queries
215
- (using -- name: syntax) and retrieve them by name.
164
+ Loads SQL files containing named queries (using -- name: syntax)
165
+ and retrieves them by name.
216
166
  """
217
167
 
218
- def __init__(self, *, encoding: str = "utf-8", storage_registry: StorageRegistry = storage_registry) -> None:
168
+ __slots__ = ("_files", "_queries", "_query_to_file", "encoding", "storage_registry")
169
+
170
+ def __init__(self, *, encoding: str = "utf-8", storage_registry: "Optional[StorageRegistry]" = None) -> None:
219
171
  """Initialize the SQL file loader.
220
172
 
221
173
  Args:
@@ -223,7 +175,8 @@ class SQLFileLoader:
223
175
  storage_registry: Storage registry for handling file URIs.
224
176
  """
225
177
  self.encoding = encoding
226
- self.storage_registry = storage_registry
178
+
179
+ self.storage_registry = storage_registry or default_storage_registry
227
180
  self._queries: dict[str, NamedStatement] = {}
228
181
  self._files: dict[str, SQLFile] = {}
229
182
  self._query_to_file: dict[str, str] = {}
@@ -309,7 +262,6 @@ class SQLFileLoader:
309
262
  except KeyError as e:
310
263
  raise SQLFileNotFoundError(path_str) from e
311
264
  except MissingDependencyError:
312
- # Fall back to standard file reading when no storage backend is available
313
265
  try:
314
266
  return path.read_text(encoding=self.encoding) # type: ignore[union-attr]
315
267
  except FileNotFoundError as e:
@@ -350,7 +302,6 @@ class SQLFileLoader:
350
302
  or invalid dialect names are specified
351
303
  """
352
304
  statements: dict[str, NamedStatement] = {}
353
- content.splitlines()
354
305
 
355
306
  name_matches = list(QUERY_NAME_PATTERN.finditer(content))
356
307
  if not name_matches:
@@ -379,20 +330,7 @@ class SQLFileLoader:
379
330
  if dialect_match:
380
331
  declared_dialect = dialect_match.group("dialect").lower()
381
332
 
382
- normalized_dialect = _normalize_dialect(declared_dialect)
383
-
384
- if normalized_dialect not in SUPPORTED_DIALECTS:
385
- suggestions = _get_dialect_suggestions(normalized_dialect)
386
- warning_msg = f"Unknown dialect '{declared_dialect}' at line {statement_start_line + 1}"
387
- if suggestions:
388
- warning_msg += f". Did you mean: {', '.join(suggestions)}?"
389
- warning_msg += (
390
- f". Supported dialects: {', '.join(sorted(SUPPORTED_DIALECTS))}. Using dialect as-is."
391
- )
392
- logger.warning(warning_msg)
393
- dialect = declared_dialect.lower()
394
- else:
395
- dialect = normalized_dialect
333
+ dialect = _normalize_dialect(declared_dialect)
396
334
  remaining_lines = section_lines[1:]
397
335
  statement_sql = "\n".join(remaining_lines)
398
336
 
@@ -473,7 +411,7 @@ class SQLFileLoader:
473
411
  raise
474
412
 
475
413
  def _load_directory(self, dir_path: Path) -> int:
476
- """Load all SQL files from a directory with namespacing."""
414
+ """Load all SQL files from a directory."""
477
415
  sql_files = list(dir_path.rglob("*.sql"))
478
416
  if not sql_files:
479
417
  return 0
@@ -486,7 +424,7 @@ class SQLFileLoader:
486
424
  return len(sql_files)
487
425
 
488
426
  def _load_single_file(self, file_path: Union[str, Path], namespace: Optional[str]) -> None:
489
- """Load a single SQL file with optional namespace and caching.
427
+ """Load a single SQL file with optional namespace.
490
428
 
491
429
  Args:
492
430
  file_path: Path to the SQL file.
@@ -543,7 +481,7 @@ class SQLFileLoader:
543
481
  unified_cache.put(cache_key, cached_file_data)
544
482
 
545
483
  def _load_file_without_cache(self, file_path: Union[str, Path], namespace: Optional[str]) -> None:
546
- """Load a single SQL file without caching.
484
+ """Load a single SQL file without using cache.
547
485
 
548
486
  Args:
549
487
  file_path: Path to the SQL file.
@@ -580,7 +518,7 @@ class SQLFileLoader:
580
518
  Raises:
581
519
  ValueError: If query name already exists.
582
520
  """
583
- # Normalize the name for consistency with file-loaded queries
521
+
584
522
  normalized_name = _normalize_query_name(name)
585
523
 
586
524
  if normalized_name in self._queries:
@@ -589,17 +527,7 @@ class SQLFileLoader:
589
527
  raise ValueError(msg)
590
528
 
591
529
  if dialect is not None:
592
- normalized_dialect = _normalize_dialect(dialect)
593
- if normalized_dialect not in SUPPORTED_DIALECTS:
594
- suggestions = _get_dialect_suggestions(normalized_dialect)
595
- warning_msg = f"Unknown dialect '{dialect}'"
596
- if suggestions:
597
- warning_msg += f". Did you mean: {', '.join(suggestions)}?"
598
- warning_msg += f". Supported dialects: {', '.join(sorted(SUPPORTED_DIALECTS))}. Using dialect as-is."
599
- logger.warning(warning_msg)
600
- dialect = dialect.lower()
601
- else:
602
- dialect = normalized_dialect
530
+ dialect = _normalize_dialect(dialect)
603
531
 
604
532
  statement = NamedStatement(name=normalized_name, sql=sql.strip(), dialect=dialect, start_line=0)
605
533
  self._queries[normalized_name] = statement
@@ -1,7 +1,7 @@
1
1
  """SQLSpec Migration Tool.
2
2
 
3
3
  A native migration system for SQLSpec that leverages the SQLFileLoader
4
- and driver architecture for database versioning.
4
+ and driver system for database versioning.
5
5
  """
6
6
 
7
7
  from sqlspec.migrations.commands import AsyncMigrationCommands, MigrationCommands, SyncMigrationCommands
@@ -3,18 +3,19 @@
3
3
  This module provides abstract base classes for migration components.
4
4
  """
5
5
 
6
+ import hashlib
6
7
  import operator
7
8
  from abc import ABC, abstractmethod
8
9
  from pathlib import Path
9
- from typing import Any, Generic, Optional, TypeVar
10
+ from typing import Any, Generic, Optional, TypeVar, cast
10
11
 
11
12
  from sqlspec._sql import sql
13
+ from sqlspec.builder import Delete, Insert, Select
12
14
  from sqlspec.builder._ddl import CreateTable
13
- from sqlspec.core.statement import SQL
14
15
  from sqlspec.loader import SQLFileLoader
15
16
  from sqlspec.migrations.loaders import get_migration_loader
16
17
  from sqlspec.utils.logging import get_logger
17
- from sqlspec.utils.sync_tools import run_
18
+ from sqlspec.utils.sync_tools import await_
18
19
 
19
20
  __all__ = ("BaseMigrationCommands", "BaseMigrationRunner", "BaseMigrationTracker")
20
21
 
@@ -28,6 +29,8 @@ ConfigT = TypeVar("ConfigT")
28
29
  class BaseMigrationTracker(ABC, Generic[DriverT]):
29
30
  """Base class for migration version tracking."""
30
31
 
32
+ __slots__ = ("version_table",)
33
+
31
34
  def __init__(self, version_table_name: str = "ddl_migrations") -> None:
32
35
  """Initialize the migration tracker.
33
36
 
@@ -36,54 +39,43 @@ class BaseMigrationTracker(ABC, Generic[DriverT]):
36
39
  """
37
40
  self.version_table = version_table_name
38
41
 
39
- def _get_create_table_sql(self) -> SQL:
40
- """Get SQL for creating the tracking table.
42
+ def _get_create_table_sql(self) -> CreateTable:
43
+ """Get SQL builder for creating the tracking table.
41
44
 
42
45
  Returns:
43
- SQL object for table creation.
46
+ SQL builder object for table creation.
44
47
  """
45
- builder = CreateTable(self.version_table)
46
- if not hasattr(builder, "_columns"):
47
- builder._columns = []
48
- if not hasattr(builder, "_constraints"):
49
- builder._constraints = []
50
- if not hasattr(builder, "_table_options"):
51
- builder._table_options = {}
52
-
53
48
  return (
54
- builder.if_not_exists()
49
+ sql.create_table(self.version_table)
50
+ .if_not_exists()
55
51
  .column("version_num", "VARCHAR(32)", primary_key=True)
56
52
  .column("description", "TEXT")
57
- .column("applied_at", "TIMESTAMP", not_null=True, default="CURRENT_TIMESTAMP")
53
+ .column("applied_at", "TIMESTAMP", default="CURRENT_TIMESTAMP", not_null=True)
58
54
  .column("execution_time_ms", "INTEGER")
59
55
  .column("checksum", "VARCHAR(64)")
60
56
  .column("applied_by", "VARCHAR(255)")
61
- ).to_statement()
57
+ )
62
58
 
63
- def _get_current_version_sql(self) -> SQL:
64
- """Get SQL for retrieving current version.
59
+ def _get_current_version_sql(self) -> Select:
60
+ """Get SQL builder for retrieving current version.
65
61
 
66
62
  Returns:
67
- SQL object for version query.
63
+ SQL builder object for version query.
68
64
  """
65
+ return sql.select("version_num").from_(self.version_table).order_by("version_num DESC").limit(1)
69
66
 
70
- return (
71
- sql.select("version_num").from_(self.version_table).order_by("version_num DESC").limit(1)
72
- ).to_statement()
73
-
74
- def _get_applied_migrations_sql(self) -> SQL:
75
- """Get SQL for retrieving all applied migrations.
67
+ def _get_applied_migrations_sql(self) -> Select:
68
+ """Get SQL builder for retrieving all applied migrations.
76
69
 
77
70
  Returns:
78
- SQL object for migrations query.
71
+ SQL builder object for migrations query.
79
72
  """
80
-
81
- return (sql.select("*").from_(self.version_table).order_by("version_num")).to_statement()
73
+ return sql.select("*").from_(self.version_table).order_by("version_num")
82
74
 
83
75
  def _get_record_migration_sql(
84
76
  self, version: str, description: str, execution_time_ms: int, checksum: str, applied_by: str
85
- ) -> SQL:
86
- """Get SQL for recording a migration.
77
+ ) -> Insert:
78
+ """Get SQL builder for recording a migration.
87
79
 
88
80
  Args:
89
81
  version: Version number of the migration.
@@ -93,26 +85,24 @@ class BaseMigrationTracker(ABC, Generic[DriverT]):
93
85
  applied_by: User who applied the migration.
94
86
 
95
87
  Returns:
96
- SQL object for insert.
88
+ SQL builder object for insert.
97
89
  """
98
-
99
90
  return (
100
91
  sql.insert(self.version_table)
101
92
  .columns("version_num", "description", "execution_time_ms", "checksum", "applied_by")
102
93
  .values(version, description, execution_time_ms, checksum, applied_by)
103
- ).to_statement()
94
+ )
104
95
 
105
- def _get_remove_migration_sql(self, version: str) -> SQL:
106
- """Get SQL for removing a migration record.
96
+ def _get_remove_migration_sql(self, version: str) -> Delete:
97
+ """Get SQL builder for removing a migration record.
107
98
 
108
99
  Args:
109
100
  version: Version number to remove.
110
101
 
111
102
  Returns:
112
- SQL object for delete.
103
+ SQL builder object for delete.
113
104
  """
114
-
115
- return (sql.delete().from_(self.version_table).where(sql.version_num == version)).to_statement()
105
+ return sql.delete().from_(self.version_table).where(sql.version_num == version)
116
106
 
117
107
  @abstractmethod
118
108
  def ensure_tracking_table(self, driver: DriverT) -> Any:
@@ -176,7 +166,6 @@ class BaseMigrationRunner(ABC, Generic[DriverT]):
176
166
  Returns:
177
167
  MD5 checksum hex string.
178
168
  """
179
- import hashlib
180
169
 
181
170
  return hashlib.md5(content.encode()).hexdigest() # noqa: S324
182
171
 
@@ -226,7 +215,7 @@ class BaseMigrationRunner(ABC, Generic[DriverT]):
226
215
  has_upgrade, has_downgrade = self.loader.has_query(up_query), self.loader.has_query(down_query)
227
216
  else:
228
217
  try:
229
- has_downgrade = bool(run_(loader.get_down_sql)(file_path))
218
+ has_downgrade = bool(await_(loader.get_down_sql, raise_sync_error=False)(file_path))
230
219
  except Exception:
231
220
  has_downgrade = False
232
221
 
@@ -240,7 +229,7 @@ class BaseMigrationRunner(ABC, Generic[DriverT]):
240
229
  "loader": loader,
241
230
  }
242
231
 
243
- def _get_migration_sql(self, migration: "dict[str, Any]", direction: str) -> Optional[SQL]:
232
+ def _get_migration_sql(self, migration: "dict[str, Any]", direction: str) -> "Optional[list[str]]":
244
233
  """Get migration SQL for given direction.
245
234
 
246
235
  Args:
@@ -261,7 +250,7 @@ class BaseMigrationRunner(ABC, Generic[DriverT]):
261
250
 
262
251
  try:
263
252
  method = loader.get_up_sql if direction == "up" else loader.get_down_sql
264
- sql_statements = run_(method)(file_path)
253
+ sql_statements = await_(method, raise_sync_error=False)(file_path)
265
254
 
266
255
  except Exception as e:
267
256
  if direction == "down":
@@ -271,7 +260,7 @@ class BaseMigrationRunner(ABC, Generic[DriverT]):
271
260
  raise ValueError(msg) from e
272
261
  else:
273
262
  if sql_statements:
274
- return SQL(sql_statements[0])
263
+ return cast("list[str]", sql_statements)
275
264
  return None
276
265
 
277
266
  @abstractmethod
@@ -312,7 +301,7 @@ class BaseMigrationCommands(ABC, Generic[ConfigT, DriverT]):
312
301
  self.config = config
313
302
  migration_config = getattr(self.config, "migration_config", {}) or {}
314
303
 
315
- self.version_table = migration_config.get("version_table_name", "sqlspec_migrations")
304
+ self.version_table = migration_config.get("version_table_name", "ddl_migrations")
316
305
  self.migrations_path = Path(migration_config.get("script_location", "migrations"))
317
306
  self.project_root = Path(migration_config["project_root"]) if "project_root" in migration_config else None
318
307
 
@@ -3,7 +3,7 @@
3
3
  This module provides the main command interface for database migrations.
4
4
  """
5
5
 
6
- from typing import TYPE_CHECKING, Any, Union, cast
6
+ from typing import TYPE_CHECKING, Any, Optional, Union, cast
7
7
 
8
8
  from rich.console import Console
9
9
  from rich.table import Table
@@ -11,7 +11,6 @@ from rich.table import Table
11
11
  from sqlspec._sql import sql
12
12
  from sqlspec.migrations.base import BaseMigrationCommands
13
13
  from sqlspec.migrations.runner import AsyncMigrationRunner, SyncMigrationRunner
14
- from sqlspec.migrations.tracker import AsyncMigrationTracker, SyncMigrationTracker
15
14
  from sqlspec.migrations.utils import create_migration_file
16
15
  from sqlspec.utils.logging import get_logger
17
16
  from sqlspec.utils.sync_tools import await_
@@ -26,7 +25,7 @@ console = Console()
26
25
 
27
26
 
28
27
  class SyncMigrationCommands(BaseMigrationCommands["SyncConfigT", Any]):
29
- """SQLSpec native migration commands."""
28
+ """Synchronous migration commands."""
30
29
 
31
30
  def __init__(self, config: "SyncConfigT") -> None:
32
31
  """Initialize migration commands.
@@ -35,7 +34,7 @@ class SyncMigrationCommands(BaseMigrationCommands["SyncConfigT", Any]):
35
34
  config: The SQLSpec configuration.
36
35
  """
37
36
  super().__init__(config)
38
- self.tracker = SyncMigrationTracker(self.version_table)
37
+ self.tracker = config.migration_tracker_type(self.version_table)
39
38
  self.runner = SyncMigrationRunner(self.migrations_path)
40
39
 
41
40
  def init(self, directory: str, package: bool = True) -> None:
@@ -47,11 +46,14 @@ class SyncMigrationCommands(BaseMigrationCommands["SyncConfigT", Any]):
47
46
  """
48
47
  self.init_directory(directory, package)
49
48
 
50
- def current(self, verbose: bool = False) -> None:
49
+ def current(self, verbose: bool = False) -> "Optional[str]":
51
50
  """Show current migration version.
52
51
 
53
52
  Args:
54
53
  verbose: Whether to show detailed migration history.
54
+
55
+ Returns:
56
+ The current migration version or None if no migrations applied.
55
57
  """
56
58
  with self.config.provide_session() as driver:
57
59
  self.tracker.ensure_tracking_table(driver)
@@ -59,7 +61,7 @@ class SyncMigrationCommands(BaseMigrationCommands["SyncConfigT", Any]):
59
61
  current = self.tracker.get_current_version(driver)
60
62
  if not current:
61
63
  console.print("[yellow]No migrations applied yet[/]")
62
- return
64
+ return None
63
65
 
64
66
  console.print(f"[green]Current version:[/] {current}")
65
67
 
@@ -84,6 +86,8 @@ class SyncMigrationCommands(BaseMigrationCommands["SyncConfigT", Any]):
84
86
 
85
87
  console.print(table)
86
88
 
89
+ return cast("Optional[str]", current)
90
+
87
91
  def upgrade(self, revision: str = "head") -> None:
88
92
  """Upgrade to a target revision.
89
93
 
@@ -137,6 +141,8 @@ class SyncMigrationCommands(BaseMigrationCommands["SyncConfigT", Any]):
137
141
  to_revert = []
138
142
  if revision == "-1":
139
143
  to_revert = [applied[-1]]
144
+ elif revision == "base":
145
+ to_revert = list(reversed(applied))
140
146
  else:
141
147
  for migration in reversed(applied):
142
148
  if migration["version_num"] > revision:
@@ -195,7 +201,7 @@ class SyncMigrationCommands(BaseMigrationCommands["SyncConfigT", Any]):
195
201
 
196
202
 
197
203
  class AsyncMigrationCommands(BaseMigrationCommands["AsyncConfigT", Any]):
198
- """SQLSpec native migration commands."""
204
+ """Asynchronous migration commands."""
199
205
 
200
206
  def __init__(self, sqlspec_config: "AsyncConfigT") -> None:
201
207
  """Initialize migration commands.
@@ -204,7 +210,7 @@ class AsyncMigrationCommands(BaseMigrationCommands["AsyncConfigT", Any]):
204
210
  sqlspec_config: The SQLSpec configuration.
205
211
  """
206
212
  super().__init__(sqlspec_config)
207
- self.tracker = AsyncMigrationTracker(self.version_table)
213
+ self.tracker = sqlspec_config.migration_tracker_type(self.version_table)
208
214
  self.runner = AsyncMigrationRunner(self.migrations_path)
209
215
 
210
216
  async def init(self, directory: str, package: bool = True) -> None:
@@ -216,11 +222,14 @@ class AsyncMigrationCommands(BaseMigrationCommands["AsyncConfigT", Any]):
216
222
  """
217
223
  self.init_directory(directory, package)
218
224
 
219
- async def current(self, verbose: bool = False) -> None:
225
+ async def current(self, verbose: bool = False) -> "Optional[str]":
220
226
  """Show current migration version.
221
227
 
222
228
  Args:
223
229
  verbose: Whether to show detailed migration history.
230
+
231
+ Returns:
232
+ The current migration version or None if no migrations applied.
224
233
  """
225
234
  async with self.config.provide_session() as driver:
226
235
  await self.tracker.ensure_tracking_table(driver)
@@ -228,7 +237,7 @@ class AsyncMigrationCommands(BaseMigrationCommands["AsyncConfigT", Any]):
228
237
  current = await self.tracker.get_current_version(driver)
229
238
  if not current:
230
239
  console.print("[yellow]No migrations applied yet[/]")
231
- return
240
+ return None
232
241
 
233
242
  console.print(f"[green]Current version:[/] {current}")
234
243
  if verbose:
@@ -249,6 +258,8 @@ class AsyncMigrationCommands(BaseMigrationCommands["AsyncConfigT", Any]):
249
258
  )
250
259
  console.print(table)
251
260
 
261
+ return cast("Optional[str]", current)
262
+
252
263
  async def upgrade(self, revision: str = "head") -> None:
253
264
  """Upgrade to a target revision.
254
265
 
@@ -297,6 +308,8 @@ class AsyncMigrationCommands(BaseMigrationCommands["AsyncConfigT", Any]):
297
308
  to_revert = []
298
309
  if revision == "-1":
299
310
  to_revert = [applied[-1]]
311
+ elif revision == "base":
312
+ to_revert = list(reversed(applied))
300
313
  else:
301
314
  for migration in reversed(applied):
302
315
  if migration["version_num"] > revision:
@@ -382,20 +395,26 @@ class MigrationCommands:
382
395
  package: Whether to create __init__.py file.
383
396
  """
384
397
  if self._is_async:
385
- await_(cast("AsyncMigrationCommands[Any]", self._impl).init)(directory, package=package)
398
+ await_(cast("AsyncMigrationCommands[Any]", self._impl).init, raise_sync_error=False)(
399
+ directory, package=package
400
+ )
386
401
  else:
387
402
  cast("SyncMigrationCommands[Any]", self._impl).init(directory, package=package)
388
403
 
389
- def current(self, verbose: bool = False) -> None:
404
+ def current(self, verbose: bool = False) -> "Optional[str]":
390
405
  """Show current migration version.
391
406
 
392
407
  Args:
393
408
  verbose: Whether to show detailed migration history.
409
+
410
+ Returns:
411
+ The current migration version or None if no migrations applied.
394
412
  """
395
413
  if self._is_async:
396
- await_(cast("AsyncMigrationCommands[Any]", self._impl).current, raise_sync_error=False)(verbose=verbose)
397
- else:
398
- cast("SyncMigrationCommands[Any]", self._impl).current(verbose=verbose)
414
+ return await_(cast("AsyncMigrationCommands[Any]", self._impl).current, raise_sync_error=False)(
415
+ verbose=verbose
416
+ )
417
+ return cast("SyncMigrationCommands[Any]", self._impl).current(verbose=verbose)
399
418
 
400
419
  def upgrade(self, revision: str = "head") -> None:
401
420
  """Upgrade to a target revision.
@@ -70,7 +70,7 @@ class BaseMigrationLoader(abc.ABC):
70
70
 
71
71
 
72
72
  class SQLFileLoader(BaseMigrationLoader):
73
- """Loader for SQL migration files using SQLFileLoader."""
73
+ """Loader for SQL migration files."""
74
74
 
75
75
  __slots__ = ("sql_loader",)
76
76