sqlspec 0.26.0__py3-none-any.whl → 0.28.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of sqlspec might be problematic. Click here for more details.

Files changed (212) hide show
  1. sqlspec/__init__.py +7 -15
  2. sqlspec/_serialization.py +55 -25
  3. sqlspec/_typing.py +155 -52
  4. sqlspec/adapters/adbc/_types.py +1 -1
  5. sqlspec/adapters/adbc/adk/__init__.py +5 -0
  6. sqlspec/adapters/adbc/adk/store.py +880 -0
  7. sqlspec/adapters/adbc/config.py +62 -12
  8. sqlspec/adapters/adbc/data_dictionary.py +74 -2
  9. sqlspec/adapters/adbc/driver.py +226 -58
  10. sqlspec/adapters/adbc/litestar/__init__.py +5 -0
  11. sqlspec/adapters/adbc/litestar/store.py +504 -0
  12. sqlspec/adapters/adbc/type_converter.py +44 -50
  13. sqlspec/adapters/aiosqlite/_types.py +1 -1
  14. sqlspec/adapters/aiosqlite/adk/__init__.py +5 -0
  15. sqlspec/adapters/aiosqlite/adk/store.py +536 -0
  16. sqlspec/adapters/aiosqlite/config.py +86 -16
  17. sqlspec/adapters/aiosqlite/data_dictionary.py +34 -2
  18. sqlspec/adapters/aiosqlite/driver.py +127 -38
  19. sqlspec/adapters/aiosqlite/litestar/__init__.py +5 -0
  20. sqlspec/adapters/aiosqlite/litestar/store.py +281 -0
  21. sqlspec/adapters/aiosqlite/pool.py +7 -7
  22. sqlspec/adapters/asyncmy/__init__.py +7 -1
  23. sqlspec/adapters/asyncmy/_types.py +1 -1
  24. sqlspec/adapters/asyncmy/adk/__init__.py +5 -0
  25. sqlspec/adapters/asyncmy/adk/store.py +503 -0
  26. sqlspec/adapters/asyncmy/config.py +59 -17
  27. sqlspec/adapters/asyncmy/data_dictionary.py +41 -2
  28. sqlspec/adapters/asyncmy/driver.py +293 -62
  29. sqlspec/adapters/asyncmy/litestar/__init__.py +5 -0
  30. sqlspec/adapters/asyncmy/litestar/store.py +296 -0
  31. sqlspec/adapters/asyncpg/__init__.py +2 -1
  32. sqlspec/adapters/asyncpg/_type_handlers.py +71 -0
  33. sqlspec/adapters/asyncpg/_types.py +11 -7
  34. sqlspec/adapters/asyncpg/adk/__init__.py +5 -0
  35. sqlspec/adapters/asyncpg/adk/store.py +460 -0
  36. sqlspec/adapters/asyncpg/config.py +57 -36
  37. sqlspec/adapters/asyncpg/data_dictionary.py +48 -2
  38. sqlspec/adapters/asyncpg/driver.py +153 -23
  39. sqlspec/adapters/asyncpg/litestar/__init__.py +5 -0
  40. sqlspec/adapters/asyncpg/litestar/store.py +253 -0
  41. sqlspec/adapters/bigquery/_types.py +1 -1
  42. sqlspec/adapters/bigquery/adk/__init__.py +5 -0
  43. sqlspec/adapters/bigquery/adk/store.py +585 -0
  44. sqlspec/adapters/bigquery/config.py +36 -11
  45. sqlspec/adapters/bigquery/data_dictionary.py +42 -2
  46. sqlspec/adapters/bigquery/driver.py +489 -144
  47. sqlspec/adapters/bigquery/litestar/__init__.py +5 -0
  48. sqlspec/adapters/bigquery/litestar/store.py +327 -0
  49. sqlspec/adapters/bigquery/type_converter.py +55 -23
  50. sqlspec/adapters/duckdb/_types.py +2 -2
  51. sqlspec/adapters/duckdb/adk/__init__.py +14 -0
  52. sqlspec/adapters/duckdb/adk/store.py +563 -0
  53. sqlspec/adapters/duckdb/config.py +79 -21
  54. sqlspec/adapters/duckdb/data_dictionary.py +41 -2
  55. sqlspec/adapters/duckdb/driver.py +225 -44
  56. sqlspec/adapters/duckdb/litestar/__init__.py +5 -0
  57. sqlspec/adapters/duckdb/litestar/store.py +332 -0
  58. sqlspec/adapters/duckdb/pool.py +5 -5
  59. sqlspec/adapters/duckdb/type_converter.py +51 -21
  60. sqlspec/adapters/oracledb/_numpy_handlers.py +133 -0
  61. sqlspec/adapters/oracledb/_types.py +20 -2
  62. sqlspec/adapters/oracledb/adk/__init__.py +5 -0
  63. sqlspec/adapters/oracledb/adk/store.py +1628 -0
  64. sqlspec/adapters/oracledb/config.py +120 -36
  65. sqlspec/adapters/oracledb/data_dictionary.py +87 -20
  66. sqlspec/adapters/oracledb/driver.py +475 -86
  67. sqlspec/adapters/oracledb/litestar/__init__.py +5 -0
  68. sqlspec/adapters/oracledb/litestar/store.py +765 -0
  69. sqlspec/adapters/oracledb/migrations.py +316 -25
  70. sqlspec/adapters/oracledb/type_converter.py +91 -16
  71. sqlspec/adapters/psqlpy/_type_handlers.py +44 -0
  72. sqlspec/adapters/psqlpy/_types.py +2 -1
  73. sqlspec/adapters/psqlpy/adk/__init__.py +5 -0
  74. sqlspec/adapters/psqlpy/adk/store.py +483 -0
  75. sqlspec/adapters/psqlpy/config.py +45 -19
  76. sqlspec/adapters/psqlpy/data_dictionary.py +48 -2
  77. sqlspec/adapters/psqlpy/driver.py +108 -41
  78. sqlspec/adapters/psqlpy/litestar/__init__.py +5 -0
  79. sqlspec/adapters/psqlpy/litestar/store.py +272 -0
  80. sqlspec/adapters/psqlpy/type_converter.py +40 -11
  81. sqlspec/adapters/psycopg/_type_handlers.py +80 -0
  82. sqlspec/adapters/psycopg/_types.py +2 -1
  83. sqlspec/adapters/psycopg/adk/__init__.py +5 -0
  84. sqlspec/adapters/psycopg/adk/store.py +962 -0
  85. sqlspec/adapters/psycopg/config.py +65 -37
  86. sqlspec/adapters/psycopg/data_dictionary.py +91 -3
  87. sqlspec/adapters/psycopg/driver.py +200 -78
  88. sqlspec/adapters/psycopg/litestar/__init__.py +5 -0
  89. sqlspec/adapters/psycopg/litestar/store.py +554 -0
  90. sqlspec/adapters/sqlite/__init__.py +2 -1
  91. sqlspec/adapters/sqlite/_type_handlers.py +86 -0
  92. sqlspec/adapters/sqlite/_types.py +1 -1
  93. sqlspec/adapters/sqlite/adk/__init__.py +5 -0
  94. sqlspec/adapters/sqlite/adk/store.py +582 -0
  95. sqlspec/adapters/sqlite/config.py +85 -16
  96. sqlspec/adapters/sqlite/data_dictionary.py +34 -2
  97. sqlspec/adapters/sqlite/driver.py +120 -52
  98. sqlspec/adapters/sqlite/litestar/__init__.py +5 -0
  99. sqlspec/adapters/sqlite/litestar/store.py +318 -0
  100. sqlspec/adapters/sqlite/pool.py +5 -5
  101. sqlspec/base.py +45 -26
  102. sqlspec/builder/__init__.py +73 -4
  103. sqlspec/builder/_base.py +91 -58
  104. sqlspec/builder/_column.py +5 -5
  105. sqlspec/builder/_ddl.py +98 -89
  106. sqlspec/builder/_delete.py +5 -4
  107. sqlspec/builder/_dml.py +388 -0
  108. sqlspec/{_sql.py → builder/_factory.py} +41 -44
  109. sqlspec/builder/_insert.py +5 -82
  110. sqlspec/builder/{mixins/_join_operations.py → _join.py} +145 -143
  111. sqlspec/builder/_merge.py +446 -11
  112. sqlspec/builder/_parsing_utils.py +9 -11
  113. sqlspec/builder/_select.py +1313 -25
  114. sqlspec/builder/_update.py +11 -42
  115. sqlspec/cli.py +76 -69
  116. sqlspec/config.py +331 -62
  117. sqlspec/core/__init__.py +5 -4
  118. sqlspec/core/cache.py +18 -18
  119. sqlspec/core/compiler.py +6 -8
  120. sqlspec/core/filters.py +55 -47
  121. sqlspec/core/hashing.py +9 -9
  122. sqlspec/core/parameters.py +76 -45
  123. sqlspec/core/result.py +234 -47
  124. sqlspec/core/splitter.py +16 -17
  125. sqlspec/core/statement.py +32 -31
  126. sqlspec/core/type_conversion.py +3 -2
  127. sqlspec/driver/__init__.py +1 -3
  128. sqlspec/driver/_async.py +183 -160
  129. sqlspec/driver/_common.py +197 -109
  130. sqlspec/driver/_sync.py +189 -161
  131. sqlspec/driver/mixins/_result_tools.py +20 -236
  132. sqlspec/driver/mixins/_sql_translator.py +4 -4
  133. sqlspec/exceptions.py +70 -7
  134. sqlspec/extensions/adk/__init__.py +53 -0
  135. sqlspec/extensions/adk/_types.py +51 -0
  136. sqlspec/extensions/adk/converters.py +172 -0
  137. sqlspec/extensions/adk/migrations/0001_create_adk_tables.py +144 -0
  138. sqlspec/extensions/adk/migrations/__init__.py +0 -0
  139. sqlspec/extensions/adk/service.py +181 -0
  140. sqlspec/extensions/adk/store.py +536 -0
  141. sqlspec/extensions/aiosql/adapter.py +69 -61
  142. sqlspec/extensions/fastapi/__init__.py +21 -0
  143. sqlspec/extensions/fastapi/extension.py +331 -0
  144. sqlspec/extensions/fastapi/providers.py +543 -0
  145. sqlspec/extensions/flask/__init__.py +36 -0
  146. sqlspec/extensions/flask/_state.py +71 -0
  147. sqlspec/extensions/flask/_utils.py +40 -0
  148. sqlspec/extensions/flask/extension.py +389 -0
  149. sqlspec/extensions/litestar/__init__.py +21 -4
  150. sqlspec/extensions/litestar/cli.py +54 -10
  151. sqlspec/extensions/litestar/config.py +56 -266
  152. sqlspec/extensions/litestar/handlers.py +46 -17
  153. sqlspec/extensions/litestar/migrations/0001_create_session_table.py +137 -0
  154. sqlspec/extensions/litestar/migrations/__init__.py +3 -0
  155. sqlspec/extensions/litestar/plugin.py +349 -224
  156. sqlspec/extensions/litestar/providers.py +25 -25
  157. sqlspec/extensions/litestar/store.py +265 -0
  158. sqlspec/extensions/starlette/__init__.py +10 -0
  159. sqlspec/extensions/starlette/_state.py +25 -0
  160. sqlspec/extensions/starlette/_utils.py +52 -0
  161. sqlspec/extensions/starlette/extension.py +254 -0
  162. sqlspec/extensions/starlette/middleware.py +154 -0
  163. sqlspec/loader.py +30 -49
  164. sqlspec/migrations/base.py +200 -76
  165. sqlspec/migrations/commands.py +591 -62
  166. sqlspec/migrations/context.py +6 -9
  167. sqlspec/migrations/fix.py +199 -0
  168. sqlspec/migrations/loaders.py +47 -19
  169. sqlspec/migrations/runner.py +241 -75
  170. sqlspec/migrations/tracker.py +237 -21
  171. sqlspec/migrations/utils.py +51 -3
  172. sqlspec/migrations/validation.py +177 -0
  173. sqlspec/protocols.py +106 -36
  174. sqlspec/storage/_utils.py +85 -0
  175. sqlspec/storage/backends/fsspec.py +133 -107
  176. sqlspec/storage/backends/local.py +78 -51
  177. sqlspec/storage/backends/obstore.py +276 -168
  178. sqlspec/storage/registry.py +75 -39
  179. sqlspec/typing.py +30 -84
  180. sqlspec/utils/__init__.py +25 -4
  181. sqlspec/utils/arrow_helpers.py +81 -0
  182. sqlspec/utils/config_resolver.py +6 -6
  183. sqlspec/utils/correlation.py +4 -5
  184. sqlspec/utils/data_transformation.py +3 -2
  185. sqlspec/utils/deprecation.py +9 -8
  186. sqlspec/utils/fixtures.py +4 -4
  187. sqlspec/utils/logging.py +46 -6
  188. sqlspec/utils/module_loader.py +205 -5
  189. sqlspec/utils/portal.py +311 -0
  190. sqlspec/utils/schema.py +288 -0
  191. sqlspec/utils/serializers.py +113 -4
  192. sqlspec/utils/sync_tools.py +36 -22
  193. sqlspec/utils/text.py +1 -2
  194. sqlspec/utils/type_guards.py +136 -20
  195. sqlspec/utils/version.py +433 -0
  196. {sqlspec-0.26.0.dist-info → sqlspec-0.28.0.dist-info}/METADATA +41 -22
  197. sqlspec-0.28.0.dist-info/RECORD +221 -0
  198. sqlspec/builder/mixins/__init__.py +0 -55
  199. sqlspec/builder/mixins/_cte_and_set_ops.py +0 -253
  200. sqlspec/builder/mixins/_delete_operations.py +0 -50
  201. sqlspec/builder/mixins/_insert_operations.py +0 -282
  202. sqlspec/builder/mixins/_merge_operations.py +0 -698
  203. sqlspec/builder/mixins/_order_limit_operations.py +0 -145
  204. sqlspec/builder/mixins/_pivot_operations.py +0 -157
  205. sqlspec/builder/mixins/_select_operations.py +0 -930
  206. sqlspec/builder/mixins/_update_operations.py +0 -199
  207. sqlspec/builder/mixins/_where_clause.py +0 -1298
  208. sqlspec-0.26.0.dist-info/RECORD +0 -157
  209. sqlspec-0.26.0.dist-info/licenses/NOTICE +0 -29
  210. {sqlspec-0.26.0.dist-info → sqlspec-0.28.0.dist-info}/WHEEL +0 -0
  211. {sqlspec-0.26.0.dist-info → sqlspec-0.28.0.dist-info}/entry_points.txt +0 -0
  212. {sqlspec-0.26.0.dist-info → sqlspec-0.28.0.dist-info}/licenses/LICENSE +0 -0
@@ -4,11 +4,11 @@ This module provides separate sync and async migration runners with clean separa
4
4
  of concerns and proper type safety.
5
5
  """
6
6
 
7
- import operator
7
+ import inspect
8
8
  import time
9
9
  from abc import ABC, abstractmethod
10
10
  from pathlib import Path
11
- from typing import TYPE_CHECKING, Any, Literal, Optional, Union, cast, overload
11
+ from typing import TYPE_CHECKING, Any, Literal, Union, cast, overload
12
12
 
13
13
  from sqlspec.core.statement import SQL
14
14
  from sqlspec.migrations.context import MigrationContext
@@ -17,7 +17,7 @@ from sqlspec.utils.logging import get_logger
17
17
  from sqlspec.utils.sync_tools import async_, await_
18
18
 
19
19
  if TYPE_CHECKING:
20
- from collections.abc import Coroutine
20
+ from collections.abc import Awaitable, Callable, Coroutine
21
21
 
22
22
  from sqlspec.driver import AsyncDriverAdapterBase, SyncDriverAdapterBase
23
23
 
@@ -32,9 +32,9 @@ class BaseMigrationRunner(ABC):
32
32
  def __init__(
33
33
  self,
34
34
  migrations_path: Path,
35
- extension_migrations: "Optional[dict[str, Path]]" = None,
36
- context: "Optional[MigrationContext]" = None,
37
- extension_configs: "Optional[dict[str, dict[str, Any]]]" = None,
35
+ extension_migrations: "dict[str, Path] | None" = None,
36
+ context: "MigrationContext | None" = None,
37
+ extension_configs: "dict[str, dict[str, Any]] | None" = None,
38
38
  ) -> None:
39
39
  """Initialize the migration runner.
40
40
 
@@ -49,31 +49,46 @@ class BaseMigrationRunner(ABC):
49
49
  from sqlspec.loader import SQLFileLoader
50
50
 
51
51
  self.loader = SQLFileLoader()
52
- self.project_root: Optional[Path] = None
52
+ self.project_root: Path | None = None
53
53
  self.context = context
54
54
  self.extension_configs = extension_configs or {}
55
55
 
56
- def _extract_version(self, filename: str) -> "Optional[str]":
56
+ def _extract_version(self, filename: str) -> "str | None":
57
57
  """Extract version from filename.
58
58
 
59
+ Supports sequential (0001), timestamp (20251011120000), and extension-prefixed
60
+ (ext_litestar_0001) version formats.
61
+
59
62
  Args:
60
63
  filename: The migration filename.
61
64
 
62
65
  Returns:
63
66
  The extracted version string or None.
64
67
  """
65
- # Handle extension-prefixed versions (e.g., "ext_litestar_0001")
66
- if filename.startswith("ext_"):
67
- # This is already a prefixed version, return as-is
68
- return filename
68
+ extension_version_parts = 3
69
+ timestamp_min_length = 4
70
+
71
+ name_without_ext = filename.rsplit(".", 1)[0]
72
+
73
+ if name_without_ext.startswith("ext_"):
74
+ parts = name_without_ext.split("_", 3)
75
+ if len(parts) >= extension_version_parts:
76
+ return f"{parts[0]}_{parts[1]}_{parts[2]}"
77
+ return None
69
78
 
70
- # Regular version extraction
71
- parts = filename.split("_", 1)
72
- return parts[0].zfill(4) if parts and parts[0].isdigit() else None
79
+ parts = name_without_ext.split("_", 1)
80
+ if parts and parts[0].isdigit():
81
+ return parts[0] if len(parts[0]) > timestamp_min_length else parts[0].zfill(4)
82
+
83
+ return None
73
84
 
74
85
  def _calculate_checksum(self, content: str) -> str:
75
86
  """Calculate MD5 checksum of migration content.
76
87
 
88
+ Canonicalizes content by excluding query name headers that change during
89
+ fix command (migrate-{version}-up/down). This ensures checksums remain
90
+ stable when converting timestamp versions to sequential format.
91
+
77
92
  Args:
78
93
  content: The migration file content.
79
94
 
@@ -81,8 +96,11 @@ class BaseMigrationRunner(ABC):
81
96
  MD5 checksum hex string.
82
97
  """
83
98
  import hashlib
99
+ import re
100
+
101
+ canonical_content = re.sub(r"^--\s*name:\s*migrate-[^-]+-(?:up|down)\s*$", "", content, flags=re.MULTILINE)
84
102
 
85
- return hashlib.md5(content.encode()).hexdigest() # noqa: S324
103
+ return hashlib.md5(canonical_content.encode()).hexdigest() # noqa: S324
86
104
 
87
105
  @abstractmethod
88
106
  def load_migration(self, file_path: Path) -> Union["dict[str, Any]", "Coroutine[Any, Any, dict[str, Any]]"]:
@@ -129,7 +147,16 @@ class BaseMigrationRunner(ABC):
129
147
  prefixed_version = f"ext_{ext_name}_{version}"
130
148
  migrations.append((prefixed_version, file_path))
131
149
 
132
- return sorted(migrations, key=operator.itemgetter(0))
150
+ from sqlspec.utils.version import parse_version
151
+
152
+ def version_sort_key(migration_tuple: "tuple[str, Path]") -> "Any":
153
+ version_str = migration_tuple[0]
154
+ try:
155
+ return parse_version(version_str)
156
+ except ValueError:
157
+ return version_str
158
+
159
+ return sorted(migrations, key=version_sort_key)
133
160
 
134
161
  def get_migration_files(self) -> "list[tuple[str, Path]]":
135
162
  """Get all migration files sorted by version.
@@ -139,29 +166,41 @@ class BaseMigrationRunner(ABC):
139
166
  """
140
167
  return self._get_migration_files_sync()
141
168
 
142
- def _load_migration_metadata_common(self, file_path: Path) -> "dict[str, Any]":
169
+ def _load_migration_metadata_common(self, file_path: Path, version: "str | None" = None) -> "dict[str, Any]":
143
170
  """Load common migration metadata that doesn't require async operations.
144
171
 
145
172
  Args:
146
173
  file_path: Path to the migration file.
174
+ version: Optional pre-extracted version (preserves prefixes like ext_adk_0001).
147
175
 
148
176
  Returns:
149
177
  Partial migration metadata dictionary.
150
178
  """
179
+ import re
180
+
151
181
  content = file_path.read_text(encoding="utf-8")
152
182
  checksum = self._calculate_checksum(content)
153
- version = self._extract_version(file_path.name)
183
+ if version is None:
184
+ version = self._extract_version(file_path.name)
154
185
  description = file_path.stem.split("_", 1)[1] if "_" in file_path.stem else ""
155
186
 
187
+ transactional_match = re.search(
188
+ r"^--\s*transactional:\s*(true|false)\s*$", content, re.MULTILINE | re.IGNORECASE
189
+ )
190
+ transactional = None
191
+ if transactional_match:
192
+ transactional = transactional_match.group(1).lower() == "true"
193
+
156
194
  return {
157
195
  "version": version,
158
196
  "description": description,
159
197
  "file_path": file_path,
160
198
  "checksum": checksum,
161
199
  "content": content,
200
+ "transactional": transactional,
162
201
  }
163
202
 
164
- def _get_context_for_migration(self, file_path: Path) -> "Optional[MigrationContext]":
203
+ def _get_context_for_migration(self, file_path: Path) -> "MigrationContext | None":
165
204
  """Get the appropriate context for a migration file.
166
205
 
167
206
  Args:
@@ -201,24 +240,43 @@ class BaseMigrationRunner(ABC):
201
240
 
202
241
  return context_to_use
203
242
 
243
+ def should_use_transaction(self, migration: "dict[str, Any]", config: Any) -> bool:
244
+ """Determine if migration should run in a transaction.
245
+
246
+ Args:
247
+ migration: Migration metadata dictionary.
248
+ config: The database configuration instance.
249
+
250
+ Returns:
251
+ True if migration should be wrapped in a transaction.
252
+ """
253
+ if not config.supports_transactional_ddl:
254
+ return False
255
+
256
+ if migration.get("transactional") is not None:
257
+ return bool(migration["transactional"])
258
+
259
+ migration_config = getattr(config, "migration_config", {}) or {}
260
+ return bool(migration_config.get("transactional", True))
261
+
204
262
 
205
263
  class SyncMigrationRunner(BaseMigrationRunner):
206
264
  """Synchronous migration runner with pure sync methods."""
207
265
 
208
- def load_migration(self, file_path: Path) -> "dict[str, Any]":
266
+ def load_migration(self, file_path: Path, version: "str | None" = None) -> "dict[str, Any]":
209
267
  """Load a migration file and extract its components.
210
268
 
211
269
  Args:
212
270
  file_path: Path to the migration file.
271
+ version: Optional pre-extracted version (preserves prefixes like ext_adk_0001).
213
272
 
214
273
  Returns:
215
274
  Dictionary containing migration metadata and queries.
216
275
  """
217
- # Get common metadata
218
- metadata = self._load_migration_metadata_common(file_path)
276
+ metadata = self._load_migration_metadata_common(file_path, version)
219
277
  context_to_use = self._get_context_for_migration(file_path)
220
278
 
221
- loader = get_migration_loader(file_path, self.migrations_path, self.project_root, context_to_use)
279
+ loader = get_migration_loader(file_path, self.migrations_path, self.project_root, context_to_use, self.loader)
222
280
  loader.validate_migration_file(file_path)
223
281
 
224
282
  has_upgrade, has_downgrade = True, False
@@ -226,8 +284,6 @@ class SyncMigrationRunner(BaseMigrationRunner):
226
284
  if file_path.suffix == ".sql":
227
285
  version = metadata["version"]
228
286
  up_query, down_query = f"migrate-{version}-up", f"migrate-{version}-down"
229
- self.loader.clear_cache()
230
- self.loader.load_sql(file_path)
231
287
  has_upgrade, has_downgrade = self.loader.has_query(up_query), self.loader.has_query(down_query)
232
288
  else:
233
289
  try:
@@ -239,13 +295,20 @@ class SyncMigrationRunner(BaseMigrationRunner):
239
295
  return metadata
240
296
 
241
297
  def execute_upgrade(
242
- self, driver: "SyncDriverAdapterBase", migration: "dict[str, Any]"
243
- ) -> "tuple[Optional[str], int]":
298
+ self,
299
+ driver: "SyncDriverAdapterBase",
300
+ migration: "dict[str, Any]",
301
+ *,
302
+ use_transaction: "bool | None" = None,
303
+ on_success: "Callable[[int], None] | None" = None,
304
+ ) -> "tuple[str | None, int]":
244
305
  """Execute an upgrade migration.
245
306
 
246
307
  Args:
247
308
  driver: The sync database driver to use.
248
309
  migration: Migration metadata dictionary.
310
+ use_transaction: Override transaction behavior. If None, uses should_use_transaction logic.
311
+ on_success: Callback invoked with execution_time_ms before commit (for version tracking).
249
312
 
250
313
  Returns:
251
314
  Tuple of (sql_content, execution_time_ms).
@@ -254,22 +317,50 @@ class SyncMigrationRunner(BaseMigrationRunner):
254
317
  if upgrade_sql_list is None:
255
318
  return None, 0
256
319
 
320
+ if use_transaction is None:
321
+ config = self.context.config if self.context else None
322
+ use_transaction = self.should_use_transaction(migration, config) if config else False
323
+
257
324
  start_time = time.time()
258
325
 
259
- for sql_statement in upgrade_sql_list:
260
- if sql_statement.strip():
261
- driver.execute_script(sql_statement)
262
- execution_time = int((time.time() - start_time) * 1000)
326
+ if use_transaction:
327
+ try:
328
+ driver.begin()
329
+ for sql_statement in upgrade_sql_list:
330
+ if sql_statement.strip():
331
+ driver.execute_script(sql_statement)
332
+ execution_time = int((time.time() - start_time) * 1000)
333
+ if on_success:
334
+ on_success(execution_time)
335
+ driver.commit()
336
+ except Exception:
337
+ driver.rollback()
338
+ raise
339
+ else:
340
+ for sql_statement in upgrade_sql_list:
341
+ if sql_statement.strip():
342
+ driver.execute_script(sql_statement)
343
+ execution_time = int((time.time() - start_time) * 1000)
344
+ if on_success:
345
+ on_success(execution_time)
346
+
263
347
  return None, execution_time
264
348
 
265
349
  def execute_downgrade(
266
- self, driver: "SyncDriverAdapterBase", migration: "dict[str, Any]"
267
- ) -> "tuple[Optional[str], int]":
350
+ self,
351
+ driver: "SyncDriverAdapterBase",
352
+ migration: "dict[str, Any]",
353
+ *,
354
+ use_transaction: "bool | None" = None,
355
+ on_success: "Callable[[int], None] | None" = None,
356
+ ) -> "tuple[str | None, int]":
268
357
  """Execute a downgrade migration.
269
358
 
270
359
  Args:
271
360
  driver: The sync database driver to use.
272
361
  migration: Migration metadata dictionary.
362
+ use_transaction: Override transaction behavior. If None, uses should_use_transaction logic.
363
+ on_success: Callback invoked with execution_time_ms before commit (for version tracking).
273
364
 
274
365
  Returns:
275
366
  Tuple of (sql_content, execution_time_ms).
@@ -278,15 +369,36 @@ class SyncMigrationRunner(BaseMigrationRunner):
278
369
  if downgrade_sql_list is None:
279
370
  return None, 0
280
371
 
372
+ if use_transaction is None:
373
+ config = self.context.config if self.context else None
374
+ use_transaction = self.should_use_transaction(migration, config) if config else False
375
+
281
376
  start_time = time.time()
282
377
 
283
- for sql_statement in downgrade_sql_list:
284
- if sql_statement.strip():
285
- driver.execute_script(sql_statement)
286
- execution_time = int((time.time() - start_time) * 1000)
378
+ if use_transaction:
379
+ try:
380
+ driver.begin()
381
+ for sql_statement in downgrade_sql_list:
382
+ if sql_statement.strip():
383
+ driver.execute_script(sql_statement)
384
+ execution_time = int((time.time() - start_time) * 1000)
385
+ if on_success:
386
+ on_success(execution_time)
387
+ driver.commit()
388
+ except Exception:
389
+ driver.rollback()
390
+ raise
391
+ else:
392
+ for sql_statement in downgrade_sql_list:
393
+ if sql_statement.strip():
394
+ driver.execute_script(sql_statement)
395
+ execution_time = int((time.time() - start_time) * 1000)
396
+ if on_success:
397
+ on_success(execution_time)
398
+
287
399
  return None, execution_time
288
400
 
289
- def _get_migration_sql_sync(self, migration: "dict[str, Any]", direction: str) -> "Optional[list[str]]":
401
+ def _get_migration_sql_sync(self, migration: "dict[str, Any]", direction: str) -> "list[str] | None":
290
402
  """Get migration SQL for given direction (sync version).
291
403
 
292
404
  Args:
@@ -309,15 +421,11 @@ class SyncMigrationRunner(BaseMigrationRunner):
309
421
 
310
422
  try:
311
423
  method = loader.get_up_sql if direction == "up" else loader.get_down_sql
312
- # Check if the method is async and handle appropriately
313
- import inspect
314
-
315
- if inspect.iscoroutinefunction(method):
316
- # For async methods, use await_ to run in sync context
317
- sql_statements = await_(method, raise_sync_error=False)(file_path)
318
- else:
319
- # For sync methods, call directly
320
- sql_statements = method(file_path)
424
+ sql_statements = (
425
+ await_(method, raise_sync_error=False)(file_path)
426
+ if inspect.iscoroutinefunction(method)
427
+ else method(file_path)
428
+ )
321
429
 
322
430
  except Exception as e:
323
431
  if direction == "down":
@@ -345,11 +453,13 @@ class SyncMigrationRunner(BaseMigrationRunner):
345
453
  for query_name in self.loader.list_queries():
346
454
  all_queries[query_name] = self.loader.get_sql(query_name)
347
455
  else:
348
- loader = get_migration_loader(file_path, self.migrations_path, self.project_root, self.context)
456
+ loader = get_migration_loader(
457
+ file_path, self.migrations_path, self.project_root, self.context, self.loader
458
+ )
349
459
 
350
460
  try:
351
- up_sql = await_(loader.get_up_sql)(file_path)
352
- down_sql = await_(loader.get_down_sql)(file_path)
461
+ up_sql = await_(loader.get_up_sql, raise_sync_error=False)(file_path)
462
+ down_sql = await_(loader.get_down_sql, raise_sync_error=False)(file_path)
353
463
 
354
464
  if up_sql:
355
465
  all_queries[f"migrate-{version}-up"] = SQL(up_sql[0])
@@ -373,20 +483,20 @@ class AsyncMigrationRunner(BaseMigrationRunner):
373
483
  """
374
484
  return self._get_migration_files_sync()
375
485
 
376
- async def load_migration(self, file_path: Path) -> "dict[str, Any]":
486
+ async def load_migration(self, file_path: Path, version: "str | None" = None) -> "dict[str, Any]":
377
487
  """Load a migration file and extract its components.
378
488
 
379
489
  Args:
380
490
  file_path: Path to the migration file.
491
+ version: Optional pre-extracted version (preserves prefixes like ext_adk_0001).
381
492
 
382
493
  Returns:
383
494
  Dictionary containing migration metadata and queries.
384
495
  """
385
- # Get common metadata
386
- metadata = self._load_migration_metadata_common(file_path)
496
+ metadata = self._load_migration_metadata_common(file_path, version)
387
497
  context_to_use = self._get_context_for_migration(file_path)
388
498
 
389
- loader = get_migration_loader(file_path, self.migrations_path, self.project_root, context_to_use)
499
+ loader = get_migration_loader(file_path, self.migrations_path, self.project_root, context_to_use, self.loader)
390
500
  loader.validate_migration_file(file_path)
391
501
 
392
502
  has_upgrade, has_downgrade = True, False
@@ -394,8 +504,6 @@ class AsyncMigrationRunner(BaseMigrationRunner):
394
504
  if file_path.suffix == ".sql":
395
505
  version = metadata["version"]
396
506
  up_query, down_query = f"migrate-{version}-up", f"migrate-{version}-down"
397
- self.loader.clear_cache()
398
- await async_(self.loader.load_sql)(file_path)
399
507
  has_upgrade, has_downgrade = self.loader.has_query(up_query), self.loader.has_query(down_query)
400
508
  else:
401
509
  try:
@@ -409,13 +517,20 @@ class AsyncMigrationRunner(BaseMigrationRunner):
409
517
  return metadata
410
518
 
411
519
  async def execute_upgrade(
412
- self, driver: "AsyncDriverAdapterBase", migration: "dict[str, Any]"
413
- ) -> "tuple[Optional[str], int]":
520
+ self,
521
+ driver: "AsyncDriverAdapterBase",
522
+ migration: "dict[str, Any]",
523
+ *,
524
+ use_transaction: "bool | None" = None,
525
+ on_success: "Callable[[int], Awaitable[None]] | None" = None,
526
+ ) -> "tuple[str | None, int]":
414
527
  """Execute an upgrade migration.
415
528
 
416
529
  Args:
417
530
  driver: The async database driver to use.
418
531
  migration: Migration metadata dictionary.
532
+ use_transaction: Override transaction behavior. If None, uses should_use_transaction logic.
533
+ on_success: Async callback invoked with execution_time_ms before commit (for version tracking).
419
534
 
420
535
  Returns:
421
536
  Tuple of (sql_content, execution_time_ms).
@@ -424,22 +539,50 @@ class AsyncMigrationRunner(BaseMigrationRunner):
424
539
  if upgrade_sql_list is None:
425
540
  return None, 0
426
541
 
542
+ if use_transaction is None:
543
+ config = self.context.config if self.context else None
544
+ use_transaction = self.should_use_transaction(migration, config) if config else False
545
+
427
546
  start_time = time.time()
428
547
 
429
- for sql_statement in upgrade_sql_list:
430
- if sql_statement.strip():
431
- await driver.execute_script(sql_statement)
432
- execution_time = int((time.time() - start_time) * 1000)
548
+ if use_transaction:
549
+ try:
550
+ await driver.begin()
551
+ for sql_statement in upgrade_sql_list:
552
+ if sql_statement.strip():
553
+ await driver.execute_script(sql_statement)
554
+ execution_time = int((time.time() - start_time) * 1000)
555
+ if on_success:
556
+ await on_success(execution_time)
557
+ await driver.commit()
558
+ except Exception:
559
+ await driver.rollback()
560
+ raise
561
+ else:
562
+ for sql_statement in upgrade_sql_list:
563
+ if sql_statement.strip():
564
+ await driver.execute_script(sql_statement)
565
+ execution_time = int((time.time() - start_time) * 1000)
566
+ if on_success:
567
+ await on_success(execution_time)
568
+
433
569
  return None, execution_time
434
570
 
435
571
  async def execute_downgrade(
436
- self, driver: "AsyncDriverAdapterBase", migration: "dict[str, Any]"
437
- ) -> "tuple[Optional[str], int]":
572
+ self,
573
+ driver: "AsyncDriverAdapterBase",
574
+ migration: "dict[str, Any]",
575
+ *,
576
+ use_transaction: "bool | None" = None,
577
+ on_success: "Callable[[int], Awaitable[None]] | None" = None,
578
+ ) -> "tuple[str | None, int]":
438
579
  """Execute a downgrade migration.
439
580
 
440
581
  Args:
441
582
  driver: The async database driver to use.
442
583
  migration: Migration metadata dictionary.
584
+ use_transaction: Override transaction behavior. If None, uses should_use_transaction logic.
585
+ on_success: Async callback invoked with execution_time_ms before commit (for version tracking).
443
586
 
444
587
  Returns:
445
588
  Tuple of (sql_content, execution_time_ms).
@@ -448,15 +591,36 @@ class AsyncMigrationRunner(BaseMigrationRunner):
448
591
  if downgrade_sql_list is None:
449
592
  return None, 0
450
593
 
594
+ if use_transaction is None:
595
+ config = self.context.config if self.context else None
596
+ use_transaction = self.should_use_transaction(migration, config) if config else False
597
+
451
598
  start_time = time.time()
452
599
 
453
- for sql_statement in downgrade_sql_list:
454
- if sql_statement.strip():
455
- await driver.execute_script(sql_statement)
456
- execution_time = int((time.time() - start_time) * 1000)
600
+ if use_transaction:
601
+ try:
602
+ await driver.begin()
603
+ for sql_statement in downgrade_sql_list:
604
+ if sql_statement.strip():
605
+ await driver.execute_script(sql_statement)
606
+ execution_time = int((time.time() - start_time) * 1000)
607
+ if on_success:
608
+ await on_success(execution_time)
609
+ await driver.commit()
610
+ except Exception:
611
+ await driver.rollback()
612
+ raise
613
+ else:
614
+ for sql_statement in downgrade_sql_list:
615
+ if sql_statement.strip():
616
+ await driver.execute_script(sql_statement)
617
+ execution_time = int((time.time() - start_time) * 1000)
618
+ if on_success:
619
+ await on_success(execution_time)
620
+
457
621
  return None, execution_time
458
622
 
459
- async def _get_migration_sql_async(self, migration: "dict[str, Any]", direction: str) -> "Optional[list[str]]":
623
+ async def _get_migration_sql_async(self, migration: "dict[str, Any]", direction: str) -> "list[str] | None":
460
624
  """Get migration SQL for given direction (async version).
461
625
 
462
626
  Args:
@@ -507,7 +671,9 @@ class AsyncMigrationRunner(BaseMigrationRunner):
507
671
  for query_name in self.loader.list_queries():
508
672
  all_queries[query_name] = self.loader.get_sql(query_name)
509
673
  else:
510
- loader = get_migration_loader(file_path, self.migrations_path, self.project_root, self.context)
674
+ loader = get_migration_loader(
675
+ file_path, self.migrations_path, self.project_root, self.context, self.loader
676
+ )
511
677
 
512
678
  try:
513
679
  up_sql = await loader.get_up_sql(file_path)
@@ -528,7 +694,7 @@ class AsyncMigrationRunner(BaseMigrationRunner):
528
694
  def create_migration_runner(
529
695
  migrations_path: Path,
530
696
  extension_migrations: "dict[str, Path]",
531
- context: "Optional[MigrationContext]",
697
+ context: "MigrationContext | None",
532
698
  extension_configs: "dict[str, Any]",
533
699
  is_async: "Literal[False]" = False,
534
700
  ) -> SyncMigrationRunner: ...
@@ -538,7 +704,7 @@ def create_migration_runner(
538
704
  def create_migration_runner(
539
705
  migrations_path: Path,
540
706
  extension_migrations: "dict[str, Path]",
541
- context: "Optional[MigrationContext]",
707
+ context: "MigrationContext | None",
542
708
  extension_configs: "dict[str, Any]",
543
709
  is_async: "Literal[True]",
544
710
  ) -> AsyncMigrationRunner: ...
@@ -547,10 +713,10 @@ def create_migration_runner(
547
713
  def create_migration_runner(
548
714
  migrations_path: Path,
549
715
  extension_migrations: "dict[str, Path]",
550
- context: "Optional[MigrationContext]",
716
+ context: "MigrationContext | None",
551
717
  extension_configs: "dict[str, Any]",
552
718
  is_async: bool = False,
553
- ) -> "Union[SyncMigrationRunner, AsyncMigrationRunner]":
719
+ ) -> "SyncMigrationRunner | AsyncMigrationRunner":
554
720
  """Factory function to create the appropriate migration runner.
555
721
 
556
722
  Args: