sqlspec 0.13.1__py3-none-any.whl → 0.14.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of sqlspec might be problematic. Click here for more details.

Files changed (112) hide show
  1. sqlspec/__init__.py +39 -1
  2. sqlspec/__main__.py +12 -0
  3. sqlspec/adapters/adbc/config.py +16 -40
  4. sqlspec/adapters/adbc/driver.py +43 -16
  5. sqlspec/adapters/adbc/transformers.py +108 -0
  6. sqlspec/adapters/aiosqlite/config.py +2 -20
  7. sqlspec/adapters/aiosqlite/driver.py +36 -18
  8. sqlspec/adapters/asyncmy/config.py +2 -33
  9. sqlspec/adapters/asyncmy/driver.py +23 -16
  10. sqlspec/adapters/asyncpg/config.py +5 -39
  11. sqlspec/adapters/asyncpg/driver.py +41 -18
  12. sqlspec/adapters/bigquery/config.py +2 -43
  13. sqlspec/adapters/bigquery/driver.py +26 -14
  14. sqlspec/adapters/duckdb/config.py +2 -49
  15. sqlspec/adapters/duckdb/driver.py +35 -16
  16. sqlspec/adapters/oracledb/config.py +4 -83
  17. sqlspec/adapters/oracledb/driver.py +54 -27
  18. sqlspec/adapters/psqlpy/config.py +2 -55
  19. sqlspec/adapters/psqlpy/driver.py +28 -8
  20. sqlspec/adapters/psycopg/config.py +4 -73
  21. sqlspec/adapters/psycopg/driver.py +69 -24
  22. sqlspec/adapters/sqlite/config.py +3 -21
  23. sqlspec/adapters/sqlite/driver.py +50 -26
  24. sqlspec/cli.py +248 -0
  25. sqlspec/config.py +18 -20
  26. sqlspec/driver/_async.py +28 -10
  27. sqlspec/driver/_common.py +5 -4
  28. sqlspec/driver/_sync.py +28 -10
  29. sqlspec/driver/mixins/__init__.py +6 -0
  30. sqlspec/driver/mixins/_cache.py +114 -0
  31. sqlspec/driver/mixins/_pipeline.py +0 -4
  32. sqlspec/{service/base.py → driver/mixins/_query_tools.py} +86 -421
  33. sqlspec/driver/mixins/_result_utils.py +0 -2
  34. sqlspec/driver/mixins/_sql_translator.py +0 -2
  35. sqlspec/driver/mixins/_storage.py +4 -18
  36. sqlspec/driver/mixins/_type_coercion.py +0 -2
  37. sqlspec/driver/parameters.py +4 -4
  38. sqlspec/extensions/aiosql/adapter.py +4 -4
  39. sqlspec/extensions/litestar/__init__.py +2 -1
  40. sqlspec/extensions/litestar/cli.py +48 -0
  41. sqlspec/extensions/litestar/plugin.py +3 -0
  42. sqlspec/loader.py +1 -1
  43. sqlspec/migrations/__init__.py +23 -0
  44. sqlspec/migrations/base.py +390 -0
  45. sqlspec/migrations/commands.py +525 -0
  46. sqlspec/migrations/runner.py +215 -0
  47. sqlspec/migrations/tracker.py +153 -0
  48. sqlspec/migrations/utils.py +89 -0
  49. sqlspec/protocols.py +37 -3
  50. sqlspec/statement/builder/__init__.py +8 -8
  51. sqlspec/statement/builder/{column.py → _column.py} +82 -52
  52. sqlspec/statement/builder/{ddl.py → _ddl.py} +5 -5
  53. sqlspec/statement/builder/_ddl_utils.py +1 -1
  54. sqlspec/statement/builder/{delete.py → _delete.py} +1 -1
  55. sqlspec/statement/builder/{insert.py → _insert.py} +1 -1
  56. sqlspec/statement/builder/{merge.py → _merge.py} +1 -1
  57. sqlspec/statement/builder/_parsing_utils.py +5 -3
  58. sqlspec/statement/builder/{select.py → _select.py} +59 -61
  59. sqlspec/statement/builder/{update.py → _update.py} +2 -2
  60. sqlspec/statement/builder/mixins/__init__.py +24 -30
  61. sqlspec/statement/builder/mixins/{_set_ops.py → _cte_and_set_ops.py} +86 -2
  62. sqlspec/statement/builder/mixins/{_delete_from.py → _delete_operations.py} +2 -0
  63. sqlspec/statement/builder/mixins/{_insert_values.py → _insert_operations.py} +70 -1
  64. sqlspec/statement/builder/mixins/{_merge_clauses.py → _merge_operations.py} +2 -0
  65. sqlspec/statement/builder/mixins/_order_limit_operations.py +123 -0
  66. sqlspec/statement/builder/mixins/{_pivot.py → _pivot_operations.py} +71 -2
  67. sqlspec/statement/builder/mixins/_select_operations.py +612 -0
  68. sqlspec/statement/builder/mixins/{_update_set.py → _update_operations.py} +73 -2
  69. sqlspec/statement/builder/mixins/_where_clause.py +536 -0
  70. sqlspec/statement/cache.py +50 -0
  71. sqlspec/statement/filters.py +37 -8
  72. sqlspec/statement/parameters.py +143 -54
  73. sqlspec/statement/pipelines/__init__.py +1 -1
  74. sqlspec/statement/pipelines/context.py +4 -10
  75. sqlspec/statement/pipelines/transformers/_expression_simplifier.py +3 -3
  76. sqlspec/statement/pipelines/validators/_parameter_style.py +22 -22
  77. sqlspec/statement/pipelines/validators/_performance.py +1 -5
  78. sqlspec/statement/sql.py +246 -176
  79. sqlspec/utils/__init__.py +2 -1
  80. sqlspec/utils/statement_hashing.py +203 -0
  81. sqlspec/utils/type_guards.py +32 -0
  82. {sqlspec-0.13.1.dist-info → sqlspec-0.14.1.dist-info}/METADATA +1 -1
  83. sqlspec-0.14.1.dist-info/RECORD +145 -0
  84. sqlspec-0.14.1.dist-info/entry_points.txt +2 -0
  85. sqlspec/service/__init__.py +0 -4
  86. sqlspec/service/_util.py +0 -147
  87. sqlspec/service/pagination.py +0 -26
  88. sqlspec/statement/builder/mixins/_aggregate_functions.py +0 -250
  89. sqlspec/statement/builder/mixins/_case_builder.py +0 -91
  90. sqlspec/statement/builder/mixins/_common_table_expr.py +0 -90
  91. sqlspec/statement/builder/mixins/_from.py +0 -63
  92. sqlspec/statement/builder/mixins/_group_by.py +0 -118
  93. sqlspec/statement/builder/mixins/_having.py +0 -35
  94. sqlspec/statement/builder/mixins/_insert_from_select.py +0 -47
  95. sqlspec/statement/builder/mixins/_insert_into.py +0 -36
  96. sqlspec/statement/builder/mixins/_limit_offset.py +0 -53
  97. sqlspec/statement/builder/mixins/_order_by.py +0 -46
  98. sqlspec/statement/builder/mixins/_returning.py +0 -37
  99. sqlspec/statement/builder/mixins/_select_columns.py +0 -61
  100. sqlspec/statement/builder/mixins/_unpivot.py +0 -77
  101. sqlspec/statement/builder/mixins/_update_from.py +0 -55
  102. sqlspec/statement/builder/mixins/_update_table.py +0 -29
  103. sqlspec/statement/builder/mixins/_where.py +0 -401
  104. sqlspec/statement/builder/mixins/_window_functions.py +0 -86
  105. sqlspec/statement/parameter_manager.py +0 -220
  106. sqlspec/statement/sql_compiler.py +0 -140
  107. sqlspec-0.13.1.dist-info/RECORD +0 -150
  108. /sqlspec/statement/builder/{base.py → _base.py} +0 -0
  109. /sqlspec/statement/builder/mixins/{_join.py → _join_operations.py} +0 -0
  110. {sqlspec-0.13.1.dist-info → sqlspec-0.14.1.dist-info}/WHEEL +0 -0
  111. {sqlspec-0.13.1.dist-info → sqlspec-0.14.1.dist-info}/licenses/LICENSE +0 -0
  112. {sqlspec-0.13.1.dist-info → sqlspec-0.14.1.dist-info}/licenses/NOTICE +0 -0
@@ -52,8 +52,6 @@ def _default_msgspec_deserializer(
52
52
 
53
53
 
54
54
  class ToSchemaMixin:
55
- __slots__ = ()
56
-
57
55
  @staticmethod
58
56
  def _determine_operation_type(statement: "Any") -> OperationType:
59
57
  """Determine operation type from SQL statement expression.
@@ -10,8 +10,6 @@ __all__ = ("SQLTranslatorMixin",)
10
10
  class SQLTranslatorMixin:
11
11
  """Mixin for drivers supporting SQL translation."""
12
12
 
13
- __slots__ = ()
14
-
15
13
  def convert_to_dialect(self, statement: "Statement", to_dialect: DialectType = None, pretty: bool = True) -> str:
16
14
  parsed_expression: exp.Expression
17
15
  if statement is not None and isinstance(statement, SQL):
@@ -44,8 +44,6 @@ WINDOWS_PATH_MIN_LENGTH = 3
44
44
  class StorageMixinBase(ABC):
45
45
  """Base class with common storage functionality."""
46
46
 
47
- __slots__ = ()
48
-
49
47
  config: Any
50
48
  _connection: Any
51
49
  dialect: "DialectType"
@@ -144,8 +142,6 @@ class StorageMixinBase(ABC):
144
142
  class SyncStorageMixin(StorageMixinBase):
145
143
  """Unified storage operations for synchronous drivers."""
146
144
 
147
- __slots__ = ()
148
-
149
145
  def ingest_arrow_table(self, table: "ArrowTable", table_name: str, mode: str = "create", **options: Any) -> int:
150
146
  """Ingest an Arrow table into the database.
151
147
 
@@ -211,7 +207,7 @@ class SyncStorageMixin(StorageMixinBase):
211
207
  # disable parameter validation entirely to allow transformer-added parameters
212
208
  if params is None and _config and _config.enable_transformations:
213
209
  # Disable validation entirely for transformer-generated parameters
214
- _config = replace(_config, strict_mode=False, enable_validation=False)
210
+ _config = replace(_config, enable_validation=False)
215
211
 
216
212
  # Only pass params if it's not None to avoid adding None as a parameter
217
213
  if params is not None:
@@ -294,12 +290,8 @@ class SyncStorageMixin(StorageMixinBase):
294
290
  _config = self.config
295
291
  if _config and not _config.dialect:
296
292
  _config = replace(_config, dialect=self.dialect)
297
- if _config and _config.enable_transformations:
298
- _config = replace(_config, enable_transformations=False)
299
293
 
300
- sql = (
301
- SQL(statement, parameters=params, config=_config) if params is not None else SQL(statement, config=_config)
302
- )
294
+ sql = SQL(statement, *params, config=_config) if params else SQL(statement, config=_config)
303
295
  for filter_ in filters:
304
296
  sql = sql.filter(filter_)
305
297
 
@@ -576,8 +568,6 @@ class SyncStorageMixin(StorageMixinBase):
576
568
  class AsyncStorageMixin(StorageMixinBase):
577
569
  """Unified storage operations for asynchronous drivers."""
578
570
 
579
- __slots__ = ()
580
-
581
571
  async def ingest_arrow_table(
582
572
  self, table: "ArrowTable", table_name: str, mode: str = "create", **options: Any
583
573
  ) -> int:
@@ -650,7 +640,7 @@ class AsyncStorageMixin(StorageMixinBase):
650
640
  # disable parameter validation entirely to allow transformer-added parameters
651
641
  if params is None and _config and _config.enable_transformations:
652
642
  # Disable validation entirely for transformer-generated parameters
653
- _config = replace(_config, strict_mode=False, enable_validation=False)
643
+ _config = replace(_config, enable_validation=False)
654
644
 
655
645
  # Only pass params if it's not None to avoid adding None as a parameter
656
646
  if params is not None:
@@ -703,12 +693,8 @@ class AsyncStorageMixin(StorageMixinBase):
703
693
  _config = self.config
704
694
  if _config and not _config.dialect:
705
695
  _config = replace(_config, dialect=self.dialect)
706
- if _config and _config.enable_transformations:
707
- _config = replace(_config, enable_transformations=False)
708
696
 
709
- sql = (
710
- SQL(statement, parameters=params, config=_config) if params is not None else SQL(statement, config=_config)
711
- )
697
+ sql = SQL(statement, *params, config=_config) if params else SQL(statement, config=_config)
712
698
  for filter_ in filters:
713
699
  sql = sql.filter(filter_)
714
700
 
@@ -22,8 +22,6 @@ class TypeCoercionMixin:
22
22
  and convert values to database-specific types.
23
23
  """
24
24
 
25
- __slots__ = ()
26
-
27
25
  def _process_parameters(self, parameters: "SQLParameterType") -> "SQLParameterType":
28
26
  """Process parameters, extracting values from TypedParameter objects.
29
27
 
@@ -13,8 +13,8 @@ if TYPE_CHECKING:
13
13
  from sqlspec.typing import StatementParameters
14
14
 
15
15
  __all__ = (
16
+ "convert_parameter_sequence",
16
17
  "convert_parameters_to_positional",
17
- "normalize_parameter_sequence",
18
18
  "process_execute_many_parameters",
19
19
  "separate_filters_and_parameters",
20
20
  "should_use_transaction",
@@ -62,19 +62,19 @@ def process_execute_many_parameters(
62
62
  param_sequence = param_values[0] if param_values else None
63
63
 
64
64
  # Normalize the parameter sequence
65
- param_sequence = normalize_parameter_sequence(param_sequence)
65
+ param_sequence = convert_parameter_sequence(param_sequence)
66
66
 
67
67
  return filters, param_sequence
68
68
 
69
69
 
70
- def normalize_parameter_sequence(params: Any) -> Optional[list[Any]]:
70
+ def convert_parameter_sequence(params: Any) -> Optional[list[Any]]:
71
71
  """Normalize a parameter sequence to a list format.
72
72
 
73
73
  Args:
74
74
  params: Parameter sequence in various formats
75
75
 
76
76
  Returns:
77
- Normalized list of parameters or None
77
+ converted list of parameters or None
78
78
  """
79
79
  if params is None:
80
80
  return None
@@ -38,7 +38,7 @@ def _normalize_dialect(dialect: "Union[str, Any, None]") -> str:
38
38
  dialect: Original dialect name (can be str, Dialect, type[Dialect], or None)
39
39
 
40
40
  Returns:
41
- Normalized dialect name
41
+ converted dialect name
42
42
  """
43
43
  if dialect is None:
44
44
  return "sql"
@@ -84,9 +84,9 @@ class _AiosqlAdapterBase:
84
84
 
85
85
  def _create_sql_object(self, sql: str, parameters: "Any" = None) -> SQL:
86
86
  """Create SQL object with proper configuration."""
87
- config = SQLConfig(strict_mode=False, enable_validation=False)
88
- normalized_dialect = _normalize_dialect(self.driver.dialect)
89
- return SQL(sql, parameters, config=config, dialect=normalized_dialect)
87
+ config = SQLConfig(enable_validation=False)
88
+ converted_dialect = _normalize_dialect(self.driver.dialect)
89
+ return SQL(sql, parameters, config=config, dialect=converted_dialect)
90
90
 
91
91
 
92
92
  class AiosqlSyncAdapter(_AiosqlAdapterBase):
@@ -1,5 +1,6 @@
1
1
  from sqlspec.extensions.litestar import handlers, providers
2
+ from sqlspec.extensions.litestar.cli import database_group
2
3
  from sqlspec.extensions.litestar.config import DatabaseConfig
3
4
  from sqlspec.extensions.litestar.plugin import SQLSpec
4
5
 
5
- __all__ = ("DatabaseConfig", "SQLSpec", "handlers", "providers")
6
+ __all__ = ("DatabaseConfig", "SQLSpec", "database_group", "handlers", "providers")
@@ -0,0 +1,48 @@
1
+ """Litestar CLI integration for SQLSpec migrations."""
2
+
3
+ from contextlib import suppress
4
+ from typing import TYPE_CHECKING
5
+
6
+ from litestar.cli._utils import LitestarGroup
7
+
8
+ from sqlspec.cli import add_migration_commands
9
+
10
+ try:
11
+ import rich_click as click
12
+ except ImportError:
13
+ import click # type: ignore[no-redef]
14
+
15
+ if TYPE_CHECKING:
16
+ from litestar import Litestar
17
+
18
+ from sqlspec.extensions.litestar.plugin import SQLSpec
19
+
20
+
21
+ def get_database_migration_plugin(app: "Litestar") -> "SQLSpec":
22
+ """Retrieve the SQLSpec plugin from the Litestar application's plugins.
23
+
24
+ Args:
25
+ app: The Litestar application
26
+
27
+ Returns:
28
+ The SQLSpec plugin
29
+
30
+ Raises:
31
+ ImproperConfigurationError: If the SQLSpec plugin is not found
32
+ """
33
+ from sqlspec.exceptions import ImproperConfigurationError
34
+ from sqlspec.extensions.litestar.plugin import SQLSpec
35
+
36
+ with suppress(KeyError):
37
+ return app.plugins.get(SQLSpec)
38
+ msg = "Failed to initialize database migrations. The required SQLSpec plugin is missing."
39
+ raise ImproperConfigurationError(msg)
40
+
41
+
42
+ @click.group(cls=LitestarGroup, name="database")
43
+ def database_group(ctx: "click.Context") -> None:
44
+ """Manage SQLSpec database components."""
45
+ ctx.obj = {"app": ctx.obj, "configs": get_database_migration_plugin(ctx.obj.app).config}
46
+
47
+
48
+ add_migration_commands(database_group)
@@ -47,6 +47,9 @@ class SQLSpec(InitPluginProtocol, SQLSpecBase):
47
47
 
48
48
  def on_cli_init(self, cli: "Group") -> None:
49
49
  """Configure the CLI for use with SQLSpec."""
50
+ from sqlspec.extensions.litestar.cli import database_group
51
+
52
+ cli.add_command(database_group)
50
53
 
51
54
  def on_app_init(self, app_config: "AppConfig") -> "AppConfig":
52
55
  """Configure application for use with SQLSpec.
sqlspec/loader.py CHANGED
@@ -40,7 +40,7 @@ def _normalize_query_name(name: str) -> str:
40
40
  name: Raw query name from SQL file
41
41
 
42
42
  Returns:
43
- Normalized query name suitable as Python identifier
43
+ converted query name suitable as Python identifier
44
44
  """
45
45
  # Strip trailing non-alphanumeric characters (excluding underscore) and replace hyphens
46
46
  return TRIM_TRAILING_SPECIAL_CHARS.sub("", name).replace("-", "_")
@@ -0,0 +1,23 @@
1
+ """SQLSpec Migration Tool.
2
+
3
+ A native migration system for SQLSpec that leverages the SQLFileLoader
4
+ and driver architecture for database versioning.
5
+ """
6
+
7
+ from sqlspec.migrations.commands import AsyncMigrationCommands, MigrationCommands, SyncMigrationCommands
8
+ from sqlspec.migrations.runner import AsyncMigrationRunner, SyncMigrationRunner
9
+ from sqlspec.migrations.tracker import AsyncMigrationTracker, SyncMigrationTracker
10
+ from sqlspec.migrations.utils import create_migration_file, drop_all, get_author
11
+
12
+ __all__ = (
13
+ "AsyncMigrationCommands",
14
+ "AsyncMigrationRunner",
15
+ "AsyncMigrationTracker",
16
+ "MigrationCommands",
17
+ "SyncMigrationCommands",
18
+ "SyncMigrationRunner",
19
+ "SyncMigrationTracker",
20
+ "create_migration_file",
21
+ "drop_all",
22
+ "get_author",
23
+ )
@@ -0,0 +1,390 @@
1
+ """Base classes for SQLSpec migrations.
2
+
3
+ This module provides abstract base classes for migration components.
4
+ """
5
+
6
+ from abc import ABC, abstractmethod
7
+ from pathlib import Path
8
+ from typing import Any, Generic, Optional, TypeVar
9
+
10
+ from sqlspec.loader import SQLFileLoader
11
+ from sqlspec.statement.sql import SQL
12
+ from sqlspec.utils.logging import get_logger
13
+
14
+ __all__ = ("BaseMigrationCommands", "BaseMigrationRunner", "BaseMigrationTracker")
15
+
16
+
17
+ logger = get_logger("migrations.base")
18
+
19
+ # Type variables for generic driver and config types
20
+ DriverT = TypeVar("DriverT")
21
+ ConfigT = TypeVar("ConfigT")
22
+
23
+
24
+ class BaseMigrationTracker(ABC, Generic[DriverT]):
25
+ """Base class for migration version tracking."""
26
+
27
+ def __init__(self, version_table_name: str = "ddl_migrations") -> None:
28
+ """Initialize the migration tracker.
29
+
30
+ Args:
31
+ version_table_name: Name of the table to track migrations.
32
+ """
33
+ self.version_table = version_table_name
34
+
35
+ def _get_create_table_sql(self) -> SQL:
36
+ """Get SQL for creating the tracking table.
37
+
38
+ Returns:
39
+ SQL object for table creation.
40
+ """
41
+ return SQL(
42
+ f"""
43
+ CREATE TABLE IF NOT EXISTS {self.version_table} (
44
+ version_num VARCHAR(32) PRIMARY KEY,
45
+ description TEXT,
46
+ applied_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
47
+ execution_time_ms INTEGER,
48
+ checksum VARCHAR(64),
49
+ applied_by VARCHAR(255)
50
+ )
51
+ """
52
+ )
53
+
54
+ def _get_current_version_sql(self) -> SQL:
55
+ """Get SQL for retrieving current version.
56
+
57
+ Returns:
58
+ SQL object for version query.
59
+ """
60
+ return SQL(f"SELECT version_num FROM {self.version_table} ORDER BY version_num DESC LIMIT 1")
61
+
62
+ def _get_applied_migrations_sql(self) -> SQL:
63
+ """Get SQL for retrieving all applied migrations.
64
+
65
+ Returns:
66
+ SQL object for migrations query.
67
+ """
68
+ return SQL(f"SELECT * FROM {self.version_table} ORDER BY version_num")
69
+
70
+ def _get_record_migration_sql(
71
+ self, version: str, description: str, execution_time_ms: int, checksum: str, applied_by: str
72
+ ) -> SQL:
73
+ """Get SQL for recording a migration.
74
+
75
+ Args:
76
+ version: Version number of the migration.
77
+ description: Description of the migration.
78
+ execution_time_ms: Execution time in milliseconds.
79
+ checksum: MD5 checksum of the migration content.
80
+ applied_by: User who applied the migration.
81
+
82
+ Returns:
83
+ SQL object for insert.
84
+ """
85
+ return SQL(
86
+ f"INSERT INTO {self.version_table} (version_num, description, execution_time_ms, checksum, applied_by) VALUES (?, ?, ?, ?, ?)",
87
+ version,
88
+ description,
89
+ execution_time_ms,
90
+ checksum,
91
+ applied_by,
92
+ )
93
+
94
+ def _get_remove_migration_sql(self, version: str) -> SQL:
95
+ """Get SQL for removing a migration record.
96
+
97
+ Args:
98
+ version: Version number to remove.
99
+
100
+ Returns:
101
+ SQL object for delete.
102
+ """
103
+ return SQL(f"DELETE FROM {self.version_table} WHERE version_num = ?", version)
104
+
105
+ @abstractmethod
106
+ def ensure_tracking_table(self, driver: DriverT) -> Any:
107
+ """Create the migration tracking table if it doesn't exist."""
108
+ ...
109
+
110
+ @abstractmethod
111
+ def get_current_version(self, driver: DriverT) -> Any:
112
+ """Get the latest applied migration version."""
113
+ ...
114
+
115
+ @abstractmethod
116
+ def get_applied_migrations(self, driver: DriverT) -> Any:
117
+ """Get all applied migrations in order."""
118
+ ...
119
+
120
+ @abstractmethod
121
+ def record_migration(
122
+ self, driver: DriverT, version: str, description: str, execution_time_ms: int, checksum: str
123
+ ) -> Any:
124
+ """Record a successfully applied migration."""
125
+ ...
126
+
127
+ @abstractmethod
128
+ def remove_migration(self, driver: DriverT, version: str) -> Any:
129
+ """Remove a migration record."""
130
+ ...
131
+
132
+
133
+ class BaseMigrationRunner(ABC, Generic[DriverT]):
134
+ """Base class for migration execution."""
135
+
136
+ def __init__(self, migrations_path: Path) -> None:
137
+ """Initialize the migration runner.
138
+
139
+ Args:
140
+ migrations_path: Path to the directory containing migration files.
141
+ """
142
+ self.migrations_path = migrations_path
143
+ self.loader = SQLFileLoader()
144
+
145
+ def _extract_version(self, filename: str) -> Optional[str]:
146
+ """Extract version from filename (e.g., '0001_initial.sql' -> '0001').
147
+
148
+ Args:
149
+ filename: The migration filename.
150
+
151
+ Returns:
152
+ The extracted version string or None.
153
+ """
154
+ parts = filename.split("_", 1)
155
+ if parts and parts[0].isdigit():
156
+ return parts[0].zfill(4)
157
+ return None
158
+
159
+ def _calculate_checksum(self, content: str) -> str:
160
+ """Calculate MD5 checksum of migration content.
161
+
162
+ Args:
163
+ content: The migration file content.
164
+
165
+ Returns:
166
+ The MD5 checksum hex string.
167
+ """
168
+ import hashlib
169
+
170
+ return hashlib.md5(content.encode()).hexdigest() # noqa: S324
171
+
172
+ def _get_migration_files_sync(self) -> "list[tuple[str, Path]]":
173
+ """Get all migration files sorted by version (sync version).
174
+
175
+ Returns:
176
+ List of tuples containing (version, file_path).
177
+ """
178
+ if not self.migrations_path.exists():
179
+ return []
180
+
181
+ migrations = []
182
+ for file_path in self.migrations_path.glob("*.sql"):
183
+ if file_path.name.startswith("."):
184
+ continue
185
+ version = self._extract_version(file_path.name)
186
+ if version:
187
+ migrations.append((version, file_path))
188
+
189
+ return sorted(migrations, key=lambda x: x[0])
190
+
191
+ def _load_migration_metadata(self, file_path: Path) -> "dict[str, Any]":
192
+ """Load migration metadata from file.
193
+
194
+ Args:
195
+ file_path: Path to the migration file.
196
+
197
+ Returns:
198
+ Dictionary containing migration metadata.
199
+ """
200
+ self.loader.clear_cache()
201
+ self.loader.load_sql(file_path)
202
+
203
+ # Read raw content for checksum
204
+ content = file_path.read_text()
205
+ checksum = self._calculate_checksum(content)
206
+
207
+ # Extract metadata
208
+ version = self._extract_version(file_path.name)
209
+ description = file_path.stem.split("_", 1)[1] if "_" in file_path.stem else ""
210
+
211
+ # Query names use versioned pattern
212
+ up_query = f"migrate-{version}-up"
213
+ down_query = f"migrate-{version}-down"
214
+
215
+ return {
216
+ "version": version,
217
+ "description": description,
218
+ "file_path": file_path,
219
+ "checksum": checksum,
220
+ "up_query": up_query,
221
+ "down_query": down_query,
222
+ "has_upgrade": self.loader.has_query(up_query),
223
+ "has_downgrade": self.loader.has_query(down_query),
224
+ }
225
+
226
+ def _get_migration_sql(self, migration: "dict[str, Any]", direction: str) -> Optional[SQL]:
227
+ """Get migration SQL for given direction.
228
+
229
+ Args:
230
+ migration: Migration metadata.
231
+ direction: Either 'up' or 'down'.
232
+
233
+ Returns:
234
+ SQL object for the migration.
235
+ """
236
+ query_key = f"{direction}_query"
237
+ has_key = f"has_{direction}grade"
238
+
239
+ if not migration.get(has_key):
240
+ if direction == "down":
241
+ logger.warning("Migration %s has no downgrade query", migration["version"])
242
+ return None
243
+ msg = f"Migration {migration['version']} has no upgrade query"
244
+ raise ValueError(msg)
245
+
246
+ return self.loader.get_sql(migration[query_key])
247
+
248
+ @abstractmethod
249
+ def get_migration_files(self) -> Any:
250
+ """Get all migration files sorted by version."""
251
+ ...
252
+
253
+ @abstractmethod
254
+ def load_migration(self, file_path: Path) -> Any:
255
+ """Load a migration file and extract its components."""
256
+ ...
257
+
258
+ @abstractmethod
259
+ def execute_upgrade(self, driver: DriverT, migration: "dict[str, Any]") -> Any:
260
+ """Execute an upgrade migration."""
261
+ ...
262
+
263
+ @abstractmethod
264
+ def execute_downgrade(self, driver: DriverT, migration: "dict[str, Any]") -> Any:
265
+ """Execute a downgrade migration."""
266
+ ...
267
+
268
+ @abstractmethod
269
+ def load_all_migrations(self) -> Any:
270
+ """Load all migrations into a single namespace for bulk operations."""
271
+ ...
272
+
273
+
274
+ class BaseMigrationCommands(ABC, Generic[ConfigT, DriverT]):
275
+ """Base class for migration commands."""
276
+
277
+ def __init__(self, config: ConfigT) -> None:
278
+ """Initialize migration commands.
279
+
280
+ Args:
281
+ config: The SQLSpec configuration.
282
+ """
283
+ self.config = config
284
+
285
+ # Get migration settings from config
286
+ migration_config = getattr(self.config, "migration_config", {})
287
+ if migration_config is None:
288
+ migration_config = {}
289
+
290
+ self.version_table = migration_config.get("version_table_name", "sqlspec_migrations")
291
+ self.migrations_path = Path(migration_config.get("script_location", "migrations"))
292
+
293
+ def _get_init_readme_content(self) -> str:
294
+ """Get the README content for migration directory initialization.
295
+
296
+ Returns:
297
+ The README markdown content.
298
+ """
299
+ return """# SQLSpec Migrations
300
+
301
+ This directory contains database migration files.
302
+
303
+ ## File Format
304
+
305
+ Migration files use SQLFileLoader's named query syntax with versioned names:
306
+
307
+ ```sql
308
+ -- name: migrate-0001-up
309
+ CREATE TABLE example (
310
+ id INTEGER PRIMARY KEY,
311
+ name TEXT NOT NULL
312
+ );
313
+
314
+ -- name: migrate-0001-down
315
+ DROP TABLE example;
316
+ ```
317
+
318
+ ## Naming Conventions
319
+
320
+ ### File Names
321
+
322
+ Format: `{version}_{description}.sql`
323
+
324
+ - Version: Zero-padded 4-digit number (0001, 0002, etc.)
325
+ - Description: Brief description using underscores
326
+ - Example: `0001_create_users_table.sql`
327
+
328
+ ### Query Names
329
+
330
+ - Upgrade: `migrate-{version}-up`
331
+ - Downgrade: `migrate-{version}-down`
332
+
333
+ This naming ensures proper sorting and avoids conflicts when loading multiple files.
334
+ """
335
+
336
+ def init_directory(self, directory: str, package: bool = True) -> None:
337
+ """Initialize migration directory structure (sync implementation).
338
+
339
+ Args:
340
+ directory: Directory to initialize migrations in.
341
+ package: Whether to create __init__.py file.
342
+ """
343
+ from rich.console import Console
344
+
345
+ console = Console()
346
+
347
+ migrations_dir = Path(directory)
348
+ migrations_dir.mkdir(parents=True, exist_ok=True)
349
+
350
+ if package:
351
+ (migrations_dir / "__init__.py").touch()
352
+
353
+ # Create README
354
+ readme = migrations_dir / "README.md"
355
+ readme.write_text(self._get_init_readme_content())
356
+
357
+ # Create .gitkeep for empty directory
358
+ (migrations_dir / ".gitkeep").touch()
359
+
360
+ console.print(f"[green]Initialized migrations in {directory}[/]")
361
+
362
+ @abstractmethod
363
+ def init(self, directory: str, package: bool = True) -> Any:
364
+ """Initialize migration directory structure."""
365
+ ...
366
+
367
+ @abstractmethod
368
+ def current(self, verbose: bool = False) -> Any:
369
+ """Show current migration version."""
370
+ ...
371
+
372
+ @abstractmethod
373
+ def upgrade(self, revision: str = "head") -> Any:
374
+ """Upgrade to a target revision."""
375
+ ...
376
+
377
+ @abstractmethod
378
+ def downgrade(self, revision: str = "-1") -> Any:
379
+ """Downgrade to a target revision."""
380
+ ...
381
+
382
+ @abstractmethod
383
+ def stamp(self, revision: str) -> Any:
384
+ """Mark database as being at a specific revision without running migrations."""
385
+ ...
386
+
387
+ @abstractmethod
388
+ def revision(self, message: str) -> Any:
389
+ """Create a new migration file."""
390
+ ...