sqlspec 0.13.1__py3-none-any.whl → 0.16.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of sqlspec might be problematic. Click here for more details.

Files changed (185) hide show
  1. sqlspec/__init__.py +71 -8
  2. sqlspec/__main__.py +12 -0
  3. sqlspec/__metadata__.py +1 -3
  4. sqlspec/_serialization.py +1 -2
  5. sqlspec/_sql.py +930 -136
  6. sqlspec/_typing.py +278 -142
  7. sqlspec/adapters/adbc/__init__.py +4 -3
  8. sqlspec/adapters/adbc/_types.py +12 -0
  9. sqlspec/adapters/adbc/config.py +116 -285
  10. sqlspec/adapters/adbc/driver.py +462 -340
  11. sqlspec/adapters/aiosqlite/__init__.py +18 -3
  12. sqlspec/adapters/aiosqlite/_types.py +13 -0
  13. sqlspec/adapters/aiosqlite/config.py +202 -150
  14. sqlspec/adapters/aiosqlite/driver.py +226 -247
  15. sqlspec/adapters/asyncmy/__init__.py +18 -3
  16. sqlspec/adapters/asyncmy/_types.py +12 -0
  17. sqlspec/adapters/asyncmy/config.py +80 -199
  18. sqlspec/adapters/asyncmy/driver.py +257 -215
  19. sqlspec/adapters/asyncpg/__init__.py +19 -4
  20. sqlspec/adapters/asyncpg/_types.py +17 -0
  21. sqlspec/adapters/asyncpg/config.py +81 -214
  22. sqlspec/adapters/asyncpg/driver.py +284 -359
  23. sqlspec/adapters/bigquery/__init__.py +17 -3
  24. sqlspec/adapters/bigquery/_types.py +12 -0
  25. sqlspec/adapters/bigquery/config.py +191 -299
  26. sqlspec/adapters/bigquery/driver.py +474 -634
  27. sqlspec/adapters/duckdb/__init__.py +14 -3
  28. sqlspec/adapters/duckdb/_types.py +12 -0
  29. sqlspec/adapters/duckdb/config.py +414 -397
  30. sqlspec/adapters/duckdb/driver.py +342 -393
  31. sqlspec/adapters/oracledb/__init__.py +19 -5
  32. sqlspec/adapters/oracledb/_types.py +14 -0
  33. sqlspec/adapters/oracledb/config.py +123 -458
  34. sqlspec/adapters/oracledb/driver.py +505 -531
  35. sqlspec/adapters/psqlpy/__init__.py +13 -3
  36. sqlspec/adapters/psqlpy/_types.py +11 -0
  37. sqlspec/adapters/psqlpy/config.py +93 -307
  38. sqlspec/adapters/psqlpy/driver.py +504 -213
  39. sqlspec/adapters/psycopg/__init__.py +19 -5
  40. sqlspec/adapters/psycopg/_types.py +17 -0
  41. sqlspec/adapters/psycopg/config.py +143 -472
  42. sqlspec/adapters/psycopg/driver.py +704 -825
  43. sqlspec/adapters/sqlite/__init__.py +14 -3
  44. sqlspec/adapters/sqlite/_types.py +11 -0
  45. sqlspec/adapters/sqlite/config.py +208 -142
  46. sqlspec/adapters/sqlite/driver.py +263 -278
  47. sqlspec/base.py +105 -9
  48. sqlspec/{statement/builder → builder}/__init__.py +12 -14
  49. sqlspec/{statement/builder/base.py → builder/_base.py} +184 -86
  50. sqlspec/{statement/builder/column.py → builder/_column.py} +97 -60
  51. sqlspec/{statement/builder/ddl.py → builder/_ddl.py} +61 -131
  52. sqlspec/{statement/builder → builder}/_ddl_utils.py +4 -10
  53. sqlspec/{statement/builder/delete.py → builder/_delete.py} +10 -30
  54. sqlspec/builder/_insert.py +421 -0
  55. sqlspec/builder/_merge.py +71 -0
  56. sqlspec/{statement/builder → builder}/_parsing_utils.py +49 -26
  57. sqlspec/builder/_select.py +170 -0
  58. sqlspec/{statement/builder/update.py → builder/_update.py} +16 -20
  59. sqlspec/builder/mixins/__init__.py +55 -0
  60. sqlspec/builder/mixins/_cte_and_set_ops.py +222 -0
  61. sqlspec/{statement/builder/mixins/_delete_from.py → builder/mixins/_delete_operations.py} +8 -1
  62. sqlspec/builder/mixins/_insert_operations.py +244 -0
  63. sqlspec/{statement/builder/mixins/_join.py → builder/mixins/_join_operations.py} +45 -13
  64. sqlspec/{statement/builder/mixins/_merge_clauses.py → builder/mixins/_merge_operations.py} +188 -30
  65. sqlspec/builder/mixins/_order_limit_operations.py +135 -0
  66. sqlspec/builder/mixins/_pivot_operations.py +153 -0
  67. sqlspec/builder/mixins/_select_operations.py +604 -0
  68. sqlspec/builder/mixins/_update_operations.py +202 -0
  69. sqlspec/builder/mixins/_where_clause.py +644 -0
  70. sqlspec/cli.py +247 -0
  71. sqlspec/config.py +183 -138
  72. sqlspec/core/__init__.py +63 -0
  73. sqlspec/core/cache.py +871 -0
  74. sqlspec/core/compiler.py +417 -0
  75. sqlspec/core/filters.py +830 -0
  76. sqlspec/core/hashing.py +310 -0
  77. sqlspec/core/parameters.py +1237 -0
  78. sqlspec/core/result.py +677 -0
  79. sqlspec/{statement → core}/splitter.py +321 -191
  80. sqlspec/core/statement.py +676 -0
  81. sqlspec/driver/__init__.py +7 -10
  82. sqlspec/driver/_async.py +422 -163
  83. sqlspec/driver/_common.py +545 -287
  84. sqlspec/driver/_sync.py +426 -160
  85. sqlspec/driver/mixins/__init__.py +2 -13
  86. sqlspec/driver/mixins/_result_tools.py +193 -0
  87. sqlspec/driver/mixins/_sql_translator.py +65 -14
  88. sqlspec/exceptions.py +5 -252
  89. sqlspec/extensions/aiosql/adapter.py +93 -96
  90. sqlspec/extensions/litestar/__init__.py +2 -1
  91. sqlspec/extensions/litestar/cli.py +48 -0
  92. sqlspec/extensions/litestar/config.py +0 -1
  93. sqlspec/extensions/litestar/handlers.py +15 -26
  94. sqlspec/extensions/litestar/plugin.py +21 -16
  95. sqlspec/extensions/litestar/providers.py +17 -52
  96. sqlspec/loader.py +423 -104
  97. sqlspec/migrations/__init__.py +35 -0
  98. sqlspec/migrations/base.py +414 -0
  99. sqlspec/migrations/commands.py +443 -0
  100. sqlspec/migrations/loaders.py +402 -0
  101. sqlspec/migrations/runner.py +213 -0
  102. sqlspec/migrations/tracker.py +140 -0
  103. sqlspec/migrations/utils.py +129 -0
  104. sqlspec/protocols.py +51 -186
  105. sqlspec/storage/__init__.py +1 -1
  106. sqlspec/storage/backends/base.py +37 -40
  107. sqlspec/storage/backends/fsspec.py +136 -112
  108. sqlspec/storage/backends/obstore.py +138 -160
  109. sqlspec/storage/capabilities.py +5 -4
  110. sqlspec/storage/registry.py +57 -106
  111. sqlspec/typing.py +136 -115
  112. sqlspec/utils/__init__.py +2 -2
  113. sqlspec/utils/correlation.py +0 -3
  114. sqlspec/utils/deprecation.py +6 -6
  115. sqlspec/utils/fixtures.py +6 -6
  116. sqlspec/utils/logging.py +0 -2
  117. sqlspec/utils/module_loader.py +7 -12
  118. sqlspec/utils/singleton.py +0 -1
  119. sqlspec/utils/sync_tools.py +17 -38
  120. sqlspec/utils/text.py +12 -51
  121. sqlspec/utils/type_guards.py +482 -235
  122. {sqlspec-0.13.1.dist-info → sqlspec-0.16.2.dist-info}/METADATA +7 -2
  123. sqlspec-0.16.2.dist-info/RECORD +134 -0
  124. sqlspec-0.16.2.dist-info/entry_points.txt +2 -0
  125. sqlspec/driver/connection.py +0 -207
  126. sqlspec/driver/mixins/_csv_writer.py +0 -91
  127. sqlspec/driver/mixins/_pipeline.py +0 -512
  128. sqlspec/driver/mixins/_result_utils.py +0 -140
  129. sqlspec/driver/mixins/_storage.py +0 -926
  130. sqlspec/driver/mixins/_type_coercion.py +0 -130
  131. sqlspec/driver/parameters.py +0 -138
  132. sqlspec/service/__init__.py +0 -4
  133. sqlspec/service/_util.py +0 -147
  134. sqlspec/service/base.py +0 -1131
  135. sqlspec/service/pagination.py +0 -26
  136. sqlspec/statement/__init__.py +0 -21
  137. sqlspec/statement/builder/insert.py +0 -288
  138. sqlspec/statement/builder/merge.py +0 -95
  139. sqlspec/statement/builder/mixins/__init__.py +0 -65
  140. sqlspec/statement/builder/mixins/_aggregate_functions.py +0 -250
  141. sqlspec/statement/builder/mixins/_case_builder.py +0 -91
  142. sqlspec/statement/builder/mixins/_common_table_expr.py +0 -90
  143. sqlspec/statement/builder/mixins/_from.py +0 -63
  144. sqlspec/statement/builder/mixins/_group_by.py +0 -118
  145. sqlspec/statement/builder/mixins/_having.py +0 -35
  146. sqlspec/statement/builder/mixins/_insert_from_select.py +0 -47
  147. sqlspec/statement/builder/mixins/_insert_into.py +0 -36
  148. sqlspec/statement/builder/mixins/_insert_values.py +0 -67
  149. sqlspec/statement/builder/mixins/_limit_offset.py +0 -53
  150. sqlspec/statement/builder/mixins/_order_by.py +0 -46
  151. sqlspec/statement/builder/mixins/_pivot.py +0 -79
  152. sqlspec/statement/builder/mixins/_returning.py +0 -37
  153. sqlspec/statement/builder/mixins/_select_columns.py +0 -61
  154. sqlspec/statement/builder/mixins/_set_ops.py +0 -122
  155. sqlspec/statement/builder/mixins/_unpivot.py +0 -77
  156. sqlspec/statement/builder/mixins/_update_from.py +0 -55
  157. sqlspec/statement/builder/mixins/_update_set.py +0 -94
  158. sqlspec/statement/builder/mixins/_update_table.py +0 -29
  159. sqlspec/statement/builder/mixins/_where.py +0 -401
  160. sqlspec/statement/builder/mixins/_window_functions.py +0 -86
  161. sqlspec/statement/builder/select.py +0 -221
  162. sqlspec/statement/filters.py +0 -596
  163. sqlspec/statement/parameter_manager.py +0 -220
  164. sqlspec/statement/parameters.py +0 -867
  165. sqlspec/statement/pipelines/__init__.py +0 -210
  166. sqlspec/statement/pipelines/analyzers/__init__.py +0 -9
  167. sqlspec/statement/pipelines/analyzers/_analyzer.py +0 -646
  168. sqlspec/statement/pipelines/context.py +0 -115
  169. sqlspec/statement/pipelines/transformers/__init__.py +0 -7
  170. sqlspec/statement/pipelines/transformers/_expression_simplifier.py +0 -88
  171. sqlspec/statement/pipelines/transformers/_literal_parameterizer.py +0 -1247
  172. sqlspec/statement/pipelines/transformers/_remove_comments_and_hints.py +0 -76
  173. sqlspec/statement/pipelines/validators/__init__.py +0 -23
  174. sqlspec/statement/pipelines/validators/_dml_safety.py +0 -290
  175. sqlspec/statement/pipelines/validators/_parameter_style.py +0 -370
  176. sqlspec/statement/pipelines/validators/_performance.py +0 -718
  177. sqlspec/statement/pipelines/validators/_security.py +0 -967
  178. sqlspec/statement/result.py +0 -435
  179. sqlspec/statement/sql.py +0 -1704
  180. sqlspec/statement/sql_compiler.py +0 -140
  181. sqlspec/utils/cached_property.py +0 -25
  182. sqlspec-0.13.1.dist-info/RECORD +0 -150
  183. {sqlspec-0.13.1.dist-info → sqlspec-0.16.2.dist-info}/WHEEL +0 -0
  184. {sqlspec-0.13.1.dist-info → sqlspec-0.16.2.dist-info}/licenses/LICENSE +0 -0
  185. {sqlspec-0.13.1.dist-info → sqlspec-0.16.2.dist-info}/licenses/NOTICE +0 -0
@@ -0,0 +1,402 @@
1
+ """Migration loader abstractions for SQLSpec."""
2
+
3
+ import abc
4
+ import inspect
5
+ import sys
6
+ import types
7
+ from collections.abc import Iterator
8
+ from contextlib import contextmanager
9
+ from pathlib import Path
10
+ from typing import Any, Final, Optional
11
+
12
+ from sqlspec.loader import SQLFileLoader as CoreSQLFileLoader
13
+
14
+ __all__ = ("BaseMigrationLoader", "MigrationLoadError", "PythonFileLoader", "SQLFileLoader", "get_migration_loader")
15
+
16
+ PROJECT_ROOT_MARKERS: Final[list[str]] = ["pyproject.toml", ".git", "setup.cfg", "setup.py"]
17
+
18
+
19
+ class MigrationLoadError(Exception):
20
+ """Exception raised when migration loading fails."""
21
+
22
+
23
+ class BaseMigrationLoader(abc.ABC):
24
+ """Abstract base class for migration loaders."""
25
+
26
+ __slots__ = ()
27
+
28
+ @abc.abstractmethod
29
+ async def get_up_sql(self, path: Path) -> list[str]:
30
+ """Load and return the 'up' SQL statements from a migration file.
31
+
32
+ Args:
33
+ path: Path to the migration file.
34
+
35
+ Returns:
36
+ List of SQL statements to execute for upgrade.
37
+
38
+ Raises:
39
+ MigrationLoadError: If loading fails.
40
+ """
41
+ ...
42
+
43
+ @abc.abstractmethod
44
+ async def get_down_sql(self, path: Path) -> list[str]:
45
+ """Load and return the 'down' SQL statements from a migration file.
46
+
47
+ Args:
48
+ path: Path to the migration file.
49
+
50
+ Returns:
51
+ List of SQL statements to execute for downgrade.
52
+ Empty list if no downgrade is available.
53
+
54
+ Raises:
55
+ MigrationLoadError: If loading fails.
56
+ """
57
+ ...
58
+
59
+ @abc.abstractmethod
60
+ def validate_migration_file(self, path: Path) -> None:
61
+ """Validate that the migration file has required components.
62
+
63
+ Args:
64
+ path: Path to the migration file.
65
+
66
+ Raises:
67
+ MigrationLoadError: If validation fails.
68
+ """
69
+ ...
70
+
71
+
72
+ class SQLFileLoader(BaseMigrationLoader):
73
+ """Loader for SQL migration files using SQLFileLoader."""
74
+
75
+ __slots__ = ("sql_loader",)
76
+
77
+ def __init__(self) -> None:
78
+ """Initialize SQL file loader."""
79
+ self.sql_loader: CoreSQLFileLoader = CoreSQLFileLoader()
80
+
81
+ async def get_up_sql(self, path: Path) -> list[str]:
82
+ """Extract the 'up' SQL from a SQL migration file.
83
+
84
+ Args:
85
+ path: Path to SQL migration file.
86
+
87
+ Returns:
88
+ List containing single SQL statement for upgrade.
89
+
90
+ Raises:
91
+ MigrationLoadError: If migration file is invalid or missing up query.
92
+ """
93
+ self.sql_loader.clear_cache()
94
+ self.sql_loader.load_sql(path)
95
+
96
+ version = self._extract_version(path.name)
97
+ up_query = f"migrate-{version}-up"
98
+
99
+ if not self.sql_loader.has_query(up_query):
100
+ msg = f"Migration {path} missing 'up' query: {up_query}"
101
+ raise MigrationLoadError(msg)
102
+
103
+ sql_obj = self.sql_loader.get_sql(up_query)
104
+ return [sql_obj.sql]
105
+
106
+ async def get_down_sql(self, path: Path) -> list[str]:
107
+ """Extract the 'down' SQL from a SQL migration file.
108
+
109
+ Args:
110
+ path: Path to SQL migration file.
111
+
112
+ Returns:
113
+ List containing single SQL statement for downgrade, or empty list.
114
+ """
115
+ self.sql_loader.clear_cache()
116
+ self.sql_loader.load_sql(path)
117
+
118
+ version = self._extract_version(path.name)
119
+ down_query = f"migrate-{version}-down"
120
+
121
+ if not self.sql_loader.has_query(down_query):
122
+ return []
123
+
124
+ sql_obj = self.sql_loader.get_sql(down_query)
125
+ return [sql_obj.sql]
126
+
127
+ def validate_migration_file(self, path: Path) -> None:
128
+ """Validate SQL migration file has required up query.
129
+
130
+ Args:
131
+ path: Path to SQL migration file.
132
+
133
+ Raises:
134
+ MigrationLoadError: If file is invalid or missing required query.
135
+ """
136
+ version = self._extract_version(path.name)
137
+ if not version:
138
+ msg = f"Invalid migration filename: {path.name}"
139
+ raise MigrationLoadError(msg)
140
+
141
+ self.sql_loader.clear_cache()
142
+ self.sql_loader.load_sql(path)
143
+ up_query = f"migrate-{version}-up"
144
+ if not self.sql_loader.has_query(up_query):
145
+ msg = f"Migration {path} missing required 'up' query: {up_query}"
146
+ raise MigrationLoadError(msg)
147
+
148
+ def _extract_version(self, filename: str) -> str:
149
+ """Extract version from filename.
150
+
151
+ Args:
152
+ filename: Migration filename to parse.
153
+
154
+ Returns:
155
+ Zero-padded version string or empty string if invalid.
156
+ """
157
+ parts = filename.split("_", 1)
158
+ return parts[0].zfill(4) if parts and parts[0].isdigit() else ""
159
+
160
+
161
+ class PythonFileLoader(BaseMigrationLoader):
162
+ """Loader for Python migration files."""
163
+
164
+ __slots__ = ("migrations_dir", "project_root")
165
+
166
+ def __init__(self, migrations_dir: Path, project_root: "Optional[Path]" = None) -> None:
167
+ """Initialize Python file loader.
168
+
169
+ Args:
170
+ migrations_dir: Directory containing migration files.
171
+ project_root: Optional project root directory for imports.
172
+ """
173
+ self.migrations_dir = migrations_dir
174
+ self.project_root = project_root if project_root is not None else self._find_project_root(migrations_dir)
175
+
176
+ async def get_up_sql(self, path: Path) -> list[str]:
177
+ """Load Python migration and execute upgrade function.
178
+
179
+ Args:
180
+ path: Path to Python migration file.
181
+
182
+ Returns:
183
+ List of SQL statements for upgrade.
184
+
185
+ Raises:
186
+ MigrationLoadError: If function is missing or execution fails.
187
+ """
188
+ with self._temporary_project_path():
189
+ module = self._load_module_from_path(path)
190
+
191
+ upgrade_func = None
192
+ func_name = None
193
+
194
+ if hasattr(module, "up") and callable(module.up):
195
+ upgrade_func = module.up
196
+ func_name = "up"
197
+ elif hasattr(module, "migrate_up") and callable(module.migrate_up):
198
+ upgrade_func = module.migrate_up
199
+ func_name = "migrate_up"
200
+ else:
201
+ msg = f"No upgrade function found in {path}. Expected 'up()' or 'migrate_up()'"
202
+ raise MigrationLoadError(msg)
203
+
204
+ if not callable(upgrade_func):
205
+ msg = f"'{func_name}' is not callable in {path}"
206
+ raise MigrationLoadError(msg)
207
+
208
+ if inspect.iscoroutinefunction(upgrade_func):
209
+ sql_result = await upgrade_func()
210
+ else:
211
+ sql_result = upgrade_func()
212
+
213
+ return self._normalize_and_validate_sql(sql_result, path)
214
+
215
+ async def get_down_sql(self, path: Path) -> list[str]:
216
+ """Load Python migration and execute downgrade function.
217
+
218
+ Args:
219
+ path: Path to Python migration file.
220
+
221
+ Returns:
222
+ List of SQL statements for downgrade, or empty list if not available.
223
+ """
224
+ with self._temporary_project_path():
225
+ module = self._load_module_from_path(path)
226
+
227
+ downgrade_func = None
228
+
229
+ if hasattr(module, "down") and callable(module.down):
230
+ downgrade_func = module.down
231
+ elif hasattr(module, "migrate_down") and callable(module.migrate_down):
232
+ downgrade_func = module.migrate_down
233
+ else:
234
+ return []
235
+
236
+ if not callable(downgrade_func):
237
+ return []
238
+
239
+ if inspect.iscoroutinefunction(downgrade_func):
240
+ sql_result = await downgrade_func()
241
+ else:
242
+ sql_result = downgrade_func()
243
+
244
+ return self._normalize_and_validate_sql(sql_result, path)
245
+
246
+ def validate_migration_file(self, path: Path) -> None:
247
+ """Validate Python migration file has required upgrade function.
248
+
249
+ Args:
250
+ path: Path to Python migration file.
251
+
252
+ Raises:
253
+ MigrationLoadError: If validation fails.
254
+ """
255
+ with self._temporary_project_path():
256
+ module = self._load_module_from_path(path)
257
+
258
+ upgrade_func = None
259
+ func_name = None
260
+
261
+ if hasattr(module, "up") and callable(module.up):
262
+ upgrade_func = module.up
263
+ func_name = "up"
264
+ elif hasattr(module, "migrate_up") and callable(module.migrate_up):
265
+ upgrade_func = module.migrate_up
266
+ func_name = "migrate_up"
267
+ else:
268
+ msg = f"Migration {path} missing required upgrade function. Expected 'up()' or 'migrate_up()'"
269
+ raise MigrationLoadError(msg)
270
+
271
+ if not callable(upgrade_func):
272
+ msg = f"Migration {path} '{func_name}' is not callable"
273
+ raise MigrationLoadError(msg)
274
+
275
+ def _find_project_root(self, start_path: Path) -> Path:
276
+ """Find project root by searching upwards for marker files.
277
+
278
+ Args:
279
+ start_path: Directory to start searching from.
280
+
281
+ Returns:
282
+ Path to project root or parent directory.
283
+ """
284
+ current_path = start_path.resolve()
285
+
286
+ while current_path != current_path.parent:
287
+ for marker in PROJECT_ROOT_MARKERS:
288
+ if (current_path / marker).exists():
289
+ return current_path
290
+ current_path = current_path.parent
291
+
292
+ return start_path.resolve().parent
293
+
294
+ @contextmanager
295
+ def _temporary_project_path(self) -> Iterator[None]:
296
+ """Temporarily add project root to sys.path for imports."""
297
+ path_to_add = str(self.project_root)
298
+ if path_to_add in sys.path:
299
+ yield
300
+ return
301
+
302
+ sys.path.insert(0, path_to_add)
303
+ try:
304
+ yield
305
+ finally:
306
+ sys.path.remove(path_to_add)
307
+
308
+ def _load_module_from_path(self, path: Path) -> Any:
309
+ """Load a Python module from file path.
310
+
311
+ Args:
312
+ path: Path to Python migration file.
313
+
314
+ Returns:
315
+ Loaded module object.
316
+
317
+ Raises:
318
+ MigrationLoadError: If module loading fails.
319
+ """
320
+ module_name = f"sqlspec_migration_{path.stem}"
321
+
322
+ if module_name in sys.modules:
323
+ sys.modules.pop(module_name, None)
324
+
325
+ try:
326
+ source_code = path.read_text(encoding="utf-8")
327
+ compiled_code = compile(source_code, str(path), "exec")
328
+
329
+ module = types.ModuleType(module_name)
330
+ module.__file__ = str(path)
331
+
332
+ sys.modules[module_name] = module
333
+
334
+ exec(compiled_code, module.__dict__) # noqa: S102
335
+
336
+ except Exception as e:
337
+ sys.modules.pop(module_name, None)
338
+ msg = f"Failed to execute migration module {path}: {e}"
339
+ raise MigrationLoadError(msg) from e
340
+
341
+ return module
342
+
343
+ def _normalize_and_validate_sql(self, sql: Any, migration_path: Path) -> list[str]:
344
+ """Validate return type and normalize to list of strings.
345
+
346
+ Args:
347
+ sql: Return value from migration function.
348
+ migration_path: Path to migration file for error messages.
349
+
350
+ Returns:
351
+ List of SQL statements.
352
+
353
+ Raises:
354
+ MigrationLoadError: If return type is invalid.
355
+ """
356
+ if isinstance(sql, str):
357
+ stripped = sql.strip()
358
+ return [stripped] if stripped else []
359
+ if isinstance(sql, list):
360
+ result = []
361
+ for i, item in enumerate(sql):
362
+ if not isinstance(item, str):
363
+ msg = (
364
+ f"Migration {migration_path} returned a list containing a non-string "
365
+ f"element at index {i} (type: {type(item).__name__})."
366
+ )
367
+ raise MigrationLoadError(msg)
368
+ stripped_item = item.strip()
369
+ if stripped_item:
370
+ result.append(stripped_item)
371
+ return result
372
+
373
+ msg = (
374
+ f"Migration {migration_path} must return a 'str' or 'List[str]', but returned type '{type(sql).__name__}'."
375
+ )
376
+ raise MigrationLoadError(msg)
377
+
378
+
379
+ def get_migration_loader(
380
+ file_path: Path, migrations_dir: Path, project_root: "Optional[Path]" = None
381
+ ) -> BaseMigrationLoader:
382
+ """Factory function to get appropriate loader for migration file.
383
+
384
+ Args:
385
+ file_path: Path to the migration file.
386
+ migrations_dir: Directory containing migration files.
387
+ project_root: Optional project root directory for Python imports.
388
+
389
+ Returns:
390
+ Appropriate loader instance for the file type.
391
+
392
+ Raises:
393
+ MigrationLoadError: If file type is not supported.
394
+ """
395
+ suffix = file_path.suffix
396
+
397
+ if suffix == ".py":
398
+ return PythonFileLoader(migrations_dir, project_root)
399
+ if suffix == ".sql":
400
+ return SQLFileLoader()
401
+ msg = f"Unsupported migration file type: {suffix}"
402
+ raise MigrationLoadError(msg)
@@ -0,0 +1,213 @@
1
+ """Migration execution engine for SQLSpec.
2
+
3
+ This module handles migration file loading and execution using SQLFileLoader.
4
+ """
5
+
6
+ import time
7
+ from pathlib import Path
8
+ from typing import TYPE_CHECKING, Any, Optional
9
+
10
+ from sqlspec.core.statement import SQL
11
+ from sqlspec.migrations.base import BaseMigrationRunner
12
+ from sqlspec.migrations.loaders import get_migration_loader
13
+ from sqlspec.utils.logging import get_logger
14
+ from sqlspec.utils.sync_tools import run_
15
+
16
+ if TYPE_CHECKING:
17
+ from sqlspec.driver import AsyncDriverAdapterBase, SyncDriverAdapterBase
18
+
19
+ __all__ = ("AsyncMigrationRunner", "SyncMigrationRunner")
20
+
21
+ logger = get_logger("migrations.runner")
22
+
23
+
24
+ class SyncMigrationRunner(BaseMigrationRunner["SyncDriverAdapterBase"]):
25
+ """Executes migrations using SQLFileLoader."""
26
+
27
+ def get_migration_files(self) -> "list[tuple[str, Path]]":
28
+ """Get all migration files sorted by version.
29
+
30
+ Returns:
31
+ List of (version, path) tuples sorted by version.
32
+ """
33
+ return self._get_migration_files_sync()
34
+
35
+ def load_migration(self, file_path: Path) -> "dict[str, Any]":
36
+ """Load a migration file and extract its components.
37
+
38
+ Args:
39
+ file_path: Path to the migration file.
40
+
41
+ Returns:
42
+ Dictionary containing migration metadata and queries.
43
+ """
44
+ return self._load_migration_metadata(file_path)
45
+
46
+ def execute_upgrade(
47
+ self, driver: "SyncDriverAdapterBase", migration: "dict[str, Any]"
48
+ ) -> "tuple[Optional[str], int]":
49
+ """Execute an upgrade migration.
50
+
51
+ Args:
52
+ driver: The database driver to use.
53
+ migration: Migration metadata dictionary.
54
+
55
+ Returns:
56
+ Tuple of (sql_content, execution_time_ms).
57
+ """
58
+ upgrade_sql = self._get_migration_sql(migration, "up")
59
+ if upgrade_sql is None:
60
+ return None, 0
61
+
62
+ start_time = time.time()
63
+ driver.execute(upgrade_sql)
64
+ execution_time = int((time.time() - start_time) * 1000)
65
+ return None, execution_time
66
+
67
+ def execute_downgrade(
68
+ self, driver: "SyncDriverAdapterBase", migration: "dict[str, Any]"
69
+ ) -> "tuple[Optional[str], int]":
70
+ """Execute a downgrade migration.
71
+
72
+ Args:
73
+ driver: The database driver to use.
74
+ migration: Migration metadata dictionary.
75
+
76
+ Returns:
77
+ Tuple of (sql_content, execution_time_ms).
78
+ """
79
+ downgrade_sql = self._get_migration_sql(migration, "down")
80
+ if downgrade_sql is None:
81
+ return None, 0
82
+
83
+ start_time = time.time()
84
+ driver.execute(downgrade_sql)
85
+ execution_time = int((time.time() - start_time) * 1000)
86
+ return None, execution_time
87
+
88
+ def load_all_migrations(self) -> "dict[str, SQL]":
89
+ """Load all migrations into a single namespace for bulk operations.
90
+
91
+ Returns:
92
+ Dictionary mapping query names to SQL objects.
93
+ """
94
+ all_queries = {}
95
+ migrations = self.get_migration_files()
96
+
97
+ for version, file_path in migrations:
98
+ if file_path.suffix == ".sql":
99
+ self.loader.load_sql(file_path)
100
+ for query_name in self.loader.list_queries():
101
+ all_queries[query_name] = self.loader.get_sql(query_name)
102
+ else:
103
+ loader = get_migration_loader(file_path, self.migrations_path, self.project_root)
104
+
105
+ try:
106
+ up_sql = run_(loader.get_up_sql)(file_path)
107
+ down_sql = run_(loader.get_down_sql)(file_path)
108
+
109
+ if up_sql:
110
+ all_queries[f"migrate-{version}-up"] = SQL(up_sql[0])
111
+ if down_sql:
112
+ all_queries[f"migrate-{version}-down"] = SQL(down_sql[0])
113
+
114
+ except Exception as e:
115
+ logger.debug("Failed to load Python migration %s: %s", file_path, e)
116
+
117
+ return all_queries
118
+
119
+
120
+ class AsyncMigrationRunner(BaseMigrationRunner["AsyncDriverAdapterBase"]):
121
+ """Executes migrations using SQLFileLoader."""
122
+
123
+ async def get_migration_files(self) -> "list[tuple[str, Path]]":
124
+ """Get all migration files sorted by version.
125
+
126
+ Returns:
127
+ List of tuples containing (version, file_path).
128
+ """
129
+ return self._get_migration_files_sync()
130
+
131
+ async def load_migration(self, file_path: Path) -> "dict[str, Any]":
132
+ """Load a migration file and extract its components.
133
+
134
+ Args:
135
+ file_path: Path to the migration file.
136
+
137
+ Returns:
138
+ Dictionary containing migration metadata.
139
+ """
140
+ return self._load_migration_metadata(file_path)
141
+
142
+ async def execute_upgrade(
143
+ self, driver: "AsyncDriverAdapterBase", migration: "dict[str, Any]"
144
+ ) -> "tuple[Optional[str], int]":
145
+ """Execute an upgrade migration.
146
+
147
+ Args:
148
+ driver: The async database driver to use.
149
+ migration: Migration metadata dictionary.
150
+
151
+ Returns:
152
+ Tuple of (sql_content, execution_time_ms).
153
+ """
154
+ upgrade_sql = self._get_migration_sql(migration, "up")
155
+ if upgrade_sql is None:
156
+ return None, 0
157
+
158
+ start_time = time.time()
159
+ await driver.execute(upgrade_sql)
160
+ execution_time = int((time.time() - start_time) * 1000)
161
+ return None, execution_time
162
+
163
+ async def execute_downgrade(
164
+ self, driver: "AsyncDriverAdapterBase", migration: "dict[str, Any]"
165
+ ) -> "tuple[Optional[str], int]":
166
+ """Execute a downgrade migration.
167
+
168
+ Args:
169
+ driver: The async database driver to use.
170
+ migration: Migration metadata dictionary.
171
+
172
+ Returns:
173
+ Tuple of (sql_content, execution_time_ms).
174
+ """
175
+ downgrade_sql = self._get_migration_sql(migration, "down")
176
+ if downgrade_sql is None:
177
+ return None, 0
178
+
179
+ start_time = time.time()
180
+ await driver.execute(downgrade_sql)
181
+ execution_time = int((time.time() - start_time) * 1000)
182
+ return None, execution_time
183
+
184
+ async def load_all_migrations(self) -> "dict[str, SQL]":
185
+ """Load all migrations into a single namespace for bulk operations.
186
+
187
+ Returns:
188
+ Dictionary mapping query names to SQL objects.
189
+ """
190
+ all_queries = {}
191
+ migrations = await self.get_migration_files()
192
+
193
+ for version, file_path in migrations:
194
+ if file_path.suffix == ".sql":
195
+ self.loader.load_sql(file_path)
196
+ for query_name in self.loader.list_queries():
197
+ all_queries[query_name] = self.loader.get_sql(query_name)
198
+ else:
199
+ loader = get_migration_loader(file_path, self.migrations_path, self.project_root)
200
+
201
+ try:
202
+ up_sql = await loader.get_up_sql(file_path)
203
+ down_sql = await loader.get_down_sql(file_path)
204
+
205
+ if up_sql:
206
+ all_queries[f"migrate-{version}-up"] = SQL(up_sql[0])
207
+ if down_sql:
208
+ all_queries[f"migrate-{version}-down"] = SQL(down_sql[0])
209
+
210
+ except Exception as e:
211
+ logger.debug("Failed to load Python migration %s: %s", file_path, e)
212
+
213
+ return all_queries