sqlspec 0.25.0__py3-none-any.whl → 0.27.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of sqlspec might be problematic. Click here for more details.

Files changed (199) hide show
  1. sqlspec/__init__.py +7 -15
  2. sqlspec/_serialization.py +256 -24
  3. sqlspec/_typing.py +71 -52
  4. sqlspec/adapters/adbc/_types.py +1 -1
  5. sqlspec/adapters/adbc/adk/__init__.py +5 -0
  6. sqlspec/adapters/adbc/adk/store.py +870 -0
  7. sqlspec/adapters/adbc/config.py +69 -12
  8. sqlspec/adapters/adbc/data_dictionary.py +340 -0
  9. sqlspec/adapters/adbc/driver.py +266 -58
  10. sqlspec/adapters/adbc/litestar/__init__.py +5 -0
  11. sqlspec/adapters/adbc/litestar/store.py +504 -0
  12. sqlspec/adapters/adbc/type_converter.py +153 -0
  13. sqlspec/adapters/aiosqlite/_types.py +1 -1
  14. sqlspec/adapters/aiosqlite/adk/__init__.py +5 -0
  15. sqlspec/adapters/aiosqlite/adk/store.py +527 -0
  16. sqlspec/adapters/aiosqlite/config.py +88 -15
  17. sqlspec/adapters/aiosqlite/data_dictionary.py +149 -0
  18. sqlspec/adapters/aiosqlite/driver.py +143 -40
  19. sqlspec/adapters/aiosqlite/litestar/__init__.py +5 -0
  20. sqlspec/adapters/aiosqlite/litestar/store.py +281 -0
  21. sqlspec/adapters/aiosqlite/pool.py +7 -7
  22. sqlspec/adapters/asyncmy/__init__.py +7 -1
  23. sqlspec/adapters/asyncmy/_types.py +2 -2
  24. sqlspec/adapters/asyncmy/adk/__init__.py +5 -0
  25. sqlspec/adapters/asyncmy/adk/store.py +493 -0
  26. sqlspec/adapters/asyncmy/config.py +68 -23
  27. sqlspec/adapters/asyncmy/data_dictionary.py +161 -0
  28. sqlspec/adapters/asyncmy/driver.py +313 -58
  29. sqlspec/adapters/asyncmy/litestar/__init__.py +5 -0
  30. sqlspec/adapters/asyncmy/litestar/store.py +296 -0
  31. sqlspec/adapters/asyncpg/__init__.py +2 -1
  32. sqlspec/adapters/asyncpg/_type_handlers.py +71 -0
  33. sqlspec/adapters/asyncpg/_types.py +11 -7
  34. sqlspec/adapters/asyncpg/adk/__init__.py +5 -0
  35. sqlspec/adapters/asyncpg/adk/store.py +450 -0
  36. sqlspec/adapters/asyncpg/config.py +59 -35
  37. sqlspec/adapters/asyncpg/data_dictionary.py +173 -0
  38. sqlspec/adapters/asyncpg/driver.py +170 -25
  39. sqlspec/adapters/asyncpg/litestar/__init__.py +5 -0
  40. sqlspec/adapters/asyncpg/litestar/store.py +253 -0
  41. sqlspec/adapters/bigquery/_types.py +1 -1
  42. sqlspec/adapters/bigquery/adk/__init__.py +5 -0
  43. sqlspec/adapters/bigquery/adk/store.py +576 -0
  44. sqlspec/adapters/bigquery/config.py +27 -10
  45. sqlspec/adapters/bigquery/data_dictionary.py +149 -0
  46. sqlspec/adapters/bigquery/driver.py +368 -142
  47. sqlspec/adapters/bigquery/litestar/__init__.py +5 -0
  48. sqlspec/adapters/bigquery/litestar/store.py +327 -0
  49. sqlspec/adapters/bigquery/type_converter.py +125 -0
  50. sqlspec/adapters/duckdb/_types.py +1 -1
  51. sqlspec/adapters/duckdb/adk/__init__.py +14 -0
  52. sqlspec/adapters/duckdb/adk/store.py +553 -0
  53. sqlspec/adapters/duckdb/config.py +80 -20
  54. sqlspec/adapters/duckdb/data_dictionary.py +163 -0
  55. sqlspec/adapters/duckdb/driver.py +167 -45
  56. sqlspec/adapters/duckdb/litestar/__init__.py +5 -0
  57. sqlspec/adapters/duckdb/litestar/store.py +332 -0
  58. sqlspec/adapters/duckdb/pool.py +4 -4
  59. sqlspec/adapters/duckdb/type_converter.py +133 -0
  60. sqlspec/adapters/oracledb/_numpy_handlers.py +133 -0
  61. sqlspec/adapters/oracledb/_types.py +20 -2
  62. sqlspec/adapters/oracledb/adk/__init__.py +5 -0
  63. sqlspec/adapters/oracledb/adk/store.py +1745 -0
  64. sqlspec/adapters/oracledb/config.py +122 -32
  65. sqlspec/adapters/oracledb/data_dictionary.py +509 -0
  66. sqlspec/adapters/oracledb/driver.py +353 -91
  67. sqlspec/adapters/oracledb/litestar/__init__.py +5 -0
  68. sqlspec/adapters/oracledb/litestar/store.py +767 -0
  69. sqlspec/adapters/oracledb/migrations.py +348 -73
  70. sqlspec/adapters/oracledb/type_converter.py +207 -0
  71. sqlspec/adapters/psqlpy/_type_handlers.py +44 -0
  72. sqlspec/adapters/psqlpy/_types.py +2 -1
  73. sqlspec/adapters/psqlpy/adk/__init__.py +5 -0
  74. sqlspec/adapters/psqlpy/adk/store.py +482 -0
  75. sqlspec/adapters/psqlpy/config.py +46 -17
  76. sqlspec/adapters/psqlpy/data_dictionary.py +172 -0
  77. sqlspec/adapters/psqlpy/driver.py +123 -209
  78. sqlspec/adapters/psqlpy/litestar/__init__.py +5 -0
  79. sqlspec/adapters/psqlpy/litestar/store.py +272 -0
  80. sqlspec/adapters/psqlpy/type_converter.py +102 -0
  81. sqlspec/adapters/psycopg/_type_handlers.py +80 -0
  82. sqlspec/adapters/psycopg/_types.py +2 -1
  83. sqlspec/adapters/psycopg/adk/__init__.py +5 -0
  84. sqlspec/adapters/psycopg/adk/store.py +944 -0
  85. sqlspec/adapters/psycopg/config.py +69 -35
  86. sqlspec/adapters/psycopg/data_dictionary.py +331 -0
  87. sqlspec/adapters/psycopg/driver.py +238 -81
  88. sqlspec/adapters/psycopg/litestar/__init__.py +5 -0
  89. sqlspec/adapters/psycopg/litestar/store.py +554 -0
  90. sqlspec/adapters/sqlite/__init__.py +2 -1
  91. sqlspec/adapters/sqlite/_type_handlers.py +86 -0
  92. sqlspec/adapters/sqlite/_types.py +1 -1
  93. sqlspec/adapters/sqlite/adk/__init__.py +5 -0
  94. sqlspec/adapters/sqlite/adk/store.py +572 -0
  95. sqlspec/adapters/sqlite/config.py +87 -15
  96. sqlspec/adapters/sqlite/data_dictionary.py +149 -0
  97. sqlspec/adapters/sqlite/driver.py +137 -54
  98. sqlspec/adapters/sqlite/litestar/__init__.py +5 -0
  99. sqlspec/adapters/sqlite/litestar/store.py +318 -0
  100. sqlspec/adapters/sqlite/pool.py +18 -9
  101. sqlspec/base.py +45 -26
  102. sqlspec/builder/__init__.py +73 -4
  103. sqlspec/builder/_base.py +162 -89
  104. sqlspec/builder/_column.py +62 -29
  105. sqlspec/builder/_ddl.py +180 -121
  106. sqlspec/builder/_delete.py +5 -4
  107. sqlspec/builder/_dml.py +388 -0
  108. sqlspec/{_sql.py → builder/_factory.py} +53 -94
  109. sqlspec/builder/_insert.py +32 -131
  110. sqlspec/builder/_join.py +375 -0
  111. sqlspec/builder/_merge.py +446 -11
  112. sqlspec/builder/_parsing_utils.py +111 -17
  113. sqlspec/builder/_select.py +1457 -24
  114. sqlspec/builder/_update.py +11 -42
  115. sqlspec/cli.py +307 -194
  116. sqlspec/config.py +252 -67
  117. sqlspec/core/__init__.py +5 -4
  118. sqlspec/core/cache.py +17 -17
  119. sqlspec/core/compiler.py +62 -9
  120. sqlspec/core/filters.py +37 -37
  121. sqlspec/core/hashing.py +9 -9
  122. sqlspec/core/parameters.py +83 -48
  123. sqlspec/core/result.py +102 -46
  124. sqlspec/core/splitter.py +16 -17
  125. sqlspec/core/statement.py +36 -30
  126. sqlspec/core/type_conversion.py +235 -0
  127. sqlspec/driver/__init__.py +7 -6
  128. sqlspec/driver/_async.py +188 -151
  129. sqlspec/driver/_common.py +285 -80
  130. sqlspec/driver/_sync.py +188 -152
  131. sqlspec/driver/mixins/_result_tools.py +20 -236
  132. sqlspec/driver/mixins/_sql_translator.py +4 -4
  133. sqlspec/exceptions.py +75 -7
  134. sqlspec/extensions/adk/__init__.py +53 -0
  135. sqlspec/extensions/adk/_types.py +51 -0
  136. sqlspec/extensions/adk/converters.py +172 -0
  137. sqlspec/extensions/adk/migrations/0001_create_adk_tables.py +144 -0
  138. sqlspec/extensions/adk/migrations/__init__.py +0 -0
  139. sqlspec/extensions/adk/service.py +181 -0
  140. sqlspec/extensions/adk/store.py +536 -0
  141. sqlspec/extensions/aiosql/adapter.py +73 -53
  142. sqlspec/extensions/litestar/__init__.py +21 -4
  143. sqlspec/extensions/litestar/cli.py +54 -10
  144. sqlspec/extensions/litestar/config.py +59 -266
  145. sqlspec/extensions/litestar/handlers.py +46 -17
  146. sqlspec/extensions/litestar/migrations/0001_create_session_table.py +137 -0
  147. sqlspec/extensions/litestar/migrations/__init__.py +3 -0
  148. sqlspec/extensions/litestar/plugin.py +324 -223
  149. sqlspec/extensions/litestar/providers.py +25 -25
  150. sqlspec/extensions/litestar/store.py +265 -0
  151. sqlspec/loader.py +30 -49
  152. sqlspec/migrations/__init__.py +4 -3
  153. sqlspec/migrations/base.py +302 -39
  154. sqlspec/migrations/commands.py +611 -144
  155. sqlspec/migrations/context.py +142 -0
  156. sqlspec/migrations/fix.py +199 -0
  157. sqlspec/migrations/loaders.py +68 -23
  158. sqlspec/migrations/runner.py +543 -107
  159. sqlspec/migrations/tracker.py +237 -21
  160. sqlspec/migrations/utils.py +51 -3
  161. sqlspec/migrations/validation.py +177 -0
  162. sqlspec/protocols.py +66 -36
  163. sqlspec/storage/_utils.py +98 -0
  164. sqlspec/storage/backends/fsspec.py +134 -106
  165. sqlspec/storage/backends/local.py +78 -51
  166. sqlspec/storage/backends/obstore.py +278 -162
  167. sqlspec/storage/registry.py +75 -39
  168. sqlspec/typing.py +16 -84
  169. sqlspec/utils/config_resolver.py +153 -0
  170. sqlspec/utils/correlation.py +4 -5
  171. sqlspec/utils/data_transformation.py +3 -2
  172. sqlspec/utils/deprecation.py +9 -8
  173. sqlspec/utils/fixtures.py +4 -4
  174. sqlspec/utils/logging.py +46 -6
  175. sqlspec/utils/module_loader.py +2 -2
  176. sqlspec/utils/schema.py +288 -0
  177. sqlspec/utils/serializers.py +50 -2
  178. sqlspec/utils/sync_tools.py +21 -17
  179. sqlspec/utils/text.py +1 -2
  180. sqlspec/utils/type_guards.py +111 -20
  181. sqlspec/utils/version.py +433 -0
  182. {sqlspec-0.25.0.dist-info → sqlspec-0.27.0.dist-info}/METADATA +40 -21
  183. sqlspec-0.27.0.dist-info/RECORD +207 -0
  184. sqlspec/builder/mixins/__init__.py +0 -55
  185. sqlspec/builder/mixins/_cte_and_set_ops.py +0 -254
  186. sqlspec/builder/mixins/_delete_operations.py +0 -50
  187. sqlspec/builder/mixins/_insert_operations.py +0 -282
  188. sqlspec/builder/mixins/_join_operations.py +0 -389
  189. sqlspec/builder/mixins/_merge_operations.py +0 -592
  190. sqlspec/builder/mixins/_order_limit_operations.py +0 -152
  191. sqlspec/builder/mixins/_pivot_operations.py +0 -157
  192. sqlspec/builder/mixins/_select_operations.py +0 -936
  193. sqlspec/builder/mixins/_update_operations.py +0 -218
  194. sqlspec/builder/mixins/_where_clause.py +0 -1304
  195. sqlspec-0.25.0.dist-info/RECORD +0 -139
  196. sqlspec-0.25.0.dist-info/licenses/NOTICE +0 -29
  197. {sqlspec-0.25.0.dist-info → sqlspec-0.27.0.dist-info}/WHEEL +0 -0
  198. {sqlspec-0.25.0.dist-info → sqlspec-0.27.0.dist-info}/entry_points.txt +0 -0
  199. {sqlspec-0.25.0.dist-info → sqlspec-0.27.0.dist-info}/licenses/LICENSE +0 -0
@@ -1,28 +1,162 @@
1
1
  """Migration execution engine for SQLSpec.
2
2
 
3
- This module handles migration file loading and execution using SQLFileLoader.
3
+ This module provides separate sync and async migration runners with clean separation
4
+ of concerns and proper type safety.
4
5
  """
5
6
 
7
+ import inspect
6
8
  import time
9
+ from abc import ABC, abstractmethod
7
10
  from pathlib import Path
8
- from typing import TYPE_CHECKING, Any, Optional, cast
11
+ from typing import TYPE_CHECKING, Any, Literal, Union, cast, overload
9
12
 
10
13
  from sqlspec.core.statement import SQL
11
- from sqlspec.migrations.base import BaseMigrationRunner
14
+ from sqlspec.migrations.context import MigrationContext
12
15
  from sqlspec.migrations.loaders import get_migration_loader
13
16
  from sqlspec.utils.logging import get_logger
14
- from sqlspec.utils.sync_tools import await_
17
+ from sqlspec.utils.sync_tools import async_, await_
15
18
 
16
19
  if TYPE_CHECKING:
20
+ from collections.abc import Awaitable, Callable, Coroutine
21
+
17
22
  from sqlspec.driver import AsyncDriverAdapterBase, SyncDriverAdapterBase
18
23
 
19
- __all__ = ("AsyncMigrationRunner", "SyncMigrationRunner")
24
+ __all__ = ("AsyncMigrationRunner", "SyncMigrationRunner", "create_migration_runner")
20
25
 
21
26
  logger = get_logger("migrations.runner")
22
27
 
23
28
 
24
- class SyncMigrationRunner(BaseMigrationRunner["SyncDriverAdapterBase"]):
25
- """Synchronous migration executor."""
29
+ class BaseMigrationRunner(ABC):
30
+ """Base migration runner with common functionality shared between sync and async implementations."""
31
+
32
+ def __init__(
33
+ self,
34
+ migrations_path: Path,
35
+ extension_migrations: "dict[str, Path] | None" = None,
36
+ context: "MigrationContext | None" = None,
37
+ extension_configs: "dict[str, dict[str, Any]] | None" = None,
38
+ ) -> None:
39
+ """Initialize the migration runner.
40
+
41
+ Args:
42
+ migrations_path: Path to the directory containing migration files.
43
+ extension_migrations: Optional mapping of extension names to their migration paths.
44
+ context: Optional migration context for Python migrations.
45
+ extension_configs: Optional mapping of extension names to their configurations.
46
+ """
47
+ self.migrations_path = migrations_path
48
+ self.extension_migrations = extension_migrations or {}
49
+ from sqlspec.loader import SQLFileLoader
50
+
51
+ self.loader = SQLFileLoader()
52
+ self.project_root: Path | None = None
53
+ self.context = context
54
+ self.extension_configs = extension_configs or {}
55
+
56
+ def _extract_version(self, filename: str) -> "str | None":
57
+ """Extract version from filename.
58
+
59
+ Supports sequential (0001), timestamp (20251011120000), and extension-prefixed
60
+ (ext_litestar_0001) version formats.
61
+
62
+ Args:
63
+ filename: The migration filename.
64
+
65
+ Returns:
66
+ The extracted version string or None.
67
+ """
68
+ extension_version_parts = 3
69
+ timestamp_min_length = 4
70
+
71
+ name_without_ext = filename.rsplit(".", 1)[0]
72
+
73
+ if name_without_ext.startswith("ext_"):
74
+ parts = name_without_ext.split("_", 3)
75
+ if len(parts) >= extension_version_parts:
76
+ return f"{parts[0]}_{parts[1]}_{parts[2]}"
77
+ return None
78
+
79
+ parts = name_without_ext.split("_", 1)
80
+ if parts and parts[0].isdigit():
81
+ return parts[0] if len(parts[0]) > timestamp_min_length else parts[0].zfill(4)
82
+
83
+ return None
84
+
85
+ def _calculate_checksum(self, content: str) -> str:
86
+ """Calculate MD5 checksum of migration content.
87
+
88
+ Canonicalizes content by excluding query name headers that change during
89
+ fix command (migrate-{version}-up/down). This ensures checksums remain
90
+ stable when converting timestamp versions to sequential format.
91
+
92
+ Args:
93
+ content: The migration file content.
94
+
95
+ Returns:
96
+ MD5 checksum hex string.
97
+ """
98
+ import hashlib
99
+ import re
100
+
101
+ canonical_content = re.sub(r"^--\s*name:\s*migrate-[^-]+-(?:up|down)\s*$", "", content, flags=re.MULTILINE)
102
+
103
+ return hashlib.md5(canonical_content.encode()).hexdigest() # noqa: S324
104
+
105
+ @abstractmethod
106
+ def load_migration(self, file_path: Path) -> Union["dict[str, Any]", "Coroutine[Any, Any, dict[str, Any]]"]:
107
+ """Load a migration file and extract its components.
108
+
109
+ Args:
110
+ file_path: Path to the migration file.
111
+
112
+ Returns:
113
+ Dictionary containing migration metadata and queries.
114
+ For async implementations, returns a coroutine.
115
+ """
116
+
117
+ def _get_migration_files_sync(self) -> "list[tuple[str, Path]]":
118
+ """Get all migration files sorted by version.
119
+
120
+ Returns:
121
+ List of tuples containing (version, file_path).
122
+ """
123
+
124
+ migrations = []
125
+
126
+ # Scan primary migration path
127
+ if self.migrations_path.exists():
128
+ for pattern in ("*.sql", "*.py"):
129
+ for file_path in self.migrations_path.glob(pattern):
130
+ if file_path.name.startswith("."):
131
+ continue
132
+ version = self._extract_version(file_path.name)
133
+ if version:
134
+ migrations.append((version, file_path))
135
+
136
+ # Scan extension migration paths
137
+ for ext_name, ext_path in self.extension_migrations.items():
138
+ if ext_path.exists():
139
+ for pattern in ("*.sql", "*.py"):
140
+ for file_path in ext_path.glob(pattern):
141
+ if file_path.name.startswith("."):
142
+ continue
143
+ # Prefix extension migrations to avoid version conflicts
144
+ version = self._extract_version(file_path.name)
145
+ if version:
146
+ # Use ext_ prefix to distinguish extension migrations
147
+ prefixed_version = f"ext_{ext_name}_{version}"
148
+ migrations.append((prefixed_version, file_path))
149
+
150
+ from sqlspec.utils.version import parse_version
151
+
152
+ def version_sort_key(migration_tuple: "tuple[str, Path]") -> "Any":
153
+ version_str = migration_tuple[0]
154
+ try:
155
+ return parse_version(version_str)
156
+ except ValueError:
157
+ return version_str
158
+
159
+ return sorted(migrations, key=version_sort_key)
26
160
 
27
161
  def get_migration_files(self) -> "list[tuple[str, Path]]":
28
162
  """Get all migration files sorted by version.
@@ -32,65 +166,278 @@ class SyncMigrationRunner(BaseMigrationRunner["SyncDriverAdapterBase"]):
32
166
  """
33
167
  return self._get_migration_files_sync()
34
168
 
35
- def load_migration(self, file_path: Path) -> "dict[str, Any]":
169
+ def _load_migration_metadata_common(self, file_path: Path, version: "str | None" = None) -> "dict[str, Any]":
170
+ """Load common migration metadata that doesn't require async operations.
171
+
172
+ Args:
173
+ file_path: Path to the migration file.
174
+ version: Optional pre-extracted version (preserves prefixes like ext_adk_0001).
175
+
176
+ Returns:
177
+ Partial migration metadata dictionary.
178
+ """
179
+ import re
180
+
181
+ content = file_path.read_text(encoding="utf-8")
182
+ checksum = self._calculate_checksum(content)
183
+ if version is None:
184
+ version = self._extract_version(file_path.name)
185
+ description = file_path.stem.split("_", 1)[1] if "_" in file_path.stem else ""
186
+
187
+ transactional_match = re.search(
188
+ r"^--\s*transactional:\s*(true|false)\s*$", content, re.MULTILINE | re.IGNORECASE
189
+ )
190
+ transactional = None
191
+ if transactional_match:
192
+ transactional = transactional_match.group(1).lower() == "true"
193
+
194
+ return {
195
+ "version": version,
196
+ "description": description,
197
+ "file_path": file_path,
198
+ "checksum": checksum,
199
+ "content": content,
200
+ "transactional": transactional,
201
+ }
202
+
203
+ def _get_context_for_migration(self, file_path: Path) -> "MigrationContext | None":
204
+ """Get the appropriate context for a migration file.
205
+
206
+ Args:
207
+ file_path: Path to the migration file.
208
+
209
+ Returns:
210
+ Migration context to use, or None to use default.
211
+ """
212
+ context_to_use = self.context
213
+ if context_to_use and file_path.name.startswith("ext_"):
214
+ version = self._extract_version(file_path.name)
215
+ if version and version.startswith("ext_"):
216
+ min_extension_version_parts = 3
217
+ parts = version.split("_", 2)
218
+ if len(parts) >= min_extension_version_parts:
219
+ ext_name = parts[1]
220
+ if ext_name in self.extension_configs:
221
+ context_to_use = MigrationContext(
222
+ dialect=self.context.dialect if self.context else None,
223
+ config=self.context.config if self.context else None,
224
+ driver=self.context.driver if self.context else None,
225
+ metadata=self.context.metadata.copy() if self.context and self.context.metadata else {},
226
+ extension_config=self.extension_configs[ext_name],
227
+ )
228
+
229
+ for ext_name, ext_path in self.extension_migrations.items():
230
+ if file_path.parent == ext_path:
231
+ if ext_name in self.extension_configs and self.context:
232
+ context_to_use = MigrationContext(
233
+ config=self.context.config,
234
+ dialect=self.context.dialect,
235
+ driver=self.context.driver,
236
+ metadata=self.context.metadata.copy() if self.context.metadata else {},
237
+ extension_config=self.extension_configs[ext_name],
238
+ )
239
+ break
240
+
241
+ return context_to_use
242
+
243
+ def should_use_transaction(self, migration: "dict[str, Any]", config: Any) -> bool:
244
+ """Determine if migration should run in a transaction.
245
+
246
+ Args:
247
+ migration: Migration metadata dictionary.
248
+ config: The database configuration instance.
249
+
250
+ Returns:
251
+ True if migration should be wrapped in a transaction.
252
+ """
253
+ if not config.supports_transactional_ddl:
254
+ return False
255
+
256
+ if migration.get("transactional") is not None:
257
+ return bool(migration["transactional"])
258
+
259
+ migration_config = getattr(config, "migration_config", {}) or {}
260
+ return bool(migration_config.get("transactional", True))
261
+
262
+
263
+ class SyncMigrationRunner(BaseMigrationRunner):
264
+ """Synchronous migration runner with pure sync methods."""
265
+
266
+ def load_migration(self, file_path: Path, version: "str | None" = None) -> "dict[str, Any]":
36
267
  """Load a migration file and extract its components.
37
268
 
38
269
  Args:
39
270
  file_path: Path to the migration file.
271
+ version: Optional pre-extracted version (preserves prefixes like ext_adk_0001).
40
272
 
41
273
  Returns:
42
274
  Dictionary containing migration metadata and queries.
43
275
  """
44
- return self._load_migration_metadata(file_path)
276
+ metadata = self._load_migration_metadata_common(file_path, version)
277
+ context_to_use = self._get_context_for_migration(file_path)
278
+
279
+ loader = get_migration_loader(file_path, self.migrations_path, self.project_root, context_to_use, self.loader)
280
+ loader.validate_migration_file(file_path)
281
+
282
+ has_upgrade, has_downgrade = True, False
283
+
284
+ if file_path.suffix == ".sql":
285
+ version = metadata["version"]
286
+ up_query, down_query = f"migrate-{version}-up", f"migrate-{version}-down"
287
+ has_upgrade, has_downgrade = self.loader.has_query(up_query), self.loader.has_query(down_query)
288
+ else:
289
+ try:
290
+ has_downgrade = bool(self._get_migration_sql_sync({"loader": loader, "file_path": file_path}, "down"))
291
+ except Exception:
292
+ has_downgrade = False
293
+
294
+ metadata.update({"has_upgrade": has_upgrade, "has_downgrade": has_downgrade, "loader": loader})
295
+ return metadata
45
296
 
46
297
  def execute_upgrade(
47
- self, driver: "SyncDriverAdapterBase", migration: "dict[str, Any]"
48
- ) -> "tuple[Optional[str], int]":
298
+ self,
299
+ driver: "SyncDriverAdapterBase",
300
+ migration: "dict[str, Any]",
301
+ *,
302
+ use_transaction: "bool | None" = None,
303
+ on_success: "Callable[[int], None] | None" = None,
304
+ ) -> "tuple[str | None, int]":
49
305
  """Execute an upgrade migration.
50
306
 
51
307
  Args:
52
- driver: The database driver to use.
308
+ driver: The sync database driver to use.
53
309
  migration: Migration metadata dictionary.
310
+ use_transaction: Override transaction behavior. If None, uses should_use_transaction logic.
311
+ on_success: Callback invoked with execution_time_ms before commit (for version tracking).
54
312
 
55
313
  Returns:
56
314
  Tuple of (sql_content, execution_time_ms).
57
315
  """
58
- upgrade_sql_list = self._get_migration_sql(migration, "up")
316
+ upgrade_sql_list = self._get_migration_sql_sync(migration, "up")
59
317
  if upgrade_sql_list is None:
60
318
  return None, 0
61
319
 
320
+ if use_transaction is None:
321
+ config = self.context.config if self.context else None
322
+ use_transaction = self.should_use_transaction(migration, config) if config else False
323
+
62
324
  start_time = time.time()
63
325
 
64
- for sql_statement in upgrade_sql_list:
65
- if sql_statement.strip():
66
- driver.execute_script(sql_statement)
67
- execution_time = int((time.time() - start_time) * 1000)
326
+ if use_transaction:
327
+ try:
328
+ driver.begin()
329
+ for sql_statement in upgrade_sql_list:
330
+ if sql_statement.strip():
331
+ driver.execute_script(sql_statement)
332
+ execution_time = int((time.time() - start_time) * 1000)
333
+ if on_success:
334
+ on_success(execution_time)
335
+ driver.commit()
336
+ except Exception:
337
+ driver.rollback()
338
+ raise
339
+ else:
340
+ for sql_statement in upgrade_sql_list:
341
+ if sql_statement.strip():
342
+ driver.execute_script(sql_statement)
343
+ execution_time = int((time.time() - start_time) * 1000)
344
+ if on_success:
345
+ on_success(execution_time)
346
+
68
347
  return None, execution_time
69
348
 
70
349
  def execute_downgrade(
71
- self, driver: "SyncDriverAdapterBase", migration: "dict[str, Any]"
72
- ) -> "tuple[Optional[str], int]":
350
+ self,
351
+ driver: "SyncDriverAdapterBase",
352
+ migration: "dict[str, Any]",
353
+ *,
354
+ use_transaction: "bool | None" = None,
355
+ on_success: "Callable[[int], None] | None" = None,
356
+ ) -> "tuple[str | None, int]":
73
357
  """Execute a downgrade migration.
74
358
 
75
359
  Args:
76
- driver: The database driver to use.
360
+ driver: The sync database driver to use.
77
361
  migration: Migration metadata dictionary.
362
+ use_transaction: Override transaction behavior. If None, uses should_use_transaction logic.
363
+ on_success: Callback invoked with execution_time_ms before commit (for version tracking).
78
364
 
79
365
  Returns:
80
366
  Tuple of (sql_content, execution_time_ms).
81
367
  """
82
- downgrade_sql_list = self._get_migration_sql(migration, "down")
368
+ downgrade_sql_list = self._get_migration_sql_sync(migration, "down")
83
369
  if downgrade_sql_list is None:
84
370
  return None, 0
85
371
 
372
+ if use_transaction is None:
373
+ config = self.context.config if self.context else None
374
+ use_transaction = self.should_use_transaction(migration, config) if config else False
375
+
86
376
  start_time = time.time()
87
377
 
88
- for sql_statement in downgrade_sql_list:
89
- if sql_statement.strip():
90
- driver.execute_script(sql_statement)
91
- execution_time = int((time.time() - start_time) * 1000)
378
+ if use_transaction:
379
+ try:
380
+ driver.begin()
381
+ for sql_statement in downgrade_sql_list:
382
+ if sql_statement.strip():
383
+ driver.execute_script(sql_statement)
384
+ execution_time = int((time.time() - start_time) * 1000)
385
+ if on_success:
386
+ on_success(execution_time)
387
+ driver.commit()
388
+ except Exception:
389
+ driver.rollback()
390
+ raise
391
+ else:
392
+ for sql_statement in downgrade_sql_list:
393
+ if sql_statement.strip():
394
+ driver.execute_script(sql_statement)
395
+ execution_time = int((time.time() - start_time) * 1000)
396
+ if on_success:
397
+ on_success(execution_time)
398
+
92
399
  return None, execution_time
93
400
 
401
+ def _get_migration_sql_sync(self, migration: "dict[str, Any]", direction: str) -> "list[str] | None":
402
+ """Get migration SQL for given direction (sync version).
403
+
404
+ Args:
405
+ migration: Migration metadata.
406
+ direction: Either 'up' or 'down'.
407
+
408
+ Returns:
409
+ SQL statements for the migration.
410
+ """
411
+ # If this is being called during migration loading (no has_*grade field yet),
412
+ # don't raise/warn - just proceed to check if the method exists
413
+ if f"has_{direction}grade" in migration and not migration.get(f"has_{direction}grade"):
414
+ if direction == "down":
415
+ logger.warning("Migration %s has no downgrade query", migration.get("version"))
416
+ return None
417
+ msg = f"Migration {migration.get('version')} has no upgrade query"
418
+ raise ValueError(msg)
419
+
420
+ file_path, loader = migration["file_path"], migration["loader"]
421
+
422
+ try:
423
+ method = loader.get_up_sql if direction == "up" else loader.get_down_sql
424
+ sql_statements = (
425
+ await_(method, raise_sync_error=False)(file_path)
426
+ if inspect.iscoroutinefunction(method)
427
+ else method(file_path)
428
+ )
429
+
430
+ except Exception as e:
431
+ if direction == "down":
432
+ logger.warning("Failed to load downgrade for migration %s: %s", migration.get("version"), e)
433
+ return None
434
+ msg = f"Failed to load upgrade for migration {migration.get('version')}: {e}"
435
+ raise ValueError(msg) from e
436
+ else:
437
+ if sql_statements:
438
+ return cast("list[str]", sql_statements)
439
+ return None
440
+
94
441
  def load_all_migrations(self) -> "dict[str, SQL]":
95
442
  """Load all migrations into a single namespace for bulk operations.
96
443
 
@@ -106,7 +453,9 @@ class SyncMigrationRunner(BaseMigrationRunner["SyncDriverAdapterBase"]):
106
453
  for query_name in self.loader.list_queries():
107
454
  all_queries[query_name] = self.loader.get_sql(query_name)
108
455
  else:
109
- loader = get_migration_loader(file_path, self.migrations_path, self.project_root)
456
+ loader = get_migration_loader(
457
+ file_path, self.migrations_path, self.project_root, self.context, self.loader
458
+ )
110
459
 
111
460
  try:
112
461
  up_sql = await_(loader.get_up_sql, raise_sync_error=False)(file_path)
@@ -123,109 +472,65 @@ class SyncMigrationRunner(BaseMigrationRunner["SyncDriverAdapterBase"]):
123
472
  return all_queries
124
473
 
125
474
 
126
- class AsyncMigrationRunner(BaseMigrationRunner["AsyncDriverAdapterBase"]):
127
- """Asynchronous migration executor."""
475
+ class AsyncMigrationRunner(BaseMigrationRunner):
476
+ """Asynchronous migration runner with pure async methods."""
128
477
 
129
- async def get_migration_files(self) -> "list[tuple[str, Path]]":
478
+ async def get_migration_files(self) -> "list[tuple[str, Path]]": # type: ignore[override]
130
479
  """Get all migration files sorted by version.
131
480
 
132
481
  Returns:
133
- List of tuples containing (version, file_path).
482
+ List of (version, path) tuples sorted by version.
134
483
  """
135
484
  return self._get_migration_files_sync()
136
485
 
137
- async def load_migration(self, file_path: Path) -> "dict[str, Any]":
486
+ async def load_migration(self, file_path: Path, version: "str | None" = None) -> "dict[str, Any]":
138
487
  """Load a migration file and extract its components.
139
488
 
140
489
  Args:
141
490
  file_path: Path to the migration file.
491
+ version: Optional pre-extracted version (preserves prefixes like ext_adk_0001).
142
492
 
143
493
  Returns:
144
- Dictionary containing migration metadata.
494
+ Dictionary containing migration metadata and queries.
145
495
  """
146
- return await self._load_migration_metadata_async(file_path)
147
-
148
- async def _load_migration_metadata_async(self, file_path: Path) -> "dict[str, Any]":
149
- """Load migration metadata from file (async version).
150
-
151
- Args:
152
- file_path: Path to the migration file.
496
+ metadata = self._load_migration_metadata_common(file_path, version)
497
+ context_to_use = self._get_context_for_migration(file_path)
153
498
 
154
- Returns:
155
- Migration metadata dictionary.
156
- """
157
- loader = get_migration_loader(file_path, self.migrations_path, self.project_root)
499
+ loader = get_migration_loader(file_path, self.migrations_path, self.project_root, context_to_use, self.loader)
158
500
  loader.validate_migration_file(file_path)
159
- content = file_path.read_text(encoding="utf-8")
160
- checksum = self._calculate_checksum(content)
161
- version = self._extract_version(file_path.name)
162
- description = file_path.stem.split("_", 1)[1] if "_" in file_path.stem else ""
163
501
 
164
502
  has_upgrade, has_downgrade = True, False
165
503
 
166
504
  if file_path.suffix == ".sql":
505
+ version = metadata["version"]
167
506
  up_query, down_query = f"migrate-{version}-up", f"migrate-{version}-down"
168
- self.loader.clear_cache()
169
- self.loader.load_sql(file_path)
170
507
  has_upgrade, has_downgrade = self.loader.has_query(up_query), self.loader.has_query(down_query)
171
508
  else:
172
509
  try:
173
- has_downgrade = bool(await loader.get_down_sql(file_path))
510
+ has_downgrade = bool(
511
+ await self._get_migration_sql_async({"loader": loader, "file_path": file_path}, "down")
512
+ )
174
513
  except Exception:
175
514
  has_downgrade = False
176
515
 
177
- return {
178
- "version": version,
179
- "description": description,
180
- "file_path": file_path,
181
- "checksum": checksum,
182
- "has_upgrade": has_upgrade,
183
- "has_downgrade": has_downgrade,
184
- "loader": loader,
185
- }
186
-
187
- async def _get_migration_sql_async(self, migration: "dict[str, Any]", direction: str) -> "Optional[list[str]]":
188
- """Get migration SQL for given direction (async version).
189
-
190
- Args:
191
- migration: Migration metadata.
192
- direction: Either 'up' or 'down'.
193
-
194
- Returns:
195
- SQL statements for the migration.
196
- """
197
- if not migration.get(f"has_{direction}grade"):
198
- if direction == "down":
199
- logger.warning("Migration %s has no downgrade query", migration["version"])
200
- return None
201
- msg = f"Migration {migration['version']} has no upgrade query"
202
- raise ValueError(msg)
203
-
204
- file_path, loader = migration["file_path"], migration["loader"]
205
-
206
- try:
207
- method = loader.get_up_sql if direction == "up" else loader.get_down_sql
208
- sql_statements = await method(file_path)
209
-
210
- except Exception as e:
211
- if direction == "down":
212
- logger.warning("Failed to load downgrade for migration %s: %s", migration["version"], e)
213
- return None
214
- msg = f"Failed to load upgrade for migration {migration['version']}: {e}"
215
- raise ValueError(msg) from e
216
- else:
217
- if sql_statements:
218
- return cast("list[str]", sql_statements)
219
- return None
516
+ metadata.update({"has_upgrade": has_upgrade, "has_downgrade": has_downgrade, "loader": loader})
517
+ return metadata
220
518
 
221
519
  async def execute_upgrade(
222
- self, driver: "AsyncDriverAdapterBase", migration: "dict[str, Any]"
223
- ) -> "tuple[Optional[str], int]":
520
+ self,
521
+ driver: "AsyncDriverAdapterBase",
522
+ migration: "dict[str, Any]",
523
+ *,
524
+ use_transaction: "bool | None" = None,
525
+ on_success: "Callable[[int], Awaitable[None]] | None" = None,
526
+ ) -> "tuple[str | None, int]":
224
527
  """Execute an upgrade migration.
225
528
 
226
529
  Args:
227
530
  driver: The async database driver to use.
228
531
  migration: Migration metadata dictionary.
532
+ use_transaction: Override transaction behavior. If None, uses should_use_transaction logic.
533
+ on_success: Async callback invoked with execution_time_ms before commit (for version tracking).
229
534
 
230
535
  Returns:
231
536
  Tuple of (sql_content, execution_time_ms).
@@ -234,22 +539,50 @@ class AsyncMigrationRunner(BaseMigrationRunner["AsyncDriverAdapterBase"]):
234
539
  if upgrade_sql_list is None:
235
540
  return None, 0
236
541
 
542
+ if use_transaction is None:
543
+ config = self.context.config if self.context else None
544
+ use_transaction = self.should_use_transaction(migration, config) if config else False
545
+
237
546
  start_time = time.time()
238
547
 
239
- for sql_statement in upgrade_sql_list:
240
- if sql_statement.strip():
241
- await driver.execute_script(sql_statement)
242
- execution_time = int((time.time() - start_time) * 1000)
548
+ if use_transaction:
549
+ try:
550
+ await driver.begin()
551
+ for sql_statement in upgrade_sql_list:
552
+ if sql_statement.strip():
553
+ await driver.execute_script(sql_statement)
554
+ execution_time = int((time.time() - start_time) * 1000)
555
+ if on_success:
556
+ await on_success(execution_time)
557
+ await driver.commit()
558
+ except Exception:
559
+ await driver.rollback()
560
+ raise
561
+ else:
562
+ for sql_statement in upgrade_sql_list:
563
+ if sql_statement.strip():
564
+ await driver.execute_script(sql_statement)
565
+ execution_time = int((time.time() - start_time) * 1000)
566
+ if on_success:
567
+ await on_success(execution_time)
568
+
243
569
  return None, execution_time
244
570
 
245
571
  async def execute_downgrade(
246
- self, driver: "AsyncDriverAdapterBase", migration: "dict[str, Any]"
247
- ) -> "tuple[Optional[str], int]":
572
+ self,
573
+ driver: "AsyncDriverAdapterBase",
574
+ migration: "dict[str, Any]",
575
+ *,
576
+ use_transaction: "bool | None" = None,
577
+ on_success: "Callable[[int], Awaitable[None]] | None" = None,
578
+ ) -> "tuple[str | None, int]":
248
579
  """Execute a downgrade migration.
249
580
 
250
581
  Args:
251
582
  driver: The async database driver to use.
252
583
  migration: Migration metadata dictionary.
584
+ use_transaction: Override transaction behavior. If None, uses should_use_transaction logic.
585
+ on_success: Async callback invoked with execution_time_ms before commit (for version tracking).
253
586
 
254
587
  Returns:
255
588
  Tuple of (sql_content, execution_time_ms).
@@ -258,14 +591,71 @@ class AsyncMigrationRunner(BaseMigrationRunner["AsyncDriverAdapterBase"]):
258
591
  if downgrade_sql_list is None:
259
592
  return None, 0
260
593
 
594
+ if use_transaction is None:
595
+ config = self.context.config if self.context else None
596
+ use_transaction = self.should_use_transaction(migration, config) if config else False
597
+
261
598
  start_time = time.time()
262
599
 
263
- for sql_statement in downgrade_sql_list:
264
- if sql_statement.strip():
265
- await driver.execute_script(sql_statement)
266
- execution_time = int((time.time() - start_time) * 1000)
600
+ if use_transaction:
601
+ try:
602
+ await driver.begin()
603
+ for sql_statement in downgrade_sql_list:
604
+ if sql_statement.strip():
605
+ await driver.execute_script(sql_statement)
606
+ execution_time = int((time.time() - start_time) * 1000)
607
+ if on_success:
608
+ await on_success(execution_time)
609
+ await driver.commit()
610
+ except Exception:
611
+ await driver.rollback()
612
+ raise
613
+ else:
614
+ for sql_statement in downgrade_sql_list:
615
+ if sql_statement.strip():
616
+ await driver.execute_script(sql_statement)
617
+ execution_time = int((time.time() - start_time) * 1000)
618
+ if on_success:
619
+ await on_success(execution_time)
620
+
267
621
  return None, execution_time
268
622
 
623
+ async def _get_migration_sql_async(self, migration: "dict[str, Any]", direction: str) -> "list[str] | None":
624
+ """Get migration SQL for given direction (async version).
625
+
626
+ Args:
627
+ migration: Migration metadata.
628
+ direction: Either 'up' or 'down'.
629
+
630
+ Returns:
631
+ SQL statements for the migration.
632
+ """
633
+ # If this is being called during migration loading (no has_*grade field yet),
634
+ # don't raise/warn - just proceed to check if the method exists
635
+ if f"has_{direction}grade" in migration and not migration.get(f"has_{direction}grade"):
636
+ if direction == "down":
637
+ logger.warning("Migration %s has no downgrade query", migration.get("version"))
638
+ return None
639
+ msg = f"Migration {migration.get('version')} has no upgrade query"
640
+ raise ValueError(msg)
641
+
642
+ file_path, loader = migration["file_path"], migration["loader"]
643
+
644
+ try:
645
+ method = loader.get_up_sql if direction == "up" else loader.get_down_sql
646
+ sql_statements = await method(file_path)
647
+
648
+ except Exception as e:
649
+ if direction == "down":
650
+ logger.warning("Failed to load downgrade for migration %s: %s", migration.get("version"), e)
651
+ return None
652
+ msg = f"Failed to load upgrade for migration {migration.get('version')}: {e}"
653
+ raise ValueError(msg) from e
654
+ else:
655
+ if sql_statements:
656
+ return cast("list[str]", sql_statements)
657
+ return None
658
+
269
659
  async def load_all_migrations(self) -> "dict[str, SQL]":
270
660
  """Load all migrations into a single namespace for bulk operations.
271
661
 
@@ -277,11 +667,13 @@ class AsyncMigrationRunner(BaseMigrationRunner["AsyncDriverAdapterBase"]):
277
667
 
278
668
  for version, file_path in migrations:
279
669
  if file_path.suffix == ".sql":
280
- self.loader.load_sql(file_path)
670
+ await async_(self.loader.load_sql)(file_path)
281
671
  for query_name in self.loader.list_queries():
282
672
  all_queries[query_name] = self.loader.get_sql(query_name)
283
673
  else:
284
- loader = get_migration_loader(file_path, self.migrations_path, self.project_root)
674
+ loader = get_migration_loader(
675
+ file_path, self.migrations_path, self.project_root, self.context, self.loader
676
+ )
285
677
 
286
678
  try:
287
679
  up_sql = await loader.get_up_sql(file_path)
@@ -296,3 +688,47 @@ class AsyncMigrationRunner(BaseMigrationRunner["AsyncDriverAdapterBase"]):
296
688
  logger.debug("Failed to load Python migration %s: %s", file_path, e)
297
689
 
298
690
  return all_queries
691
+
692
+
693
+ @overload
694
+ def create_migration_runner(
695
+ migrations_path: Path,
696
+ extension_migrations: "dict[str, Path]",
697
+ context: "MigrationContext | None",
698
+ extension_configs: "dict[str, Any]",
699
+ is_async: "Literal[False]" = False,
700
+ ) -> SyncMigrationRunner: ...
701
+
702
+
703
+ @overload
704
+ def create_migration_runner(
705
+ migrations_path: Path,
706
+ extension_migrations: "dict[str, Path]",
707
+ context: "MigrationContext | None",
708
+ extension_configs: "dict[str, Any]",
709
+ is_async: "Literal[True]",
710
+ ) -> AsyncMigrationRunner: ...
711
+
712
+
713
+ def create_migration_runner(
714
+ migrations_path: Path,
715
+ extension_migrations: "dict[str, Path]",
716
+ context: "MigrationContext | None",
717
+ extension_configs: "dict[str, Any]",
718
+ is_async: bool = False,
719
+ ) -> "SyncMigrationRunner | AsyncMigrationRunner":
720
+ """Factory function to create the appropriate migration runner.
721
+
722
+ Args:
723
+ migrations_path: Path to migrations directory.
724
+ extension_migrations: Extension migration paths.
725
+ context: Migration context.
726
+ extension_configs: Extension configurations.
727
+ is_async: Whether to create async or sync runner.
728
+
729
+ Returns:
730
+ Appropriate migration runner instance.
731
+ """
732
+ if is_async:
733
+ return AsyncMigrationRunner(migrations_path, extension_migrations, context, extension_configs)
734
+ return SyncMigrationRunner(migrations_path, extension_migrations, context, extension_configs)