sqlspec 0.14.0__py3-none-any.whl → 0.15.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of sqlspec might be problematic. Click here for more details.
- sqlspec/__init__.py +50 -25
- sqlspec/__main__.py +12 -0
- sqlspec/__metadata__.py +1 -3
- sqlspec/_serialization.py +1 -2
- sqlspec/_sql.py +256 -120
- sqlspec/_typing.py +278 -142
- sqlspec/adapters/adbc/__init__.py +4 -3
- sqlspec/adapters/adbc/_types.py +12 -0
- sqlspec/adapters/adbc/config.py +115 -248
- sqlspec/adapters/adbc/driver.py +462 -353
- sqlspec/adapters/aiosqlite/__init__.py +18 -3
- sqlspec/adapters/aiosqlite/_types.py +13 -0
- sqlspec/adapters/aiosqlite/config.py +199 -129
- sqlspec/adapters/aiosqlite/driver.py +230 -269
- sqlspec/adapters/asyncmy/__init__.py +18 -3
- sqlspec/adapters/asyncmy/_types.py +12 -0
- sqlspec/adapters/asyncmy/config.py +80 -168
- sqlspec/adapters/asyncmy/driver.py +260 -225
- sqlspec/adapters/asyncpg/__init__.py +19 -4
- sqlspec/adapters/asyncpg/_types.py +17 -0
- sqlspec/adapters/asyncpg/config.py +82 -181
- sqlspec/adapters/asyncpg/driver.py +285 -383
- sqlspec/adapters/bigquery/__init__.py +17 -3
- sqlspec/adapters/bigquery/_types.py +12 -0
- sqlspec/adapters/bigquery/config.py +191 -258
- sqlspec/adapters/bigquery/driver.py +474 -646
- sqlspec/adapters/duckdb/__init__.py +14 -3
- sqlspec/adapters/duckdb/_types.py +12 -0
- sqlspec/adapters/duckdb/config.py +415 -351
- sqlspec/adapters/duckdb/driver.py +343 -413
- sqlspec/adapters/oracledb/__init__.py +19 -5
- sqlspec/adapters/oracledb/_types.py +14 -0
- sqlspec/adapters/oracledb/config.py +123 -379
- sqlspec/adapters/oracledb/driver.py +507 -560
- sqlspec/adapters/psqlpy/__init__.py +13 -3
- sqlspec/adapters/psqlpy/_types.py +11 -0
- sqlspec/adapters/psqlpy/config.py +93 -254
- sqlspec/adapters/psqlpy/driver.py +505 -234
- sqlspec/adapters/psycopg/__init__.py +19 -5
- sqlspec/adapters/psycopg/_types.py +17 -0
- sqlspec/adapters/psycopg/config.py +143 -403
- sqlspec/adapters/psycopg/driver.py +706 -872
- sqlspec/adapters/sqlite/__init__.py +14 -3
- sqlspec/adapters/sqlite/_types.py +11 -0
- sqlspec/adapters/sqlite/config.py +202 -118
- sqlspec/adapters/sqlite/driver.py +264 -303
- sqlspec/base.py +105 -9
- sqlspec/{statement/builder → builder}/__init__.py +12 -14
- sqlspec/{statement/builder → builder}/_base.py +120 -55
- sqlspec/{statement/builder → builder}/_column.py +17 -6
- sqlspec/{statement/builder → builder}/_ddl.py +46 -79
- sqlspec/{statement/builder → builder}/_ddl_utils.py +5 -10
- sqlspec/{statement/builder → builder}/_delete.py +6 -25
- sqlspec/{statement/builder → builder}/_insert.py +6 -64
- sqlspec/builder/_merge.py +56 -0
- sqlspec/{statement/builder → builder}/_parsing_utils.py +3 -10
- sqlspec/{statement/builder → builder}/_select.py +11 -56
- sqlspec/{statement/builder → builder}/_update.py +12 -18
- sqlspec/{statement/builder → builder}/mixins/__init__.py +10 -14
- sqlspec/{statement/builder → builder}/mixins/_cte_and_set_ops.py +48 -59
- sqlspec/{statement/builder → builder}/mixins/_insert_operations.py +22 -16
- sqlspec/{statement/builder → builder}/mixins/_join_operations.py +1 -3
- sqlspec/{statement/builder → builder}/mixins/_merge_operations.py +3 -5
- sqlspec/{statement/builder → builder}/mixins/_order_limit_operations.py +3 -3
- sqlspec/{statement/builder → builder}/mixins/_pivot_operations.py +4 -8
- sqlspec/{statement/builder → builder}/mixins/_select_operations.py +21 -36
- sqlspec/{statement/builder → builder}/mixins/_update_operations.py +3 -14
- sqlspec/{statement/builder → builder}/mixins/_where_clause.py +52 -79
- sqlspec/cli.py +4 -5
- sqlspec/config.py +180 -133
- sqlspec/core/__init__.py +63 -0
- sqlspec/core/cache.py +873 -0
- sqlspec/core/compiler.py +396 -0
- sqlspec/core/filters.py +828 -0
- sqlspec/core/hashing.py +310 -0
- sqlspec/core/parameters.py +1209 -0
- sqlspec/core/result.py +664 -0
- sqlspec/{statement → core}/splitter.py +321 -191
- sqlspec/core/statement.py +651 -0
- sqlspec/driver/__init__.py +7 -10
- sqlspec/driver/_async.py +387 -176
- sqlspec/driver/_common.py +527 -289
- sqlspec/driver/_sync.py +390 -172
- sqlspec/driver/mixins/__init__.py +2 -19
- sqlspec/driver/mixins/_result_tools.py +168 -0
- sqlspec/driver/mixins/_sql_translator.py +6 -3
- sqlspec/exceptions.py +5 -252
- sqlspec/extensions/aiosql/adapter.py +93 -96
- sqlspec/extensions/litestar/config.py +0 -1
- sqlspec/extensions/litestar/handlers.py +15 -26
- sqlspec/extensions/litestar/plugin.py +16 -14
- sqlspec/extensions/litestar/providers.py +17 -52
- sqlspec/loader.py +424 -105
- sqlspec/migrations/__init__.py +12 -0
- sqlspec/migrations/base.py +92 -68
- sqlspec/migrations/commands.py +24 -106
- sqlspec/migrations/loaders.py +402 -0
- sqlspec/migrations/runner.py +49 -51
- sqlspec/migrations/tracker.py +31 -44
- sqlspec/migrations/utils.py +64 -24
- sqlspec/protocols.py +7 -183
- sqlspec/storage/__init__.py +1 -1
- sqlspec/storage/backends/base.py +37 -40
- sqlspec/storage/backends/fsspec.py +136 -112
- sqlspec/storage/backends/obstore.py +138 -160
- sqlspec/storage/capabilities.py +5 -4
- sqlspec/storage/registry.py +57 -106
- sqlspec/typing.py +136 -115
- sqlspec/utils/__init__.py +2 -3
- sqlspec/utils/correlation.py +0 -3
- sqlspec/utils/deprecation.py +6 -6
- sqlspec/utils/fixtures.py +6 -6
- sqlspec/utils/logging.py +0 -2
- sqlspec/utils/module_loader.py +7 -12
- sqlspec/utils/singleton.py +0 -1
- sqlspec/utils/sync_tools.py +16 -37
- sqlspec/utils/text.py +12 -51
- sqlspec/utils/type_guards.py +443 -232
- {sqlspec-0.14.0.dist-info → sqlspec-0.15.0.dist-info}/METADATA +7 -2
- sqlspec-0.15.0.dist-info/RECORD +134 -0
- sqlspec-0.15.0.dist-info/entry_points.txt +2 -0
- sqlspec/driver/connection.py +0 -207
- sqlspec/driver/mixins/_cache.py +0 -114
- sqlspec/driver/mixins/_csv_writer.py +0 -91
- sqlspec/driver/mixins/_pipeline.py +0 -508
- sqlspec/driver/mixins/_query_tools.py +0 -796
- sqlspec/driver/mixins/_result_utils.py +0 -138
- sqlspec/driver/mixins/_storage.py +0 -912
- sqlspec/driver/mixins/_type_coercion.py +0 -128
- sqlspec/driver/parameters.py +0 -138
- sqlspec/statement/__init__.py +0 -21
- sqlspec/statement/builder/_merge.py +0 -95
- sqlspec/statement/cache.py +0 -50
- sqlspec/statement/filters.py +0 -625
- sqlspec/statement/parameters.py +0 -996
- sqlspec/statement/pipelines/__init__.py +0 -210
- sqlspec/statement/pipelines/analyzers/__init__.py +0 -9
- sqlspec/statement/pipelines/analyzers/_analyzer.py +0 -646
- sqlspec/statement/pipelines/context.py +0 -115
- sqlspec/statement/pipelines/transformers/__init__.py +0 -7
- sqlspec/statement/pipelines/transformers/_expression_simplifier.py +0 -88
- sqlspec/statement/pipelines/transformers/_literal_parameterizer.py +0 -1247
- sqlspec/statement/pipelines/transformers/_remove_comments_and_hints.py +0 -76
- sqlspec/statement/pipelines/validators/__init__.py +0 -23
- sqlspec/statement/pipelines/validators/_dml_safety.py +0 -290
- sqlspec/statement/pipelines/validators/_parameter_style.py +0 -370
- sqlspec/statement/pipelines/validators/_performance.py +0 -714
- sqlspec/statement/pipelines/validators/_security.py +0 -967
- sqlspec/statement/result.py +0 -435
- sqlspec/statement/sql.py +0 -1774
- sqlspec/utils/cached_property.py +0 -25
- sqlspec/utils/statement_hashing.py +0 -203
- sqlspec-0.14.0.dist-info/RECORD +0 -143
- sqlspec-0.14.0.dist-info/entry_points.txt +0 -2
- /sqlspec/{statement/builder → builder}/mixins/_delete_operations.py +0 -0
- {sqlspec-0.14.0.dist-info → sqlspec-0.15.0.dist-info}/WHEEL +0 -0
- {sqlspec-0.14.0.dist-info → sqlspec-0.15.0.dist-info}/licenses/LICENSE +0 -0
- {sqlspec-0.14.0.dist-info → sqlspec-0.15.0.dist-info}/licenses/NOTICE +0 -0
|
@@ -0,0 +1,402 @@
|
|
|
1
|
+
"""Migration loader abstractions for SQLSpec."""
|
|
2
|
+
|
|
3
|
+
import abc
|
|
4
|
+
import inspect
|
|
5
|
+
import sys
|
|
6
|
+
import types
|
|
7
|
+
from collections.abc import Iterator
|
|
8
|
+
from contextlib import contextmanager
|
|
9
|
+
from pathlib import Path
|
|
10
|
+
from typing import Any, Final, Optional
|
|
11
|
+
|
|
12
|
+
from sqlspec.loader import SQLFileLoader as CoreSQLFileLoader
|
|
13
|
+
|
|
14
|
+
__all__ = ("BaseMigrationLoader", "MigrationLoadError", "PythonFileLoader", "SQLFileLoader", "get_migration_loader")
|
|
15
|
+
|
|
16
|
+
PROJECT_ROOT_MARKERS: Final[list[str]] = ["pyproject.toml", ".git", "setup.cfg", "setup.py"]
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class MigrationLoadError(Exception):
|
|
20
|
+
"""Exception raised when migration loading fails."""
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
class BaseMigrationLoader(abc.ABC):
|
|
24
|
+
"""Abstract base class for migration loaders."""
|
|
25
|
+
|
|
26
|
+
__slots__ = ()
|
|
27
|
+
|
|
28
|
+
@abc.abstractmethod
|
|
29
|
+
async def get_up_sql(self, path: Path) -> list[str]:
|
|
30
|
+
"""Load and return the 'up' SQL statements from a migration file.
|
|
31
|
+
|
|
32
|
+
Args:
|
|
33
|
+
path: Path to the migration file.
|
|
34
|
+
|
|
35
|
+
Returns:
|
|
36
|
+
List of SQL statements to execute for upgrade.
|
|
37
|
+
|
|
38
|
+
Raises:
|
|
39
|
+
MigrationLoadError: If loading fails.
|
|
40
|
+
"""
|
|
41
|
+
...
|
|
42
|
+
|
|
43
|
+
@abc.abstractmethod
|
|
44
|
+
async def get_down_sql(self, path: Path) -> list[str]:
|
|
45
|
+
"""Load and return the 'down' SQL statements from a migration file.
|
|
46
|
+
|
|
47
|
+
Args:
|
|
48
|
+
path: Path to the migration file.
|
|
49
|
+
|
|
50
|
+
Returns:
|
|
51
|
+
List of SQL statements to execute for downgrade.
|
|
52
|
+
Empty list if no downgrade is available.
|
|
53
|
+
|
|
54
|
+
Raises:
|
|
55
|
+
MigrationLoadError: If loading fails.
|
|
56
|
+
"""
|
|
57
|
+
...
|
|
58
|
+
|
|
59
|
+
@abc.abstractmethod
|
|
60
|
+
def validate_migration_file(self, path: Path) -> None:
|
|
61
|
+
"""Validate that the migration file has required components.
|
|
62
|
+
|
|
63
|
+
Args:
|
|
64
|
+
path: Path to the migration file.
|
|
65
|
+
|
|
66
|
+
Raises:
|
|
67
|
+
MigrationLoadError: If validation fails.
|
|
68
|
+
"""
|
|
69
|
+
...
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
class SQLFileLoader(BaseMigrationLoader):
|
|
73
|
+
"""Loader for SQL migration files using SQLFileLoader."""
|
|
74
|
+
|
|
75
|
+
__slots__ = ("sql_loader",)
|
|
76
|
+
|
|
77
|
+
def __init__(self) -> None:
|
|
78
|
+
"""Initialize SQL file loader."""
|
|
79
|
+
self.sql_loader: CoreSQLFileLoader = CoreSQLFileLoader()
|
|
80
|
+
|
|
81
|
+
async def get_up_sql(self, path: Path) -> list[str]:
|
|
82
|
+
"""Extract the 'up' SQL from a SQL migration file.
|
|
83
|
+
|
|
84
|
+
Args:
|
|
85
|
+
path: Path to SQL migration file.
|
|
86
|
+
|
|
87
|
+
Returns:
|
|
88
|
+
List containing single SQL statement for upgrade.
|
|
89
|
+
|
|
90
|
+
Raises:
|
|
91
|
+
MigrationLoadError: If migration file is invalid or missing up query.
|
|
92
|
+
"""
|
|
93
|
+
self.sql_loader.clear_cache()
|
|
94
|
+
self.sql_loader.load_sql(path)
|
|
95
|
+
|
|
96
|
+
version = self._extract_version(path.name)
|
|
97
|
+
up_query = f"migrate-{version}-up"
|
|
98
|
+
|
|
99
|
+
if not self.sql_loader.has_query(up_query):
|
|
100
|
+
msg = f"Migration {path} missing 'up' query: {up_query}"
|
|
101
|
+
raise MigrationLoadError(msg)
|
|
102
|
+
|
|
103
|
+
sql_obj = self.sql_loader.get_sql(up_query)
|
|
104
|
+
return [sql_obj.sql]
|
|
105
|
+
|
|
106
|
+
async def get_down_sql(self, path: Path) -> list[str]:
|
|
107
|
+
"""Extract the 'down' SQL from a SQL migration file.
|
|
108
|
+
|
|
109
|
+
Args:
|
|
110
|
+
path: Path to SQL migration file.
|
|
111
|
+
|
|
112
|
+
Returns:
|
|
113
|
+
List containing single SQL statement for downgrade, or empty list.
|
|
114
|
+
"""
|
|
115
|
+
self.sql_loader.clear_cache()
|
|
116
|
+
self.sql_loader.load_sql(path)
|
|
117
|
+
|
|
118
|
+
version = self._extract_version(path.name)
|
|
119
|
+
down_query = f"migrate-{version}-down"
|
|
120
|
+
|
|
121
|
+
if not self.sql_loader.has_query(down_query):
|
|
122
|
+
return []
|
|
123
|
+
|
|
124
|
+
sql_obj = self.sql_loader.get_sql(down_query)
|
|
125
|
+
return [sql_obj.sql]
|
|
126
|
+
|
|
127
|
+
def validate_migration_file(self, path: Path) -> None:
|
|
128
|
+
"""Validate SQL migration file has required up query.
|
|
129
|
+
|
|
130
|
+
Args:
|
|
131
|
+
path: Path to SQL migration file.
|
|
132
|
+
|
|
133
|
+
Raises:
|
|
134
|
+
MigrationLoadError: If file is invalid or missing required query.
|
|
135
|
+
"""
|
|
136
|
+
version = self._extract_version(path.name)
|
|
137
|
+
if not version:
|
|
138
|
+
msg = f"Invalid migration filename: {path.name}"
|
|
139
|
+
raise MigrationLoadError(msg)
|
|
140
|
+
|
|
141
|
+
self.sql_loader.clear_cache()
|
|
142
|
+
self.sql_loader.load_sql(path)
|
|
143
|
+
up_query = f"migrate-{version}-up"
|
|
144
|
+
if not self.sql_loader.has_query(up_query):
|
|
145
|
+
msg = f"Migration {path} missing required 'up' query: {up_query}"
|
|
146
|
+
raise MigrationLoadError(msg)
|
|
147
|
+
|
|
148
|
+
def _extract_version(self, filename: str) -> str:
|
|
149
|
+
"""Extract version from filename.
|
|
150
|
+
|
|
151
|
+
Args:
|
|
152
|
+
filename: Migration filename to parse.
|
|
153
|
+
|
|
154
|
+
Returns:
|
|
155
|
+
Zero-padded version string or empty string if invalid.
|
|
156
|
+
"""
|
|
157
|
+
parts = filename.split("_", 1)
|
|
158
|
+
return parts[0].zfill(4) if parts and parts[0].isdigit() else ""
|
|
159
|
+
|
|
160
|
+
|
|
161
|
+
class PythonFileLoader(BaseMigrationLoader):
|
|
162
|
+
"""Loader for Python migration files."""
|
|
163
|
+
|
|
164
|
+
__slots__ = ("migrations_dir", "project_root")
|
|
165
|
+
|
|
166
|
+
def __init__(self, migrations_dir: Path, project_root: "Optional[Path]" = None) -> None:
|
|
167
|
+
"""Initialize Python file loader.
|
|
168
|
+
|
|
169
|
+
Args:
|
|
170
|
+
migrations_dir: Directory containing migration files.
|
|
171
|
+
project_root: Optional project root directory for imports.
|
|
172
|
+
"""
|
|
173
|
+
self.migrations_dir = migrations_dir
|
|
174
|
+
self.project_root = project_root if project_root is not None else self._find_project_root(migrations_dir)
|
|
175
|
+
|
|
176
|
+
async def get_up_sql(self, path: Path) -> list[str]:
|
|
177
|
+
"""Load Python migration and execute upgrade function.
|
|
178
|
+
|
|
179
|
+
Args:
|
|
180
|
+
path: Path to Python migration file.
|
|
181
|
+
|
|
182
|
+
Returns:
|
|
183
|
+
List of SQL statements for upgrade.
|
|
184
|
+
|
|
185
|
+
Raises:
|
|
186
|
+
MigrationLoadError: If function is missing or execution fails.
|
|
187
|
+
"""
|
|
188
|
+
with self._temporary_project_path():
|
|
189
|
+
module = self._load_module_from_path(path)
|
|
190
|
+
|
|
191
|
+
upgrade_func = None
|
|
192
|
+
func_name = None
|
|
193
|
+
|
|
194
|
+
if hasattr(module, "up") and callable(module.up):
|
|
195
|
+
upgrade_func = module.up
|
|
196
|
+
func_name = "up"
|
|
197
|
+
elif hasattr(module, "migrate_up") and callable(module.migrate_up):
|
|
198
|
+
upgrade_func = module.migrate_up
|
|
199
|
+
func_name = "migrate_up"
|
|
200
|
+
else:
|
|
201
|
+
msg = f"No upgrade function found in {path}. Expected 'up()' or 'migrate_up()'"
|
|
202
|
+
raise MigrationLoadError(msg)
|
|
203
|
+
|
|
204
|
+
if not callable(upgrade_func):
|
|
205
|
+
msg = f"'{func_name}' is not callable in {path}"
|
|
206
|
+
raise MigrationLoadError(msg)
|
|
207
|
+
|
|
208
|
+
if inspect.iscoroutinefunction(upgrade_func):
|
|
209
|
+
sql_result = await upgrade_func()
|
|
210
|
+
else:
|
|
211
|
+
sql_result = upgrade_func()
|
|
212
|
+
|
|
213
|
+
return self._normalize_and_validate_sql(sql_result, path)
|
|
214
|
+
|
|
215
|
+
async def get_down_sql(self, path: Path) -> list[str]:
|
|
216
|
+
"""Load Python migration and execute downgrade function.
|
|
217
|
+
|
|
218
|
+
Args:
|
|
219
|
+
path: Path to Python migration file.
|
|
220
|
+
|
|
221
|
+
Returns:
|
|
222
|
+
List of SQL statements for downgrade, or empty list if not available.
|
|
223
|
+
"""
|
|
224
|
+
with self._temporary_project_path():
|
|
225
|
+
module = self._load_module_from_path(path)
|
|
226
|
+
|
|
227
|
+
downgrade_func = None
|
|
228
|
+
|
|
229
|
+
if hasattr(module, "down") and callable(module.down):
|
|
230
|
+
downgrade_func = module.down
|
|
231
|
+
elif hasattr(module, "migrate_down") and callable(module.migrate_down):
|
|
232
|
+
downgrade_func = module.migrate_down
|
|
233
|
+
else:
|
|
234
|
+
return []
|
|
235
|
+
|
|
236
|
+
if not callable(downgrade_func):
|
|
237
|
+
return []
|
|
238
|
+
|
|
239
|
+
if inspect.iscoroutinefunction(downgrade_func):
|
|
240
|
+
sql_result = await downgrade_func()
|
|
241
|
+
else:
|
|
242
|
+
sql_result = downgrade_func()
|
|
243
|
+
|
|
244
|
+
return self._normalize_and_validate_sql(sql_result, path)
|
|
245
|
+
|
|
246
|
+
def validate_migration_file(self, path: Path) -> None:
|
|
247
|
+
"""Validate Python migration file has required upgrade function.
|
|
248
|
+
|
|
249
|
+
Args:
|
|
250
|
+
path: Path to Python migration file.
|
|
251
|
+
|
|
252
|
+
Raises:
|
|
253
|
+
MigrationLoadError: If validation fails.
|
|
254
|
+
"""
|
|
255
|
+
with self._temporary_project_path():
|
|
256
|
+
module = self._load_module_from_path(path)
|
|
257
|
+
|
|
258
|
+
upgrade_func = None
|
|
259
|
+
func_name = None
|
|
260
|
+
|
|
261
|
+
if hasattr(module, "up") and callable(module.up):
|
|
262
|
+
upgrade_func = module.up
|
|
263
|
+
func_name = "up"
|
|
264
|
+
elif hasattr(module, "migrate_up") and callable(module.migrate_up):
|
|
265
|
+
upgrade_func = module.migrate_up
|
|
266
|
+
func_name = "migrate_up"
|
|
267
|
+
else:
|
|
268
|
+
msg = f"Migration {path} missing required upgrade function. Expected 'up()' or 'migrate_up()'"
|
|
269
|
+
raise MigrationLoadError(msg)
|
|
270
|
+
|
|
271
|
+
if not callable(upgrade_func):
|
|
272
|
+
msg = f"Migration {path} '{func_name}' is not callable"
|
|
273
|
+
raise MigrationLoadError(msg)
|
|
274
|
+
|
|
275
|
+
def _find_project_root(self, start_path: Path) -> Path:
|
|
276
|
+
"""Find project root by searching upwards for marker files.
|
|
277
|
+
|
|
278
|
+
Args:
|
|
279
|
+
start_path: Directory to start searching from.
|
|
280
|
+
|
|
281
|
+
Returns:
|
|
282
|
+
Path to project root or parent directory.
|
|
283
|
+
"""
|
|
284
|
+
current_path = start_path.resolve()
|
|
285
|
+
|
|
286
|
+
while current_path != current_path.parent:
|
|
287
|
+
for marker in PROJECT_ROOT_MARKERS:
|
|
288
|
+
if (current_path / marker).exists():
|
|
289
|
+
return current_path
|
|
290
|
+
current_path = current_path.parent
|
|
291
|
+
|
|
292
|
+
return start_path.resolve().parent
|
|
293
|
+
|
|
294
|
+
@contextmanager
|
|
295
|
+
def _temporary_project_path(self) -> Iterator[None]:
|
|
296
|
+
"""Temporarily add project root to sys.path for imports."""
|
|
297
|
+
path_to_add = str(self.project_root)
|
|
298
|
+
if path_to_add in sys.path:
|
|
299
|
+
yield
|
|
300
|
+
return
|
|
301
|
+
|
|
302
|
+
sys.path.insert(0, path_to_add)
|
|
303
|
+
try:
|
|
304
|
+
yield
|
|
305
|
+
finally:
|
|
306
|
+
sys.path.remove(path_to_add)
|
|
307
|
+
|
|
308
|
+
def _load_module_from_path(self, path: Path) -> Any:
|
|
309
|
+
"""Load a Python module from file path.
|
|
310
|
+
|
|
311
|
+
Args:
|
|
312
|
+
path: Path to Python migration file.
|
|
313
|
+
|
|
314
|
+
Returns:
|
|
315
|
+
Loaded module object.
|
|
316
|
+
|
|
317
|
+
Raises:
|
|
318
|
+
MigrationLoadError: If module loading fails.
|
|
319
|
+
"""
|
|
320
|
+
module_name = f"sqlspec_migration_{path.stem}"
|
|
321
|
+
|
|
322
|
+
if module_name in sys.modules:
|
|
323
|
+
sys.modules.pop(module_name, None)
|
|
324
|
+
|
|
325
|
+
try:
|
|
326
|
+
source_code = path.read_text(encoding="utf-8")
|
|
327
|
+
compiled_code = compile(source_code, str(path), "exec")
|
|
328
|
+
|
|
329
|
+
module = types.ModuleType(module_name)
|
|
330
|
+
module.__file__ = str(path)
|
|
331
|
+
|
|
332
|
+
sys.modules[module_name] = module
|
|
333
|
+
|
|
334
|
+
exec(compiled_code, module.__dict__) # noqa: S102
|
|
335
|
+
|
|
336
|
+
except Exception as e:
|
|
337
|
+
sys.modules.pop(module_name, None)
|
|
338
|
+
msg = f"Failed to execute migration module {path}: {e}"
|
|
339
|
+
raise MigrationLoadError(msg) from e
|
|
340
|
+
|
|
341
|
+
return module
|
|
342
|
+
|
|
343
|
+
def _normalize_and_validate_sql(self, sql: Any, migration_path: Path) -> list[str]:
|
|
344
|
+
"""Validate return type and normalize to list of strings.
|
|
345
|
+
|
|
346
|
+
Args:
|
|
347
|
+
sql: Return value from migration function.
|
|
348
|
+
migration_path: Path to migration file for error messages.
|
|
349
|
+
|
|
350
|
+
Returns:
|
|
351
|
+
List of SQL statements.
|
|
352
|
+
|
|
353
|
+
Raises:
|
|
354
|
+
MigrationLoadError: If return type is invalid.
|
|
355
|
+
"""
|
|
356
|
+
if isinstance(sql, str):
|
|
357
|
+
stripped = sql.strip()
|
|
358
|
+
return [stripped] if stripped else []
|
|
359
|
+
if isinstance(sql, list):
|
|
360
|
+
result = []
|
|
361
|
+
for i, item in enumerate(sql):
|
|
362
|
+
if not isinstance(item, str):
|
|
363
|
+
msg = (
|
|
364
|
+
f"Migration {migration_path} returned a list containing a non-string "
|
|
365
|
+
f"element at index {i} (type: {type(item).__name__})."
|
|
366
|
+
)
|
|
367
|
+
raise MigrationLoadError(msg)
|
|
368
|
+
stripped_item = item.strip()
|
|
369
|
+
if stripped_item:
|
|
370
|
+
result.append(stripped_item)
|
|
371
|
+
return result
|
|
372
|
+
|
|
373
|
+
msg = (
|
|
374
|
+
f"Migration {migration_path} must return a 'str' or 'List[str]', but returned type '{type(sql).__name__}'."
|
|
375
|
+
)
|
|
376
|
+
raise MigrationLoadError(msg)
|
|
377
|
+
|
|
378
|
+
|
|
379
|
+
def get_migration_loader(
|
|
380
|
+
file_path: Path, migrations_dir: Path, project_root: "Optional[Path]" = None
|
|
381
|
+
) -> BaseMigrationLoader:
|
|
382
|
+
"""Factory function to get appropriate loader for migration file.
|
|
383
|
+
|
|
384
|
+
Args:
|
|
385
|
+
file_path: Path to the migration file.
|
|
386
|
+
migrations_dir: Directory containing migration files.
|
|
387
|
+
project_root: Optional project root directory for Python imports.
|
|
388
|
+
|
|
389
|
+
Returns:
|
|
390
|
+
Appropriate loader instance for the file type.
|
|
391
|
+
|
|
392
|
+
Raises:
|
|
393
|
+
MigrationLoadError: If file type is not supported.
|
|
394
|
+
"""
|
|
395
|
+
suffix = file_path.suffix
|
|
396
|
+
|
|
397
|
+
if suffix == ".py":
|
|
398
|
+
return PythonFileLoader(migrations_dir, project_root)
|
|
399
|
+
if suffix == ".sql":
|
|
400
|
+
return SQLFileLoader()
|
|
401
|
+
msg = f"Unsupported migration file type: {suffix}"
|
|
402
|
+
raise MigrationLoadError(msg)
|
sqlspec/migrations/runner.py
CHANGED
|
@@ -7,20 +7,22 @@ import time
|
|
|
7
7
|
from pathlib import Path
|
|
8
8
|
from typing import TYPE_CHECKING, Any, Optional
|
|
9
9
|
|
|
10
|
+
from sqlspec.core.statement import SQL
|
|
10
11
|
from sqlspec.migrations.base import BaseMigrationRunner
|
|
12
|
+
from sqlspec.migrations.loaders import get_migration_loader
|
|
11
13
|
from sqlspec.utils.logging import get_logger
|
|
14
|
+
from sqlspec.utils.sync_tools import run_
|
|
12
15
|
|
|
13
16
|
if TYPE_CHECKING:
|
|
14
|
-
from sqlspec.driver import
|
|
15
|
-
from sqlspec.statement.sql import SQL
|
|
17
|
+
from sqlspec.driver import AsyncDriverAdapterBase, SyncDriverAdapterBase
|
|
16
18
|
|
|
17
19
|
__all__ = ("AsyncMigrationRunner", "SyncMigrationRunner")
|
|
18
20
|
|
|
19
21
|
logger = get_logger("migrations.runner")
|
|
20
22
|
|
|
21
23
|
|
|
22
|
-
class SyncMigrationRunner(BaseMigrationRunner["
|
|
23
|
-
"""
|
|
24
|
+
class SyncMigrationRunner(BaseMigrationRunner["SyncDriverAdapterBase"]):
|
|
25
|
+
"""Executes migrations using SQLFileLoader."""
|
|
24
26
|
|
|
25
27
|
def get_migration_files(self) -> "list[tuple[str, Path]]":
|
|
26
28
|
"""Get all migration files sorted by version.
|
|
@@ -42,7 +44,7 @@ class SyncMigrationRunner(BaseMigrationRunner["SyncDriverAdapterProtocol[Any]"])
|
|
|
42
44
|
return self._load_migration_metadata(file_path)
|
|
43
45
|
|
|
44
46
|
def execute_upgrade(
|
|
45
|
-
self, driver: "
|
|
47
|
+
self, driver: "SyncDriverAdapterBase", migration: "dict[str, Any]"
|
|
46
48
|
) -> "tuple[Optional[str], int]":
|
|
47
49
|
"""Execute an upgrade migration.
|
|
48
50
|
|
|
@@ -58,15 +60,12 @@ class SyncMigrationRunner(BaseMigrationRunner["SyncDriverAdapterProtocol[Any]"])
|
|
|
58
60
|
return None, 0
|
|
59
61
|
|
|
60
62
|
start_time = time.time()
|
|
61
|
-
|
|
62
|
-
# Execute migration
|
|
63
63
|
driver.execute(upgrade_sql)
|
|
64
|
-
|
|
65
64
|
execution_time = int((time.time() - start_time) * 1000)
|
|
66
65
|
return None, execution_time
|
|
67
66
|
|
|
68
67
|
def execute_downgrade(
|
|
69
|
-
self, driver: "
|
|
68
|
+
self, driver: "SyncDriverAdapterBase", migration: "dict[str, Any]"
|
|
70
69
|
) -> "tuple[Optional[str], int]":
|
|
71
70
|
"""Execute a downgrade migration.
|
|
72
71
|
|
|
@@ -82,42 +81,44 @@ class SyncMigrationRunner(BaseMigrationRunner["SyncDriverAdapterProtocol[Any]"])
|
|
|
82
81
|
return None, 0
|
|
83
82
|
|
|
84
83
|
start_time = time.time()
|
|
85
|
-
|
|
86
|
-
# Execute migration
|
|
87
84
|
driver.execute(downgrade_sql)
|
|
88
|
-
|
|
89
85
|
execution_time = int((time.time() - start_time) * 1000)
|
|
90
86
|
return None, execution_time
|
|
91
87
|
|
|
92
88
|
def load_all_migrations(self) -> "dict[str, SQL]":
|
|
93
89
|
"""Load all migrations into a single namespace for bulk operations.
|
|
94
90
|
|
|
95
|
-
Returns a dictionary mapping query names to SQL objects.
|
|
96
|
-
Useful for:
|
|
97
|
-
- Migration analysis tools
|
|
98
|
-
- Documentation generation
|
|
99
|
-
- Validation and linting
|
|
100
|
-
- Migration squashing
|
|
101
|
-
|
|
102
91
|
Returns:
|
|
103
92
|
Dictionary mapping query names to SQL objects.
|
|
104
93
|
"""
|
|
105
94
|
all_queries = {}
|
|
106
95
|
migrations = self.get_migration_files()
|
|
107
96
|
|
|
108
|
-
for
|
|
109
|
-
|
|
97
|
+
for version, file_path in migrations:
|
|
98
|
+
if file_path.suffix == ".sql":
|
|
99
|
+
self.loader.load_sql(file_path)
|
|
100
|
+
for query_name in self.loader.list_queries():
|
|
101
|
+
all_queries[query_name] = self.loader.get_sql(query_name)
|
|
102
|
+
else:
|
|
103
|
+
loader = get_migration_loader(file_path, self.migrations_path, self.project_root)
|
|
104
|
+
|
|
105
|
+
try:
|
|
106
|
+
up_sql = run_(loader.get_up_sql)(file_path)
|
|
107
|
+
down_sql = run_(loader.get_down_sql)(file_path)
|
|
108
|
+
|
|
109
|
+
if up_sql:
|
|
110
|
+
all_queries[f"migrate-{version}-up"] = SQL(up_sql[0])
|
|
111
|
+
if down_sql:
|
|
112
|
+
all_queries[f"migrate-{version}-down"] = SQL(down_sql[0])
|
|
110
113
|
|
|
111
|
-
|
|
112
|
-
|
|
113
|
-
# Store with full query name for uniqueness
|
|
114
|
-
all_queries[query_name] = self.loader.get_sql(query_name)
|
|
114
|
+
except Exception as e:
|
|
115
|
+
logger.debug("Failed to load Python migration %s: %s", file_path, e)
|
|
115
116
|
|
|
116
117
|
return all_queries
|
|
117
118
|
|
|
118
119
|
|
|
119
|
-
class AsyncMigrationRunner(BaseMigrationRunner["
|
|
120
|
-
"""
|
|
120
|
+
class AsyncMigrationRunner(BaseMigrationRunner["AsyncDriverAdapterBase"]):
|
|
121
|
+
"""Executes migrations using SQLFileLoader."""
|
|
121
122
|
|
|
122
123
|
async def get_migration_files(self) -> "list[tuple[str, Path]]":
|
|
123
124
|
"""Get all migration files sorted by version.
|
|
@@ -125,7 +126,6 @@ class AsyncMigrationRunner(BaseMigrationRunner["AsyncDriverAdapterProtocol[Any]"
|
|
|
125
126
|
Returns:
|
|
126
127
|
List of tuples containing (version, file_path).
|
|
127
128
|
"""
|
|
128
|
-
# For async, we still use the sync file operations since Path.glob is sync
|
|
129
129
|
return self._get_migration_files_sync()
|
|
130
130
|
|
|
131
131
|
async def load_migration(self, file_path: Path) -> "dict[str, Any]":
|
|
@@ -137,11 +137,10 @@ class AsyncMigrationRunner(BaseMigrationRunner["AsyncDriverAdapterProtocol[Any]"
|
|
|
137
137
|
Returns:
|
|
138
138
|
Dictionary containing migration metadata.
|
|
139
139
|
"""
|
|
140
|
-
# File loading is still sync, so we use the base implementation
|
|
141
140
|
return self._load_migration_metadata(file_path)
|
|
142
141
|
|
|
143
142
|
async def execute_upgrade(
|
|
144
|
-
self, driver: "
|
|
143
|
+
self, driver: "AsyncDriverAdapterBase", migration: "dict[str, Any]"
|
|
145
144
|
) -> "tuple[Optional[str], int]":
|
|
146
145
|
"""Execute an upgrade migration.
|
|
147
146
|
|
|
@@ -157,15 +156,12 @@ class AsyncMigrationRunner(BaseMigrationRunner["AsyncDriverAdapterProtocol[Any]"
|
|
|
157
156
|
return None, 0
|
|
158
157
|
|
|
159
158
|
start_time = time.time()
|
|
160
|
-
|
|
161
|
-
# Execute migration
|
|
162
159
|
await driver.execute(upgrade_sql)
|
|
163
|
-
|
|
164
160
|
execution_time = int((time.time() - start_time) * 1000)
|
|
165
161
|
return None, execution_time
|
|
166
162
|
|
|
167
163
|
async def execute_downgrade(
|
|
168
|
-
self, driver: "
|
|
164
|
+
self, driver: "AsyncDriverAdapterBase", migration: "dict[str, Any]"
|
|
169
165
|
) -> "tuple[Optional[str], int]":
|
|
170
166
|
"""Execute a downgrade migration.
|
|
171
167
|
|
|
@@ -181,35 +177,37 @@ class AsyncMigrationRunner(BaseMigrationRunner["AsyncDriverAdapterProtocol[Any]"
|
|
|
181
177
|
return None, 0
|
|
182
178
|
|
|
183
179
|
start_time = time.time()
|
|
184
|
-
|
|
185
|
-
# Execute migration
|
|
186
180
|
await driver.execute(downgrade_sql)
|
|
187
|
-
|
|
188
181
|
execution_time = int((time.time() - start_time) * 1000)
|
|
189
182
|
return None, execution_time
|
|
190
183
|
|
|
191
184
|
async def load_all_migrations(self) -> "dict[str, SQL]":
|
|
192
185
|
"""Load all migrations into a single namespace for bulk operations.
|
|
193
186
|
|
|
194
|
-
Returns a dictionary mapping query names to SQL objects.
|
|
195
|
-
Useful for:
|
|
196
|
-
- Migration analysis tools
|
|
197
|
-
- Documentation generation
|
|
198
|
-
- Validation and linting
|
|
199
|
-
- Migration squashing
|
|
200
|
-
|
|
201
187
|
Returns:
|
|
202
188
|
Dictionary mapping query names to SQL objects.
|
|
203
189
|
"""
|
|
204
190
|
all_queries = {}
|
|
205
191
|
migrations = await self.get_migration_files()
|
|
206
192
|
|
|
207
|
-
for
|
|
208
|
-
|
|
209
|
-
|
|
210
|
-
|
|
211
|
-
|
|
212
|
-
|
|
213
|
-
|
|
193
|
+
for version, file_path in migrations:
|
|
194
|
+
if file_path.suffix == ".sql":
|
|
195
|
+
self.loader.load_sql(file_path)
|
|
196
|
+
for query_name in self.loader.list_queries():
|
|
197
|
+
all_queries[query_name] = self.loader.get_sql(query_name)
|
|
198
|
+
else:
|
|
199
|
+
loader = get_migration_loader(file_path, self.migrations_path, self.project_root)
|
|
200
|
+
|
|
201
|
+
try:
|
|
202
|
+
up_sql = await loader.get_up_sql(file_path)
|
|
203
|
+
down_sql = await loader.get_down_sql(file_path)
|
|
204
|
+
|
|
205
|
+
if up_sql:
|
|
206
|
+
all_queries[f"migrate-{version}-up"] = SQL(up_sql[0])
|
|
207
|
+
if down_sql:
|
|
208
|
+
all_queries[f"migrate-{version}-down"] = SQL(down_sql[0])
|
|
209
|
+
|
|
210
|
+
except Exception as e:
|
|
211
|
+
logger.debug("Failed to load Python migration %s: %s", file_path, e)
|
|
214
212
|
|
|
215
213
|
return all_queries
|