sqlspec 0.18.0__py3-none-any.whl → 0.20.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of sqlspec might be problematic. Click here for more details.
- sqlspec/adapters/adbc/driver.py +192 -28
- sqlspec/adapters/asyncmy/driver.py +72 -15
- sqlspec/adapters/asyncpg/config.py +23 -3
- sqlspec/adapters/asyncpg/driver.py +30 -14
- sqlspec/adapters/bigquery/driver.py +79 -9
- sqlspec/adapters/duckdb/driver.py +39 -56
- sqlspec/adapters/oracledb/driver.py +99 -52
- sqlspec/adapters/psqlpy/driver.py +89 -31
- sqlspec/adapters/psycopg/driver.py +11 -23
- sqlspec/adapters/sqlite/driver.py +77 -8
- sqlspec/base.py +29 -25
- sqlspec/builder/__init__.py +1 -1
- sqlspec/builder/_base.py +4 -5
- sqlspec/builder/_column.py +3 -3
- sqlspec/builder/_ddl.py +5 -1
- sqlspec/builder/_delete.py +5 -6
- sqlspec/builder/_insert.py +6 -7
- sqlspec/builder/_merge.py +5 -5
- sqlspec/builder/_parsing_utils.py +3 -3
- sqlspec/builder/_select.py +6 -5
- sqlspec/builder/_update.py +4 -5
- sqlspec/builder/mixins/_cte_and_set_ops.py +5 -1
- sqlspec/builder/mixins/_delete_operations.py +5 -1
- sqlspec/builder/mixins/_insert_operations.py +5 -1
- sqlspec/builder/mixins/_join_operations.py +5 -0
- sqlspec/builder/mixins/_merge_operations.py +5 -1
- sqlspec/builder/mixins/_order_limit_operations.py +5 -1
- sqlspec/builder/mixins/_pivot_operations.py +4 -1
- sqlspec/builder/mixins/_select_operations.py +5 -1
- sqlspec/builder/mixins/_update_operations.py +5 -1
- sqlspec/builder/mixins/_where_clause.py +5 -1
- sqlspec/cli.py +281 -33
- sqlspec/config.py +160 -10
- sqlspec/core/compiler.py +11 -3
- sqlspec/core/filters.py +30 -9
- sqlspec/core/parameters.py +67 -67
- sqlspec/core/result.py +62 -31
- sqlspec/core/splitter.py +160 -34
- sqlspec/core/statement.py +95 -14
- sqlspec/driver/_common.py +12 -3
- sqlspec/driver/mixins/_result_tools.py +21 -4
- sqlspec/driver/mixins/_sql_translator.py +45 -7
- sqlspec/extensions/aiosql/adapter.py +1 -1
- sqlspec/extensions/litestar/_utils.py +1 -1
- sqlspec/extensions/litestar/handlers.py +21 -0
- sqlspec/extensions/litestar/plugin.py +15 -8
- sqlspec/loader.py +12 -12
- sqlspec/migrations/loaders.py +5 -2
- sqlspec/migrations/utils.py +2 -2
- sqlspec/storage/backends/obstore.py +1 -3
- sqlspec/storage/registry.py +1 -1
- sqlspec/utils/__init__.py +7 -0
- sqlspec/utils/deprecation.py +6 -0
- sqlspec/utils/fixtures.py +239 -30
- sqlspec/utils/module_loader.py +5 -1
- sqlspec/utils/serializers.py +6 -0
- sqlspec/utils/singleton.py +6 -0
- sqlspec/utils/sync_tools.py +10 -1
- {sqlspec-0.18.0.dist-info → sqlspec-0.20.0.dist-info}/METADATA +1 -1
- {sqlspec-0.18.0.dist-info → sqlspec-0.20.0.dist-info}/RECORD +64 -64
- {sqlspec-0.18.0.dist-info → sqlspec-0.20.0.dist-info}/WHEEL +0 -0
- {sqlspec-0.18.0.dist-info → sqlspec-0.20.0.dist-info}/entry_points.txt +0 -0
- {sqlspec-0.18.0.dist-info → sqlspec-0.20.0.dist-info}/licenses/LICENSE +0 -0
- {sqlspec-0.18.0.dist-info → sqlspec-0.20.0.dist-info}/licenses/NOTICE +0 -0
|
@@ -1,3 +1,5 @@
|
|
|
1
|
+
"""SQL translation mixin for cross-database compatibility."""
|
|
2
|
+
|
|
1
3
|
from typing import Final, NoReturn, Optional
|
|
2
4
|
|
|
3
5
|
from mypy_extensions import trait
|
|
@@ -33,8 +35,7 @@ class SQLTranslatorMixin:
|
|
|
33
35
|
Returns:
|
|
34
36
|
SQL string in target dialect
|
|
35
37
|
|
|
36
|
-
|
|
37
|
-
SQLConversionError: If parsing or conversion fails
|
|
38
|
+
|
|
38
39
|
"""
|
|
39
40
|
|
|
40
41
|
parsed_expression: Optional[exp.Expression] = None
|
|
@@ -53,7 +54,15 @@ class SQLTranslatorMixin:
|
|
|
53
54
|
return self._generate_sql_safely(parsed_expression, target_dialect, pretty)
|
|
54
55
|
|
|
55
56
|
def _parse_statement_safely(self, statement: "Statement") -> "exp.Expression":
|
|
56
|
-
"""Parse statement with
|
|
57
|
+
"""Parse statement with error handling.
|
|
58
|
+
|
|
59
|
+
Args:
|
|
60
|
+
statement: SQL statement to parse
|
|
61
|
+
|
|
62
|
+
Returns:
|
|
63
|
+
Parsed expression
|
|
64
|
+
|
|
65
|
+
"""
|
|
57
66
|
try:
|
|
58
67
|
sql_string = str(statement)
|
|
59
68
|
|
|
@@ -62,23 +71,52 @@ class SQLTranslatorMixin:
|
|
|
62
71
|
self._raise_parse_error(e)
|
|
63
72
|
|
|
64
73
|
def _generate_sql_safely(self, expression: "exp.Expression", dialect: DialectType, pretty: bool) -> str:
|
|
65
|
-
"""Generate SQL with
|
|
74
|
+
"""Generate SQL with error handling.
|
|
75
|
+
|
|
76
|
+
Args:
|
|
77
|
+
expression: Parsed expression to convert
|
|
78
|
+
dialect: Target SQL dialect
|
|
79
|
+
pretty: Whether to format the output SQL
|
|
80
|
+
|
|
81
|
+
Returns:
|
|
82
|
+
Generated SQL string
|
|
83
|
+
|
|
84
|
+
"""
|
|
66
85
|
try:
|
|
67
86
|
return expression.sql(dialect=dialect, pretty=pretty)
|
|
68
87
|
except Exception as e:
|
|
69
88
|
self._raise_conversion_error(dialect, e)
|
|
70
89
|
|
|
71
90
|
def _raise_statement_parse_error(self) -> NoReturn:
|
|
72
|
-
"""Raise error for unparsable statements.
|
|
91
|
+
"""Raise error for unparsable statements.
|
|
92
|
+
|
|
93
|
+
Raises:
|
|
94
|
+
SQLConversionError: Always raised
|
|
95
|
+
"""
|
|
73
96
|
msg = "Statement could not be parsed"
|
|
74
97
|
raise SQLConversionError(msg)
|
|
75
98
|
|
|
76
99
|
def _raise_parse_error(self, e: Exception) -> NoReturn:
|
|
77
|
-
"""Raise error for parsing failures.
|
|
100
|
+
"""Raise error for parsing failures.
|
|
101
|
+
|
|
102
|
+
Args:
|
|
103
|
+
e: Original exception that caused the failure
|
|
104
|
+
|
|
105
|
+
Raises:
|
|
106
|
+
SQLConversionError: Always raised
|
|
107
|
+
"""
|
|
78
108
|
error_msg = f"Failed to parse SQL statement: {e!s}"
|
|
79
109
|
raise SQLConversionError(error_msg) from e
|
|
80
110
|
|
|
81
111
|
def _raise_conversion_error(self, dialect: DialectType, e: Exception) -> NoReturn:
|
|
82
|
-
"""Raise error for conversion failures.
|
|
112
|
+
"""Raise error for conversion failures.
|
|
113
|
+
|
|
114
|
+
Args:
|
|
115
|
+
dialect: Target dialect that caused the failure
|
|
116
|
+
e: Original exception that caused the failure
|
|
117
|
+
|
|
118
|
+
Raises:
|
|
119
|
+
SQLConversionError: Always raised
|
|
120
|
+
"""
|
|
83
121
|
error_msg = f"Failed to convert SQL expression to {dialect}: {e!s}"
|
|
84
122
|
raise SQLConversionError(error_msg) from e
|
|
@@ -2,7 +2,7 @@
|
|
|
2
2
|
|
|
3
3
|
This module provides adapter classes that implement the aiosql adapter protocols
|
|
4
4
|
while using SQLSpec drivers under the hood. This enables users to load SQL queries
|
|
5
|
-
from files using aiosql while
|
|
5
|
+
from files using aiosql while using SQLSpec's features for execution and type mapping.
|
|
6
6
|
"""
|
|
7
7
|
|
|
8
8
|
import logging
|
|
@@ -12,7 +12,7 @@ def get_sqlspec_scope_state(scope: "Scope", key: str, default: Any = None, pop:
|
|
|
12
12
|
"""Get an internal value from connection scope state.
|
|
13
13
|
|
|
14
14
|
Note:
|
|
15
|
-
If called with a default value, this method behaves like
|
|
15
|
+
If called with a default value, this method behaves like `dict.setdefault()`, both setting the key in the
|
|
16
16
|
namespace to the default value, and returning it.
|
|
17
17
|
|
|
18
18
|
If called without a default value, the method behaves like `dict.get()`, returning ``None`` if the key does not
|
|
@@ -199,6 +199,17 @@ def pool_provider_maker(
|
|
|
199
199
|
def connection_provider_maker(
|
|
200
200
|
config: "DatabaseConfigProtocol[ConnectionT, PoolT, DriverT]", pool_key: str, connection_key: str
|
|
201
201
|
) -> "Callable[[State, Scope], AsyncGenerator[ConnectionT, None]]":
|
|
202
|
+
"""Create provider for database connections with proper lifecycle management.
|
|
203
|
+
|
|
204
|
+
Args:
|
|
205
|
+
config: The database configuration object.
|
|
206
|
+
pool_key: The key used to retrieve the connection pool from `app.state`.
|
|
207
|
+
connection_key: The key used to store the connection in the ASGI scope.
|
|
208
|
+
|
|
209
|
+
Returns:
|
|
210
|
+
The connection provider function.
|
|
211
|
+
"""
|
|
212
|
+
|
|
202
213
|
async def provide_connection(state: "State", scope: "Scope") -> "AsyncGenerator[ConnectionT, None]":
|
|
203
214
|
if (db_pool := state.get(pool_key)) is None:
|
|
204
215
|
msg = f"Database pool with key '{pool_key}' not found. Cannot create a connection."
|
|
@@ -230,6 +241,16 @@ def connection_provider_maker(
|
|
|
230
241
|
def session_provider_maker(
|
|
231
242
|
config: "DatabaseConfigProtocol[ConnectionT, PoolT, DriverT]", connection_dependency_key: str
|
|
232
243
|
) -> "Callable[[Any], AsyncGenerator[DriverT, None]]":
|
|
244
|
+
"""Create provider for database driver sessions.
|
|
245
|
+
|
|
246
|
+
Args:
|
|
247
|
+
config: The database configuration object.
|
|
248
|
+
connection_dependency_key: The key used for connection dependency injection.
|
|
249
|
+
|
|
250
|
+
Returns:
|
|
251
|
+
The session provider function.
|
|
252
|
+
"""
|
|
253
|
+
|
|
233
254
|
async def provide_session(*args: Any, **kwargs: Any) -> "AsyncGenerator[DriverT, None]":
|
|
234
255
|
yield cast("DriverT", config.driver_type(connection=args[0] if args else kwargs.get(connection_dependency_key))) # pyright: ignore
|
|
235
256
|
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
from typing import TYPE_CHECKING,
|
|
1
|
+
from typing import TYPE_CHECKING, Optional, Union
|
|
2
2
|
|
|
3
3
|
from litestar.di import Provide
|
|
4
4
|
from litestar.plugins import CLIPlugin, InitPluginProtocol
|
|
@@ -14,23 +14,31 @@ if TYPE_CHECKING:
|
|
|
14
14
|
from click import Group
|
|
15
15
|
from litestar.config.app import AppConfig
|
|
16
16
|
|
|
17
|
+
from sqlspec.loader import SQLFileLoader
|
|
18
|
+
|
|
17
19
|
logger = get_logger("extensions.litestar")
|
|
18
20
|
|
|
19
21
|
|
|
20
|
-
class SQLSpec(InitPluginProtocol, CLIPlugin
|
|
22
|
+
class SQLSpec(SQLSpecBase, InitPluginProtocol, CLIPlugin):
|
|
21
23
|
"""Litestar plugin for SQLSpec database integration."""
|
|
22
24
|
|
|
23
|
-
__slots__ = ("
|
|
25
|
+
__slots__ = ("_plugin_configs",)
|
|
24
26
|
|
|
25
|
-
def __init__(
|
|
27
|
+
def __init__(
|
|
28
|
+
self,
|
|
29
|
+
config: Union["SyncConfigT", "AsyncConfigT", "DatabaseConfig", list["DatabaseConfig"]],
|
|
30
|
+
*,
|
|
31
|
+
loader: "Optional[SQLFileLoader]" = None,
|
|
32
|
+
) -> None:
|
|
26
33
|
"""Initialize SQLSpec plugin.
|
|
27
34
|
|
|
28
35
|
Args:
|
|
29
36
|
config: Database configuration for SQLSpec plugin.
|
|
37
|
+
loader: Optional SQL file loader instance.
|
|
30
38
|
"""
|
|
31
|
-
|
|
39
|
+
super().__init__(loader=loader)
|
|
32
40
|
if isinstance(config, DatabaseConfigProtocol):
|
|
33
|
-
self._plugin_configs: list[DatabaseConfig] = [DatabaseConfig(config=config)]
|
|
41
|
+
self._plugin_configs: list[DatabaseConfig] = [DatabaseConfig(config=config)] # pyright: ignore
|
|
34
42
|
elif isinstance(config, DatabaseConfig):
|
|
35
43
|
self._plugin_configs = [config]
|
|
36
44
|
else:
|
|
@@ -83,8 +91,7 @@ class SQLSpec(InitPluginProtocol, CLIPlugin, SQLSpecBase):
|
|
|
83
91
|
app_config.signature_types.append(c.config.connection_type) # type: ignore[union-attr]
|
|
84
92
|
app_config.signature_types.append(c.config.driver_type) # type: ignore[union-attr]
|
|
85
93
|
|
|
86
|
-
|
|
87
|
-
signature_namespace.update(c.config.get_signature_namespace()) # type: ignore[attr-defined]
|
|
94
|
+
signature_namespace.update(c.config.get_signature_namespace()) # type: ignore[union-attr]
|
|
88
95
|
|
|
89
96
|
app_config.before_send.append(c.before_send_handler)
|
|
90
97
|
app_config.lifespan.append(c.lifespan_handler) # pyright: ignore[reportUnknownMemberType]
|
sqlspec/loader.py
CHANGED
|
@@ -55,10 +55,10 @@ def _normalize_query_name(name: str) -> str:
|
|
|
55
55
|
"""Normalize query name to be a valid Python identifier.
|
|
56
56
|
|
|
57
57
|
Args:
|
|
58
|
-
name: Raw query name from SQL file
|
|
58
|
+
name: Raw query name from SQL file.
|
|
59
59
|
|
|
60
60
|
Returns:
|
|
61
|
-
Normalized query name suitable as Python identifier
|
|
61
|
+
Normalized query name suitable as Python identifier.
|
|
62
62
|
"""
|
|
63
63
|
return TRIM_SPECIAL_CHARS.sub("", name).replace("-", "_")
|
|
64
64
|
|
|
@@ -67,10 +67,10 @@ def _normalize_dialect(dialect: str) -> str:
|
|
|
67
67
|
"""Normalize dialect name with aliases.
|
|
68
68
|
|
|
69
69
|
Args:
|
|
70
|
-
dialect: Raw dialect name from SQL file
|
|
70
|
+
dialect: Raw dialect name from SQL file.
|
|
71
71
|
|
|
72
72
|
Returns:
|
|
73
|
-
Normalized dialect name
|
|
73
|
+
Normalized dialect name.
|
|
74
74
|
"""
|
|
75
75
|
normalized = dialect.lower().strip()
|
|
76
76
|
return DIALECT_ALIASES.get(normalized, normalized)
|
|
@@ -80,10 +80,10 @@ def _normalize_dialect_for_sqlglot(dialect: str) -> str:
|
|
|
80
80
|
"""Normalize dialect name for SQLGlot compatibility.
|
|
81
81
|
|
|
82
82
|
Args:
|
|
83
|
-
dialect: Dialect name from SQL file or parameter
|
|
83
|
+
dialect: Dialect name from SQL file or parameter.
|
|
84
84
|
|
|
85
85
|
Returns:
|
|
86
|
-
SQLGlot-compatible dialect name
|
|
86
|
+
SQLGlot-compatible dialect name.
|
|
87
87
|
"""
|
|
88
88
|
normalized = dialect.lower().strip()
|
|
89
89
|
return DIALECT_ALIASES.get(normalized, normalized)
|
|
@@ -125,7 +125,7 @@ class SQLFile:
|
|
|
125
125
|
"""Initialize SQLFile.
|
|
126
126
|
|
|
127
127
|
Args:
|
|
128
|
-
content:
|
|
128
|
+
content: Raw SQL content from the file.
|
|
129
129
|
path: Path where the SQL file was loaded from.
|
|
130
130
|
metadata: Optional metadata associated with the SQL file.
|
|
131
131
|
loaded_at: Timestamp when the file was loaded.
|
|
@@ -150,7 +150,7 @@ class CachedSQLFile:
|
|
|
150
150
|
"""Initialize cached SQL file.
|
|
151
151
|
|
|
152
152
|
Args:
|
|
153
|
-
sql_file:
|
|
153
|
+
sql_file: Original SQLFile with content and metadata.
|
|
154
154
|
parsed_statements: Named statements from the file.
|
|
155
155
|
"""
|
|
156
156
|
self.sql_file = sql_file
|
|
@@ -291,15 +291,15 @@ class SQLFileLoader:
|
|
|
291
291
|
"""Parse SQL content and extract named statements with dialect specifications.
|
|
292
292
|
|
|
293
293
|
Args:
|
|
294
|
-
content: Raw SQL file content to parse
|
|
295
|
-
file_path: File path for error reporting
|
|
294
|
+
content: Raw SQL file content to parse.
|
|
295
|
+
file_path: File path for error reporting.
|
|
296
296
|
|
|
297
297
|
Returns:
|
|
298
|
-
Dictionary mapping normalized statement names to NamedStatement objects
|
|
298
|
+
Dictionary mapping normalized statement names to NamedStatement objects.
|
|
299
299
|
|
|
300
300
|
Raises:
|
|
301
301
|
SQLFileParseError: If no named statements found, duplicate names exist,
|
|
302
|
-
or invalid dialect names are specified
|
|
302
|
+
or invalid dialect names are specified.
|
|
303
303
|
"""
|
|
304
304
|
statements: dict[str, NamedStatement] = {}
|
|
305
305
|
|
sqlspec/migrations/loaders.py
CHANGED
|
@@ -1,4 +1,7 @@
|
|
|
1
|
-
"""Migration loader
|
|
1
|
+
"""Migration loader implementations for SQLSpec.
|
|
2
|
+
|
|
3
|
+
This module provides loader classes for different migration file formats.
|
|
4
|
+
"""
|
|
2
5
|
|
|
3
6
|
import abc
|
|
4
7
|
import inspect
|
|
@@ -341,7 +344,7 @@ class PythonFileLoader(BaseMigrationLoader):
|
|
|
341
344
|
return module
|
|
342
345
|
|
|
343
346
|
def _normalize_and_validate_sql(self, sql: Any, migration_path: Path) -> list[str]:
|
|
344
|
-
"""Validate
|
|
347
|
+
"""Validate and normalize SQL return value to list of strings.
|
|
345
348
|
|
|
346
349
|
Args:
|
|
347
350
|
sql: Return value from migration function.
|
sqlspec/migrations/utils.py
CHANGED
|
@@ -42,7 +42,7 @@ Author: {get_author()}
|
|
|
42
42
|
|
|
43
43
|
Migration functions can use either naming convention:
|
|
44
44
|
- Preferred: up()/down()
|
|
45
|
-
-
|
|
45
|
+
- Alternate: migrate_up()/migrate_down()
|
|
46
46
|
|
|
47
47
|
Both can be synchronous or asynchronous:
|
|
48
48
|
- def up(): ...
|
|
@@ -71,7 +71,7 @@ def up() -> Union[str, List[str]]:
|
|
|
71
71
|
|
|
72
72
|
|
|
73
73
|
def down() -> Union[str, List[str]]:
|
|
74
|
-
"""Reverse the migration
|
|
74
|
+
"""Reverse the migration.
|
|
75
75
|
|
|
76
76
|
Returns:
|
|
77
77
|
SQL statement(s) to execute for downgrade.
|
|
@@ -1,7 +1,7 @@
|
|
|
1
1
|
"""Object storage backend using obstore.
|
|
2
2
|
|
|
3
3
|
Implements the ObjectStoreProtocol using obstore for S3, GCS, Azure,
|
|
4
|
-
and local file storage
|
|
4
|
+
and local file storage.
|
|
5
5
|
"""
|
|
6
6
|
|
|
7
7
|
from __future__ import annotations
|
|
@@ -58,8 +58,6 @@ class ObStoreBackend(ObjectStoreBase, HasStorageCapabilities):
|
|
|
58
58
|
Uses obstore's Rust-based implementation for storage operations.
|
|
59
59
|
Supports AWS S3, Google Cloud Storage, Azure Blob Storage,
|
|
60
60
|
local filesystem, and HTTP endpoints.
|
|
61
|
-
|
|
62
|
-
Includes Arrow support.
|
|
63
61
|
"""
|
|
64
62
|
|
|
65
63
|
capabilities: ClassVar[StorageCapabilities] = StorageCapabilities(
|
sqlspec/storage/registry.py
CHANGED
sqlspec/utils/__init__.py
CHANGED
|
@@ -1,3 +1,10 @@
|
|
|
1
|
+
"""Utility functions and classes for SQLSpec.
|
|
2
|
+
|
|
3
|
+
This package provides various utility modules for deprecation handling,
|
|
4
|
+
fixture loading, logging, module loading, singleton patterns, sync/async
|
|
5
|
+
conversion, text processing, and type guards.
|
|
6
|
+
"""
|
|
7
|
+
|
|
1
8
|
from sqlspec.utils import deprecation, fixtures, logging, module_loader, singleton, sync_tools, text, type_guards
|
|
2
9
|
|
|
3
10
|
__all__ = ("deprecation", "fixtures", "logging", "module_loader", "singleton", "sync_tools", "text", "type_guards")
|
sqlspec/utils/deprecation.py
CHANGED
|
@@ -1,3 +1,9 @@
|
|
|
1
|
+
"""Deprecation utilities for SQLSpec.
|
|
2
|
+
|
|
3
|
+
Provides decorators and warning functions for marking deprecated functionality.
|
|
4
|
+
Used to communicate API changes and migration paths to users.
|
|
5
|
+
"""
|
|
6
|
+
|
|
1
7
|
import inspect
|
|
2
8
|
from functools import wraps
|
|
3
9
|
from typing import Callable, Literal, Optional
|
sqlspec/utils/fixtures.py
CHANGED
|
@@ -1,58 +1,267 @@
|
|
|
1
|
+
"""Fixture loading utilities for SQLSpec.
|
|
2
|
+
|
|
3
|
+
Provides functions for writing, loading and parsing JSON fixture files
|
|
4
|
+
used in testing and development. Supports both sync and async operations.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
import gzip
|
|
8
|
+
import zipfile
|
|
1
9
|
from pathlib import Path
|
|
2
|
-
from typing import Any
|
|
10
|
+
from typing import TYPE_CHECKING, Any, Union
|
|
11
|
+
|
|
12
|
+
from sqlspec.storage import storage_registry
|
|
13
|
+
from sqlspec.utils.serializers import from_json as decode_json
|
|
14
|
+
from sqlspec.utils.serializers import to_json as encode_json
|
|
15
|
+
from sqlspec.utils.sync_tools import async_
|
|
16
|
+
from sqlspec.utils.type_guards import schema_dump
|
|
17
|
+
|
|
18
|
+
if TYPE_CHECKING:
|
|
19
|
+
from sqlspec.typing import ModelDictList, SupportedSchemaModel
|
|
20
|
+
|
|
21
|
+
__all__ = ("open_fixture", "open_fixture_async", "write_fixture", "write_fixture_async")
|
|
3
22
|
|
|
4
|
-
from sqlspec._serialization import decode_json
|
|
5
|
-
from sqlspec.exceptions import MissingDependencyError
|
|
6
23
|
|
|
7
|
-
|
|
24
|
+
def _read_compressed_file(file_path: Path) -> str:
|
|
25
|
+
"""Read and decompress a file based on its extension.
|
|
26
|
+
|
|
27
|
+
Args:
|
|
28
|
+
file_path: Path to the file to read
|
|
29
|
+
|
|
30
|
+
Returns:
|
|
31
|
+
The decompressed file content as a string
|
|
32
|
+
|
|
33
|
+
Raises:
|
|
34
|
+
ValueError: If the file format is not supported
|
|
35
|
+
"""
|
|
36
|
+
if file_path.suffix == ".gz":
|
|
37
|
+
with gzip.open(file_path, mode="rt", encoding="utf-8") as f:
|
|
38
|
+
return f.read()
|
|
39
|
+
elif file_path.suffix == ".zip":
|
|
40
|
+
with zipfile.ZipFile(file_path, "r") as zf:
|
|
41
|
+
# Assume the JSON file inside has the same name without .zip
|
|
42
|
+
json_name = file_path.stem + ".json"
|
|
43
|
+
if json_name in zf.namelist():
|
|
44
|
+
with zf.open(json_name) as f:
|
|
45
|
+
return f.read().decode("utf-8")
|
|
46
|
+
# If not found, try the first JSON file in the archive
|
|
47
|
+
json_files = [name for name in zf.namelist() if name.endswith(".json")]
|
|
48
|
+
if json_files:
|
|
49
|
+
with zf.open(json_files[0]) as f:
|
|
50
|
+
return f.read().decode("utf-8")
|
|
51
|
+
msg = f"No JSON file found in ZIP archive: {file_path}"
|
|
52
|
+
raise ValueError(msg)
|
|
53
|
+
else:
|
|
54
|
+
msg = f"Unsupported compression format: {file_path.suffix}"
|
|
55
|
+
raise ValueError(msg)
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
def _find_fixture_file(fixtures_path: Any, fixture_name: str) -> Path:
|
|
59
|
+
"""Find a fixture file with various extensions.
|
|
60
|
+
|
|
61
|
+
Args:
|
|
62
|
+
fixtures_path: The path to look for fixtures
|
|
63
|
+
fixture_name: The fixture name to load
|
|
64
|
+
|
|
65
|
+
Returns:
|
|
66
|
+
Path to the found fixture file
|
|
67
|
+
|
|
68
|
+
Raises:
|
|
69
|
+
FileNotFoundError: If no fixture file is found
|
|
70
|
+
"""
|
|
71
|
+
base_path = Path(fixtures_path)
|
|
72
|
+
|
|
73
|
+
# Try different file extensions in order of preference
|
|
74
|
+
for extension in [".json", ".json.gz", ".json.zip"]:
|
|
75
|
+
fixture_path = base_path / f"{fixture_name}{extension}"
|
|
76
|
+
if fixture_path.exists():
|
|
77
|
+
return fixture_path
|
|
78
|
+
|
|
79
|
+
# If no file found, raise error
|
|
80
|
+
msg = f"Could not find the {fixture_name} fixture"
|
|
81
|
+
raise FileNotFoundError(msg)
|
|
8
82
|
|
|
9
83
|
|
|
10
84
|
def open_fixture(fixtures_path: Any, fixture_name: str) -> Any:
|
|
11
|
-
"""Load and parse a JSON fixture file.
|
|
85
|
+
"""Load and parse a JSON fixture file with compression support.
|
|
86
|
+
|
|
87
|
+
Supports reading from:
|
|
88
|
+
- Regular JSON files (.json)
|
|
89
|
+
- Gzipped JSON files (.json.gz)
|
|
90
|
+
- Zipped JSON files (.json.zip)
|
|
12
91
|
|
|
13
92
|
Args:
|
|
14
|
-
fixtures_path: The path to look for fixtures (pathlib.Path
|
|
93
|
+
fixtures_path: The path to look for fixtures (pathlib.Path)
|
|
15
94
|
fixture_name: The fixture name to load.
|
|
16
95
|
|
|
17
|
-
Raises:
|
|
18
|
-
FileNotFoundError: Fixtures not found.
|
|
19
96
|
|
|
20
97
|
Returns:
|
|
21
98
|
The parsed JSON data
|
|
22
99
|
"""
|
|
100
|
+
fixture_path = _find_fixture_file(fixtures_path, fixture_name)
|
|
23
101
|
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
|
|
102
|
+
if fixture_path.suffix in {".gz", ".zip"}:
|
|
103
|
+
f_data = _read_compressed_file(fixture_path)
|
|
104
|
+
else:
|
|
105
|
+
# Regular JSON file
|
|
106
|
+
with fixture_path.open(mode="r", encoding="utf-8") as f:
|
|
27
107
|
f_data = f.read()
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
raise FileNotFoundError(msg)
|
|
108
|
+
|
|
109
|
+
return decode_json(f_data)
|
|
31
110
|
|
|
32
111
|
|
|
33
112
|
async def open_fixture_async(fixtures_path: Any, fixture_name: str) -> Any:
|
|
34
|
-
"""Load and parse a JSON fixture file asynchronously.
|
|
113
|
+
"""Load and parse a JSON fixture file asynchronously with compression support.
|
|
114
|
+
|
|
115
|
+
Supports reading from:
|
|
116
|
+
- Regular JSON files (.json)
|
|
117
|
+
- Gzipped JSON files (.json.gz)
|
|
118
|
+
- Zipped JSON files (.json.zip)
|
|
119
|
+
|
|
120
|
+
For compressed files, uses sync reading in a thread pool since gzip and zipfile
|
|
121
|
+
don't have native async equivalents.
|
|
35
122
|
|
|
36
123
|
Args:
|
|
37
|
-
fixtures_path: The path to look for fixtures (pathlib.Path
|
|
124
|
+
fixtures_path: The path to look for fixtures (pathlib.Path)
|
|
38
125
|
fixture_name: The fixture name to load.
|
|
39
126
|
|
|
40
|
-
Raises:
|
|
41
|
-
FileNotFoundError: Fixtures not found.
|
|
42
|
-
MissingDependencyError: The `anyio` library is required to use this function.
|
|
43
127
|
|
|
44
128
|
Returns:
|
|
45
129
|
The parsed JSON data
|
|
46
130
|
"""
|
|
131
|
+
# Use sync path finding since it's fast
|
|
132
|
+
fixture_path = _find_fixture_file(fixtures_path, fixture_name)
|
|
133
|
+
|
|
134
|
+
if fixture_path.suffix in {".gz", ".zip"}:
|
|
135
|
+
# For compressed files, run in thread pool since they don't have async equivalents
|
|
136
|
+
read_func = async_(_read_compressed_file)
|
|
137
|
+
f_data = await read_func(fixture_path)
|
|
138
|
+
else:
|
|
139
|
+
# For regular JSON files, use async file reading
|
|
140
|
+
async_read = async_(lambda p: p.read_text(encoding="utf-8"))
|
|
141
|
+
f_data = await async_read(fixture_path)
|
|
142
|
+
|
|
143
|
+
return decode_json(f_data)
|
|
144
|
+
|
|
145
|
+
|
|
146
|
+
def _serialize_data(data: Any) -> str:
|
|
147
|
+
"""Serialize data to JSON string, handling different input types.
|
|
148
|
+
|
|
149
|
+
Args:
|
|
150
|
+
data: Data to serialize. Can be dict, list, or SQLSpec model types
|
|
151
|
+
|
|
152
|
+
Returns:
|
|
153
|
+
JSON string representation of the data
|
|
154
|
+
"""
|
|
155
|
+
if isinstance(data, (list, tuple)):
|
|
156
|
+
# List of models or dicts - convert each item, handling primitives
|
|
157
|
+
serialized_items: list[Any] = []
|
|
158
|
+
for item in data:
|
|
159
|
+
# Use schema_dump for structured data, pass primitives through
|
|
160
|
+
if isinstance(item, (str, int, float, bool, type(None))):
|
|
161
|
+
serialized_items.append(item)
|
|
162
|
+
else:
|
|
163
|
+
serialized_items.append(schema_dump(item))
|
|
164
|
+
return encode_json(serialized_items)
|
|
165
|
+
# Single model, dict, or other type - try schema_dump first, fallback for primitives
|
|
166
|
+
if isinstance(data, (str, int, float, bool, type(None))):
|
|
167
|
+
return encode_json(data)
|
|
168
|
+
return encode_json(schema_dump(data))
|
|
169
|
+
|
|
170
|
+
|
|
171
|
+
def write_fixture(
|
|
172
|
+
fixtures_path: str,
|
|
173
|
+
table_name: str,
|
|
174
|
+
data: "Union[ModelDictList, list[dict[str, Any]], SupportedSchemaModel]",
|
|
175
|
+
storage_backend: str = "local",
|
|
176
|
+
compress: bool = False,
|
|
177
|
+
**storage_kwargs: Any,
|
|
178
|
+
) -> None:
|
|
179
|
+
"""Write fixture data to storage using SQLSpec storage backend.
|
|
180
|
+
|
|
181
|
+
Args:
|
|
182
|
+
fixtures_path: Base path where fixtures should be stored
|
|
183
|
+
table_name: Name of the table/fixture (used as filename)
|
|
184
|
+
data: Data to write - can be list of dicts, models, or single model
|
|
185
|
+
storage_backend: Storage backend to use (default: "local")
|
|
186
|
+
compress: Whether to gzip compress the output
|
|
187
|
+
**storage_kwargs: Additional arguments for the storage backend
|
|
188
|
+
|
|
189
|
+
Raises:
|
|
190
|
+
ValueError: If storage backend is not found
|
|
191
|
+
"""
|
|
192
|
+
# Get the storage backend using URI-based registration
|
|
193
|
+
# For "local" backend, use file:// URI with base_path parameter
|
|
194
|
+
if storage_backend == "local":
|
|
195
|
+
uri = "file://"
|
|
196
|
+
storage_kwargs["base_path"] = str(Path(fixtures_path).resolve())
|
|
197
|
+
else:
|
|
198
|
+
uri = storage_backend
|
|
199
|
+
|
|
47
200
|
try:
|
|
48
|
-
|
|
49
|
-
except
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
|
|
201
|
+
storage = storage_registry.get(uri, **storage_kwargs)
|
|
202
|
+
except Exception as exc:
|
|
203
|
+
msg = f"Failed to get storage backend for '{storage_backend}': {exc}"
|
|
204
|
+
raise ValueError(msg) from exc
|
|
205
|
+
|
|
206
|
+
# Serialize the data
|
|
207
|
+
json_content = _serialize_data(data)
|
|
208
|
+
|
|
209
|
+
# Determine file path and content - use relative path from the base path
|
|
210
|
+
if compress:
|
|
211
|
+
file_path = f"{table_name}.json.gz"
|
|
212
|
+
content = gzip.compress(json_content.encode("utf-8"))
|
|
213
|
+
storage.write_bytes(file_path, content)
|
|
214
|
+
else:
|
|
215
|
+
file_path = f"{table_name}.json"
|
|
216
|
+
storage.write_text(file_path, json_content)
|
|
217
|
+
|
|
218
|
+
|
|
219
|
+
async def write_fixture_async(
|
|
220
|
+
fixtures_path: str,
|
|
221
|
+
table_name: str,
|
|
222
|
+
data: "Union[ModelDictList, list[dict[str, Any]], SupportedSchemaModel]",
|
|
223
|
+
storage_backend: str = "local",
|
|
224
|
+
compress: bool = False,
|
|
225
|
+
**storage_kwargs: Any,
|
|
226
|
+
) -> None:
|
|
227
|
+
"""Write fixture data to storage using SQLSpec storage backend asynchronously.
|
|
228
|
+
|
|
229
|
+
Args:
|
|
230
|
+
fixtures_path: Base path where fixtures should be stored
|
|
231
|
+
table_name: Name of the table/fixture (used as filename)
|
|
232
|
+
data: Data to write - can be list of dicts, models, or single model
|
|
233
|
+
storage_backend: Storage backend to use (default: "local")
|
|
234
|
+
compress: Whether to gzip compress the output
|
|
235
|
+
**storage_kwargs: Additional arguments for the storage backend
|
|
236
|
+
|
|
237
|
+
Raises:
|
|
238
|
+
ValueError: If storage backend is not found
|
|
239
|
+
"""
|
|
240
|
+
# Get the storage backend using URI-based registration
|
|
241
|
+
# For "local" backend, use file:// URI with base_path parameter
|
|
242
|
+
if storage_backend == "local":
|
|
243
|
+
uri = "file://"
|
|
244
|
+
storage_kwargs["base_path"] = str(Path(fixtures_path).resolve())
|
|
245
|
+
else:
|
|
246
|
+
uri = storage_backend
|
|
247
|
+
|
|
248
|
+
try:
|
|
249
|
+
storage = storage_registry.get(uri, **storage_kwargs)
|
|
250
|
+
except Exception as exc:
|
|
251
|
+
msg = f"Failed to get storage backend for '{storage_backend}': {exc}"
|
|
252
|
+
raise ValueError(msg) from exc
|
|
253
|
+
|
|
254
|
+
# Serialize the data in a thread pool since it might be CPU intensive
|
|
255
|
+
serialize_func = async_(_serialize_data)
|
|
256
|
+
json_content = await serialize_func(data)
|
|
257
|
+
|
|
258
|
+
# Determine file path and content
|
|
259
|
+
if compress:
|
|
260
|
+
file_path = f"{table_name}.json.gz"
|
|
261
|
+
# Compress in thread pool since gzip is CPU intensive
|
|
262
|
+
compress_func = async_(lambda content: gzip.compress(content.encode("utf-8")))
|
|
263
|
+
content = await compress_func(json_content)
|
|
264
|
+
await storage.write_bytes_async(file_path, content)
|
|
265
|
+
else:
|
|
266
|
+
file_path = f"{table_name}.json"
|
|
267
|
+
await storage.write_text_async(file_path, json_content)
|
sqlspec/utils/module_loader.py
CHANGED
|
@@ -1,4 +1,8 @@
|
|
|
1
|
-
"""
|
|
1
|
+
"""Module loading utilities for SQLSpec.
|
|
2
|
+
|
|
3
|
+
Provides functions for dynamic module imports and path resolution.
|
|
4
|
+
Used for loading modules from dotted paths and converting module paths to filesystem paths.
|
|
5
|
+
"""
|
|
2
6
|
|
|
3
7
|
import importlib
|
|
4
8
|
from importlib.util import find_spec
|