sqlspec 0.14.1__py3-none-any.whl → 0.15.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of sqlspec might be problematic. Click here for more details.

Files changed (158) hide show
  1. sqlspec/__init__.py +50 -25
  2. sqlspec/__main__.py +1 -1
  3. sqlspec/__metadata__.py +1 -3
  4. sqlspec/_serialization.py +1 -2
  5. sqlspec/_sql.py +256 -120
  6. sqlspec/_typing.py +278 -142
  7. sqlspec/adapters/adbc/__init__.py +4 -3
  8. sqlspec/adapters/adbc/_types.py +12 -0
  9. sqlspec/adapters/adbc/config.py +115 -260
  10. sqlspec/adapters/adbc/driver.py +462 -367
  11. sqlspec/adapters/aiosqlite/__init__.py +18 -3
  12. sqlspec/adapters/aiosqlite/_types.py +13 -0
  13. sqlspec/adapters/aiosqlite/config.py +199 -129
  14. sqlspec/adapters/aiosqlite/driver.py +230 -269
  15. sqlspec/adapters/asyncmy/__init__.py +18 -3
  16. sqlspec/adapters/asyncmy/_types.py +12 -0
  17. sqlspec/adapters/asyncmy/config.py +80 -168
  18. sqlspec/adapters/asyncmy/driver.py +260 -225
  19. sqlspec/adapters/asyncpg/__init__.py +19 -4
  20. sqlspec/adapters/asyncpg/_types.py +17 -0
  21. sqlspec/adapters/asyncpg/config.py +82 -181
  22. sqlspec/adapters/asyncpg/driver.py +285 -383
  23. sqlspec/adapters/bigquery/__init__.py +17 -3
  24. sqlspec/adapters/bigquery/_types.py +12 -0
  25. sqlspec/adapters/bigquery/config.py +191 -258
  26. sqlspec/adapters/bigquery/driver.py +474 -646
  27. sqlspec/adapters/duckdb/__init__.py +14 -3
  28. sqlspec/adapters/duckdb/_types.py +12 -0
  29. sqlspec/adapters/duckdb/config.py +415 -351
  30. sqlspec/adapters/duckdb/driver.py +343 -413
  31. sqlspec/adapters/oracledb/__init__.py +19 -5
  32. sqlspec/adapters/oracledb/_types.py +14 -0
  33. sqlspec/adapters/oracledb/config.py +123 -379
  34. sqlspec/adapters/oracledb/driver.py +507 -560
  35. sqlspec/adapters/psqlpy/__init__.py +13 -3
  36. sqlspec/adapters/psqlpy/_types.py +11 -0
  37. sqlspec/adapters/psqlpy/config.py +93 -254
  38. sqlspec/adapters/psqlpy/driver.py +505 -234
  39. sqlspec/adapters/psycopg/__init__.py +19 -5
  40. sqlspec/adapters/psycopg/_types.py +17 -0
  41. sqlspec/adapters/psycopg/config.py +143 -403
  42. sqlspec/adapters/psycopg/driver.py +706 -872
  43. sqlspec/adapters/sqlite/__init__.py +14 -3
  44. sqlspec/adapters/sqlite/_types.py +11 -0
  45. sqlspec/adapters/sqlite/config.py +202 -118
  46. sqlspec/adapters/sqlite/driver.py +264 -303
  47. sqlspec/base.py +105 -9
  48. sqlspec/{statement/builder → builder}/__init__.py +12 -14
  49. sqlspec/{statement/builder → builder}/_base.py +120 -55
  50. sqlspec/{statement/builder → builder}/_column.py +17 -6
  51. sqlspec/{statement/builder → builder}/_ddl.py +46 -79
  52. sqlspec/{statement/builder → builder}/_ddl_utils.py +5 -10
  53. sqlspec/{statement/builder → builder}/_delete.py +6 -25
  54. sqlspec/{statement/builder → builder}/_insert.py +6 -64
  55. sqlspec/builder/_merge.py +56 -0
  56. sqlspec/{statement/builder → builder}/_parsing_utils.py +3 -10
  57. sqlspec/{statement/builder → builder}/_select.py +11 -56
  58. sqlspec/{statement/builder → builder}/_update.py +12 -18
  59. sqlspec/{statement/builder → builder}/mixins/__init__.py +10 -14
  60. sqlspec/{statement/builder → builder}/mixins/_cte_and_set_ops.py +48 -59
  61. sqlspec/{statement/builder → builder}/mixins/_insert_operations.py +22 -16
  62. sqlspec/{statement/builder → builder}/mixins/_join_operations.py +1 -3
  63. sqlspec/{statement/builder → builder}/mixins/_merge_operations.py +3 -5
  64. sqlspec/{statement/builder → builder}/mixins/_order_limit_operations.py +3 -3
  65. sqlspec/{statement/builder → builder}/mixins/_pivot_operations.py +4 -8
  66. sqlspec/{statement/builder → builder}/mixins/_select_operations.py +21 -36
  67. sqlspec/{statement/builder → builder}/mixins/_update_operations.py +3 -14
  68. sqlspec/{statement/builder → builder}/mixins/_where_clause.py +52 -79
  69. sqlspec/cli.py +4 -5
  70. sqlspec/config.py +180 -133
  71. sqlspec/core/__init__.py +63 -0
  72. sqlspec/core/cache.py +873 -0
  73. sqlspec/core/compiler.py +396 -0
  74. sqlspec/core/filters.py +828 -0
  75. sqlspec/core/hashing.py +310 -0
  76. sqlspec/core/parameters.py +1209 -0
  77. sqlspec/core/result.py +664 -0
  78. sqlspec/{statement → core}/splitter.py +321 -191
  79. sqlspec/core/statement.py +651 -0
  80. sqlspec/driver/__init__.py +7 -10
  81. sqlspec/driver/_async.py +387 -176
  82. sqlspec/driver/_common.py +527 -289
  83. sqlspec/driver/_sync.py +390 -172
  84. sqlspec/driver/mixins/__init__.py +2 -19
  85. sqlspec/driver/mixins/_result_tools.py +168 -0
  86. sqlspec/driver/mixins/_sql_translator.py +6 -3
  87. sqlspec/exceptions.py +5 -252
  88. sqlspec/extensions/aiosql/adapter.py +93 -96
  89. sqlspec/extensions/litestar/config.py +0 -1
  90. sqlspec/extensions/litestar/handlers.py +15 -26
  91. sqlspec/extensions/litestar/plugin.py +16 -14
  92. sqlspec/extensions/litestar/providers.py +17 -52
  93. sqlspec/loader.py +424 -105
  94. sqlspec/migrations/__init__.py +12 -0
  95. sqlspec/migrations/base.py +92 -68
  96. sqlspec/migrations/commands.py +24 -106
  97. sqlspec/migrations/loaders.py +402 -0
  98. sqlspec/migrations/runner.py +49 -51
  99. sqlspec/migrations/tracker.py +31 -44
  100. sqlspec/migrations/utils.py +64 -24
  101. sqlspec/protocols.py +7 -183
  102. sqlspec/storage/__init__.py +1 -1
  103. sqlspec/storage/backends/base.py +37 -40
  104. sqlspec/storage/backends/fsspec.py +136 -112
  105. sqlspec/storage/backends/obstore.py +138 -160
  106. sqlspec/storage/capabilities.py +5 -4
  107. sqlspec/storage/registry.py +57 -106
  108. sqlspec/typing.py +136 -115
  109. sqlspec/utils/__init__.py +2 -3
  110. sqlspec/utils/correlation.py +0 -3
  111. sqlspec/utils/deprecation.py +6 -6
  112. sqlspec/utils/fixtures.py +6 -6
  113. sqlspec/utils/logging.py +0 -2
  114. sqlspec/utils/module_loader.py +7 -12
  115. sqlspec/utils/singleton.py +0 -1
  116. sqlspec/utils/sync_tools.py +16 -37
  117. sqlspec/utils/text.py +12 -51
  118. sqlspec/utils/type_guards.py +443 -232
  119. {sqlspec-0.14.1.dist-info → sqlspec-0.15.0.dist-info}/METADATA +7 -2
  120. sqlspec-0.15.0.dist-info/RECORD +134 -0
  121. sqlspec/adapters/adbc/transformers.py +0 -108
  122. sqlspec/driver/connection.py +0 -207
  123. sqlspec/driver/mixins/_cache.py +0 -114
  124. sqlspec/driver/mixins/_csv_writer.py +0 -91
  125. sqlspec/driver/mixins/_pipeline.py +0 -508
  126. sqlspec/driver/mixins/_query_tools.py +0 -796
  127. sqlspec/driver/mixins/_result_utils.py +0 -138
  128. sqlspec/driver/mixins/_storage.py +0 -912
  129. sqlspec/driver/mixins/_type_coercion.py +0 -128
  130. sqlspec/driver/parameters.py +0 -138
  131. sqlspec/statement/__init__.py +0 -21
  132. sqlspec/statement/builder/_merge.py +0 -95
  133. sqlspec/statement/cache.py +0 -50
  134. sqlspec/statement/filters.py +0 -625
  135. sqlspec/statement/parameters.py +0 -956
  136. sqlspec/statement/pipelines/__init__.py +0 -210
  137. sqlspec/statement/pipelines/analyzers/__init__.py +0 -9
  138. sqlspec/statement/pipelines/analyzers/_analyzer.py +0 -646
  139. sqlspec/statement/pipelines/context.py +0 -109
  140. sqlspec/statement/pipelines/transformers/__init__.py +0 -7
  141. sqlspec/statement/pipelines/transformers/_expression_simplifier.py +0 -88
  142. sqlspec/statement/pipelines/transformers/_literal_parameterizer.py +0 -1247
  143. sqlspec/statement/pipelines/transformers/_remove_comments_and_hints.py +0 -76
  144. sqlspec/statement/pipelines/validators/__init__.py +0 -23
  145. sqlspec/statement/pipelines/validators/_dml_safety.py +0 -290
  146. sqlspec/statement/pipelines/validators/_parameter_style.py +0 -370
  147. sqlspec/statement/pipelines/validators/_performance.py +0 -714
  148. sqlspec/statement/pipelines/validators/_security.py +0 -967
  149. sqlspec/statement/result.py +0 -435
  150. sqlspec/statement/sql.py +0 -1774
  151. sqlspec/utils/cached_property.py +0 -25
  152. sqlspec/utils/statement_hashing.py +0 -203
  153. sqlspec-0.14.1.dist-info/RECORD +0 -145
  154. /sqlspec/{statement/builder → builder}/mixins/_delete_operations.py +0 -0
  155. {sqlspec-0.14.1.dist-info → sqlspec-0.15.0.dist-info}/WHEEL +0 -0
  156. {sqlspec-0.14.1.dist-info → sqlspec-0.15.0.dist-info}/entry_points.txt +0 -0
  157. {sqlspec-0.14.1.dist-info → sqlspec-0.15.0.dist-info}/licenses/LICENSE +0 -0
  158. {sqlspec-0.14.1.dist-info → sqlspec-0.15.0.dist-info}/licenses/NOTICE +0 -0
@@ -0,0 +1,402 @@
1
+ """Migration loader abstractions for SQLSpec."""
2
+
3
+ import abc
4
+ import inspect
5
+ import sys
6
+ import types
7
+ from collections.abc import Iterator
8
+ from contextlib import contextmanager
9
+ from pathlib import Path
10
+ from typing import Any, Final, Optional
11
+
12
+ from sqlspec.loader import SQLFileLoader as CoreSQLFileLoader
13
+
14
+ __all__ = ("BaseMigrationLoader", "MigrationLoadError", "PythonFileLoader", "SQLFileLoader", "get_migration_loader")
15
+
16
+ PROJECT_ROOT_MARKERS: Final[list[str]] = ["pyproject.toml", ".git", "setup.cfg", "setup.py"]
17
+
18
+
19
+ class MigrationLoadError(Exception):
20
+ """Exception raised when migration loading fails."""
21
+
22
+
23
+ class BaseMigrationLoader(abc.ABC):
24
+ """Abstract base class for migration loaders."""
25
+
26
+ __slots__ = ()
27
+
28
+ @abc.abstractmethod
29
+ async def get_up_sql(self, path: Path) -> list[str]:
30
+ """Load and return the 'up' SQL statements from a migration file.
31
+
32
+ Args:
33
+ path: Path to the migration file.
34
+
35
+ Returns:
36
+ List of SQL statements to execute for upgrade.
37
+
38
+ Raises:
39
+ MigrationLoadError: If loading fails.
40
+ """
41
+ ...
42
+
43
+ @abc.abstractmethod
44
+ async def get_down_sql(self, path: Path) -> list[str]:
45
+ """Load and return the 'down' SQL statements from a migration file.
46
+
47
+ Args:
48
+ path: Path to the migration file.
49
+
50
+ Returns:
51
+ List of SQL statements to execute for downgrade.
52
+ Empty list if no downgrade is available.
53
+
54
+ Raises:
55
+ MigrationLoadError: If loading fails.
56
+ """
57
+ ...
58
+
59
+ @abc.abstractmethod
60
+ def validate_migration_file(self, path: Path) -> None:
61
+ """Validate that the migration file has required components.
62
+
63
+ Args:
64
+ path: Path to the migration file.
65
+
66
+ Raises:
67
+ MigrationLoadError: If validation fails.
68
+ """
69
+ ...
70
+
71
+
72
+ class SQLFileLoader(BaseMigrationLoader):
73
+ """Loader for SQL migration files using SQLFileLoader."""
74
+
75
+ __slots__ = ("sql_loader",)
76
+
77
+ def __init__(self) -> None:
78
+ """Initialize SQL file loader."""
79
+ self.sql_loader: CoreSQLFileLoader = CoreSQLFileLoader()
80
+
81
+ async def get_up_sql(self, path: Path) -> list[str]:
82
+ """Extract the 'up' SQL from a SQL migration file.
83
+
84
+ Args:
85
+ path: Path to SQL migration file.
86
+
87
+ Returns:
88
+ List containing single SQL statement for upgrade.
89
+
90
+ Raises:
91
+ MigrationLoadError: If migration file is invalid or missing up query.
92
+ """
93
+ self.sql_loader.clear_cache()
94
+ self.sql_loader.load_sql(path)
95
+
96
+ version = self._extract_version(path.name)
97
+ up_query = f"migrate-{version}-up"
98
+
99
+ if not self.sql_loader.has_query(up_query):
100
+ msg = f"Migration {path} missing 'up' query: {up_query}"
101
+ raise MigrationLoadError(msg)
102
+
103
+ sql_obj = self.sql_loader.get_sql(up_query)
104
+ return [sql_obj.sql]
105
+
106
+ async def get_down_sql(self, path: Path) -> list[str]:
107
+ """Extract the 'down' SQL from a SQL migration file.
108
+
109
+ Args:
110
+ path: Path to SQL migration file.
111
+
112
+ Returns:
113
+ List containing single SQL statement for downgrade, or empty list.
114
+ """
115
+ self.sql_loader.clear_cache()
116
+ self.sql_loader.load_sql(path)
117
+
118
+ version = self._extract_version(path.name)
119
+ down_query = f"migrate-{version}-down"
120
+
121
+ if not self.sql_loader.has_query(down_query):
122
+ return []
123
+
124
+ sql_obj = self.sql_loader.get_sql(down_query)
125
+ return [sql_obj.sql]
126
+
127
+ def validate_migration_file(self, path: Path) -> None:
128
+ """Validate SQL migration file has required up query.
129
+
130
+ Args:
131
+ path: Path to SQL migration file.
132
+
133
+ Raises:
134
+ MigrationLoadError: If file is invalid or missing required query.
135
+ """
136
+ version = self._extract_version(path.name)
137
+ if not version:
138
+ msg = f"Invalid migration filename: {path.name}"
139
+ raise MigrationLoadError(msg)
140
+
141
+ self.sql_loader.clear_cache()
142
+ self.sql_loader.load_sql(path)
143
+ up_query = f"migrate-{version}-up"
144
+ if not self.sql_loader.has_query(up_query):
145
+ msg = f"Migration {path} missing required 'up' query: {up_query}"
146
+ raise MigrationLoadError(msg)
147
+
148
+ def _extract_version(self, filename: str) -> str:
149
+ """Extract version from filename.
150
+
151
+ Args:
152
+ filename: Migration filename to parse.
153
+
154
+ Returns:
155
+ Zero-padded version string or empty string if invalid.
156
+ """
157
+ parts = filename.split("_", 1)
158
+ return parts[0].zfill(4) if parts and parts[0].isdigit() else ""
159
+
160
+
161
+ class PythonFileLoader(BaseMigrationLoader):
162
+ """Loader for Python migration files."""
163
+
164
+ __slots__ = ("migrations_dir", "project_root")
165
+
166
+ def __init__(self, migrations_dir: Path, project_root: "Optional[Path]" = None) -> None:
167
+ """Initialize Python file loader.
168
+
169
+ Args:
170
+ migrations_dir: Directory containing migration files.
171
+ project_root: Optional project root directory for imports.
172
+ """
173
+ self.migrations_dir = migrations_dir
174
+ self.project_root = project_root if project_root is not None else self._find_project_root(migrations_dir)
175
+
176
+ async def get_up_sql(self, path: Path) -> list[str]:
177
+ """Load Python migration and execute upgrade function.
178
+
179
+ Args:
180
+ path: Path to Python migration file.
181
+
182
+ Returns:
183
+ List of SQL statements for upgrade.
184
+
185
+ Raises:
186
+ MigrationLoadError: If function is missing or execution fails.
187
+ """
188
+ with self._temporary_project_path():
189
+ module = self._load_module_from_path(path)
190
+
191
+ upgrade_func = None
192
+ func_name = None
193
+
194
+ if hasattr(module, "up") and callable(module.up):
195
+ upgrade_func = module.up
196
+ func_name = "up"
197
+ elif hasattr(module, "migrate_up") and callable(module.migrate_up):
198
+ upgrade_func = module.migrate_up
199
+ func_name = "migrate_up"
200
+ else:
201
+ msg = f"No upgrade function found in {path}. Expected 'up()' or 'migrate_up()'"
202
+ raise MigrationLoadError(msg)
203
+
204
+ if not callable(upgrade_func):
205
+ msg = f"'{func_name}' is not callable in {path}"
206
+ raise MigrationLoadError(msg)
207
+
208
+ if inspect.iscoroutinefunction(upgrade_func):
209
+ sql_result = await upgrade_func()
210
+ else:
211
+ sql_result = upgrade_func()
212
+
213
+ return self._normalize_and_validate_sql(sql_result, path)
214
+
215
+ async def get_down_sql(self, path: Path) -> list[str]:
216
+ """Load Python migration and execute downgrade function.
217
+
218
+ Args:
219
+ path: Path to Python migration file.
220
+
221
+ Returns:
222
+ List of SQL statements for downgrade, or empty list if not available.
223
+ """
224
+ with self._temporary_project_path():
225
+ module = self._load_module_from_path(path)
226
+
227
+ downgrade_func = None
228
+
229
+ if hasattr(module, "down") and callable(module.down):
230
+ downgrade_func = module.down
231
+ elif hasattr(module, "migrate_down") and callable(module.migrate_down):
232
+ downgrade_func = module.migrate_down
233
+ else:
234
+ return []
235
+
236
+ if not callable(downgrade_func):
237
+ return []
238
+
239
+ if inspect.iscoroutinefunction(downgrade_func):
240
+ sql_result = await downgrade_func()
241
+ else:
242
+ sql_result = downgrade_func()
243
+
244
+ return self._normalize_and_validate_sql(sql_result, path)
245
+
246
+ def validate_migration_file(self, path: Path) -> None:
247
+ """Validate Python migration file has required upgrade function.
248
+
249
+ Args:
250
+ path: Path to Python migration file.
251
+
252
+ Raises:
253
+ MigrationLoadError: If validation fails.
254
+ """
255
+ with self._temporary_project_path():
256
+ module = self._load_module_from_path(path)
257
+
258
+ upgrade_func = None
259
+ func_name = None
260
+
261
+ if hasattr(module, "up") and callable(module.up):
262
+ upgrade_func = module.up
263
+ func_name = "up"
264
+ elif hasattr(module, "migrate_up") and callable(module.migrate_up):
265
+ upgrade_func = module.migrate_up
266
+ func_name = "migrate_up"
267
+ else:
268
+ msg = f"Migration {path} missing required upgrade function. Expected 'up()' or 'migrate_up()'"
269
+ raise MigrationLoadError(msg)
270
+
271
+ if not callable(upgrade_func):
272
+ msg = f"Migration {path} '{func_name}' is not callable"
273
+ raise MigrationLoadError(msg)
274
+
275
+ def _find_project_root(self, start_path: Path) -> Path:
276
+ """Find project root by searching upwards for marker files.
277
+
278
+ Args:
279
+ start_path: Directory to start searching from.
280
+
281
+ Returns:
282
+ Path to project root or parent directory.
283
+ """
284
+ current_path = start_path.resolve()
285
+
286
+ while current_path != current_path.parent:
287
+ for marker in PROJECT_ROOT_MARKERS:
288
+ if (current_path / marker).exists():
289
+ return current_path
290
+ current_path = current_path.parent
291
+
292
+ return start_path.resolve().parent
293
+
294
+ @contextmanager
295
+ def _temporary_project_path(self) -> Iterator[None]:
296
+ """Temporarily add project root to sys.path for imports."""
297
+ path_to_add = str(self.project_root)
298
+ if path_to_add in sys.path:
299
+ yield
300
+ return
301
+
302
+ sys.path.insert(0, path_to_add)
303
+ try:
304
+ yield
305
+ finally:
306
+ sys.path.remove(path_to_add)
307
+
308
+ def _load_module_from_path(self, path: Path) -> Any:
309
+ """Load a Python module from file path.
310
+
311
+ Args:
312
+ path: Path to Python migration file.
313
+
314
+ Returns:
315
+ Loaded module object.
316
+
317
+ Raises:
318
+ MigrationLoadError: If module loading fails.
319
+ """
320
+ module_name = f"sqlspec_migration_{path.stem}"
321
+
322
+ if module_name in sys.modules:
323
+ sys.modules.pop(module_name, None)
324
+
325
+ try:
326
+ source_code = path.read_text(encoding="utf-8")
327
+ compiled_code = compile(source_code, str(path), "exec")
328
+
329
+ module = types.ModuleType(module_name)
330
+ module.__file__ = str(path)
331
+
332
+ sys.modules[module_name] = module
333
+
334
+ exec(compiled_code, module.__dict__) # noqa: S102
335
+
336
+ except Exception as e:
337
+ sys.modules.pop(module_name, None)
338
+ msg = f"Failed to execute migration module {path}: {e}"
339
+ raise MigrationLoadError(msg) from e
340
+
341
+ return module
342
+
343
+ def _normalize_and_validate_sql(self, sql: Any, migration_path: Path) -> list[str]:
344
+ """Validate return type and normalize to list of strings.
345
+
346
+ Args:
347
+ sql: Return value from migration function.
348
+ migration_path: Path to migration file for error messages.
349
+
350
+ Returns:
351
+ List of SQL statements.
352
+
353
+ Raises:
354
+ MigrationLoadError: If return type is invalid.
355
+ """
356
+ if isinstance(sql, str):
357
+ stripped = sql.strip()
358
+ return [stripped] if stripped else []
359
+ if isinstance(sql, list):
360
+ result = []
361
+ for i, item in enumerate(sql):
362
+ if not isinstance(item, str):
363
+ msg = (
364
+ f"Migration {migration_path} returned a list containing a non-string "
365
+ f"element at index {i} (type: {type(item).__name__})."
366
+ )
367
+ raise MigrationLoadError(msg)
368
+ stripped_item = item.strip()
369
+ if stripped_item:
370
+ result.append(stripped_item)
371
+ return result
372
+
373
+ msg = (
374
+ f"Migration {migration_path} must return a 'str' or 'List[str]', but returned type '{type(sql).__name__}'."
375
+ )
376
+ raise MigrationLoadError(msg)
377
+
378
+
379
+ def get_migration_loader(
380
+ file_path: Path, migrations_dir: Path, project_root: "Optional[Path]" = None
381
+ ) -> BaseMigrationLoader:
382
+ """Factory function to get appropriate loader for migration file.
383
+
384
+ Args:
385
+ file_path: Path to the migration file.
386
+ migrations_dir: Directory containing migration files.
387
+ project_root: Optional project root directory for Python imports.
388
+
389
+ Returns:
390
+ Appropriate loader instance for the file type.
391
+
392
+ Raises:
393
+ MigrationLoadError: If file type is not supported.
394
+ """
395
+ suffix = file_path.suffix
396
+
397
+ if suffix == ".py":
398
+ return PythonFileLoader(migrations_dir, project_root)
399
+ if suffix == ".sql":
400
+ return SQLFileLoader()
401
+ msg = f"Unsupported migration file type: {suffix}"
402
+ raise MigrationLoadError(msg)
@@ -7,20 +7,22 @@ import time
7
7
  from pathlib import Path
8
8
  from typing import TYPE_CHECKING, Any, Optional
9
9
 
10
+ from sqlspec.core.statement import SQL
10
11
  from sqlspec.migrations.base import BaseMigrationRunner
12
+ from sqlspec.migrations.loaders import get_migration_loader
11
13
  from sqlspec.utils.logging import get_logger
14
+ from sqlspec.utils.sync_tools import run_
12
15
 
13
16
  if TYPE_CHECKING:
14
- from sqlspec.driver import AsyncDriverAdapterProtocol, SyncDriverAdapterProtocol
15
- from sqlspec.statement.sql import SQL
17
+ from sqlspec.driver import AsyncDriverAdapterBase, SyncDriverAdapterBase
16
18
 
17
19
  __all__ = ("AsyncMigrationRunner", "SyncMigrationRunner")
18
20
 
19
21
  logger = get_logger("migrations.runner")
20
22
 
21
23
 
22
- class SyncMigrationRunner(BaseMigrationRunner["SyncDriverAdapterProtocol[Any]"]):
23
- """Sync version - executes migrations using SQLFileLoader."""
24
+ class SyncMigrationRunner(BaseMigrationRunner["SyncDriverAdapterBase"]):
25
+ """Executes migrations using SQLFileLoader."""
24
26
 
25
27
  def get_migration_files(self) -> "list[tuple[str, Path]]":
26
28
  """Get all migration files sorted by version.
@@ -42,7 +44,7 @@ class SyncMigrationRunner(BaseMigrationRunner["SyncDriverAdapterProtocol[Any]"])
42
44
  return self._load_migration_metadata(file_path)
43
45
 
44
46
  def execute_upgrade(
45
- self, driver: "SyncDriverAdapterProtocol[Any]", migration: "dict[str, Any]"
47
+ self, driver: "SyncDriverAdapterBase", migration: "dict[str, Any]"
46
48
  ) -> "tuple[Optional[str], int]":
47
49
  """Execute an upgrade migration.
48
50
 
@@ -58,15 +60,12 @@ class SyncMigrationRunner(BaseMigrationRunner["SyncDriverAdapterProtocol[Any]"])
58
60
  return None, 0
59
61
 
60
62
  start_time = time.time()
61
-
62
- # Execute migration
63
63
  driver.execute(upgrade_sql)
64
-
65
64
  execution_time = int((time.time() - start_time) * 1000)
66
65
  return None, execution_time
67
66
 
68
67
  def execute_downgrade(
69
- self, driver: "SyncDriverAdapterProtocol[Any]", migration: "dict[str, Any]"
68
+ self, driver: "SyncDriverAdapterBase", migration: "dict[str, Any]"
70
69
  ) -> "tuple[Optional[str], int]":
71
70
  """Execute a downgrade migration.
72
71
 
@@ -82,42 +81,44 @@ class SyncMigrationRunner(BaseMigrationRunner["SyncDriverAdapterProtocol[Any]"])
82
81
  return None, 0
83
82
 
84
83
  start_time = time.time()
85
-
86
- # Execute migration
87
84
  driver.execute(downgrade_sql)
88
-
89
85
  execution_time = int((time.time() - start_time) * 1000)
90
86
  return None, execution_time
91
87
 
92
88
  def load_all_migrations(self) -> "dict[str, SQL]":
93
89
  """Load all migrations into a single namespace for bulk operations.
94
90
 
95
- Returns a dictionary mapping query names to SQL objects.
96
- Useful for:
97
- - Migration analysis tools
98
- - Documentation generation
99
- - Validation and linting
100
- - Migration squashing
101
-
102
91
  Returns:
103
92
  Dictionary mapping query names to SQL objects.
104
93
  """
105
94
  all_queries = {}
106
95
  migrations = self.get_migration_files()
107
96
 
108
- for _version, file_path in migrations:
109
- self.loader.load_sql(file_path)
97
+ for version, file_path in migrations:
98
+ if file_path.suffix == ".sql":
99
+ self.loader.load_sql(file_path)
100
+ for query_name in self.loader.list_queries():
101
+ all_queries[query_name] = self.loader.get_sql(query_name)
102
+ else:
103
+ loader = get_migration_loader(file_path, self.migrations_path, self.project_root)
104
+
105
+ try:
106
+ up_sql = run_(loader.get_up_sql)(file_path)
107
+ down_sql = run_(loader.get_down_sql)(file_path)
108
+
109
+ if up_sql:
110
+ all_queries[f"migrate-{version}-up"] = SQL(up_sql[0])
111
+ if down_sql:
112
+ all_queries[f"migrate-{version}-down"] = SQL(down_sql[0])
110
113
 
111
- # Get all queries from this file
112
- for query_name in self.loader.list_queries():
113
- # Store with full query name for uniqueness
114
- all_queries[query_name] = self.loader.get_sql(query_name)
114
+ except Exception as e:
115
+ logger.debug("Failed to load Python migration %s: %s", file_path, e)
115
116
 
116
117
  return all_queries
117
118
 
118
119
 
119
- class AsyncMigrationRunner(BaseMigrationRunner["AsyncDriverAdapterProtocol[Any]"]):
120
- """Async version - executes migrations using SQLFileLoader."""
120
+ class AsyncMigrationRunner(BaseMigrationRunner["AsyncDriverAdapterBase"]):
121
+ """Executes migrations using SQLFileLoader."""
121
122
 
122
123
  async def get_migration_files(self) -> "list[tuple[str, Path]]":
123
124
  """Get all migration files sorted by version.
@@ -125,7 +126,6 @@ class AsyncMigrationRunner(BaseMigrationRunner["AsyncDriverAdapterProtocol[Any]"
125
126
  Returns:
126
127
  List of tuples containing (version, file_path).
127
128
  """
128
- # For async, we still use the sync file operations since Path.glob is sync
129
129
  return self._get_migration_files_sync()
130
130
 
131
131
  async def load_migration(self, file_path: Path) -> "dict[str, Any]":
@@ -137,11 +137,10 @@ class AsyncMigrationRunner(BaseMigrationRunner["AsyncDriverAdapterProtocol[Any]"
137
137
  Returns:
138
138
  Dictionary containing migration metadata.
139
139
  """
140
- # File loading is still sync, so we use the base implementation
141
140
  return self._load_migration_metadata(file_path)
142
141
 
143
142
  async def execute_upgrade(
144
- self, driver: "AsyncDriverAdapterProtocol[Any]", migration: "dict[str, Any]"
143
+ self, driver: "AsyncDriverAdapterBase", migration: "dict[str, Any]"
145
144
  ) -> "tuple[Optional[str], int]":
146
145
  """Execute an upgrade migration.
147
146
 
@@ -157,15 +156,12 @@ class AsyncMigrationRunner(BaseMigrationRunner["AsyncDriverAdapterProtocol[Any]"
157
156
  return None, 0
158
157
 
159
158
  start_time = time.time()
160
-
161
- # Execute migration
162
159
  await driver.execute(upgrade_sql)
163
-
164
160
  execution_time = int((time.time() - start_time) * 1000)
165
161
  return None, execution_time
166
162
 
167
163
  async def execute_downgrade(
168
- self, driver: "AsyncDriverAdapterProtocol[Any]", migration: "dict[str, Any]"
164
+ self, driver: "AsyncDriverAdapterBase", migration: "dict[str, Any]"
169
165
  ) -> "tuple[Optional[str], int]":
170
166
  """Execute a downgrade migration.
171
167
 
@@ -181,35 +177,37 @@ class AsyncMigrationRunner(BaseMigrationRunner["AsyncDriverAdapterProtocol[Any]"
181
177
  return None, 0
182
178
 
183
179
  start_time = time.time()
184
-
185
- # Execute migration
186
180
  await driver.execute(downgrade_sql)
187
-
188
181
  execution_time = int((time.time() - start_time) * 1000)
189
182
  return None, execution_time
190
183
 
191
184
  async def load_all_migrations(self) -> "dict[str, SQL]":
192
185
  """Load all migrations into a single namespace for bulk operations.
193
186
 
194
- Returns a dictionary mapping query names to SQL objects.
195
- Useful for:
196
- - Migration analysis tools
197
- - Documentation generation
198
- - Validation and linting
199
- - Migration squashing
200
-
201
187
  Returns:
202
188
  Dictionary mapping query names to SQL objects.
203
189
  """
204
190
  all_queries = {}
205
191
  migrations = await self.get_migration_files()
206
192
 
207
- for _version, file_path in migrations:
208
- self.loader.load_sql(file_path)
209
-
210
- # Get all queries from this file
211
- for query_name in self.loader.list_queries():
212
- # Store with full query name for uniqueness
213
- all_queries[query_name] = self.loader.get_sql(query_name)
193
+ for version, file_path in migrations:
194
+ if file_path.suffix == ".sql":
195
+ self.loader.load_sql(file_path)
196
+ for query_name in self.loader.list_queries():
197
+ all_queries[query_name] = self.loader.get_sql(query_name)
198
+ else:
199
+ loader = get_migration_loader(file_path, self.migrations_path, self.project_root)
200
+
201
+ try:
202
+ up_sql = await loader.get_up_sql(file_path)
203
+ down_sql = await loader.get_down_sql(file_path)
204
+
205
+ if up_sql:
206
+ all_queries[f"migrate-{version}-up"] = SQL(up_sql[0])
207
+ if down_sql:
208
+ all_queries[f"migrate-{version}-down"] = SQL(down_sql[0])
209
+
210
+ except Exception as e:
211
+ logger.debug("Failed to load Python migration %s: %s", file_path, e)
214
212
 
215
213
  return all_queries