sqlspec 0.25.0__py3-none-any.whl → 0.27.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of sqlspec might be problematic. Click here for more details.

Files changed (199) hide show
  1. sqlspec/__init__.py +7 -15
  2. sqlspec/_serialization.py +256 -24
  3. sqlspec/_typing.py +71 -52
  4. sqlspec/adapters/adbc/_types.py +1 -1
  5. sqlspec/adapters/adbc/adk/__init__.py +5 -0
  6. sqlspec/adapters/adbc/adk/store.py +870 -0
  7. sqlspec/adapters/adbc/config.py +69 -12
  8. sqlspec/adapters/adbc/data_dictionary.py +340 -0
  9. sqlspec/adapters/adbc/driver.py +266 -58
  10. sqlspec/adapters/adbc/litestar/__init__.py +5 -0
  11. sqlspec/adapters/adbc/litestar/store.py +504 -0
  12. sqlspec/adapters/adbc/type_converter.py +153 -0
  13. sqlspec/adapters/aiosqlite/_types.py +1 -1
  14. sqlspec/adapters/aiosqlite/adk/__init__.py +5 -0
  15. sqlspec/adapters/aiosqlite/adk/store.py +527 -0
  16. sqlspec/adapters/aiosqlite/config.py +88 -15
  17. sqlspec/adapters/aiosqlite/data_dictionary.py +149 -0
  18. sqlspec/adapters/aiosqlite/driver.py +143 -40
  19. sqlspec/adapters/aiosqlite/litestar/__init__.py +5 -0
  20. sqlspec/adapters/aiosqlite/litestar/store.py +281 -0
  21. sqlspec/adapters/aiosqlite/pool.py +7 -7
  22. sqlspec/adapters/asyncmy/__init__.py +7 -1
  23. sqlspec/adapters/asyncmy/_types.py +2 -2
  24. sqlspec/adapters/asyncmy/adk/__init__.py +5 -0
  25. sqlspec/adapters/asyncmy/adk/store.py +493 -0
  26. sqlspec/adapters/asyncmy/config.py +68 -23
  27. sqlspec/adapters/asyncmy/data_dictionary.py +161 -0
  28. sqlspec/adapters/asyncmy/driver.py +313 -58
  29. sqlspec/adapters/asyncmy/litestar/__init__.py +5 -0
  30. sqlspec/adapters/asyncmy/litestar/store.py +296 -0
  31. sqlspec/adapters/asyncpg/__init__.py +2 -1
  32. sqlspec/adapters/asyncpg/_type_handlers.py +71 -0
  33. sqlspec/adapters/asyncpg/_types.py +11 -7
  34. sqlspec/adapters/asyncpg/adk/__init__.py +5 -0
  35. sqlspec/adapters/asyncpg/adk/store.py +450 -0
  36. sqlspec/adapters/asyncpg/config.py +59 -35
  37. sqlspec/adapters/asyncpg/data_dictionary.py +173 -0
  38. sqlspec/adapters/asyncpg/driver.py +170 -25
  39. sqlspec/adapters/asyncpg/litestar/__init__.py +5 -0
  40. sqlspec/adapters/asyncpg/litestar/store.py +253 -0
  41. sqlspec/adapters/bigquery/_types.py +1 -1
  42. sqlspec/adapters/bigquery/adk/__init__.py +5 -0
  43. sqlspec/adapters/bigquery/adk/store.py +576 -0
  44. sqlspec/adapters/bigquery/config.py +27 -10
  45. sqlspec/adapters/bigquery/data_dictionary.py +149 -0
  46. sqlspec/adapters/bigquery/driver.py +368 -142
  47. sqlspec/adapters/bigquery/litestar/__init__.py +5 -0
  48. sqlspec/adapters/bigquery/litestar/store.py +327 -0
  49. sqlspec/adapters/bigquery/type_converter.py +125 -0
  50. sqlspec/adapters/duckdb/_types.py +1 -1
  51. sqlspec/adapters/duckdb/adk/__init__.py +14 -0
  52. sqlspec/adapters/duckdb/adk/store.py +553 -0
  53. sqlspec/adapters/duckdb/config.py +80 -20
  54. sqlspec/adapters/duckdb/data_dictionary.py +163 -0
  55. sqlspec/adapters/duckdb/driver.py +167 -45
  56. sqlspec/adapters/duckdb/litestar/__init__.py +5 -0
  57. sqlspec/adapters/duckdb/litestar/store.py +332 -0
  58. sqlspec/adapters/duckdb/pool.py +4 -4
  59. sqlspec/adapters/duckdb/type_converter.py +133 -0
  60. sqlspec/adapters/oracledb/_numpy_handlers.py +133 -0
  61. sqlspec/adapters/oracledb/_types.py +20 -2
  62. sqlspec/adapters/oracledb/adk/__init__.py +5 -0
  63. sqlspec/adapters/oracledb/adk/store.py +1745 -0
  64. sqlspec/adapters/oracledb/config.py +122 -32
  65. sqlspec/adapters/oracledb/data_dictionary.py +509 -0
  66. sqlspec/adapters/oracledb/driver.py +353 -91
  67. sqlspec/adapters/oracledb/litestar/__init__.py +5 -0
  68. sqlspec/adapters/oracledb/litestar/store.py +767 -0
  69. sqlspec/adapters/oracledb/migrations.py +348 -73
  70. sqlspec/adapters/oracledb/type_converter.py +207 -0
  71. sqlspec/adapters/psqlpy/_type_handlers.py +44 -0
  72. sqlspec/adapters/psqlpy/_types.py +2 -1
  73. sqlspec/adapters/psqlpy/adk/__init__.py +5 -0
  74. sqlspec/adapters/psqlpy/adk/store.py +482 -0
  75. sqlspec/adapters/psqlpy/config.py +46 -17
  76. sqlspec/adapters/psqlpy/data_dictionary.py +172 -0
  77. sqlspec/adapters/psqlpy/driver.py +123 -209
  78. sqlspec/adapters/psqlpy/litestar/__init__.py +5 -0
  79. sqlspec/adapters/psqlpy/litestar/store.py +272 -0
  80. sqlspec/adapters/psqlpy/type_converter.py +102 -0
  81. sqlspec/adapters/psycopg/_type_handlers.py +80 -0
  82. sqlspec/adapters/psycopg/_types.py +2 -1
  83. sqlspec/adapters/psycopg/adk/__init__.py +5 -0
  84. sqlspec/adapters/psycopg/adk/store.py +944 -0
  85. sqlspec/adapters/psycopg/config.py +69 -35
  86. sqlspec/adapters/psycopg/data_dictionary.py +331 -0
  87. sqlspec/adapters/psycopg/driver.py +238 -81
  88. sqlspec/adapters/psycopg/litestar/__init__.py +5 -0
  89. sqlspec/adapters/psycopg/litestar/store.py +554 -0
  90. sqlspec/adapters/sqlite/__init__.py +2 -1
  91. sqlspec/adapters/sqlite/_type_handlers.py +86 -0
  92. sqlspec/adapters/sqlite/_types.py +1 -1
  93. sqlspec/adapters/sqlite/adk/__init__.py +5 -0
  94. sqlspec/adapters/sqlite/adk/store.py +572 -0
  95. sqlspec/adapters/sqlite/config.py +87 -15
  96. sqlspec/adapters/sqlite/data_dictionary.py +149 -0
  97. sqlspec/adapters/sqlite/driver.py +137 -54
  98. sqlspec/adapters/sqlite/litestar/__init__.py +5 -0
  99. sqlspec/adapters/sqlite/litestar/store.py +318 -0
  100. sqlspec/adapters/sqlite/pool.py +18 -9
  101. sqlspec/base.py +45 -26
  102. sqlspec/builder/__init__.py +73 -4
  103. sqlspec/builder/_base.py +162 -89
  104. sqlspec/builder/_column.py +62 -29
  105. sqlspec/builder/_ddl.py +180 -121
  106. sqlspec/builder/_delete.py +5 -4
  107. sqlspec/builder/_dml.py +388 -0
  108. sqlspec/{_sql.py → builder/_factory.py} +53 -94
  109. sqlspec/builder/_insert.py +32 -131
  110. sqlspec/builder/_join.py +375 -0
  111. sqlspec/builder/_merge.py +446 -11
  112. sqlspec/builder/_parsing_utils.py +111 -17
  113. sqlspec/builder/_select.py +1457 -24
  114. sqlspec/builder/_update.py +11 -42
  115. sqlspec/cli.py +307 -194
  116. sqlspec/config.py +252 -67
  117. sqlspec/core/__init__.py +5 -4
  118. sqlspec/core/cache.py +17 -17
  119. sqlspec/core/compiler.py +62 -9
  120. sqlspec/core/filters.py +37 -37
  121. sqlspec/core/hashing.py +9 -9
  122. sqlspec/core/parameters.py +83 -48
  123. sqlspec/core/result.py +102 -46
  124. sqlspec/core/splitter.py +16 -17
  125. sqlspec/core/statement.py +36 -30
  126. sqlspec/core/type_conversion.py +235 -0
  127. sqlspec/driver/__init__.py +7 -6
  128. sqlspec/driver/_async.py +188 -151
  129. sqlspec/driver/_common.py +285 -80
  130. sqlspec/driver/_sync.py +188 -152
  131. sqlspec/driver/mixins/_result_tools.py +20 -236
  132. sqlspec/driver/mixins/_sql_translator.py +4 -4
  133. sqlspec/exceptions.py +75 -7
  134. sqlspec/extensions/adk/__init__.py +53 -0
  135. sqlspec/extensions/adk/_types.py +51 -0
  136. sqlspec/extensions/adk/converters.py +172 -0
  137. sqlspec/extensions/adk/migrations/0001_create_adk_tables.py +144 -0
  138. sqlspec/extensions/adk/migrations/__init__.py +0 -0
  139. sqlspec/extensions/adk/service.py +181 -0
  140. sqlspec/extensions/adk/store.py +536 -0
  141. sqlspec/extensions/aiosql/adapter.py +73 -53
  142. sqlspec/extensions/litestar/__init__.py +21 -4
  143. sqlspec/extensions/litestar/cli.py +54 -10
  144. sqlspec/extensions/litestar/config.py +59 -266
  145. sqlspec/extensions/litestar/handlers.py +46 -17
  146. sqlspec/extensions/litestar/migrations/0001_create_session_table.py +137 -0
  147. sqlspec/extensions/litestar/migrations/__init__.py +3 -0
  148. sqlspec/extensions/litestar/plugin.py +324 -223
  149. sqlspec/extensions/litestar/providers.py +25 -25
  150. sqlspec/extensions/litestar/store.py +265 -0
  151. sqlspec/loader.py +30 -49
  152. sqlspec/migrations/__init__.py +4 -3
  153. sqlspec/migrations/base.py +302 -39
  154. sqlspec/migrations/commands.py +611 -144
  155. sqlspec/migrations/context.py +142 -0
  156. sqlspec/migrations/fix.py +199 -0
  157. sqlspec/migrations/loaders.py +68 -23
  158. sqlspec/migrations/runner.py +543 -107
  159. sqlspec/migrations/tracker.py +237 -21
  160. sqlspec/migrations/utils.py +51 -3
  161. sqlspec/migrations/validation.py +177 -0
  162. sqlspec/protocols.py +66 -36
  163. sqlspec/storage/_utils.py +98 -0
  164. sqlspec/storage/backends/fsspec.py +134 -106
  165. sqlspec/storage/backends/local.py +78 -51
  166. sqlspec/storage/backends/obstore.py +278 -162
  167. sqlspec/storage/registry.py +75 -39
  168. sqlspec/typing.py +16 -84
  169. sqlspec/utils/config_resolver.py +153 -0
  170. sqlspec/utils/correlation.py +4 -5
  171. sqlspec/utils/data_transformation.py +3 -2
  172. sqlspec/utils/deprecation.py +9 -8
  173. sqlspec/utils/fixtures.py +4 -4
  174. sqlspec/utils/logging.py +46 -6
  175. sqlspec/utils/module_loader.py +2 -2
  176. sqlspec/utils/schema.py +288 -0
  177. sqlspec/utils/serializers.py +50 -2
  178. sqlspec/utils/sync_tools.py +21 -17
  179. sqlspec/utils/text.py +1 -2
  180. sqlspec/utils/type_guards.py +111 -20
  181. sqlspec/utils/version.py +433 -0
  182. {sqlspec-0.25.0.dist-info → sqlspec-0.27.0.dist-info}/METADATA +40 -21
  183. sqlspec-0.27.0.dist-info/RECORD +207 -0
  184. sqlspec/builder/mixins/__init__.py +0 -55
  185. sqlspec/builder/mixins/_cte_and_set_ops.py +0 -254
  186. sqlspec/builder/mixins/_delete_operations.py +0 -50
  187. sqlspec/builder/mixins/_insert_operations.py +0 -282
  188. sqlspec/builder/mixins/_join_operations.py +0 -389
  189. sqlspec/builder/mixins/_merge_operations.py +0 -592
  190. sqlspec/builder/mixins/_order_limit_operations.py +0 -152
  191. sqlspec/builder/mixins/_pivot_operations.py +0 -157
  192. sqlspec/builder/mixins/_select_operations.py +0 -936
  193. sqlspec/builder/mixins/_update_operations.py +0 -218
  194. sqlspec/builder/mixins/_where_clause.py +0 -1304
  195. sqlspec-0.25.0.dist-info/RECORD +0 -139
  196. sqlspec-0.25.0.dist-info/licenses/NOTICE +0 -29
  197. {sqlspec-0.25.0.dist-info → sqlspec-0.27.0.dist-info}/WHEEL +0 -0
  198. {sqlspec-0.25.0.dist-info → sqlspec-0.27.0.dist-info}/entry_points.txt +0 -0
  199. {sqlspec-0.25.0.dist-info → sqlspec-0.27.0.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,142 @@
1
+ """Migration context for passing runtime information to migrations."""
2
+
3
+ import asyncio
4
+ import inspect
5
+ from dataclasses import dataclass, field
6
+ from typing import TYPE_CHECKING, Any
7
+
8
+ from sqlspec.utils.logging import get_logger
9
+
10
+ if TYPE_CHECKING:
11
+ from sqlspec.driver import AsyncDriverAdapterBase, SyncDriverAdapterBase
12
+
13
+ logger = get_logger("migrations.context")
14
+
15
+ __all__ = ("MigrationContext",)
16
+
17
+
18
+ @dataclass
19
+ class MigrationContext:
20
+ """Context object passed to migration functions.
21
+
22
+ Provides runtime information about the database environment
23
+ to migration functions, allowing them to generate dialect-specific SQL.
24
+ """
25
+
26
+ config: "Any | None" = None
27
+ """Database configuration object."""
28
+ dialect: "str | None" = None
29
+ """Database dialect (e.g., 'postgres', 'mysql', 'sqlite')."""
30
+ metadata: "dict[str, Any] | None" = None
31
+ """Additional metadata for the migration."""
32
+ extension_config: "dict[str, Any] | None" = None
33
+ """Extension-specific configuration options."""
34
+
35
+ driver: "SyncDriverAdapterBase | AsyncDriverAdapterBase | None" = None
36
+ """Database driver instance (available during execution)."""
37
+
38
+ _execution_metadata: "dict[str, Any]" = field(default_factory=dict)
39
+ """Internal execution metadata for tracking async operations."""
40
+
41
+ def __post_init__(self) -> None:
42
+ """Initialize metadata and extension config if not provided."""
43
+ if not self.metadata:
44
+ self.metadata = {}
45
+ if not self.extension_config:
46
+ self.extension_config = {}
47
+
48
+ @classmethod
49
+ def from_config(cls, config: Any) -> "MigrationContext":
50
+ """Create context from database configuration.
51
+
52
+ Args:
53
+ config: Database configuration object.
54
+
55
+ Returns:
56
+ Migration context with dialect information.
57
+ """
58
+ dialect = None
59
+ try:
60
+ if hasattr(config, "statement_config") and config.statement_config:
61
+ dialect = getattr(config.statement_config, "dialect", None)
62
+ elif hasattr(config, "_create_statement_config") and callable(config._create_statement_config):
63
+ stmt_config = config._create_statement_config()
64
+ dialect = getattr(stmt_config, "dialect", None)
65
+ except Exception:
66
+ logger.debug("Unable to extract dialect from config")
67
+
68
+ return cls(dialect=dialect, config=config)
69
+
70
+ @property
71
+ def is_async_execution(self) -> bool:
72
+ """Check if migrations are running in an async execution context.
73
+
74
+ Returns:
75
+ True if executing in an async context.
76
+ """
77
+ try:
78
+ asyncio.current_task()
79
+ except RuntimeError:
80
+ return False
81
+ else:
82
+ return True
83
+
84
+ @property
85
+ def is_async_driver(self) -> bool:
86
+ """Check if the current driver is async.
87
+
88
+ Returns:
89
+ True if driver supports async operations.
90
+ """
91
+ if self.driver is None:
92
+ return False
93
+
94
+ execute_method = getattr(self.driver, "execute_script", None)
95
+ return execute_method is not None and inspect.iscoroutinefunction(execute_method)
96
+
97
+ @property
98
+ def execution_mode(self) -> str:
99
+ """Get the current execution mode.
100
+
101
+ Returns:
102
+ 'async' if in async context, 'sync' otherwise.
103
+ """
104
+ return "async" if self.is_async_execution else "sync"
105
+
106
+ def set_execution_metadata(self, key: str, value: Any) -> None:
107
+ """Set execution metadata for tracking migration state.
108
+
109
+ Args:
110
+ key: Metadata key.
111
+ value: Metadata value.
112
+ """
113
+ self._execution_metadata[key] = value
114
+
115
+ def get_execution_metadata(self, key: str, default: Any = None) -> Any:
116
+ """Get execution metadata.
117
+
118
+ Args:
119
+ key: Metadata key.
120
+ default: Default value if key not found.
121
+
122
+ Returns:
123
+ Metadata value or default.
124
+ """
125
+ return self._execution_metadata.get(key, default)
126
+
127
+ def validate_async_usage(self, migration_func: Any) -> None:
128
+ """Validate proper usage of async functions in migration context.
129
+
130
+ Args:
131
+ migration_func: The migration function to validate.
132
+ """
133
+ if inspect.iscoroutinefunction(migration_func) and not self.is_async_execution and not self.is_async_driver:
134
+ msg = (
135
+ "Async migration function detected but execution context is sync. "
136
+ "Consider using async database configuration or sync migration functions."
137
+ )
138
+ logger.warning(msg)
139
+
140
+ if not inspect.iscoroutinefunction(migration_func) and self.is_async_driver:
141
+ self.set_execution_metadata("mixed_execution", value=True)
142
+ logger.debug("Sync migration function in async driver context - using compatibility mode")
@@ -0,0 +1,199 @@
1
+ """Migration file fix operations for converting timestamp to sequential versions.
2
+
3
+ This module provides utilities to convert timestamp-format migration files to
4
+ sequential format, supporting the hybrid versioning workflow where development
5
+ uses timestamps and production uses sequential numbers.
6
+ """
7
+
8
+ import logging
9
+ import re
10
+ import shutil
11
+ from dataclasses import dataclass
12
+ from datetime import datetime, timezone
13
+ from pathlib import Path
14
+
15
+ __all__ = ("MigrationFixer", "MigrationRename")
16
+
17
+ logger = logging.getLogger(__name__)
18
+
19
+
20
+ @dataclass
21
+ class MigrationRename:
22
+ """Represents a planned migration file rename operation.
23
+
24
+ Attributes:
25
+ old_path: Current file path.
26
+ new_path: Target file path after rename.
27
+ old_version: Current version string.
28
+ new_version: Target version string.
29
+ needs_content_update: Whether file content needs updating.
30
+ True for SQL files that contain query names.
31
+ """
32
+
33
+ old_path: Path
34
+ new_path: Path
35
+ old_version: str
36
+ new_version: str
37
+ needs_content_update: bool
38
+
39
+
40
+ class MigrationFixer:
41
+ """Handles atomic migration file conversion operations.
42
+
43
+ Provides backup/rollback functionality and manages conversion from
44
+ timestamp-based migration files to sequential format.
45
+ """
46
+
47
+ def __init__(self, migrations_path: Path) -> None:
48
+ """Initialize migration fixer.
49
+
50
+ Args:
51
+ migrations_path: Path to migrations directory.
52
+ """
53
+ self.migrations_path = migrations_path
54
+ self.backup_path: Path | None = None
55
+
56
+ def plan_renames(self, conversion_map: dict[str, str]) -> list[MigrationRename]:
57
+ """Plan all file rename operations from conversion map.
58
+
59
+ Scans migration directory and builds list of MigrationRename objects
60
+ for all files that need conversion. Validates no target collisions.
61
+
62
+ Args:
63
+ conversion_map: Dictionary mapping old versions to new versions.
64
+
65
+ Returns:
66
+ List of planned rename operations.
67
+
68
+ Raises:
69
+ ValueError: If target file already exists or collision detected.
70
+ """
71
+ if not conversion_map:
72
+ return []
73
+
74
+ renames: list[MigrationRename] = []
75
+
76
+ for old_version, new_version in conversion_map.items():
77
+ matching_files = list(self.migrations_path.glob(f"{old_version}_*"))
78
+
79
+ for old_path in matching_files:
80
+ suffix = old_path.suffix
81
+ description = old_path.stem.replace(f"{old_version}_", "")
82
+
83
+ new_filename = f"{new_version}_{description}{suffix}"
84
+ new_path = self.migrations_path / new_filename
85
+
86
+ if new_path.exists() and new_path != old_path:
87
+ msg = f"Target file already exists: {new_path}"
88
+ raise ValueError(msg)
89
+
90
+ needs_content_update = suffix == ".sql"
91
+
92
+ renames.append(
93
+ MigrationRename(
94
+ old_path=old_path,
95
+ new_path=new_path,
96
+ old_version=old_version,
97
+ new_version=new_version,
98
+ needs_content_update=needs_content_update,
99
+ )
100
+ )
101
+
102
+ return renames
103
+
104
+ def create_backup(self) -> Path:
105
+ """Create timestamped backup directory with all migration files.
106
+
107
+ Returns:
108
+ Path to created backup directory.
109
+
110
+ """
111
+ timestamp = datetime.now(tz=timezone.utc).strftime("%Y%m%d_%H%M%S")
112
+ backup_dir = self.migrations_path / f".backup_{timestamp}"
113
+
114
+ backup_dir.mkdir(parents=True, exist_ok=False)
115
+
116
+ for file_path in self.migrations_path.iterdir():
117
+ if file_path.is_file() and not file_path.name.startswith("."):
118
+ shutil.copy2(file_path, backup_dir / file_path.name)
119
+
120
+ self.backup_path = backup_dir
121
+ return backup_dir
122
+
123
+ def apply_renames(self, renames: "list[MigrationRename]", dry_run: bool = False) -> None:
124
+ """Execute planned rename operations.
125
+
126
+ Args:
127
+ renames: List of planned rename operations.
128
+ dry_run: If True, log operations without executing.
129
+
130
+ """
131
+ if not renames:
132
+ return
133
+
134
+ for rename in renames:
135
+ if dry_run:
136
+ continue
137
+
138
+ if rename.needs_content_update:
139
+ self.update_file_content(rename.old_path, rename.old_version, rename.new_version)
140
+
141
+ rename.old_path.rename(rename.new_path)
142
+
143
+ def update_file_content(self, file_path: Path, old_version: str, new_version: str) -> None:
144
+ """Update SQL query names and version comments in file content.
145
+
146
+ Transforms query names and version metadata from old version to new version:
147
+ -- name: migrate-{old_version}-up → -- name: migrate-{new_version}-up
148
+ -- name: migrate-{old_version}-down → -- name: migrate-{new_version}-down
149
+ -- Version: {old_version} → -- Version: {new_version}
150
+
151
+ Creates version-specific regex patterns to avoid unintended replacements
152
+ of other migrate-* patterns in the file.
153
+
154
+ Args:
155
+ file_path: Path to file to update.
156
+ old_version: Old version string.
157
+ new_version: New version string.
158
+
159
+ """
160
+ content = file_path.read_text(encoding="utf-8")
161
+
162
+ up_pattern = re.compile(rf"(-- name:\s+migrate-){re.escape(old_version)}(-up)")
163
+ down_pattern = re.compile(rf"(-- name:\s+migrate-){re.escape(old_version)}(-down)")
164
+ version_pattern = re.compile(rf"(-- Version:\s+){re.escape(old_version)}")
165
+
166
+ content = up_pattern.sub(rf"\g<1>{new_version}\g<2>", content)
167
+ content = down_pattern.sub(rf"\g<1>{new_version}\g<2>", content)
168
+ content = version_pattern.sub(rf"\g<1>{new_version}", content)
169
+
170
+ file_path.write_text(content, encoding="utf-8")
171
+ logger.debug("Updated content in %s", file_path.name)
172
+
173
+ def rollback(self) -> None:
174
+ """Restore migration files from backup.
175
+
176
+ Deletes current migration files and restores from backup directory.
177
+ Only restores if backup exists.
178
+ """
179
+ if not self.backup_path or not self.backup_path.exists():
180
+ return
181
+
182
+ for file_path in self.migrations_path.iterdir():
183
+ if file_path.is_file() and not file_path.name.startswith("."):
184
+ file_path.unlink()
185
+
186
+ for backup_file in self.backup_path.iterdir():
187
+ if backup_file.is_file():
188
+ shutil.copy2(backup_file, self.migrations_path / backup_file.name)
189
+
190
+ def cleanup(self) -> None:
191
+ """Remove backup directory after successful conversion.
192
+
193
+ Only removes backup if it exists. Logs warning if no backup found.
194
+ """
195
+ if not self.backup_path or not self.backup_path.exists():
196
+ return
197
+
198
+ shutil.rmtree(self.backup_path)
199
+ self.backup_path = None
@@ -10,7 +10,7 @@ import types
10
10
  from collections.abc import Iterator
11
11
  from contextlib import contextmanager
12
12
  from pathlib import Path
13
- from typing import Any, Final, Optional
13
+ from typing import Any, Final
14
14
 
15
15
  from sqlspec.loader import SQLFileLoader as CoreSQLFileLoader
16
16
 
@@ -77,13 +77,22 @@ class SQLFileLoader(BaseMigrationLoader):
77
77
 
78
78
  __slots__ = ("sql_loader",)
79
79
 
80
- def __init__(self) -> None:
81
- """Initialize SQL file loader."""
82
- self.sql_loader: CoreSQLFileLoader = CoreSQLFileLoader()
80
+ def __init__(self, sql_loader: "CoreSQLFileLoader | None" = None) -> None:
81
+ """Initialize SQL file loader.
82
+
83
+ Args:
84
+ sql_loader: Optional shared SQLFileLoader instance to reuse.
85
+ If not provided, creates a new instance.
86
+ """
87
+ self.sql_loader: CoreSQLFileLoader = sql_loader if sql_loader is not None else CoreSQLFileLoader()
83
88
 
84
89
  async def get_up_sql(self, path: Path) -> list[str]:
85
90
  """Extract the 'up' SQL from a SQL migration file.
86
91
 
92
+ The SQL file must already be loaded via validate_migration_file()
93
+ before calling this method. This design ensures the file is loaded
94
+ exactly once during the migration process.
95
+
87
96
  Args:
88
97
  path: Path to SQL migration file.
89
98
 
@@ -93,9 +102,6 @@ class SQLFileLoader(BaseMigrationLoader):
93
102
  Raises:
94
103
  MigrationLoadError: If migration file is invalid or missing up query.
95
104
  """
96
- self.sql_loader.clear_cache()
97
- self.sql_loader.load_sql(path)
98
-
99
105
  version = self._extract_version(path.name)
100
106
  up_query = f"migrate-{version}-up"
101
107
 
@@ -109,15 +115,16 @@ class SQLFileLoader(BaseMigrationLoader):
109
115
  async def get_down_sql(self, path: Path) -> list[str]:
110
116
  """Extract the 'down' SQL from a SQL migration file.
111
117
 
118
+ The SQL file must already be loaded via validate_migration_file()
119
+ before calling this method. This design ensures the file is loaded
120
+ exactly once during the migration process.
121
+
112
122
  Args:
113
123
  path: Path to SQL migration file.
114
124
 
115
125
  Returns:
116
126
  List containing single SQL statement for downgrade, or empty list.
117
127
  """
118
- self.sql_loader.clear_cache()
119
- self.sql_loader.load_sql(path)
120
-
121
128
  version = self._extract_version(path.name)
122
129
  down_query = f"migrate-{version}-down"
123
130
 
@@ -141,7 +148,6 @@ class SQLFileLoader(BaseMigrationLoader):
141
148
  msg = f"Invalid migration filename: {path.name}"
142
149
  raise MigrationLoadError(msg)
143
150
 
144
- self.sql_loader.clear_cache()
145
151
  self.sql_loader.load_sql(path)
146
152
  up_query = f"migrate-{version}-up"
147
153
  if not self.sql_loader.has_query(up_query):
@@ -151,30 +157,49 @@ class SQLFileLoader(BaseMigrationLoader):
151
157
  def _extract_version(self, filename: str) -> str:
152
158
  """Extract version from filename.
153
159
 
160
+ Supports sequential (0001), timestamp (20251011120000), and extension-prefixed
161
+ (ext_litestar_0001) version formats.
162
+
154
163
  Args:
155
164
  filename: Migration filename to parse.
156
165
 
157
166
  Returns:
158
- Zero-padded version string or empty string if invalid.
167
+ Version string or empty string if invalid.
159
168
  """
160
- parts = filename.split("_", 1)
161
- return parts[0].zfill(4) if parts and parts[0].isdigit() else ""
169
+ extension_version_parts = 3
170
+ timestamp_min_length = 4
171
+
172
+ name_without_ext = filename.rsplit(".", 1)[0]
173
+
174
+ if name_without_ext.startswith("ext_"):
175
+ parts = name_without_ext.split("_", 3)
176
+ if len(parts) >= extension_version_parts:
177
+ return f"{parts[0]}_{parts[1]}_{parts[2]}"
178
+ return ""
179
+
180
+ parts = name_without_ext.split("_", 1)
181
+ if parts and parts[0].isdigit():
182
+ return parts[0] if len(parts[0]) > timestamp_min_length else parts[0].zfill(4)
183
+
184
+ return ""
162
185
 
163
186
 
164
187
  class PythonFileLoader(BaseMigrationLoader):
165
188
  """Loader for Python migration files."""
166
189
 
167
- __slots__ = ("migrations_dir", "project_root")
190
+ __slots__ = ("context", "migrations_dir", "project_root")
168
191
 
169
- def __init__(self, migrations_dir: Path, project_root: "Optional[Path]" = None) -> None:
192
+ def __init__(self, migrations_dir: Path, project_root: "Path | None" = None, context: "Any | None" = None) -> None:
170
193
  """Initialize Python file loader.
171
194
 
172
195
  Args:
173
196
  migrations_dir: Directory containing migration files.
174
197
  project_root: Optional project root directory for imports.
198
+ context: Optional migration context to pass to functions.
175
199
  """
176
200
  self.migrations_dir = migrations_dir
177
201
  self.project_root = project_root if project_root is not None else self._find_project_root(migrations_dir)
202
+ self.context = context
178
203
 
179
204
  async def get_up_sql(self, path: Path) -> list[str]:
180
205
  """Load Python migration and execute upgrade function.
@@ -208,10 +233,16 @@ class PythonFileLoader(BaseMigrationLoader):
208
233
  msg = f"'{func_name}' is not callable in {path}"
209
234
  raise MigrationLoadError(msg)
210
235
 
236
+ # Check if function accepts context parameter
237
+ sig = inspect.signature(upgrade_func)
238
+ accepts_context = "context" in sig.parameters or len(sig.parameters) > 0
239
+
211
240
  if inspect.iscoroutinefunction(upgrade_func):
212
- sql_result = await upgrade_func()
241
+ sql_result = (
242
+ await upgrade_func(self.context) if accepts_context and self.context else await upgrade_func()
243
+ )
213
244
  else:
214
- sql_result = upgrade_func()
245
+ sql_result = upgrade_func(self.context) if accepts_context and self.context else upgrade_func()
215
246
 
216
247
  return self._normalize_and_validate_sql(sql_result, path)
217
248
 
@@ -239,10 +270,16 @@ class PythonFileLoader(BaseMigrationLoader):
239
270
  if not callable(downgrade_func):
240
271
  return []
241
272
 
273
+ # Check if function accepts context parameter
274
+ sig = inspect.signature(downgrade_func)
275
+ accepts_context = "context" in sig.parameters or len(sig.parameters) > 0
276
+
242
277
  if inspect.iscoroutinefunction(downgrade_func):
243
- sql_result = await downgrade_func()
278
+ sql_result = (
279
+ await downgrade_func(self.context) if accepts_context and self.context else await downgrade_func()
280
+ )
244
281
  else:
245
- sql_result = downgrade_func()
282
+ sql_result = downgrade_func(self.context) if accepts_context and self.context else downgrade_func()
246
283
 
247
284
  return self._normalize_and_validate_sql(sql_result, path)
248
285
 
@@ -380,7 +417,11 @@ class PythonFileLoader(BaseMigrationLoader):
380
417
 
381
418
 
382
419
  def get_migration_loader(
383
- file_path: Path, migrations_dir: Path, project_root: "Optional[Path]" = None
420
+ file_path: Path,
421
+ migrations_dir: Path,
422
+ project_root: "Path | None" = None,
423
+ context: "Any | None" = None,
424
+ sql_loader: "CoreSQLFileLoader | None" = None,
384
425
  ) -> BaseMigrationLoader:
385
426
  """Factory function to get appropriate loader for migration file.
386
427
 
@@ -388,6 +429,10 @@ def get_migration_loader(
388
429
  file_path: Path to the migration file.
389
430
  migrations_dir: Directory containing migration files.
390
431
  project_root: Optional project root directory for Python imports.
432
+ context: Optional migration context to pass to Python migrations.
433
+ sql_loader: Optional shared SQLFileLoader instance for SQL migrations.
434
+ When provided, SQL files are loaded using this shared instance,
435
+ avoiding redundant file parsing.
391
436
 
392
437
  Returns:
393
438
  Appropriate loader instance for the file type.
@@ -398,8 +443,8 @@ def get_migration_loader(
398
443
  suffix = file_path.suffix
399
444
 
400
445
  if suffix == ".py":
401
- return PythonFileLoader(migrations_dir, project_root)
446
+ return PythonFileLoader(migrations_dir, project_root, context)
402
447
  if suffix == ".sql":
403
- return SQLFileLoader()
448
+ return SQLFileLoader(sql_loader)
404
449
  msg = f"Unsupported migration file type: {suffix}"
405
450
  raise MigrationLoadError(msg)