sqlspec 0.24.1__py3-none-any.whl → 0.26.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of sqlspec might be problematic. Click here for more details.

Files changed (95) hide show
  1. sqlspec/_serialization.py +223 -21
  2. sqlspec/_sql.py +20 -62
  3. sqlspec/_typing.py +11 -0
  4. sqlspec/adapters/adbc/config.py +8 -1
  5. sqlspec/adapters/adbc/data_dictionary.py +290 -0
  6. sqlspec/adapters/adbc/driver.py +129 -20
  7. sqlspec/adapters/adbc/type_converter.py +159 -0
  8. sqlspec/adapters/aiosqlite/config.py +3 -0
  9. sqlspec/adapters/aiosqlite/data_dictionary.py +117 -0
  10. sqlspec/adapters/aiosqlite/driver.py +17 -3
  11. sqlspec/adapters/asyncmy/_types.py +1 -1
  12. sqlspec/adapters/asyncmy/config.py +11 -8
  13. sqlspec/adapters/asyncmy/data_dictionary.py +122 -0
  14. sqlspec/adapters/asyncmy/driver.py +31 -7
  15. sqlspec/adapters/asyncpg/config.py +3 -0
  16. sqlspec/adapters/asyncpg/data_dictionary.py +134 -0
  17. sqlspec/adapters/asyncpg/driver.py +19 -4
  18. sqlspec/adapters/bigquery/config.py +3 -0
  19. sqlspec/adapters/bigquery/data_dictionary.py +109 -0
  20. sqlspec/adapters/bigquery/driver.py +21 -3
  21. sqlspec/adapters/bigquery/type_converter.py +93 -0
  22. sqlspec/adapters/duckdb/_types.py +1 -1
  23. sqlspec/adapters/duckdb/config.py +2 -0
  24. sqlspec/adapters/duckdb/data_dictionary.py +124 -0
  25. sqlspec/adapters/duckdb/driver.py +32 -5
  26. sqlspec/adapters/duckdb/pool.py +1 -1
  27. sqlspec/adapters/duckdb/type_converter.py +103 -0
  28. sqlspec/adapters/oracledb/config.py +6 -0
  29. sqlspec/adapters/oracledb/data_dictionary.py +442 -0
  30. sqlspec/adapters/oracledb/driver.py +68 -9
  31. sqlspec/adapters/oracledb/migrations.py +51 -67
  32. sqlspec/adapters/oracledb/type_converter.py +132 -0
  33. sqlspec/adapters/psqlpy/config.py +3 -0
  34. sqlspec/adapters/psqlpy/data_dictionary.py +133 -0
  35. sqlspec/adapters/psqlpy/driver.py +23 -179
  36. sqlspec/adapters/psqlpy/type_converter.py +73 -0
  37. sqlspec/adapters/psycopg/config.py +8 -4
  38. sqlspec/adapters/psycopg/data_dictionary.py +257 -0
  39. sqlspec/adapters/psycopg/driver.py +40 -5
  40. sqlspec/adapters/sqlite/config.py +3 -0
  41. sqlspec/adapters/sqlite/data_dictionary.py +117 -0
  42. sqlspec/adapters/sqlite/driver.py +18 -3
  43. sqlspec/adapters/sqlite/pool.py +13 -4
  44. sqlspec/base.py +3 -4
  45. sqlspec/builder/_base.py +130 -48
  46. sqlspec/builder/_column.py +66 -24
  47. sqlspec/builder/_ddl.py +91 -41
  48. sqlspec/builder/_insert.py +40 -58
  49. sqlspec/builder/_parsing_utils.py +127 -12
  50. sqlspec/builder/_select.py +147 -2
  51. sqlspec/builder/_update.py +1 -1
  52. sqlspec/builder/mixins/_cte_and_set_ops.py +31 -23
  53. sqlspec/builder/mixins/_delete_operations.py +12 -7
  54. sqlspec/builder/mixins/_insert_operations.py +50 -36
  55. sqlspec/builder/mixins/_join_operations.py +15 -30
  56. sqlspec/builder/mixins/_merge_operations.py +210 -78
  57. sqlspec/builder/mixins/_order_limit_operations.py +4 -10
  58. sqlspec/builder/mixins/_pivot_operations.py +1 -0
  59. sqlspec/builder/mixins/_select_operations.py +44 -22
  60. sqlspec/builder/mixins/_update_operations.py +30 -37
  61. sqlspec/builder/mixins/_where_clause.py +52 -70
  62. sqlspec/cli.py +246 -140
  63. sqlspec/config.py +33 -19
  64. sqlspec/core/__init__.py +3 -2
  65. sqlspec/core/cache.py +298 -352
  66. sqlspec/core/compiler.py +61 -4
  67. sqlspec/core/filters.py +246 -213
  68. sqlspec/core/hashing.py +9 -11
  69. sqlspec/core/parameters.py +27 -10
  70. sqlspec/core/statement.py +72 -12
  71. sqlspec/core/type_conversion.py +234 -0
  72. sqlspec/driver/__init__.py +6 -3
  73. sqlspec/driver/_async.py +108 -5
  74. sqlspec/driver/_common.py +186 -17
  75. sqlspec/driver/_sync.py +108 -5
  76. sqlspec/driver/mixins/_result_tools.py +60 -7
  77. sqlspec/exceptions.py +5 -0
  78. sqlspec/loader.py +8 -9
  79. sqlspec/migrations/__init__.py +4 -3
  80. sqlspec/migrations/base.py +153 -14
  81. sqlspec/migrations/commands.py +34 -96
  82. sqlspec/migrations/context.py +145 -0
  83. sqlspec/migrations/loaders.py +25 -8
  84. sqlspec/migrations/runner.py +352 -82
  85. sqlspec/storage/backends/fsspec.py +1 -0
  86. sqlspec/typing.py +4 -0
  87. sqlspec/utils/config_resolver.py +153 -0
  88. sqlspec/utils/serializers.py +50 -2
  89. {sqlspec-0.24.1.dist-info → sqlspec-0.26.0.dist-info}/METADATA +1 -1
  90. sqlspec-0.26.0.dist-info/RECORD +157 -0
  91. sqlspec-0.24.1.dist-info/RECORD +0 -139
  92. {sqlspec-0.24.1.dist-info → sqlspec-0.26.0.dist-info}/WHEEL +0 -0
  93. {sqlspec-0.24.1.dist-info → sqlspec-0.26.0.dist-info}/entry_points.txt +0 -0
  94. {sqlspec-0.24.1.dist-info → sqlspec-0.26.0.dist-info}/licenses/LICENSE +0 -0
  95. {sqlspec-0.24.1.dist-info → sqlspec-0.26.0.dist-info}/licenses/NOTICE +0 -0
@@ -1,28 +1,135 @@
1
1
  """Migration execution engine for SQLSpec.
2
2
 
3
- This module handles migration file loading and execution using SQLFileLoader.
3
+ This module provides separate sync and async migration runners with clean separation
4
+ of concerns and proper type safety.
4
5
  """
5
6
 
7
+ import operator
6
8
  import time
9
+ from abc import ABC, abstractmethod
7
10
  from pathlib import Path
8
- from typing import TYPE_CHECKING, Any, Optional, cast
11
+ from typing import TYPE_CHECKING, Any, Literal, Optional, Union, cast, overload
9
12
 
10
13
  from sqlspec.core.statement import SQL
11
- from sqlspec.migrations.base import BaseMigrationRunner
14
+ from sqlspec.migrations.context import MigrationContext
12
15
  from sqlspec.migrations.loaders import get_migration_loader
13
16
  from sqlspec.utils.logging import get_logger
14
- from sqlspec.utils.sync_tools import await_
17
+ from sqlspec.utils.sync_tools import async_, await_
15
18
 
16
19
  if TYPE_CHECKING:
20
+ from collections.abc import Coroutine
21
+
17
22
  from sqlspec.driver import AsyncDriverAdapterBase, SyncDriverAdapterBase
18
23
 
19
- __all__ = ("AsyncMigrationRunner", "SyncMigrationRunner")
24
+ __all__ = ("AsyncMigrationRunner", "SyncMigrationRunner", "create_migration_runner")
20
25
 
21
26
  logger = get_logger("migrations.runner")
22
27
 
23
28
 
24
- class SyncMigrationRunner(BaseMigrationRunner["SyncDriverAdapterBase"]):
25
- """Synchronous migration executor."""
29
+ class BaseMigrationRunner(ABC):
30
+ """Base migration runner with common functionality shared between sync and async implementations."""
31
+
32
+ def __init__(
33
+ self,
34
+ migrations_path: Path,
35
+ extension_migrations: "Optional[dict[str, Path]]" = None,
36
+ context: "Optional[MigrationContext]" = None,
37
+ extension_configs: "Optional[dict[str, dict[str, Any]]]" = None,
38
+ ) -> None:
39
+ """Initialize the migration runner.
40
+
41
+ Args:
42
+ migrations_path: Path to the directory containing migration files.
43
+ extension_migrations: Optional mapping of extension names to their migration paths.
44
+ context: Optional migration context for Python migrations.
45
+ extension_configs: Optional mapping of extension names to their configurations.
46
+ """
47
+ self.migrations_path = migrations_path
48
+ self.extension_migrations = extension_migrations or {}
49
+ from sqlspec.loader import SQLFileLoader
50
+
51
+ self.loader = SQLFileLoader()
52
+ self.project_root: Optional[Path] = None
53
+ self.context = context
54
+ self.extension_configs = extension_configs or {}
55
+
56
+ def _extract_version(self, filename: str) -> "Optional[str]":
57
+ """Extract version from filename.
58
+
59
+ Args:
60
+ filename: The migration filename.
61
+
62
+ Returns:
63
+ The extracted version string or None.
64
+ """
65
+ # Handle extension-prefixed versions (e.g., "ext_litestar_0001")
66
+ if filename.startswith("ext_"):
67
+ # This is already a prefixed version, return as-is
68
+ return filename
69
+
70
+ # Regular version extraction
71
+ parts = filename.split("_", 1)
72
+ return parts[0].zfill(4) if parts and parts[0].isdigit() else None
73
+
74
+ def _calculate_checksum(self, content: str) -> str:
75
+ """Calculate MD5 checksum of migration content.
76
+
77
+ Args:
78
+ content: The migration file content.
79
+
80
+ Returns:
81
+ MD5 checksum hex string.
82
+ """
83
+ import hashlib
84
+
85
+ return hashlib.md5(content.encode()).hexdigest() # noqa: S324
86
+
87
+ @abstractmethod
88
+ def load_migration(self, file_path: Path) -> Union["dict[str, Any]", "Coroutine[Any, Any, dict[str, Any]]"]:
89
+ """Load a migration file and extract its components.
90
+
91
+ Args:
92
+ file_path: Path to the migration file.
93
+
94
+ Returns:
95
+ Dictionary containing migration metadata and queries.
96
+ For async implementations, returns a coroutine.
97
+ """
98
+
99
+ def _get_migration_files_sync(self) -> "list[tuple[str, Path]]":
100
+ """Get all migration files sorted by version.
101
+
102
+ Returns:
103
+ List of tuples containing (version, file_path).
104
+ """
105
+
106
+ migrations = []
107
+
108
+ # Scan primary migration path
109
+ if self.migrations_path.exists():
110
+ for pattern in ("*.sql", "*.py"):
111
+ for file_path in self.migrations_path.glob(pattern):
112
+ if file_path.name.startswith("."):
113
+ continue
114
+ version = self._extract_version(file_path.name)
115
+ if version:
116
+ migrations.append((version, file_path))
117
+
118
+ # Scan extension migration paths
119
+ for ext_name, ext_path in self.extension_migrations.items():
120
+ if ext_path.exists():
121
+ for pattern in ("*.sql", "*.py"):
122
+ for file_path in ext_path.glob(pattern):
123
+ if file_path.name.startswith("."):
124
+ continue
125
+ # Prefix extension migrations to avoid version conflicts
126
+ version = self._extract_version(file_path.name)
127
+ if version:
128
+ # Use ext_ prefix to distinguish extension migrations
129
+ prefixed_version = f"ext_{ext_name}_{version}"
130
+ migrations.append((prefixed_version, file_path))
131
+
132
+ return sorted(migrations, key=operator.itemgetter(0))
26
133
 
27
134
  def get_migration_files(self) -> "list[tuple[str, Path]]":
28
135
  """Get all migration files sorted by version.
@@ -32,6 +139,72 @@ class SyncMigrationRunner(BaseMigrationRunner["SyncDriverAdapterBase"]):
32
139
  """
33
140
  return self._get_migration_files_sync()
34
141
 
142
+ def _load_migration_metadata_common(self, file_path: Path) -> "dict[str, Any]":
143
+ """Load common migration metadata that doesn't require async operations.
144
+
145
+ Args:
146
+ file_path: Path to the migration file.
147
+
148
+ Returns:
149
+ Partial migration metadata dictionary.
150
+ """
151
+ content = file_path.read_text(encoding="utf-8")
152
+ checksum = self._calculate_checksum(content)
153
+ version = self._extract_version(file_path.name)
154
+ description = file_path.stem.split("_", 1)[1] if "_" in file_path.stem else ""
155
+
156
+ return {
157
+ "version": version,
158
+ "description": description,
159
+ "file_path": file_path,
160
+ "checksum": checksum,
161
+ "content": content,
162
+ }
163
+
164
+ def _get_context_for_migration(self, file_path: Path) -> "Optional[MigrationContext]":
165
+ """Get the appropriate context for a migration file.
166
+
167
+ Args:
168
+ file_path: Path to the migration file.
169
+
170
+ Returns:
171
+ Migration context to use, or None to use default.
172
+ """
173
+ context_to_use = self.context
174
+ if context_to_use and file_path.name.startswith("ext_"):
175
+ version = self._extract_version(file_path.name)
176
+ if version and version.startswith("ext_"):
177
+ min_extension_version_parts = 3
178
+ parts = version.split("_", 2)
179
+ if len(parts) >= min_extension_version_parts:
180
+ ext_name = parts[1]
181
+ if ext_name in self.extension_configs:
182
+ context_to_use = MigrationContext(
183
+ dialect=self.context.dialect if self.context else None,
184
+ config=self.context.config if self.context else None,
185
+ driver=self.context.driver if self.context else None,
186
+ metadata=self.context.metadata.copy() if self.context and self.context.metadata else {},
187
+ extension_config=self.extension_configs[ext_name],
188
+ )
189
+
190
+ for ext_name, ext_path in self.extension_migrations.items():
191
+ if file_path.parent == ext_path:
192
+ if ext_name in self.extension_configs and self.context:
193
+ context_to_use = MigrationContext(
194
+ config=self.context.config,
195
+ dialect=self.context.dialect,
196
+ driver=self.context.driver,
197
+ metadata=self.context.metadata.copy() if self.context.metadata else {},
198
+ extension_config=self.extension_configs[ext_name],
199
+ )
200
+ break
201
+
202
+ return context_to_use
203
+
204
+
205
+ class SyncMigrationRunner(BaseMigrationRunner):
206
+ """Synchronous migration runner with pure sync methods."""
207
+
35
208
  def load_migration(self, file_path: Path) -> "dict[str, Any]":
36
209
  """Load a migration file and extract its components.
37
210
 
@@ -41,7 +214,29 @@ class SyncMigrationRunner(BaseMigrationRunner["SyncDriverAdapterBase"]):
41
214
  Returns:
42
215
  Dictionary containing migration metadata and queries.
43
216
  """
44
- return self._load_migration_metadata(file_path)
217
+ # Get common metadata
218
+ metadata = self._load_migration_metadata_common(file_path)
219
+ context_to_use = self._get_context_for_migration(file_path)
220
+
221
+ loader = get_migration_loader(file_path, self.migrations_path, self.project_root, context_to_use)
222
+ loader.validate_migration_file(file_path)
223
+
224
+ has_upgrade, has_downgrade = True, False
225
+
226
+ if file_path.suffix == ".sql":
227
+ version = metadata["version"]
228
+ up_query, down_query = f"migrate-{version}-up", f"migrate-{version}-down"
229
+ self.loader.clear_cache()
230
+ self.loader.load_sql(file_path)
231
+ has_upgrade, has_downgrade = self.loader.has_query(up_query), self.loader.has_query(down_query)
232
+ else:
233
+ try:
234
+ has_downgrade = bool(self._get_migration_sql_sync({"loader": loader, "file_path": file_path}, "down"))
235
+ except Exception:
236
+ has_downgrade = False
237
+
238
+ metadata.update({"has_upgrade": has_upgrade, "has_downgrade": has_downgrade, "loader": loader})
239
+ return metadata
45
240
 
46
241
  def execute_upgrade(
47
242
  self, driver: "SyncDriverAdapterBase", migration: "dict[str, Any]"
@@ -49,13 +244,13 @@ class SyncMigrationRunner(BaseMigrationRunner["SyncDriverAdapterBase"]):
49
244
  """Execute an upgrade migration.
50
245
 
51
246
  Args:
52
- driver: The database driver to use.
247
+ driver: The sync database driver to use.
53
248
  migration: Migration metadata dictionary.
54
249
 
55
250
  Returns:
56
251
  Tuple of (sql_content, execution_time_ms).
57
252
  """
58
- upgrade_sql_list = self._get_migration_sql(migration, "up")
253
+ upgrade_sql_list = self._get_migration_sql_sync(migration, "up")
59
254
  if upgrade_sql_list is None:
60
255
  return None, 0
61
256
 
@@ -73,13 +268,13 @@ class SyncMigrationRunner(BaseMigrationRunner["SyncDriverAdapterBase"]):
73
268
  """Execute a downgrade migration.
74
269
 
75
270
  Args:
76
- driver: The database driver to use.
271
+ driver: The sync database driver to use.
77
272
  migration: Migration metadata dictionary.
78
273
 
79
274
  Returns:
80
275
  Tuple of (sql_content, execution_time_ms).
81
276
  """
82
- downgrade_sql_list = self._get_migration_sql(migration, "down")
277
+ downgrade_sql_list = self._get_migration_sql_sync(migration, "down")
83
278
  if downgrade_sql_list is None:
84
279
  return None, 0
85
280
 
@@ -91,6 +286,50 @@ class SyncMigrationRunner(BaseMigrationRunner["SyncDriverAdapterBase"]):
91
286
  execution_time = int((time.time() - start_time) * 1000)
92
287
  return None, execution_time
93
288
 
289
+ def _get_migration_sql_sync(self, migration: "dict[str, Any]", direction: str) -> "Optional[list[str]]":
290
+ """Get migration SQL for given direction (sync version).
291
+
292
+ Args:
293
+ migration: Migration metadata.
294
+ direction: Either 'up' or 'down'.
295
+
296
+ Returns:
297
+ SQL statements for the migration.
298
+ """
299
+ # If this is being called during migration loading (no has_*grade field yet),
300
+ # don't raise/warn - just proceed to check if the method exists
301
+ if f"has_{direction}grade" in migration and not migration.get(f"has_{direction}grade"):
302
+ if direction == "down":
303
+ logger.warning("Migration %s has no downgrade query", migration.get("version"))
304
+ return None
305
+ msg = f"Migration {migration.get('version')} has no upgrade query"
306
+ raise ValueError(msg)
307
+
308
+ file_path, loader = migration["file_path"], migration["loader"]
309
+
310
+ try:
311
+ method = loader.get_up_sql if direction == "up" else loader.get_down_sql
312
+ # Check if the method is async and handle appropriately
313
+ import inspect
314
+
315
+ if inspect.iscoroutinefunction(method):
316
+ # For async methods, use await_ to run in sync context
317
+ sql_statements = await_(method, raise_sync_error=False)(file_path)
318
+ else:
319
+ # For sync methods, call directly
320
+ sql_statements = method(file_path)
321
+
322
+ except Exception as e:
323
+ if direction == "down":
324
+ logger.warning("Failed to load downgrade for migration %s: %s", migration.get("version"), e)
325
+ return None
326
+ msg = f"Failed to load upgrade for migration {migration.get('version')}: {e}"
327
+ raise ValueError(msg) from e
328
+ else:
329
+ if sql_statements:
330
+ return cast("list[str]", sql_statements)
331
+ return None
332
+
94
333
  def load_all_migrations(self) -> "dict[str, SQL]":
95
334
  """Load all migrations into a single namespace for bulk operations.
96
335
 
@@ -106,11 +345,11 @@ class SyncMigrationRunner(BaseMigrationRunner["SyncDriverAdapterBase"]):
106
345
  for query_name in self.loader.list_queries():
107
346
  all_queries[query_name] = self.loader.get_sql(query_name)
108
347
  else:
109
- loader = get_migration_loader(file_path, self.migrations_path, self.project_root)
348
+ loader = get_migration_loader(file_path, self.migrations_path, self.project_root, self.context)
110
349
 
111
350
  try:
112
- up_sql = await_(loader.get_up_sql, raise_sync_error=False)(file_path)
113
- down_sql = await_(loader.get_down_sql, raise_sync_error=False)(file_path)
351
+ up_sql = await_(loader.get_up_sql)(file_path)
352
+ down_sql = await_(loader.get_down_sql)(file_path)
114
353
 
115
354
  if up_sql:
116
355
  all_queries[f"migrate-{version}-up"] = SQL(up_sql[0])
@@ -123,14 +362,14 @@ class SyncMigrationRunner(BaseMigrationRunner["SyncDriverAdapterBase"]):
123
362
  return all_queries
124
363
 
125
364
 
126
- class AsyncMigrationRunner(BaseMigrationRunner["AsyncDriverAdapterBase"]):
127
- """Asynchronous migration executor."""
365
+ class AsyncMigrationRunner(BaseMigrationRunner):
366
+ """Asynchronous migration runner with pure async methods."""
128
367
 
129
- async def get_migration_files(self) -> "list[tuple[str, Path]]":
368
+ async def get_migration_files(self) -> "list[tuple[str, Path]]": # type: ignore[override]
130
369
  """Get all migration files sorted by version.
131
370
 
132
371
  Returns:
133
- List of tuples containing (version, file_path).
372
+ List of (version, path) tuples sorted by version.
134
373
  """
135
374
  return self._get_migration_files_sync()
136
375
 
@@ -141,82 +380,33 @@ class AsyncMigrationRunner(BaseMigrationRunner["AsyncDriverAdapterBase"]):
141
380
  file_path: Path to the migration file.
142
381
 
143
382
  Returns:
144
- Dictionary containing migration metadata.
383
+ Dictionary containing migration metadata and queries.
145
384
  """
146
- return await self._load_migration_metadata_async(file_path)
147
-
148
- async def _load_migration_metadata_async(self, file_path: Path) -> "dict[str, Any]":
149
- """Load migration metadata from file (async version).
150
-
151
- Args:
152
- file_path: Path to the migration file.
385
+ # Get common metadata
386
+ metadata = self._load_migration_metadata_common(file_path)
387
+ context_to_use = self._get_context_for_migration(file_path)
153
388
 
154
- Returns:
155
- Migration metadata dictionary.
156
- """
157
- loader = get_migration_loader(file_path, self.migrations_path, self.project_root)
389
+ loader = get_migration_loader(file_path, self.migrations_path, self.project_root, context_to_use)
158
390
  loader.validate_migration_file(file_path)
159
- content = file_path.read_text(encoding="utf-8")
160
- checksum = self._calculate_checksum(content)
161
- version = self._extract_version(file_path.name)
162
- description = file_path.stem.split("_", 1)[1] if "_" in file_path.stem else ""
163
391
 
164
392
  has_upgrade, has_downgrade = True, False
165
393
 
166
394
  if file_path.suffix == ".sql":
395
+ version = metadata["version"]
167
396
  up_query, down_query = f"migrate-{version}-up", f"migrate-{version}-down"
168
397
  self.loader.clear_cache()
169
- self.loader.load_sql(file_path)
398
+ await async_(self.loader.load_sql)(file_path)
170
399
  has_upgrade, has_downgrade = self.loader.has_query(up_query), self.loader.has_query(down_query)
171
400
  else:
172
401
  try:
173
- has_downgrade = bool(await loader.get_down_sql(file_path))
402
+ has_downgrade = bool(
403
+ await self._get_migration_sql_async({"loader": loader, "file_path": file_path}, "down")
404
+ )
174
405
  except Exception:
175
406
  has_downgrade = False
176
407
 
177
- return {
178
- "version": version,
179
- "description": description,
180
- "file_path": file_path,
181
- "checksum": checksum,
182
- "has_upgrade": has_upgrade,
183
- "has_downgrade": has_downgrade,
184
- "loader": loader,
185
- }
186
-
187
- async def _get_migration_sql_async(self, migration: "dict[str, Any]", direction: str) -> "Optional[list[str]]":
188
- """Get migration SQL for given direction (async version).
189
-
190
- Args:
191
- migration: Migration metadata.
192
- direction: Either 'up' or 'down'.
193
-
194
- Returns:
195
- SQL statements for the migration.
196
- """
197
- if not migration.get(f"has_{direction}grade"):
198
- if direction == "down":
199
- logger.warning("Migration %s has no downgrade query", migration["version"])
200
- return None
201
- msg = f"Migration {migration['version']} has no upgrade query"
202
- raise ValueError(msg)
203
-
204
- file_path, loader = migration["file_path"], migration["loader"]
205
-
206
- try:
207
- method = loader.get_up_sql if direction == "up" else loader.get_down_sql
208
- sql_statements = await method(file_path)
209
-
210
- except Exception as e:
211
- if direction == "down":
212
- logger.warning("Failed to load downgrade for migration %s: %s", migration["version"], e)
213
- return None
214
- msg = f"Failed to load upgrade for migration {migration['version']}: {e}"
215
- raise ValueError(msg) from e
216
- else:
217
- if sql_statements:
218
- return cast("list[str]", sql_statements)
219
- return None
408
+ metadata.update({"has_upgrade": has_upgrade, "has_downgrade": has_downgrade, "loader": loader})
409
+ return metadata
220
410
 
221
411
  async def execute_upgrade(
222
412
  self, driver: "AsyncDriverAdapterBase", migration: "dict[str, Any]"
@@ -266,6 +456,42 @@ class AsyncMigrationRunner(BaseMigrationRunner["AsyncDriverAdapterBase"]):
266
456
  execution_time = int((time.time() - start_time) * 1000)
267
457
  return None, execution_time
268
458
 
459
+ async def _get_migration_sql_async(self, migration: "dict[str, Any]", direction: str) -> "Optional[list[str]]":
460
+ """Get migration SQL for given direction (async version).
461
+
462
+ Args:
463
+ migration: Migration metadata.
464
+ direction: Either 'up' or 'down'.
465
+
466
+ Returns:
467
+ SQL statements for the migration.
468
+ """
469
+ # If this is being called during migration loading (no has_*grade field yet),
470
+ # don't raise/warn - just proceed to check if the method exists
471
+ if f"has_{direction}grade" in migration and not migration.get(f"has_{direction}grade"):
472
+ if direction == "down":
473
+ logger.warning("Migration %s has no downgrade query", migration.get("version"))
474
+ return None
475
+ msg = f"Migration {migration.get('version')} has no upgrade query"
476
+ raise ValueError(msg)
477
+
478
+ file_path, loader = migration["file_path"], migration["loader"]
479
+
480
+ try:
481
+ method = loader.get_up_sql if direction == "up" else loader.get_down_sql
482
+ sql_statements = await method(file_path)
483
+
484
+ except Exception as e:
485
+ if direction == "down":
486
+ logger.warning("Failed to load downgrade for migration %s: %s", migration.get("version"), e)
487
+ return None
488
+ msg = f"Failed to load upgrade for migration {migration.get('version')}: {e}"
489
+ raise ValueError(msg) from e
490
+ else:
491
+ if sql_statements:
492
+ return cast("list[str]", sql_statements)
493
+ return None
494
+
269
495
  async def load_all_migrations(self) -> "dict[str, SQL]":
270
496
  """Load all migrations into a single namespace for bulk operations.
271
497
 
@@ -277,11 +503,11 @@ class AsyncMigrationRunner(BaseMigrationRunner["AsyncDriverAdapterBase"]):
277
503
 
278
504
  for version, file_path in migrations:
279
505
  if file_path.suffix == ".sql":
280
- self.loader.load_sql(file_path)
506
+ await async_(self.loader.load_sql)(file_path)
281
507
  for query_name in self.loader.list_queries():
282
508
  all_queries[query_name] = self.loader.get_sql(query_name)
283
509
  else:
284
- loader = get_migration_loader(file_path, self.migrations_path, self.project_root)
510
+ loader = get_migration_loader(file_path, self.migrations_path, self.project_root, self.context)
285
511
 
286
512
  try:
287
513
  up_sql = await loader.get_up_sql(file_path)
@@ -296,3 +522,47 @@ class AsyncMigrationRunner(BaseMigrationRunner["AsyncDriverAdapterBase"]):
296
522
  logger.debug("Failed to load Python migration %s: %s", file_path, e)
297
523
 
298
524
  return all_queries
525
+
526
+
527
+ @overload
528
+ def create_migration_runner(
529
+ migrations_path: Path,
530
+ extension_migrations: "dict[str, Path]",
531
+ context: "Optional[MigrationContext]",
532
+ extension_configs: "dict[str, Any]",
533
+ is_async: "Literal[False]" = False,
534
+ ) -> SyncMigrationRunner: ...
535
+
536
+
537
+ @overload
538
+ def create_migration_runner(
539
+ migrations_path: Path,
540
+ extension_migrations: "dict[str, Path]",
541
+ context: "Optional[MigrationContext]",
542
+ extension_configs: "dict[str, Any]",
543
+ is_async: "Literal[True]",
544
+ ) -> AsyncMigrationRunner: ...
545
+
546
+
547
+ def create_migration_runner(
548
+ migrations_path: Path,
549
+ extension_migrations: "dict[str, Path]",
550
+ context: "Optional[MigrationContext]",
551
+ extension_configs: "dict[str, Any]",
552
+ is_async: bool = False,
553
+ ) -> "Union[SyncMigrationRunner, AsyncMigrationRunner]":
554
+ """Factory function to create the appropriate migration runner.
555
+
556
+ Args:
557
+ migrations_path: Path to migrations directory.
558
+ extension_migrations: Extension migration paths.
559
+ context: Migration context.
560
+ extension_configs: Extension configurations.
561
+ is_async: Whether to create async or sync runner.
562
+
563
+ Returns:
564
+ Appropriate migration runner instance.
565
+ """
566
+ if is_async:
567
+ return AsyncMigrationRunner(migrations_path, extension_migrations, context, extension_configs)
568
+ return SyncMigrationRunner(migrations_path, extension_migrations, context, extension_configs)
@@ -1,3 +1,4 @@
1
+ # pyright: reportPrivateUsage=false
1
2
  import logging
2
3
  from pathlib import Path
3
4
  from typing import TYPE_CHECKING, Any, Optional, Union
sqlspec/typing.py CHANGED
@@ -12,8 +12,10 @@ from sqlspec._typing import (
12
12
  FSSPEC_INSTALLED,
13
13
  LITESTAR_INSTALLED,
14
14
  MSGSPEC_INSTALLED,
15
+ NUMPY_INSTALLED,
15
16
  OBSTORE_INSTALLED,
16
17
  OPENTELEMETRY_INSTALLED,
18
+ ORJSON_INSTALLED,
17
19
  PGVECTOR_INSTALLED,
18
20
  PROMETHEUS_INSTALLED,
19
21
  PYARROW_INSTALLED,
@@ -187,8 +189,10 @@ __all__ = (
187
189
  "FSSPEC_INSTALLED",
188
190
  "LITESTAR_INSTALLED",
189
191
  "MSGSPEC_INSTALLED",
192
+ "NUMPY_INSTALLED",
190
193
  "OBSTORE_INSTALLED",
191
194
  "OPENTELEMETRY_INSTALLED",
195
+ "ORJSON_INSTALLED",
192
196
  "PGVECTOR_INSTALLED",
193
197
  "PROMETHEUS_INSTALLED",
194
198
  "PYARROW_INSTALLED",