sqlspec 0.24.1__py3-none-any.whl → 0.26.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of sqlspec might be problematic. Click here for more details.

Files changed (95) hide show
  1. sqlspec/_serialization.py +223 -21
  2. sqlspec/_sql.py +20 -62
  3. sqlspec/_typing.py +11 -0
  4. sqlspec/adapters/adbc/config.py +8 -1
  5. sqlspec/adapters/adbc/data_dictionary.py +290 -0
  6. sqlspec/adapters/adbc/driver.py +129 -20
  7. sqlspec/adapters/adbc/type_converter.py +159 -0
  8. sqlspec/adapters/aiosqlite/config.py +3 -0
  9. sqlspec/adapters/aiosqlite/data_dictionary.py +117 -0
  10. sqlspec/adapters/aiosqlite/driver.py +17 -3
  11. sqlspec/adapters/asyncmy/_types.py +1 -1
  12. sqlspec/adapters/asyncmy/config.py +11 -8
  13. sqlspec/adapters/asyncmy/data_dictionary.py +122 -0
  14. sqlspec/adapters/asyncmy/driver.py +31 -7
  15. sqlspec/adapters/asyncpg/config.py +3 -0
  16. sqlspec/adapters/asyncpg/data_dictionary.py +134 -0
  17. sqlspec/adapters/asyncpg/driver.py +19 -4
  18. sqlspec/adapters/bigquery/config.py +3 -0
  19. sqlspec/adapters/bigquery/data_dictionary.py +109 -0
  20. sqlspec/adapters/bigquery/driver.py +21 -3
  21. sqlspec/adapters/bigquery/type_converter.py +93 -0
  22. sqlspec/adapters/duckdb/_types.py +1 -1
  23. sqlspec/adapters/duckdb/config.py +2 -0
  24. sqlspec/adapters/duckdb/data_dictionary.py +124 -0
  25. sqlspec/adapters/duckdb/driver.py +32 -5
  26. sqlspec/adapters/duckdb/pool.py +1 -1
  27. sqlspec/adapters/duckdb/type_converter.py +103 -0
  28. sqlspec/adapters/oracledb/config.py +6 -0
  29. sqlspec/adapters/oracledb/data_dictionary.py +442 -0
  30. sqlspec/adapters/oracledb/driver.py +68 -9
  31. sqlspec/adapters/oracledb/migrations.py +51 -67
  32. sqlspec/adapters/oracledb/type_converter.py +132 -0
  33. sqlspec/adapters/psqlpy/config.py +3 -0
  34. sqlspec/adapters/psqlpy/data_dictionary.py +133 -0
  35. sqlspec/adapters/psqlpy/driver.py +23 -179
  36. sqlspec/adapters/psqlpy/type_converter.py +73 -0
  37. sqlspec/adapters/psycopg/config.py +8 -4
  38. sqlspec/adapters/psycopg/data_dictionary.py +257 -0
  39. sqlspec/adapters/psycopg/driver.py +40 -5
  40. sqlspec/adapters/sqlite/config.py +3 -0
  41. sqlspec/adapters/sqlite/data_dictionary.py +117 -0
  42. sqlspec/adapters/sqlite/driver.py +18 -3
  43. sqlspec/adapters/sqlite/pool.py +13 -4
  44. sqlspec/base.py +3 -4
  45. sqlspec/builder/_base.py +130 -48
  46. sqlspec/builder/_column.py +66 -24
  47. sqlspec/builder/_ddl.py +91 -41
  48. sqlspec/builder/_insert.py +40 -58
  49. sqlspec/builder/_parsing_utils.py +127 -12
  50. sqlspec/builder/_select.py +147 -2
  51. sqlspec/builder/_update.py +1 -1
  52. sqlspec/builder/mixins/_cte_and_set_ops.py +31 -23
  53. sqlspec/builder/mixins/_delete_operations.py +12 -7
  54. sqlspec/builder/mixins/_insert_operations.py +50 -36
  55. sqlspec/builder/mixins/_join_operations.py +15 -30
  56. sqlspec/builder/mixins/_merge_operations.py +210 -78
  57. sqlspec/builder/mixins/_order_limit_operations.py +4 -10
  58. sqlspec/builder/mixins/_pivot_operations.py +1 -0
  59. sqlspec/builder/mixins/_select_operations.py +44 -22
  60. sqlspec/builder/mixins/_update_operations.py +30 -37
  61. sqlspec/builder/mixins/_where_clause.py +52 -70
  62. sqlspec/cli.py +246 -140
  63. sqlspec/config.py +33 -19
  64. sqlspec/core/__init__.py +3 -2
  65. sqlspec/core/cache.py +298 -352
  66. sqlspec/core/compiler.py +61 -4
  67. sqlspec/core/filters.py +246 -213
  68. sqlspec/core/hashing.py +9 -11
  69. sqlspec/core/parameters.py +27 -10
  70. sqlspec/core/statement.py +72 -12
  71. sqlspec/core/type_conversion.py +234 -0
  72. sqlspec/driver/__init__.py +6 -3
  73. sqlspec/driver/_async.py +108 -5
  74. sqlspec/driver/_common.py +186 -17
  75. sqlspec/driver/_sync.py +108 -5
  76. sqlspec/driver/mixins/_result_tools.py +60 -7
  77. sqlspec/exceptions.py +5 -0
  78. sqlspec/loader.py +8 -9
  79. sqlspec/migrations/__init__.py +4 -3
  80. sqlspec/migrations/base.py +153 -14
  81. sqlspec/migrations/commands.py +34 -96
  82. sqlspec/migrations/context.py +145 -0
  83. sqlspec/migrations/loaders.py +25 -8
  84. sqlspec/migrations/runner.py +352 -82
  85. sqlspec/storage/backends/fsspec.py +1 -0
  86. sqlspec/typing.py +4 -0
  87. sqlspec/utils/config_resolver.py +153 -0
  88. sqlspec/utils/serializers.py +50 -2
  89. {sqlspec-0.24.1.dist-info → sqlspec-0.26.0.dist-info}/METADATA +1 -1
  90. sqlspec-0.26.0.dist-info/RECORD +157 -0
  91. sqlspec-0.24.1.dist-info/RECORD +0 -139
  92. {sqlspec-0.24.1.dist-info → sqlspec-0.26.0.dist-info}/WHEEL +0 -0
  93. {sqlspec-0.24.1.dist-info → sqlspec-0.26.0.dist-info}/entry_points.txt +0 -0
  94. {sqlspec-0.24.1.dist-info → sqlspec-0.26.0.dist-info}/licenses/LICENSE +0 -0
  95. {sqlspec-0.24.1.dist-info → sqlspec-0.26.0.dist-info}/licenses/NOTICE +0 -0
@@ -15,6 +15,7 @@ from sqlspec.builder._ddl import CreateTable
15
15
  from sqlspec.loader import SQLFileLoader
16
16
  from sqlspec.migrations.loaders import get_migration_loader
17
17
  from sqlspec.utils.logging import get_logger
18
+ from sqlspec.utils.module_loader import module_to_os_path
18
19
  from sqlspec.utils.sync_tools import await_
19
20
 
20
21
  __all__ = ("BaseMigrationCommands", "BaseMigrationRunner", "BaseMigrationTracker")
@@ -135,15 +136,29 @@ class BaseMigrationTracker(ABC, Generic[DriverT]):
135
136
  class BaseMigrationRunner(ABC, Generic[DriverT]):
136
137
  """Base class for migration execution."""
137
138
 
138
- def __init__(self, migrations_path: Path) -> None:
139
+ extension_configs: "dict[str, dict[str, Any]]"
140
+
141
+ def __init__(
142
+ self,
143
+ migrations_path: Path,
144
+ extension_migrations: "Optional[dict[str, Path]]" = None,
145
+ context: "Optional[Any]" = None,
146
+ extension_configs: "Optional[dict[str, dict[str, Any]]]" = None,
147
+ ) -> None:
139
148
  """Initialize the migration runner.
140
149
 
141
150
  Args:
142
151
  migrations_path: Path to the directory containing migration files.
152
+ extension_migrations: Optional mapping of extension names to their migration paths.
153
+ context: Optional migration context for Python migrations.
154
+ extension_configs: Optional mapping of extension names to their configurations.
143
155
  """
144
156
  self.migrations_path = migrations_path
157
+ self.extension_migrations = extension_migrations or {}
145
158
  self.loader = SQLFileLoader()
146
159
  self.project_root: Optional[Path] = None
160
+ self.context = context
161
+ self.extension_configs = extension_configs or {}
147
162
 
148
163
  def _extract_version(self, filename: str) -> Optional[str]:
149
164
  """Extract version from filename.
@@ -154,6 +169,12 @@ class BaseMigrationRunner(ABC, Generic[DriverT]):
154
169
  Returns:
155
170
  The extracted version string or None.
156
171
  """
172
+ # Handle extension-prefixed versions (e.g., "ext_litestar_0001")
173
+ if filename.startswith("ext_"):
174
+ # This is already a prefixed version, return as-is
175
+ return filename
176
+
177
+ # Regular version extraction
157
178
  parts = filename.split("_", 1)
158
179
  return parts[0].zfill(4) if parts and parts[0].isdigit() else None
159
180
 
@@ -175,17 +196,31 @@ class BaseMigrationRunner(ABC, Generic[DriverT]):
175
196
  Returns:
176
197
  List of tuples containing (version, file_path).
177
198
  """
178
- if not self.migrations_path.exists():
179
- return []
180
-
181
199
  migrations = []
182
- for pattern in ["*.sql", "*.py"]:
183
- for file_path in self.migrations_path.glob(pattern):
184
- if file_path.name.startswith("."):
185
- continue
186
- version = self._extract_version(file_path.name)
187
- if version:
188
- migrations.append((version, file_path))
200
+
201
+ # Scan primary migration path
202
+ if self.migrations_path.exists():
203
+ for pattern in ("*.sql", "*.py"):
204
+ for file_path in self.migrations_path.glob(pattern):
205
+ if file_path.name.startswith("."):
206
+ continue
207
+ version = self._extract_version(file_path.name)
208
+ if version:
209
+ migrations.append((version, file_path))
210
+
211
+ # Scan extension migration paths
212
+ for ext_name, ext_path in self.extension_migrations.items():
213
+ if ext_path.exists():
214
+ for pattern in ("*.sql", "*.py"):
215
+ for file_path in ext_path.glob(pattern):
216
+ if file_path.name.startswith("."):
217
+ continue
218
+ # Prefix extension migrations to avoid version conflicts
219
+ version = self._extract_version(file_path.name)
220
+ if version:
221
+ # Use ext_ prefix to distinguish extension migrations
222
+ prefixed_version = f"ext_{ext_name}_{version}"
223
+ migrations.append((prefixed_version, file_path))
189
224
 
190
225
  return sorted(migrations, key=operator.itemgetter(0))
191
226
 
@@ -199,7 +234,45 @@ class BaseMigrationRunner(ABC, Generic[DriverT]):
199
234
  Migration metadata dictionary.
200
235
  """
201
236
 
202
- loader = get_migration_loader(file_path, self.migrations_path, self.project_root)
237
+ # Check if this is an extension migration and update context accordingly
238
+ context_to_use = self.context
239
+ if context_to_use and file_path.name.startswith("ext_"):
240
+ # Try to extract extension name from the version
241
+ version = self._extract_version(file_path.name)
242
+ if version and version.startswith("ext_"):
243
+ # Parse extension name from version like "ext_litestar_0001"
244
+ min_extension_version_parts = 3
245
+ parts = version.split("_", 2)
246
+ if len(parts) >= min_extension_version_parts:
247
+ ext_name = parts[1]
248
+ if ext_name in self.extension_configs:
249
+ # Create a new context with the extension config
250
+ from sqlspec.migrations.context import MigrationContext
251
+
252
+ context_to_use = MigrationContext(
253
+ dialect=self.context.dialect if self.context else None,
254
+ config=self.context.config if self.context else None,
255
+ driver=self.context.driver if self.context else None,
256
+ metadata=self.context.metadata.copy() if self.context and self.context.metadata else {},
257
+ extension_config=self.extension_configs[ext_name],
258
+ )
259
+
260
+ # For extension migrations, check by path
261
+ for ext_name, ext_path in self.extension_migrations.items():
262
+ if file_path.parent == ext_path:
263
+ if ext_name in self.extension_configs and self.context:
264
+ from sqlspec.migrations.context import MigrationContext
265
+
266
+ context_to_use = MigrationContext(
267
+ dialect=self.context.dialect,
268
+ config=self.context.config,
269
+ driver=self.context.driver,
270
+ metadata=self.context.metadata.copy() if self.context.metadata else {},
271
+ extension_config=self.extension_configs[ext_name],
272
+ )
273
+ break
274
+
275
+ loader = get_migration_loader(file_path, self.migrations_path, self.project_root, context_to_use)
203
276
  loader.validate_migration_file(file_path)
204
277
  content = file_path.read_text(encoding="utf-8")
205
278
  checksum = self._calculate_checksum(content)
@@ -292,6 +365,8 @@ class BaseMigrationRunner(ABC, Generic[DriverT]):
292
365
  class BaseMigrationCommands(ABC, Generic[ConfigT, DriverT]):
293
366
  """Base class for migration commands."""
294
367
 
368
+ extension_configs: "dict[str, dict[str, Any]]"
369
+
295
370
  def __init__(self, config: ConfigT) -> None:
296
371
  """Initialize migration commands.
297
372
 
@@ -304,6 +379,72 @@ class BaseMigrationCommands(ABC, Generic[ConfigT, DriverT]):
304
379
  self.version_table = migration_config.get("version_table_name", "ddl_migrations")
305
380
  self.migrations_path = Path(migration_config.get("script_location", "migrations"))
306
381
  self.project_root = Path(migration_config["project_root"]) if "project_root" in migration_config else None
382
+ self.include_extensions = migration_config.get("include_extensions", [])
383
+ self.extension_configs = self._parse_extension_configs()
384
+
385
+ def _parse_extension_configs(self) -> "dict[str, dict[str, Any]]":
386
+ """Parse extension configurations from include_extensions.
387
+
388
+ Supports both string format (extension name) and dict format
389
+ (extension name with configuration).
390
+
391
+ Returns:
392
+ Dictionary mapping extension names to their configurations.
393
+ """
394
+ configs = {}
395
+
396
+ for ext_config in self.include_extensions:
397
+ if isinstance(ext_config, str):
398
+ # Simple string format: just the extension name
399
+ ext_name = ext_config
400
+ ext_options = {}
401
+ elif isinstance(ext_config, dict):
402
+ # Dict format: {"name": "litestar", "session_table": "custom_sessions"}
403
+ ext_name_raw = ext_config.get("name")
404
+ if not ext_name_raw:
405
+ logger.warning("Extension configuration missing 'name' field: %s", ext_config)
406
+ continue
407
+ # Assert for type narrowing: ext_name_raw is guaranteed to be str here
408
+ assert isinstance(ext_name_raw, str)
409
+ ext_name = ext_name_raw
410
+ ext_options = {k: v for k, v in ext_config.items() if k != "name"}
411
+ else:
412
+ logger.warning("Invalid extension configuration format: %s", ext_config)
413
+ continue
414
+
415
+ # Apply default configurations for known extensions
416
+ if ext_name == "litestar" and "session_table" not in ext_options:
417
+ ext_options["session_table"] = "litestar_sessions"
418
+
419
+ configs[ext_name] = ext_options
420
+
421
+ return configs
422
+
423
+ def _discover_extension_migrations(self) -> "dict[str, Path]":
424
+ """Discover migration paths for configured extensions.
425
+
426
+ Returns:
427
+ Dictionary mapping extension names to their migration paths.
428
+ """
429
+
430
+ extension_migrations = {}
431
+
432
+ for ext_name in self.extension_configs:
433
+ module_name = "sqlspec.extensions.litestar" if ext_name == "litestar" else f"sqlspec.extensions.{ext_name}"
434
+
435
+ try:
436
+ module_path = module_to_os_path(module_name)
437
+ migrations_dir = module_path / "migrations"
438
+
439
+ if migrations_dir.exists():
440
+ extension_migrations[ext_name] = migrations_dir
441
+ logger.debug("Found migrations for extension %s at %s", ext_name, migrations_dir)
442
+ else:
443
+ logger.warning("No migrations directory found for extension %s", ext_name)
444
+ except TypeError:
445
+ logger.warning("Extension %s not found", ext_name)
446
+
447
+ return extension_migrations
307
448
 
308
449
  def _get_init_readme_content(self) -> str:
309
450
  """Get README content for migration directory initialization.
@@ -368,8 +509,6 @@ This naming ensures proper sorting and avoids conflicts when loading multiple fi
368
509
  readme = migrations_dir / "README.md"
369
510
  readme.write_text(self._get_init_readme_content())
370
511
 
371
- (migrations_dir / ".gitkeep").touch()
372
-
373
512
  console.print(f"[green]Initialized migrations in {directory}[/]")
374
513
 
375
514
  @abstractmethod
@@ -10,15 +10,15 @@ from rich.table import Table
10
10
 
11
11
  from sqlspec._sql import sql
12
12
  from sqlspec.migrations.base import BaseMigrationCommands
13
+ from sqlspec.migrations.context import MigrationContext
13
14
  from sqlspec.migrations.runner import AsyncMigrationRunner, SyncMigrationRunner
14
15
  from sqlspec.migrations.utils import create_migration_file
15
16
  from sqlspec.utils.logging import get_logger
16
- from sqlspec.utils.sync_tools import await_
17
17
 
18
18
  if TYPE_CHECKING:
19
19
  from sqlspec.config import AsyncConfigT, SyncConfigT
20
20
 
21
- __all__ = ("AsyncMigrationCommands", "MigrationCommands", "SyncMigrationCommands")
21
+ __all__ = ("AsyncMigrationCommands", "SyncMigrationCommands", "create_migration_commands")
22
22
 
23
23
  logger = get_logger("migrations.commands")
24
24
  console = Console()
@@ -35,7 +35,14 @@ class SyncMigrationCommands(BaseMigrationCommands["SyncConfigT", Any]):
35
35
  """
36
36
  super().__init__(config)
37
37
  self.tracker = config.migration_tracker_type(self.version_table)
38
- self.runner = SyncMigrationRunner(self.migrations_path)
38
+
39
+ # Create context with extension configurations
40
+ context = MigrationContext.from_config(config)
41
+ context.extension_config = self.extension_configs
42
+
43
+ self.runner = SyncMigrationRunner(
44
+ self.migrations_path, self._discover_extension_migrations(), context, self.extension_configs
45
+ )
39
46
 
40
47
  def init(self, directory: str, package: bool = True) -> None:
41
48
  """Initialize migration directory structure.
@@ -203,15 +210,22 @@ class SyncMigrationCommands(BaseMigrationCommands["SyncConfigT", Any]):
203
210
  class AsyncMigrationCommands(BaseMigrationCommands["AsyncConfigT", Any]):
204
211
  """Asynchronous migration commands."""
205
212
 
206
- def __init__(self, sqlspec_config: "AsyncConfigT") -> None:
213
+ def __init__(self, config: "AsyncConfigT") -> None:
207
214
  """Initialize migration commands.
208
215
 
209
216
  Args:
210
- sqlspec_config: The SQLSpec configuration.
217
+ config: The SQLSpec configuration.
211
218
  """
212
- super().__init__(sqlspec_config)
213
- self.tracker = sqlspec_config.migration_tracker_type(self.version_table)
214
- self.runner = AsyncMigrationRunner(self.migrations_path)
219
+ super().__init__(config)
220
+ self.tracker = config.migration_tracker_type(self.version_table)
221
+
222
+ # Create context with extension configurations
223
+ context = MigrationContext.from_config(config)
224
+ context.extension_config = self.extension_configs
225
+
226
+ self.runner = AsyncMigrationRunner(
227
+ self.migrations_path, self._discover_extension_migrations(), context, self.extension_configs
228
+ )
215
229
 
216
230
  async def init(self, directory: str, package: bool = True) -> None:
217
231
  """Initialize migration directory structure.
@@ -370,93 +384,17 @@ class AsyncMigrationCommands(BaseMigrationCommands["AsyncConfigT", Any]):
370
384
  console.print(f"[green]Created migration:[/] {file_path}")
371
385
 
372
386
 
373
- class MigrationCommands:
374
- """Unified migration commands that adapt to sync/async configs."""
375
-
376
- def __init__(self, config: "Union[SyncConfigT, AsyncConfigT]") -> None:
377
- """Initialize migration commands with sync/async implementation.
378
-
379
- Args:
380
- config: The SQLSpec configuration.
381
- """
382
- if config.is_async:
383
- self._impl: Union[AsyncMigrationCommands[Any], SyncMigrationCommands[Any]] = AsyncMigrationCommands(
384
- cast("AsyncConfigT", config)
385
- )
386
- else:
387
- self._impl = SyncMigrationCommands(cast("SyncConfigT", config))
388
- self._is_async = config.is_async
389
-
390
- def init(self, directory: str, package: bool = True) -> None:
391
- """Initialize migration directory structure.
387
+ def create_migration_commands(
388
+ config: "Union[SyncConfigT, AsyncConfigT]",
389
+ ) -> "Union[SyncMigrationCommands[Any], AsyncMigrationCommands[Any]]":
390
+ """Factory function to create the appropriate migration commands.
392
391
 
393
- Args:
394
- directory: Directory to initialize migrations in.
395
- package: Whether to create __init__.py file.
396
- """
397
- if self._is_async:
398
- await_(cast("AsyncMigrationCommands[Any]", self._impl).init, raise_sync_error=False)(
399
- directory, package=package
400
- )
401
- else:
402
- cast("SyncMigrationCommands[Any]", self._impl).init(directory, package=package)
403
-
404
- def current(self, verbose: bool = False) -> "Optional[str]":
405
- """Show current migration version.
406
-
407
- Args:
408
- verbose: Whether to show detailed migration history.
409
-
410
- Returns:
411
- The current migration version or None if no migrations applied.
412
- """
413
- if self._is_async:
414
- return await_(cast("AsyncMigrationCommands[Any]", self._impl).current, raise_sync_error=False)(
415
- verbose=verbose
416
- )
417
- return cast("SyncMigrationCommands[Any]", self._impl).current(verbose=verbose)
392
+ Args:
393
+ config: The SQLSpec configuration.
418
394
 
419
- def upgrade(self, revision: str = "head") -> None:
420
- """Upgrade to a target revision.
421
-
422
- Args:
423
- revision: Target revision or "head" for latest.
424
- """
425
- if self._is_async:
426
- await_(cast("AsyncMigrationCommands[Any]", self._impl).upgrade, raise_sync_error=False)(revision=revision)
427
- else:
428
- cast("SyncMigrationCommands[Any]", self._impl).upgrade(revision=revision)
429
-
430
- def downgrade(self, revision: str = "-1") -> None:
431
- """Downgrade to a target revision.
432
-
433
- Args:
434
- revision: Target revision or "-1" for one step back.
435
- """
436
- if self._is_async:
437
- await_(cast("AsyncMigrationCommands[Any]", self._impl).downgrade, raise_sync_error=False)(revision=revision)
438
- else:
439
- cast("SyncMigrationCommands[Any]", self._impl).downgrade(revision=revision)
440
-
441
- def stamp(self, revision: str) -> None:
442
- """Mark database as being at a specific revision without running migrations.
443
-
444
- Args:
445
- revision: The revision to stamp.
446
- """
447
- if self._is_async:
448
- await_(cast("AsyncMigrationCommands[Any]", self._impl).stamp, raise_sync_error=False)(revision)
449
- else:
450
- cast("SyncMigrationCommands[Any]", self._impl).stamp(revision)
451
-
452
- def revision(self, message: str, file_type: str = "sql") -> None:
453
- """Create a new migration file.
454
-
455
- Args:
456
- message: Description for the migration.
457
- file_type: Type of migration file to create ('sql' or 'py').
458
- """
459
- if self._is_async:
460
- await_(cast("AsyncMigrationCommands[Any]", self._impl).revision, raise_sync_error=False)(message, file_type)
461
- else:
462
- cast("SyncMigrationCommands[Any]", self._impl).revision(message, file_type)
395
+ Returns:
396
+ Appropriate migration commands instance.
397
+ """
398
+ if config.is_async:
399
+ return AsyncMigrationCommands(cast("AsyncConfigT", config))
400
+ return SyncMigrationCommands(cast("SyncConfigT", config))
@@ -0,0 +1,145 @@
1
+ """Migration context for passing runtime information to migrations."""
2
+
3
+ import asyncio
4
+ import inspect
5
+ from dataclasses import dataclass, field
6
+ from typing import TYPE_CHECKING, Any, Optional, Union
7
+
8
+ from sqlspec.utils.logging import get_logger
9
+
10
+ if TYPE_CHECKING:
11
+ from sqlspec.driver import AsyncDriverAdapterBase, SyncDriverAdapterBase
12
+
13
+ logger = get_logger("migrations.context")
14
+
15
+ __all__ = ("MigrationContext",)
16
+
17
+
18
+ @dataclass
19
+ class MigrationContext:
20
+ """Context object passed to migration functions.
21
+
22
+ Provides runtime information about the database environment
23
+ to migration functions, allowing them to generate dialect-specific SQL.
24
+ """
25
+
26
+ config: "Optional[Any]" = None
27
+ """Database configuration object."""
28
+ dialect: "Optional[str]" = None
29
+ """Database dialect (e.g., 'postgres', 'mysql', 'sqlite')."""
30
+ metadata: "Optional[dict[str, Any]]" = None
31
+ """Additional metadata for the migration."""
32
+ extension_config: "Optional[dict[str, Any]]" = None
33
+ """Extension-specific configuration options."""
34
+
35
+ driver: "Optional[Union[SyncDriverAdapterBase, AsyncDriverAdapterBase]]" = None
36
+ """Database driver instance (available during execution)."""
37
+
38
+ _execution_metadata: "dict[str, Any]" = field(default_factory=dict)
39
+ """Internal execution metadata for tracking async operations."""
40
+
41
+ def __post_init__(self) -> None:
42
+ """Initialize metadata and extension config if not provided."""
43
+ if not self.metadata:
44
+ self.metadata = {}
45
+ if not self.extension_config:
46
+ self.extension_config = {}
47
+
48
+ @classmethod
49
+ def from_config(cls, config: Any) -> "MigrationContext":
50
+ """Create context from database configuration.
51
+
52
+ Args:
53
+ config: Database configuration object.
54
+
55
+ Returns:
56
+ Migration context with dialect information.
57
+ """
58
+ dialect = None
59
+ try:
60
+ if hasattr(config, "statement_config") and config.statement_config:
61
+ dialect = getattr(config.statement_config, "dialect", None)
62
+ elif hasattr(config, "_create_statement_config") and callable(config._create_statement_config):
63
+ stmt_config = config._create_statement_config()
64
+ dialect = getattr(stmt_config, "dialect", None)
65
+ except Exception:
66
+ logger.debug("Unable to extract dialect from config")
67
+
68
+ return cls(dialect=dialect, config=config)
69
+
70
+ @property
71
+ def is_async_execution(self) -> bool:
72
+ """Check if migrations are running in an async execution context.
73
+
74
+ Returns:
75
+ True if executing in an async context.
76
+ """
77
+ try:
78
+ asyncio.current_task()
79
+ except RuntimeError:
80
+ return False
81
+ else:
82
+ return True
83
+
84
+ @property
85
+ def is_async_driver(self) -> bool:
86
+ """Check if the current driver is async.
87
+
88
+ Returns:
89
+ True if driver supports async operations.
90
+ """
91
+ if self.driver is None:
92
+ return False
93
+
94
+ execute_method = getattr(self.driver, "execute_script", None)
95
+ return execute_method is not None and inspect.iscoroutinefunction(execute_method)
96
+
97
+ @property
98
+ def execution_mode(self) -> str:
99
+ """Get the current execution mode.
100
+
101
+ Returns:
102
+ 'async' if in async context, 'sync' otherwise.
103
+ """
104
+ return "async" if self.is_async_execution else "sync"
105
+
106
+ def set_execution_metadata(self, key: str, value: Any) -> None:
107
+ """Set execution metadata for tracking migration state.
108
+
109
+ Args:
110
+ key: Metadata key.
111
+ value: Metadata value.
112
+ """
113
+ self._execution_metadata[key] = value
114
+
115
+ def get_execution_metadata(self, key: str, default: Any = None) -> Any:
116
+ """Get execution metadata.
117
+
118
+ Args:
119
+ key: Metadata key.
120
+ default: Default value if key not found.
121
+
122
+ Returns:
123
+ Metadata value or default.
124
+ """
125
+ return self._execution_metadata.get(key, default)
126
+
127
+ def validate_async_usage(self, migration_func: Any) -> None:
128
+ """Validate proper usage of async functions in migration context.
129
+
130
+ Args:
131
+ migration_func: The migration function to validate.
132
+
133
+ Raises:
134
+ RuntimeError: If async function is used inappropriately.
135
+ """
136
+ if inspect.iscoroutinefunction(migration_func) and not self.is_async_execution and not self.is_async_driver:
137
+ msg = (
138
+ "Async migration function detected but execution context is sync. "
139
+ "Consider using async database configuration or sync migration functions."
140
+ )
141
+ logger.warning(msg)
142
+
143
+ if not inspect.iscoroutinefunction(migration_func) and self.is_async_driver:
144
+ self.set_execution_metadata("mixed_execution", value=True)
145
+ logger.debug("Sync migration function in async driver context - using compatibility mode")
@@ -164,17 +164,21 @@ class SQLFileLoader(BaseMigrationLoader):
164
164
  class PythonFileLoader(BaseMigrationLoader):
165
165
  """Loader for Python migration files."""
166
166
 
167
- __slots__ = ("migrations_dir", "project_root")
167
+ __slots__ = ("context", "migrations_dir", "project_root")
168
168
 
169
- def __init__(self, migrations_dir: Path, project_root: "Optional[Path]" = None) -> None:
169
+ def __init__(
170
+ self, migrations_dir: Path, project_root: "Optional[Path]" = None, context: "Optional[Any]" = None
171
+ ) -> None:
170
172
  """Initialize Python file loader.
171
173
 
172
174
  Args:
173
175
  migrations_dir: Directory containing migration files.
174
176
  project_root: Optional project root directory for imports.
177
+ context: Optional migration context to pass to functions.
175
178
  """
176
179
  self.migrations_dir = migrations_dir
177
180
  self.project_root = project_root if project_root is not None else self._find_project_root(migrations_dir)
181
+ self.context = context
178
182
 
179
183
  async def get_up_sql(self, path: Path) -> list[str]:
180
184
  """Load Python migration and execute upgrade function.
@@ -208,10 +212,16 @@ class PythonFileLoader(BaseMigrationLoader):
208
212
  msg = f"'{func_name}' is not callable in {path}"
209
213
  raise MigrationLoadError(msg)
210
214
 
215
+ # Check if function accepts context parameter
216
+ sig = inspect.signature(upgrade_func)
217
+ accepts_context = "context" in sig.parameters or len(sig.parameters) > 0
218
+
211
219
  if inspect.iscoroutinefunction(upgrade_func):
212
- sql_result = await upgrade_func()
220
+ sql_result = (
221
+ await upgrade_func(self.context) if accepts_context and self.context else await upgrade_func()
222
+ )
213
223
  else:
214
- sql_result = upgrade_func()
224
+ sql_result = upgrade_func(self.context) if accepts_context and self.context else upgrade_func()
215
225
 
216
226
  return self._normalize_and_validate_sql(sql_result, path)
217
227
 
@@ -239,10 +249,16 @@ class PythonFileLoader(BaseMigrationLoader):
239
249
  if not callable(downgrade_func):
240
250
  return []
241
251
 
252
+ # Check if function accepts context parameter
253
+ sig = inspect.signature(downgrade_func)
254
+ accepts_context = "context" in sig.parameters or len(sig.parameters) > 0
255
+
242
256
  if inspect.iscoroutinefunction(downgrade_func):
243
- sql_result = await downgrade_func()
257
+ sql_result = (
258
+ await downgrade_func(self.context) if accepts_context and self.context else await downgrade_func()
259
+ )
244
260
  else:
245
- sql_result = downgrade_func()
261
+ sql_result = downgrade_func(self.context) if accepts_context and self.context else downgrade_func()
246
262
 
247
263
  return self._normalize_and_validate_sql(sql_result, path)
248
264
 
@@ -380,7 +396,7 @@ class PythonFileLoader(BaseMigrationLoader):
380
396
 
381
397
 
382
398
  def get_migration_loader(
383
- file_path: Path, migrations_dir: Path, project_root: "Optional[Path]" = None
399
+ file_path: Path, migrations_dir: Path, project_root: "Optional[Path]" = None, context: "Optional[Any]" = None
384
400
  ) -> BaseMigrationLoader:
385
401
  """Factory function to get appropriate loader for migration file.
386
402
 
@@ -388,6 +404,7 @@ def get_migration_loader(
388
404
  file_path: Path to the migration file.
389
405
  migrations_dir: Directory containing migration files.
390
406
  project_root: Optional project root directory for Python imports.
407
+ context: Optional migration context to pass to Python migrations.
391
408
 
392
409
  Returns:
393
410
  Appropriate loader instance for the file type.
@@ -398,7 +415,7 @@ def get_migration_loader(
398
415
  suffix = file_path.suffix
399
416
 
400
417
  if suffix == ".py":
401
- return PythonFileLoader(migrations_dir, project_root)
418
+ return PythonFileLoader(migrations_dir, project_root, context)
402
419
  if suffix == ".sql":
403
420
  return SQLFileLoader()
404
421
  msg = f"Unsupported migration file type: {suffix}"