sqlspec 0.17.0__py3-none-any.whl → 0.18.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of sqlspec might be problematic. Click here for more details.

Files changed (80) hide show
  1. sqlspec/__init__.py +1 -1
  2. sqlspec/_sql.py +188 -234
  3. sqlspec/adapters/adbc/config.py +24 -30
  4. sqlspec/adapters/adbc/driver.py +42 -61
  5. sqlspec/adapters/aiosqlite/config.py +5 -10
  6. sqlspec/adapters/aiosqlite/driver.py +9 -25
  7. sqlspec/adapters/aiosqlite/pool.py +43 -35
  8. sqlspec/adapters/asyncmy/config.py +10 -7
  9. sqlspec/adapters/asyncmy/driver.py +18 -39
  10. sqlspec/adapters/asyncpg/config.py +4 -0
  11. sqlspec/adapters/asyncpg/driver.py +32 -79
  12. sqlspec/adapters/bigquery/config.py +12 -65
  13. sqlspec/adapters/bigquery/driver.py +39 -133
  14. sqlspec/adapters/duckdb/config.py +11 -15
  15. sqlspec/adapters/duckdb/driver.py +61 -85
  16. sqlspec/adapters/duckdb/pool.py +2 -5
  17. sqlspec/adapters/oracledb/_types.py +8 -1
  18. sqlspec/adapters/oracledb/config.py +55 -38
  19. sqlspec/adapters/oracledb/driver.py +35 -92
  20. sqlspec/adapters/oracledb/migrations.py +257 -0
  21. sqlspec/adapters/psqlpy/config.py +13 -9
  22. sqlspec/adapters/psqlpy/driver.py +28 -103
  23. sqlspec/adapters/psycopg/config.py +9 -5
  24. sqlspec/adapters/psycopg/driver.py +107 -175
  25. sqlspec/adapters/sqlite/config.py +7 -5
  26. sqlspec/adapters/sqlite/driver.py +37 -73
  27. sqlspec/adapters/sqlite/pool.py +3 -12
  28. sqlspec/base.py +1 -8
  29. sqlspec/builder/__init__.py +1 -1
  30. sqlspec/builder/_base.py +34 -20
  31. sqlspec/builder/_column.py +5 -1
  32. sqlspec/builder/_ddl.py +407 -183
  33. sqlspec/builder/_expression_wrappers.py +46 -0
  34. sqlspec/builder/_insert.py +2 -4
  35. sqlspec/builder/_update.py +5 -5
  36. sqlspec/builder/mixins/_insert_operations.py +26 -6
  37. sqlspec/builder/mixins/_merge_operations.py +1 -1
  38. sqlspec/builder/mixins/_order_limit_operations.py +16 -4
  39. sqlspec/builder/mixins/_select_operations.py +3 -7
  40. sqlspec/builder/mixins/_update_operations.py +4 -4
  41. sqlspec/config.py +32 -13
  42. sqlspec/core/__init__.py +89 -14
  43. sqlspec/core/cache.py +57 -104
  44. sqlspec/core/compiler.py +57 -112
  45. sqlspec/core/filters.py +1 -21
  46. sqlspec/core/hashing.py +13 -47
  47. sqlspec/core/parameters.py +272 -261
  48. sqlspec/core/result.py +12 -27
  49. sqlspec/core/splitter.py +17 -21
  50. sqlspec/core/statement.py +150 -159
  51. sqlspec/driver/_async.py +2 -15
  52. sqlspec/driver/_common.py +16 -95
  53. sqlspec/driver/_sync.py +2 -15
  54. sqlspec/driver/mixins/_result_tools.py +8 -29
  55. sqlspec/driver/mixins/_sql_translator.py +6 -8
  56. sqlspec/exceptions.py +1 -2
  57. sqlspec/loader.py +43 -115
  58. sqlspec/migrations/__init__.py +1 -1
  59. sqlspec/migrations/base.py +34 -45
  60. sqlspec/migrations/commands.py +34 -15
  61. sqlspec/migrations/loaders.py +1 -1
  62. sqlspec/migrations/runner.py +104 -19
  63. sqlspec/migrations/tracker.py +49 -2
  64. sqlspec/protocols.py +13 -6
  65. sqlspec/storage/__init__.py +4 -4
  66. sqlspec/storage/backends/fsspec.py +5 -6
  67. sqlspec/storage/backends/obstore.py +7 -8
  68. sqlspec/storage/registry.py +3 -3
  69. sqlspec/utils/__init__.py +2 -2
  70. sqlspec/utils/logging.py +6 -10
  71. sqlspec/utils/sync_tools.py +27 -4
  72. sqlspec/utils/text.py +6 -1
  73. {sqlspec-0.17.0.dist-info → sqlspec-0.18.0.dist-info}/METADATA +1 -1
  74. sqlspec-0.18.0.dist-info/RECORD +138 -0
  75. sqlspec/builder/_ddl_utils.py +0 -103
  76. sqlspec-0.17.0.dist-info/RECORD +0 -137
  77. {sqlspec-0.17.0.dist-info → sqlspec-0.18.0.dist-info}/WHEEL +0 -0
  78. {sqlspec-0.17.0.dist-info → sqlspec-0.18.0.dist-info}/entry_points.txt +0 -0
  79. {sqlspec-0.17.0.dist-info → sqlspec-0.18.0.dist-info}/licenses/LICENSE +0 -0
  80. {sqlspec-0.17.0.dist-info → sqlspec-0.18.0.dist-info}/licenses/NOTICE +0 -0
@@ -1,7 +1,7 @@
1
1
  """SQLSpec Migration Tool.
2
2
 
3
3
  A native migration system for SQLSpec that leverages the SQLFileLoader
4
- and driver architecture for database versioning.
4
+ and driver system for database versioning.
5
5
  """
6
6
 
7
7
  from sqlspec.migrations.commands import AsyncMigrationCommands, MigrationCommands, SyncMigrationCommands
@@ -3,18 +3,19 @@
3
3
  This module provides abstract base classes for migration components.
4
4
  """
5
5
 
6
+ import hashlib
6
7
  import operator
7
8
  from abc import ABC, abstractmethod
8
9
  from pathlib import Path
9
- from typing import Any, Generic, Optional, TypeVar
10
+ from typing import Any, Generic, Optional, TypeVar, cast
10
11
 
11
12
  from sqlspec._sql import sql
13
+ from sqlspec.builder import Delete, Insert, Select
12
14
  from sqlspec.builder._ddl import CreateTable
13
- from sqlspec.core.statement import SQL
14
15
  from sqlspec.loader import SQLFileLoader
15
16
  from sqlspec.migrations.loaders import get_migration_loader
16
17
  from sqlspec.utils.logging import get_logger
17
- from sqlspec.utils.sync_tools import run_
18
+ from sqlspec.utils.sync_tools import await_
18
19
 
19
20
  __all__ = ("BaseMigrationCommands", "BaseMigrationRunner", "BaseMigrationTracker")
20
21
 
@@ -28,6 +29,8 @@ ConfigT = TypeVar("ConfigT")
28
29
  class BaseMigrationTracker(ABC, Generic[DriverT]):
29
30
  """Base class for migration version tracking."""
30
31
 
32
+ __slots__ = ("version_table",)
33
+
31
34
  def __init__(self, version_table_name: str = "ddl_migrations") -> None:
32
35
  """Initialize the migration tracker.
33
36
 
@@ -36,54 +39,43 @@ class BaseMigrationTracker(ABC, Generic[DriverT]):
36
39
  """
37
40
  self.version_table = version_table_name
38
41
 
39
- def _get_create_table_sql(self) -> SQL:
40
- """Get SQL for creating the tracking table.
42
+ def _get_create_table_sql(self) -> CreateTable:
43
+ """Get SQL builder for creating the tracking table.
41
44
 
42
45
  Returns:
43
- SQL object for table creation.
46
+ SQL builder object for table creation.
44
47
  """
45
- builder = CreateTable(self.version_table)
46
- if not hasattr(builder, "_columns"):
47
- builder._columns = []
48
- if not hasattr(builder, "_constraints"):
49
- builder._constraints = []
50
- if not hasattr(builder, "_table_options"):
51
- builder._table_options = {}
52
-
53
48
  return (
54
- builder.if_not_exists()
49
+ sql.create_table(self.version_table)
50
+ .if_not_exists()
55
51
  .column("version_num", "VARCHAR(32)", primary_key=True)
56
52
  .column("description", "TEXT")
57
- .column("applied_at", "TIMESTAMP", not_null=True, default="CURRENT_TIMESTAMP")
53
+ .column("applied_at", "TIMESTAMP", default="CURRENT_TIMESTAMP", not_null=True)
58
54
  .column("execution_time_ms", "INTEGER")
59
55
  .column("checksum", "VARCHAR(64)")
60
56
  .column("applied_by", "VARCHAR(255)")
61
- ).to_statement()
57
+ )
62
58
 
63
- def _get_current_version_sql(self) -> SQL:
64
- """Get SQL for retrieving current version.
59
+ def _get_current_version_sql(self) -> Select:
60
+ """Get SQL builder for retrieving current version.
65
61
 
66
62
  Returns:
67
- SQL object for version query.
63
+ SQL builder object for version query.
68
64
  """
65
+ return sql.select("version_num").from_(self.version_table).order_by("version_num DESC").limit(1)
69
66
 
70
- return (
71
- sql.select("version_num").from_(self.version_table).order_by("version_num DESC").limit(1)
72
- ).to_statement()
73
-
74
- def _get_applied_migrations_sql(self) -> SQL:
75
- """Get SQL for retrieving all applied migrations.
67
+ def _get_applied_migrations_sql(self) -> Select:
68
+ """Get SQL builder for retrieving all applied migrations.
76
69
 
77
70
  Returns:
78
- SQL object for migrations query.
71
+ SQL builder object for migrations query.
79
72
  """
80
-
81
- return (sql.select("*").from_(self.version_table).order_by("version_num")).to_statement()
73
+ return sql.select("*").from_(self.version_table).order_by("version_num")
82
74
 
83
75
  def _get_record_migration_sql(
84
76
  self, version: str, description: str, execution_time_ms: int, checksum: str, applied_by: str
85
- ) -> SQL:
86
- """Get SQL for recording a migration.
77
+ ) -> Insert:
78
+ """Get SQL builder for recording a migration.
87
79
 
88
80
  Args:
89
81
  version: Version number of the migration.
@@ -93,26 +85,24 @@ class BaseMigrationTracker(ABC, Generic[DriverT]):
93
85
  applied_by: User who applied the migration.
94
86
 
95
87
  Returns:
96
- SQL object for insert.
88
+ SQL builder object for insert.
97
89
  """
98
-
99
90
  return (
100
91
  sql.insert(self.version_table)
101
92
  .columns("version_num", "description", "execution_time_ms", "checksum", "applied_by")
102
93
  .values(version, description, execution_time_ms, checksum, applied_by)
103
- ).to_statement()
94
+ )
104
95
 
105
- def _get_remove_migration_sql(self, version: str) -> SQL:
106
- """Get SQL for removing a migration record.
96
+ def _get_remove_migration_sql(self, version: str) -> Delete:
97
+ """Get SQL builder for removing a migration record.
107
98
 
108
99
  Args:
109
100
  version: Version number to remove.
110
101
 
111
102
  Returns:
112
- SQL object for delete.
103
+ SQL builder object for delete.
113
104
  """
114
-
115
- return (sql.delete().from_(self.version_table).where(sql.version_num == version)).to_statement()
105
+ return sql.delete().from_(self.version_table).where(sql.version_num == version)
116
106
 
117
107
  @abstractmethod
118
108
  def ensure_tracking_table(self, driver: DriverT) -> Any:
@@ -176,7 +166,6 @@ class BaseMigrationRunner(ABC, Generic[DriverT]):
176
166
  Returns:
177
167
  MD5 checksum hex string.
178
168
  """
179
- import hashlib
180
169
 
181
170
  return hashlib.md5(content.encode()).hexdigest() # noqa: S324
182
171
 
@@ -226,7 +215,7 @@ class BaseMigrationRunner(ABC, Generic[DriverT]):
226
215
  has_upgrade, has_downgrade = self.loader.has_query(up_query), self.loader.has_query(down_query)
227
216
  else:
228
217
  try:
229
- has_downgrade = bool(run_(loader.get_down_sql)(file_path))
218
+ has_downgrade = bool(await_(loader.get_down_sql, raise_sync_error=False)(file_path))
230
219
  except Exception:
231
220
  has_downgrade = False
232
221
 
@@ -240,7 +229,7 @@ class BaseMigrationRunner(ABC, Generic[DriverT]):
240
229
  "loader": loader,
241
230
  }
242
231
 
243
- def _get_migration_sql(self, migration: "dict[str, Any]", direction: str) -> Optional[SQL]:
232
+ def _get_migration_sql(self, migration: "dict[str, Any]", direction: str) -> "Optional[list[str]]":
244
233
  """Get migration SQL for given direction.
245
234
 
246
235
  Args:
@@ -261,7 +250,7 @@ class BaseMigrationRunner(ABC, Generic[DriverT]):
261
250
 
262
251
  try:
263
252
  method = loader.get_up_sql if direction == "up" else loader.get_down_sql
264
- sql_statements = run_(method)(file_path)
253
+ sql_statements = await_(method, raise_sync_error=False)(file_path)
265
254
 
266
255
  except Exception as e:
267
256
  if direction == "down":
@@ -271,7 +260,7 @@ class BaseMigrationRunner(ABC, Generic[DriverT]):
271
260
  raise ValueError(msg) from e
272
261
  else:
273
262
  if sql_statements:
274
- return SQL(sql_statements[0])
263
+ return cast("list[str]", sql_statements)
275
264
  return None
276
265
 
277
266
  @abstractmethod
@@ -312,7 +301,7 @@ class BaseMigrationCommands(ABC, Generic[ConfigT, DriverT]):
312
301
  self.config = config
313
302
  migration_config = getattr(self.config, "migration_config", {}) or {}
314
303
 
315
- self.version_table = migration_config.get("version_table_name", "sqlspec_migrations")
304
+ self.version_table = migration_config.get("version_table_name", "ddl_migrations")
316
305
  self.migrations_path = Path(migration_config.get("script_location", "migrations"))
317
306
  self.project_root = Path(migration_config["project_root"]) if "project_root" in migration_config else None
318
307
 
@@ -3,7 +3,7 @@
3
3
  This module provides the main command interface for database migrations.
4
4
  """
5
5
 
6
- from typing import TYPE_CHECKING, Any, Union, cast
6
+ from typing import TYPE_CHECKING, Any, Optional, Union, cast
7
7
 
8
8
  from rich.console import Console
9
9
  from rich.table import Table
@@ -11,7 +11,6 @@ from rich.table import Table
11
11
  from sqlspec._sql import sql
12
12
  from sqlspec.migrations.base import BaseMigrationCommands
13
13
  from sqlspec.migrations.runner import AsyncMigrationRunner, SyncMigrationRunner
14
- from sqlspec.migrations.tracker import AsyncMigrationTracker, SyncMigrationTracker
15
14
  from sqlspec.migrations.utils import create_migration_file
16
15
  from sqlspec.utils.logging import get_logger
17
16
  from sqlspec.utils.sync_tools import await_
@@ -26,7 +25,7 @@ console = Console()
26
25
 
27
26
 
28
27
  class SyncMigrationCommands(BaseMigrationCommands["SyncConfigT", Any]):
29
- """SQLSpec native migration commands."""
28
+ """Synchronous migration commands."""
30
29
 
31
30
  def __init__(self, config: "SyncConfigT") -> None:
32
31
  """Initialize migration commands.
@@ -35,7 +34,7 @@ class SyncMigrationCommands(BaseMigrationCommands["SyncConfigT", Any]):
35
34
  config: The SQLSpec configuration.
36
35
  """
37
36
  super().__init__(config)
38
- self.tracker = SyncMigrationTracker(self.version_table)
37
+ self.tracker = config.migration_tracker_type(self.version_table)
39
38
  self.runner = SyncMigrationRunner(self.migrations_path)
40
39
 
41
40
  def init(self, directory: str, package: bool = True) -> None:
@@ -47,11 +46,14 @@ class SyncMigrationCommands(BaseMigrationCommands["SyncConfigT", Any]):
47
46
  """
48
47
  self.init_directory(directory, package)
49
48
 
50
- def current(self, verbose: bool = False) -> None:
49
+ def current(self, verbose: bool = False) -> "Optional[str]":
51
50
  """Show current migration version.
52
51
 
53
52
  Args:
54
53
  verbose: Whether to show detailed migration history.
54
+
55
+ Returns:
56
+ The current migration version or None if no migrations applied.
55
57
  """
56
58
  with self.config.provide_session() as driver:
57
59
  self.tracker.ensure_tracking_table(driver)
@@ -59,7 +61,7 @@ class SyncMigrationCommands(BaseMigrationCommands["SyncConfigT", Any]):
59
61
  current = self.tracker.get_current_version(driver)
60
62
  if not current:
61
63
  console.print("[yellow]No migrations applied yet[/]")
62
- return
64
+ return None
63
65
 
64
66
  console.print(f"[green]Current version:[/] {current}")
65
67
 
@@ -84,6 +86,8 @@ class SyncMigrationCommands(BaseMigrationCommands["SyncConfigT", Any]):
84
86
 
85
87
  console.print(table)
86
88
 
89
+ return cast("Optional[str]", current)
90
+
87
91
  def upgrade(self, revision: str = "head") -> None:
88
92
  """Upgrade to a target revision.
89
93
 
@@ -137,6 +141,8 @@ class SyncMigrationCommands(BaseMigrationCommands["SyncConfigT", Any]):
137
141
  to_revert = []
138
142
  if revision == "-1":
139
143
  to_revert = [applied[-1]]
144
+ elif revision == "base":
145
+ to_revert = list(reversed(applied))
140
146
  else:
141
147
  for migration in reversed(applied):
142
148
  if migration["version_num"] > revision:
@@ -195,7 +201,7 @@ class SyncMigrationCommands(BaseMigrationCommands["SyncConfigT", Any]):
195
201
 
196
202
 
197
203
  class AsyncMigrationCommands(BaseMigrationCommands["AsyncConfigT", Any]):
198
- """SQLSpec native migration commands."""
204
+ """Asynchronous migration commands."""
199
205
 
200
206
  def __init__(self, sqlspec_config: "AsyncConfigT") -> None:
201
207
  """Initialize migration commands.
@@ -204,7 +210,7 @@ class AsyncMigrationCommands(BaseMigrationCommands["AsyncConfigT", Any]):
204
210
  sqlspec_config: The SQLSpec configuration.
205
211
  """
206
212
  super().__init__(sqlspec_config)
207
- self.tracker = AsyncMigrationTracker(self.version_table)
213
+ self.tracker = sqlspec_config.migration_tracker_type(self.version_table)
208
214
  self.runner = AsyncMigrationRunner(self.migrations_path)
209
215
 
210
216
  async def init(self, directory: str, package: bool = True) -> None:
@@ -216,11 +222,14 @@ class AsyncMigrationCommands(BaseMigrationCommands["AsyncConfigT", Any]):
216
222
  """
217
223
  self.init_directory(directory, package)
218
224
 
219
- async def current(self, verbose: bool = False) -> None:
225
+ async def current(self, verbose: bool = False) -> "Optional[str]":
220
226
  """Show current migration version.
221
227
 
222
228
  Args:
223
229
  verbose: Whether to show detailed migration history.
230
+
231
+ Returns:
232
+ The current migration version or None if no migrations applied.
224
233
  """
225
234
  async with self.config.provide_session() as driver:
226
235
  await self.tracker.ensure_tracking_table(driver)
@@ -228,7 +237,7 @@ class AsyncMigrationCommands(BaseMigrationCommands["AsyncConfigT", Any]):
228
237
  current = await self.tracker.get_current_version(driver)
229
238
  if not current:
230
239
  console.print("[yellow]No migrations applied yet[/]")
231
- return
240
+ return None
232
241
 
233
242
  console.print(f"[green]Current version:[/] {current}")
234
243
  if verbose:
@@ -249,6 +258,8 @@ class AsyncMigrationCommands(BaseMigrationCommands["AsyncConfigT", Any]):
249
258
  )
250
259
  console.print(table)
251
260
 
261
+ return cast("Optional[str]", current)
262
+
252
263
  async def upgrade(self, revision: str = "head") -> None:
253
264
  """Upgrade to a target revision.
254
265
 
@@ -297,6 +308,8 @@ class AsyncMigrationCommands(BaseMigrationCommands["AsyncConfigT", Any]):
297
308
  to_revert = []
298
309
  if revision == "-1":
299
310
  to_revert = [applied[-1]]
311
+ elif revision == "base":
312
+ to_revert = list(reversed(applied))
300
313
  else:
301
314
  for migration in reversed(applied):
302
315
  if migration["version_num"] > revision:
@@ -382,20 +395,26 @@ class MigrationCommands:
382
395
  package: Whether to create __init__.py file.
383
396
  """
384
397
  if self._is_async:
385
- await_(cast("AsyncMigrationCommands[Any]", self._impl).init)(directory, package=package)
398
+ await_(cast("AsyncMigrationCommands[Any]", self._impl).init, raise_sync_error=False)(
399
+ directory, package=package
400
+ )
386
401
  else:
387
402
  cast("SyncMigrationCommands[Any]", self._impl).init(directory, package=package)
388
403
 
389
- def current(self, verbose: bool = False) -> None:
404
+ def current(self, verbose: bool = False) -> "Optional[str]":
390
405
  """Show current migration version.
391
406
 
392
407
  Args:
393
408
  verbose: Whether to show detailed migration history.
409
+
410
+ Returns:
411
+ The current migration version or None if no migrations applied.
394
412
  """
395
413
  if self._is_async:
396
- await_(cast("AsyncMigrationCommands[Any]", self._impl).current, raise_sync_error=False)(verbose=verbose)
397
- else:
398
- cast("SyncMigrationCommands[Any]", self._impl).current(verbose=verbose)
414
+ return await_(cast("AsyncMigrationCommands[Any]", self._impl).current, raise_sync_error=False)(
415
+ verbose=verbose
416
+ )
417
+ return cast("SyncMigrationCommands[Any]", self._impl).current(verbose=verbose)
399
418
 
400
419
  def upgrade(self, revision: str = "head") -> None:
401
420
  """Upgrade to a target revision.
@@ -70,7 +70,7 @@ class BaseMigrationLoader(abc.ABC):
70
70
 
71
71
 
72
72
  class SQLFileLoader(BaseMigrationLoader):
73
- """Loader for SQL migration files using SQLFileLoader."""
73
+ """Loader for SQL migration files."""
74
74
 
75
75
  __slots__ = ("sql_loader",)
76
76
 
@@ -5,13 +5,13 @@ This module handles migration file loading and execution using SQLFileLoader.
5
5
 
6
6
  import time
7
7
  from pathlib import Path
8
- from typing import TYPE_CHECKING, Any, Optional
8
+ from typing import TYPE_CHECKING, Any, Optional, cast
9
9
 
10
10
  from sqlspec.core.statement import SQL
11
11
  from sqlspec.migrations.base import BaseMigrationRunner
12
12
  from sqlspec.migrations.loaders import get_migration_loader
13
13
  from sqlspec.utils.logging import get_logger
14
- from sqlspec.utils.sync_tools import run_
14
+ from sqlspec.utils.sync_tools import await_
15
15
 
16
16
  if TYPE_CHECKING:
17
17
  from sqlspec.driver import AsyncDriverAdapterBase, SyncDriverAdapterBase
@@ -22,7 +22,7 @@ logger = get_logger("migrations.runner")
22
22
 
23
23
 
24
24
  class SyncMigrationRunner(BaseMigrationRunner["SyncDriverAdapterBase"]):
25
- """Executes migrations using SQLFileLoader."""
25
+ """Synchronous migration executor."""
26
26
 
27
27
  def get_migration_files(self) -> "list[tuple[str, Path]]":
28
28
  """Get all migration files sorted by version.
@@ -55,12 +55,15 @@ class SyncMigrationRunner(BaseMigrationRunner["SyncDriverAdapterBase"]):
55
55
  Returns:
56
56
  Tuple of (sql_content, execution_time_ms).
57
57
  """
58
- upgrade_sql = self._get_migration_sql(migration, "up")
59
- if upgrade_sql is None:
58
+ upgrade_sql_list = self._get_migration_sql(migration, "up")
59
+ if upgrade_sql_list is None:
60
60
  return None, 0
61
61
 
62
62
  start_time = time.time()
63
- driver.execute(upgrade_sql)
63
+
64
+ for sql_statement in upgrade_sql_list:
65
+ if sql_statement.strip():
66
+ driver.execute_script(sql_statement)
64
67
  execution_time = int((time.time() - start_time) * 1000)
65
68
  return None, execution_time
66
69
 
@@ -76,12 +79,15 @@ class SyncMigrationRunner(BaseMigrationRunner["SyncDriverAdapterBase"]):
76
79
  Returns:
77
80
  Tuple of (sql_content, execution_time_ms).
78
81
  """
79
- downgrade_sql = self._get_migration_sql(migration, "down")
80
- if downgrade_sql is None:
82
+ downgrade_sql_list = self._get_migration_sql(migration, "down")
83
+ if downgrade_sql_list is None:
81
84
  return None, 0
82
85
 
83
86
  start_time = time.time()
84
- driver.execute(downgrade_sql)
87
+
88
+ for sql_statement in downgrade_sql_list:
89
+ if sql_statement.strip():
90
+ driver.execute_script(sql_statement)
85
91
  execution_time = int((time.time() - start_time) * 1000)
86
92
  return None, execution_time
87
93
 
@@ -103,8 +109,8 @@ class SyncMigrationRunner(BaseMigrationRunner["SyncDriverAdapterBase"]):
103
109
  loader = get_migration_loader(file_path, self.migrations_path, self.project_root)
104
110
 
105
111
  try:
106
- up_sql = run_(loader.get_up_sql)(file_path)
107
- down_sql = run_(loader.get_down_sql)(file_path)
112
+ up_sql = await_(loader.get_up_sql, raise_sync_error=False)(file_path)
113
+ down_sql = await_(loader.get_down_sql, raise_sync_error=False)(file_path)
108
114
 
109
115
  if up_sql:
110
116
  all_queries[f"migrate-{version}-up"] = SQL(up_sql[0])
@@ -118,7 +124,7 @@ class SyncMigrationRunner(BaseMigrationRunner["SyncDriverAdapterBase"]):
118
124
 
119
125
 
120
126
  class AsyncMigrationRunner(BaseMigrationRunner["AsyncDriverAdapterBase"]):
121
- """Executes migrations using SQLFileLoader."""
127
+ """Asynchronous migration executor."""
122
128
 
123
129
  async def get_migration_files(self) -> "list[tuple[str, Path]]":
124
130
  """Get all migration files sorted by version.
@@ -137,7 +143,80 @@ class AsyncMigrationRunner(BaseMigrationRunner["AsyncDriverAdapterBase"]):
137
143
  Returns:
138
144
  Dictionary containing migration metadata.
139
145
  """
140
- return self._load_migration_metadata(file_path)
146
+ return await self._load_migration_metadata_async(file_path)
147
+
148
+ async def _load_migration_metadata_async(self, file_path: Path) -> "dict[str, Any]":
149
+ """Load migration metadata from file (async version).
150
+
151
+ Args:
152
+ file_path: Path to the migration file.
153
+
154
+ Returns:
155
+ Migration metadata dictionary.
156
+ """
157
+ loader = get_migration_loader(file_path, self.migrations_path, self.project_root)
158
+ loader.validate_migration_file(file_path)
159
+ content = file_path.read_text(encoding="utf-8")
160
+ checksum = self._calculate_checksum(content)
161
+ version = self._extract_version(file_path.name)
162
+ description = file_path.stem.split("_", 1)[1] if "_" in file_path.stem else ""
163
+
164
+ has_upgrade, has_downgrade = True, False
165
+
166
+ if file_path.suffix == ".sql":
167
+ up_query, down_query = f"migrate-{version}-up", f"migrate-{version}-down"
168
+ self.loader.clear_cache()
169
+ self.loader.load_sql(file_path)
170
+ has_upgrade, has_downgrade = self.loader.has_query(up_query), self.loader.has_query(down_query)
171
+ else:
172
+ try:
173
+ has_downgrade = bool(await loader.get_down_sql(file_path))
174
+ except Exception:
175
+ has_downgrade = False
176
+
177
+ return {
178
+ "version": version,
179
+ "description": description,
180
+ "file_path": file_path,
181
+ "checksum": checksum,
182
+ "has_upgrade": has_upgrade,
183
+ "has_downgrade": has_downgrade,
184
+ "loader": loader,
185
+ }
186
+
187
+ async def _get_migration_sql_async(self, migration: "dict[str, Any]", direction: str) -> "Optional[list[str]]":
188
+ """Get migration SQL for given direction (async version).
189
+
190
+ Args:
191
+ migration: Migration metadata.
192
+ direction: Either 'up' or 'down'.
193
+
194
+ Returns:
195
+ SQL statements for the migration.
196
+ """
197
+ if not migration.get(f"has_{direction}grade"):
198
+ if direction == "down":
199
+ logger.warning("Migration %s has no downgrade query", migration["version"])
200
+ return None
201
+ msg = f"Migration {migration['version']} has no upgrade query"
202
+ raise ValueError(msg)
203
+
204
+ file_path, loader = migration["file_path"], migration["loader"]
205
+
206
+ try:
207
+ method = loader.get_up_sql if direction == "up" else loader.get_down_sql
208
+ sql_statements = await method(file_path)
209
+
210
+ except Exception as e:
211
+ if direction == "down":
212
+ logger.warning("Failed to load downgrade for migration %s: %s", migration["version"], e)
213
+ return None
214
+ msg = f"Failed to load upgrade for migration {migration['version']}: {e}"
215
+ raise ValueError(msg) from e
216
+ else:
217
+ if sql_statements:
218
+ return cast("list[str]", sql_statements)
219
+ return None
141
220
 
142
221
  async def execute_upgrade(
143
222
  self, driver: "AsyncDriverAdapterBase", migration: "dict[str, Any]"
@@ -151,12 +230,15 @@ class AsyncMigrationRunner(BaseMigrationRunner["AsyncDriverAdapterBase"]):
151
230
  Returns:
152
231
  Tuple of (sql_content, execution_time_ms).
153
232
  """
154
- upgrade_sql = self._get_migration_sql(migration, "up")
155
- if upgrade_sql is None:
233
+ upgrade_sql_list = await self._get_migration_sql_async(migration, "up")
234
+ if upgrade_sql_list is None:
156
235
  return None, 0
157
236
 
158
237
  start_time = time.time()
159
- await driver.execute(upgrade_sql)
238
+
239
+ for sql_statement in upgrade_sql_list:
240
+ if sql_statement.strip():
241
+ await driver.execute_script(sql_statement)
160
242
  execution_time = int((time.time() - start_time) * 1000)
161
243
  return None, execution_time
162
244
 
@@ -172,12 +254,15 @@ class AsyncMigrationRunner(BaseMigrationRunner["AsyncDriverAdapterBase"]):
172
254
  Returns:
173
255
  Tuple of (sql_content, execution_time_ms).
174
256
  """
175
- downgrade_sql = self._get_migration_sql(migration, "down")
176
- if downgrade_sql is None:
257
+ downgrade_sql_list = await self._get_migration_sql_async(migration, "down")
258
+ if downgrade_sql_list is None:
177
259
  return None, 0
178
260
 
179
261
  start_time = time.time()
180
- await driver.execute(downgrade_sql)
262
+
263
+ for sql_statement in downgrade_sql_list:
264
+ if sql_statement.strip():
265
+ await driver.execute_script(sql_statement)
181
266
  execution_time = int((time.time() - start_time) * 1000)
182
267
  return None, execution_time
183
268