sqlspec 0.17.0__py3-none-any.whl → 0.18.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of sqlspec might be problematic. Click here for more details.
- sqlspec/__init__.py +1 -1
- sqlspec/_sql.py +188 -234
- sqlspec/adapters/adbc/config.py +24 -30
- sqlspec/adapters/adbc/driver.py +42 -61
- sqlspec/adapters/aiosqlite/config.py +5 -10
- sqlspec/adapters/aiosqlite/driver.py +9 -25
- sqlspec/adapters/aiosqlite/pool.py +43 -35
- sqlspec/adapters/asyncmy/config.py +10 -7
- sqlspec/adapters/asyncmy/driver.py +18 -39
- sqlspec/adapters/asyncpg/config.py +4 -0
- sqlspec/adapters/asyncpg/driver.py +32 -79
- sqlspec/adapters/bigquery/config.py +12 -65
- sqlspec/adapters/bigquery/driver.py +39 -133
- sqlspec/adapters/duckdb/config.py +11 -15
- sqlspec/adapters/duckdb/driver.py +61 -85
- sqlspec/adapters/duckdb/pool.py +2 -5
- sqlspec/adapters/oracledb/_types.py +8 -1
- sqlspec/adapters/oracledb/config.py +55 -38
- sqlspec/adapters/oracledb/driver.py +35 -92
- sqlspec/adapters/oracledb/migrations.py +257 -0
- sqlspec/adapters/psqlpy/config.py +13 -9
- sqlspec/adapters/psqlpy/driver.py +28 -103
- sqlspec/adapters/psycopg/config.py +9 -5
- sqlspec/adapters/psycopg/driver.py +107 -175
- sqlspec/adapters/sqlite/config.py +7 -5
- sqlspec/adapters/sqlite/driver.py +37 -73
- sqlspec/adapters/sqlite/pool.py +3 -12
- sqlspec/base.py +1 -8
- sqlspec/builder/__init__.py +1 -1
- sqlspec/builder/_base.py +34 -20
- sqlspec/builder/_column.py +5 -1
- sqlspec/builder/_ddl.py +407 -183
- sqlspec/builder/_expression_wrappers.py +46 -0
- sqlspec/builder/_insert.py +2 -4
- sqlspec/builder/_update.py +5 -5
- sqlspec/builder/mixins/_insert_operations.py +26 -6
- sqlspec/builder/mixins/_merge_operations.py +1 -1
- sqlspec/builder/mixins/_order_limit_operations.py +16 -4
- sqlspec/builder/mixins/_select_operations.py +3 -7
- sqlspec/builder/mixins/_update_operations.py +4 -4
- sqlspec/config.py +32 -13
- sqlspec/core/__init__.py +89 -14
- sqlspec/core/cache.py +57 -104
- sqlspec/core/compiler.py +57 -112
- sqlspec/core/filters.py +1 -21
- sqlspec/core/hashing.py +13 -47
- sqlspec/core/parameters.py +272 -261
- sqlspec/core/result.py +12 -27
- sqlspec/core/splitter.py +17 -21
- sqlspec/core/statement.py +150 -159
- sqlspec/driver/_async.py +2 -15
- sqlspec/driver/_common.py +16 -95
- sqlspec/driver/_sync.py +2 -15
- sqlspec/driver/mixins/_result_tools.py +8 -29
- sqlspec/driver/mixins/_sql_translator.py +6 -8
- sqlspec/exceptions.py +1 -2
- sqlspec/loader.py +43 -115
- sqlspec/migrations/__init__.py +1 -1
- sqlspec/migrations/base.py +34 -45
- sqlspec/migrations/commands.py +34 -15
- sqlspec/migrations/loaders.py +1 -1
- sqlspec/migrations/runner.py +104 -19
- sqlspec/migrations/tracker.py +49 -2
- sqlspec/protocols.py +13 -6
- sqlspec/storage/__init__.py +4 -4
- sqlspec/storage/backends/fsspec.py +5 -6
- sqlspec/storage/backends/obstore.py +7 -8
- sqlspec/storage/registry.py +3 -3
- sqlspec/utils/__init__.py +2 -2
- sqlspec/utils/logging.py +6 -10
- sqlspec/utils/sync_tools.py +27 -4
- sqlspec/utils/text.py +6 -1
- {sqlspec-0.17.0.dist-info → sqlspec-0.18.0.dist-info}/METADATA +1 -1
- sqlspec-0.18.0.dist-info/RECORD +138 -0
- sqlspec/builder/_ddl_utils.py +0 -103
- sqlspec-0.17.0.dist-info/RECORD +0 -137
- {sqlspec-0.17.0.dist-info → sqlspec-0.18.0.dist-info}/WHEEL +0 -0
- {sqlspec-0.17.0.dist-info → sqlspec-0.18.0.dist-info}/entry_points.txt +0 -0
- {sqlspec-0.17.0.dist-info → sqlspec-0.18.0.dist-info}/licenses/LICENSE +0 -0
- {sqlspec-0.17.0.dist-info → sqlspec-0.18.0.dist-info}/licenses/NOTICE +0 -0
sqlspec/driver/_common.py
CHANGED
|
@@ -1,8 +1,4 @@
|
|
|
1
|
-
"""Common driver attributes and utilities.
|
|
2
|
-
|
|
3
|
-
This module provides core driver infrastructure including execution result handling,
|
|
4
|
-
common driver attributes, parameter processing, and SQL compilation utilities.
|
|
5
|
-
"""
|
|
1
|
+
"""Common driver attributes and utilities."""
|
|
6
2
|
|
|
7
3
|
from typing import TYPE_CHECKING, Any, Final, NamedTuple, Optional, Union, cast
|
|
8
4
|
|
|
@@ -10,7 +6,7 @@ from mypy_extensions import trait
|
|
|
10
6
|
from sqlglot import exp
|
|
11
7
|
|
|
12
8
|
from sqlspec.builder import QueryBuilder
|
|
13
|
-
from sqlspec.core import SQL,
|
|
9
|
+
from sqlspec.core import SQL, ParameterStyle, SQLResult, Statement, StatementConfig, TypedParameter
|
|
14
10
|
from sqlspec.core.cache import get_cache_config, sql_cache
|
|
15
11
|
from sqlspec.core.splitter import split_sql_script
|
|
16
12
|
from sqlspec.exceptions import ImproperConfigurationError
|
|
@@ -38,19 +34,7 @@ logger = get_logger("driver")
|
|
|
38
34
|
|
|
39
35
|
|
|
40
36
|
class ScriptExecutionResult(NamedTuple):
|
|
41
|
-
"""Result from script execution with statement count information.
|
|
42
|
-
|
|
43
|
-
This named tuple eliminates the need for redundant script splitting
|
|
44
|
-
by providing statement count information during execution rather than
|
|
45
|
-
requiring re-parsing after execution.
|
|
46
|
-
|
|
47
|
-
Attributes:
|
|
48
|
-
cursor_result: The result returned by the database cursor/driver
|
|
49
|
-
rowcount_override: Optional override for the number of affected rows
|
|
50
|
-
special_data: Any special metadata or additional information
|
|
51
|
-
statement_count: Total number of statements in the script
|
|
52
|
-
successful_statements: Number of statements that executed successfully
|
|
53
|
-
"""
|
|
37
|
+
"""Result from script execution with statement count information."""
|
|
54
38
|
|
|
55
39
|
cursor_result: Any
|
|
56
40
|
rowcount_override: Optional[int]
|
|
@@ -60,24 +44,7 @@ class ScriptExecutionResult(NamedTuple):
|
|
|
60
44
|
|
|
61
45
|
|
|
62
46
|
class ExecutionResult(NamedTuple):
|
|
63
|
-
"""
|
|
64
|
-
|
|
65
|
-
This named tuple consolidates all execution result data to eliminate the need
|
|
66
|
-
for additional data extraction calls and script re-parsing in build_statement_result.
|
|
67
|
-
|
|
68
|
-
Attributes:
|
|
69
|
-
cursor_result: The raw result returned by the database cursor/driver
|
|
70
|
-
rowcount_override: Optional override for the number of affected rows
|
|
71
|
-
special_data: Any special metadata or additional information from execution
|
|
72
|
-
selected_data: For SELECT operations, the extracted row data
|
|
73
|
-
column_names: For SELECT operations, the column names
|
|
74
|
-
data_row_count: For SELECT operations, the number of rows returned
|
|
75
|
-
statement_count: For script operations, total number of statements
|
|
76
|
-
successful_statements: For script operations, number of successful statements
|
|
77
|
-
is_script_result: Whether this result is from script execution
|
|
78
|
-
is_select_result: Whether this result is from a SELECT operation
|
|
79
|
-
is_many_result: Whether this result is from an execute_many operation
|
|
80
|
-
"""
|
|
47
|
+
"""Execution result containing all data needed for SQLResult building."""
|
|
81
48
|
|
|
82
49
|
cursor_result: Any
|
|
83
50
|
rowcount_override: Optional[int]
|
|
@@ -93,20 +60,15 @@ class ExecutionResult(NamedTuple):
|
|
|
93
60
|
last_inserted_id: Optional[Union[int, str]] = None
|
|
94
61
|
|
|
95
62
|
|
|
96
|
-
EXEC_CURSOR_RESULT = 0
|
|
97
|
-
EXEC_ROWCOUNT_OVERRIDE = 1
|
|
98
|
-
EXEC_SPECIAL_DATA = 2
|
|
63
|
+
EXEC_CURSOR_RESULT: Final[int] = 0
|
|
64
|
+
EXEC_ROWCOUNT_OVERRIDE: Final[int] = 1
|
|
65
|
+
EXEC_SPECIAL_DATA: Final[int] = 2
|
|
99
66
|
DEFAULT_EXECUTION_RESULT: Final[tuple[Any, Optional[int], Any]] = (None, None, None)
|
|
100
67
|
|
|
101
68
|
|
|
102
69
|
@trait
|
|
103
70
|
class CommonDriverAttributesMixin:
|
|
104
|
-
"""Common attributes and methods for driver adapters.
|
|
105
|
-
|
|
106
|
-
This mixin provides the foundation for all SQLSpec drivers, including
|
|
107
|
-
connection and configuration management, parameter processing, caching,
|
|
108
|
-
and SQL compilation.
|
|
109
|
-
"""
|
|
71
|
+
"""Common attributes and methods for driver adapters."""
|
|
110
72
|
|
|
111
73
|
__slots__ = ("connection", "driver_features", "statement_config")
|
|
112
74
|
connection: "Any"
|
|
@@ -180,9 +142,6 @@ class CommonDriverAttributesMixin:
|
|
|
180
142
|
def build_statement_result(self, statement: "SQL", execution_result: ExecutionResult) -> "SQLResult":
|
|
181
143
|
"""Build and return the SQLResult from ExecutionResult data.
|
|
182
144
|
|
|
183
|
-
Creates SQLResult objects from ExecutionResult data without requiring
|
|
184
|
-
additional data extraction calls or script re-parsing.
|
|
185
|
-
|
|
186
145
|
Args:
|
|
187
146
|
statement: SQL statement that was executed
|
|
188
147
|
execution_result: ExecutionResult containing all necessary data
|
|
@@ -215,51 +174,11 @@ class CommonDriverAttributesMixin:
|
|
|
215
174
|
statement=statement,
|
|
216
175
|
data=[],
|
|
217
176
|
rows_affected=execution_result.rowcount_override or 0,
|
|
218
|
-
operation_type=
|
|
177
|
+
operation_type=statement.operation_type,
|
|
219
178
|
last_inserted_id=execution_result.last_inserted_id,
|
|
220
179
|
metadata=execution_result.special_data or {"status_message": "OK"},
|
|
221
180
|
)
|
|
222
181
|
|
|
223
|
-
def _determine_operation_type(self, statement: "Any") -> OperationType:
|
|
224
|
-
"""Determine operation type from SQL statement expression.
|
|
225
|
-
|
|
226
|
-
Examines the statement's expression type to determine if it's
|
|
227
|
-
INSERT, UPDATE, DELETE, SELECT, SCRIPT, or generic EXECUTE.
|
|
228
|
-
|
|
229
|
-
Args:
|
|
230
|
-
statement: SQL statement object with expression attribute
|
|
231
|
-
|
|
232
|
-
Returns:
|
|
233
|
-
OperationType literal value
|
|
234
|
-
"""
|
|
235
|
-
if statement.is_script:
|
|
236
|
-
return "SCRIPT"
|
|
237
|
-
|
|
238
|
-
try:
|
|
239
|
-
expression = statement.expression
|
|
240
|
-
except AttributeError:
|
|
241
|
-
return "EXECUTE"
|
|
242
|
-
|
|
243
|
-
if not expression:
|
|
244
|
-
return "EXECUTE"
|
|
245
|
-
|
|
246
|
-
expr_type = type(expression).__name__.upper()
|
|
247
|
-
|
|
248
|
-
if "ANONYMOUS" in expr_type and statement.is_script:
|
|
249
|
-
return "SCRIPT"
|
|
250
|
-
|
|
251
|
-
if "INSERT" in expr_type:
|
|
252
|
-
return "INSERT"
|
|
253
|
-
if "UPDATE" in expr_type:
|
|
254
|
-
return "UPDATE"
|
|
255
|
-
if "DELETE" in expr_type:
|
|
256
|
-
return "DELETE"
|
|
257
|
-
if "SELECT" in expr_type:
|
|
258
|
-
return "SELECT"
|
|
259
|
-
if "COPY" in expr_type:
|
|
260
|
-
return "COPY"
|
|
261
|
-
return "EXECUTE"
|
|
262
|
-
|
|
263
182
|
def prepare_statement(
|
|
264
183
|
self,
|
|
265
184
|
statement: "Union[Statement, QueryBuilder]",
|
|
@@ -489,7 +408,8 @@ class CommonDriverAttributesMixin:
|
|
|
489
408
|
if cached_result is not None:
|
|
490
409
|
return cached_result
|
|
491
410
|
|
|
492
|
-
|
|
411
|
+
prepared_statement = self.prepare_statement(statement, statement_config=statement_config)
|
|
412
|
+
compiled_sql, execution_parameters = prepared_statement.compile()
|
|
493
413
|
|
|
494
414
|
prepared_parameters = self.prepare_driver_parameters(
|
|
495
415
|
execution_parameters, statement_config, is_many=statement.is_many
|
|
@@ -590,7 +510,7 @@ class CommonDriverAttributesMixin:
|
|
|
590
510
|
def find_filter(
|
|
591
511
|
filter_type: "type[FilterTypeT]",
|
|
592
512
|
filters: "Sequence[StatementFilter | StatementParameters] | Sequence[StatementFilter]",
|
|
593
|
-
) -> "FilterTypeT
|
|
513
|
+
) -> "Optional[FilterTypeT]":
|
|
594
514
|
"""Get the filter specified by filter type from the filters.
|
|
595
515
|
|
|
596
516
|
Args:
|
|
@@ -600,9 +520,10 @@ class CommonDriverAttributesMixin:
|
|
|
600
520
|
Returns:
|
|
601
521
|
The match filter instance or None
|
|
602
522
|
"""
|
|
603
|
-
|
|
604
|
-
|
|
605
|
-
|
|
523
|
+
for filter_ in filters:
|
|
524
|
+
if isinstance(filter_, filter_type):
|
|
525
|
+
return filter_
|
|
526
|
+
return None
|
|
606
527
|
|
|
607
528
|
def _create_count_query(self, original_sql: "SQL") -> "SQL":
|
|
608
529
|
"""Create a COUNT query from the original SQL statement.
|
sqlspec/driver/_sync.py
CHANGED
|
@@ -1,8 +1,4 @@
|
|
|
1
|
-
"""Synchronous driver protocol implementation.
|
|
2
|
-
|
|
3
|
-
This module provides the sync driver infrastructure for database adapters,
|
|
4
|
-
including connection management, transaction support, and result processing.
|
|
5
|
-
"""
|
|
1
|
+
"""Synchronous driver protocol implementation."""
|
|
6
2
|
|
|
7
3
|
from abc import abstractmethod
|
|
8
4
|
from typing import TYPE_CHECKING, Any, Final, NoReturn, Optional, Union, cast, overload
|
|
@@ -32,22 +28,13 @@ EMPTY_FILTERS: Final["list[StatementFilter]"] = []
|
|
|
32
28
|
|
|
33
29
|
|
|
34
30
|
class SyncDriverAdapterBase(CommonDriverAttributesMixin, SQLTranslatorMixin, ToSchemaMixin):
|
|
35
|
-
"""Base class for synchronous database drivers.
|
|
36
|
-
|
|
37
|
-
Provides the foundation for sync database adapters, including connection management,
|
|
38
|
-
transaction support, and SQL execution methods. All database operations are performed
|
|
39
|
-
synchronously and support context manager patterns for proper resource cleanup.
|
|
40
|
-
"""
|
|
31
|
+
"""Base class for synchronous database drivers."""
|
|
41
32
|
|
|
42
33
|
__slots__ = ()
|
|
43
34
|
|
|
44
35
|
def dispatch_statement_execution(self, statement: "SQL", connection: "Any") -> "SQLResult":
|
|
45
36
|
"""Central execution dispatcher using the Template Method Pattern.
|
|
46
37
|
|
|
47
|
-
Orchestrates the common execution flow, delegating database-specific steps
|
|
48
|
-
to abstract methods that concrete adapters must implement.
|
|
49
|
-
All database operations are wrapped in exception handling.
|
|
50
|
-
|
|
51
38
|
Args:
|
|
52
39
|
statement: The SQL statement to execute
|
|
53
40
|
connection: The database connection to use
|
|
@@ -1,4 +1,3 @@
|
|
|
1
|
-
# pyright: reportCallIssue=false, reportAttributeAccessIssue=false, reportArgumentType=false
|
|
2
1
|
import datetime
|
|
3
2
|
import logging
|
|
4
3
|
from collections.abc import Sequence
|
|
@@ -28,10 +27,8 @@ __all__ = ("_DEFAULT_TYPE_DECODERS", "_default_msgspec_deserializer")
|
|
|
28
27
|
|
|
29
28
|
logger = logging.getLogger(__name__)
|
|
30
29
|
|
|
31
|
-
# Constants for performance optimization
|
|
32
|
-
_DATETIME_TYPES: Final[set[type]] = {datetime.datetime, datetime.date, datetime.time}
|
|
33
|
-
_PATH_TYPES: Final[tuple[type, ...]] = (Path, PurePath, UUID)
|
|
34
30
|
|
|
31
|
+
_DATETIME_TYPES: Final[set[type]] = {datetime.datetime, datetime.date, datetime.time}
|
|
35
32
|
_DEFAULT_TYPE_DECODERS: Final[list[tuple[Callable[[Any], bool], Callable[[Any, Any], Any]]]] = [
|
|
36
33
|
(lambda x: x is UUID, lambda t, v: t(v.hex)),
|
|
37
34
|
(lambda x: x is datetime.datetime, lambda t, v: t(v.isoformat())),
|
|
@@ -44,21 +41,15 @@ _DEFAULT_TYPE_DECODERS: Final[list[tuple[Callable[[Any], bool], Callable[[Any, A
|
|
|
44
41
|
def _default_msgspec_deserializer(
|
|
45
42
|
target_type: Any, value: Any, type_decoders: "Optional[Sequence[tuple[Any, Any]]]" = None
|
|
46
43
|
) -> Any:
|
|
47
|
-
"""Default msgspec deserializer with type conversion support.
|
|
48
|
-
|
|
49
|
-
Converts values to appropriate types for msgspec deserialization, including
|
|
50
|
-
UUID, datetime, date, time, Enum, Path, and PurePath types.
|
|
51
|
-
"""
|
|
44
|
+
"""Default msgspec deserializer with type conversion support."""
|
|
52
45
|
if type_decoders:
|
|
53
46
|
for predicate, decoder in type_decoders:
|
|
54
47
|
if predicate(target_type):
|
|
55
48
|
return decoder(target_type, value)
|
|
56
49
|
|
|
57
|
-
# Fast path checks using type identity and isinstance
|
|
58
50
|
if target_type is UUID and isinstance(value, UUID):
|
|
59
51
|
return value.hex
|
|
60
52
|
|
|
61
|
-
# Use pre-computed set for faster lookup
|
|
62
53
|
if target_type in _DATETIME_TYPES:
|
|
63
54
|
try:
|
|
64
55
|
return value.isoformat()
|
|
@@ -71,7 +62,6 @@ def _default_msgspec_deserializer(
|
|
|
71
62
|
if isinstance(value, target_type):
|
|
72
63
|
return value
|
|
73
64
|
|
|
74
|
-
# Check for path types using pre-computed tuple
|
|
75
65
|
if isinstance(target_type, type):
|
|
76
66
|
try:
|
|
77
67
|
if issubclass(target_type, (Path, PurePath)) or issubclass(target_type, UUID):
|
|
@@ -86,7 +76,6 @@ def _default_msgspec_deserializer(
|
|
|
86
76
|
class ToSchemaMixin:
|
|
87
77
|
__slots__ = ()
|
|
88
78
|
|
|
89
|
-
# Schema conversion overloads - handle common cases first
|
|
90
79
|
@overload
|
|
91
80
|
@staticmethod
|
|
92
81
|
def to_schema(data: "list[dict[str, Any]]") -> "list[dict[str, Any]]": ...
|
|
@@ -125,15 +114,11 @@ class ToSchemaMixin:
|
|
|
125
114
|
def to_schema(data: Any, *, schema_type: "Optional[type[ModelDTOT]]" = None) -> Any:
|
|
126
115
|
"""Convert data to a specified schema type.
|
|
127
116
|
|
|
128
|
-
Supports conversion to dataclasses, msgspec structs, Pydantic models, and attrs classes.
|
|
129
|
-
Handles both single objects and sequences.
|
|
130
|
-
|
|
131
117
|
Raises:
|
|
132
118
|
SQLSpecError if `schema_type` is not a valid type.
|
|
133
119
|
|
|
134
120
|
Returns:
|
|
135
121
|
Converted data in the specified schema type.
|
|
136
|
-
|
|
137
122
|
"""
|
|
138
123
|
if schema_type is None:
|
|
139
124
|
return data
|
|
@@ -152,30 +137,24 @@ class ToSchemaMixin:
|
|
|
152
137
|
return schema_type(**data) # type: ignore[operator]
|
|
153
138
|
return data
|
|
154
139
|
if is_msgspec_struct(schema_type):
|
|
155
|
-
# Cache the deserializer to avoid repeated partial() calls
|
|
156
140
|
deserializer = partial(_default_msgspec_deserializer, type_decoders=_DEFAULT_TYPE_DECODERS)
|
|
157
141
|
if not isinstance(data, Sequence):
|
|
158
142
|
return convert(obj=data, type=schema_type, from_attributes=True, dec_hook=deserializer)
|
|
159
|
-
return convert(
|
|
160
|
-
obj=data,
|
|
161
|
-
type=list[schema_type], # type: ignore[valid-type] # pyright: ignore
|
|
162
|
-
from_attributes=True,
|
|
163
|
-
dec_hook=deserializer,
|
|
164
|
-
)
|
|
143
|
+
return convert(obj=data, type=list[schema_type], from_attributes=True, dec_hook=deserializer) # type: ignore[valid-type]
|
|
165
144
|
if is_pydantic_model(schema_type):
|
|
166
145
|
if not isinstance(data, Sequence):
|
|
167
146
|
adapter = get_type_adapter(schema_type)
|
|
168
|
-
return adapter.validate_python(data, from_attributes=True)
|
|
169
|
-
list_adapter = get_type_adapter(list[schema_type]) # type: ignore[valid-type]
|
|
147
|
+
return adapter.validate_python(data, from_attributes=True)
|
|
148
|
+
list_adapter = get_type_adapter(list[schema_type]) # type: ignore[valid-type]
|
|
170
149
|
return list_adapter.validate_python(data, from_attributes=True)
|
|
171
150
|
if is_attrs_schema(schema_type):
|
|
172
151
|
if CATTRS_INSTALLED:
|
|
173
152
|
if isinstance(data, Sequence):
|
|
174
|
-
return cattrs_structure(data, list[schema_type]) # type: ignore[valid-type]
|
|
153
|
+
return cattrs_structure(data, list[schema_type]) # type: ignore[valid-type]
|
|
175
154
|
if hasattr(data, "__attrs_attrs__"):
|
|
176
155
|
unstructured_data = cattrs_unstructure(data)
|
|
177
|
-
return cattrs_structure(unstructured_data, schema_type)
|
|
178
|
-
return cattrs_structure(data, schema_type)
|
|
156
|
+
return cattrs_structure(unstructured_data, schema_type)
|
|
157
|
+
return cattrs_structure(data, schema_type)
|
|
179
158
|
if isinstance(data, list):
|
|
180
159
|
attrs_result: list[Any] = []
|
|
181
160
|
for item in data:
|
|
@@ -9,7 +9,7 @@ from sqlspec.exceptions import SQLConversionError
|
|
|
9
9
|
|
|
10
10
|
__all__ = ("SQLTranslatorMixin",)
|
|
11
11
|
|
|
12
|
-
|
|
12
|
+
|
|
13
13
|
_DEFAULT_PRETTY: Final[bool] = True
|
|
14
14
|
|
|
15
15
|
|
|
@@ -18,6 +18,7 @@ class SQLTranslatorMixin:
|
|
|
18
18
|
"""Mixin for drivers supporting SQL translation."""
|
|
19
19
|
|
|
20
20
|
__slots__ = ()
|
|
21
|
+
dialect: "Optional[DialectType]"
|
|
21
22
|
|
|
22
23
|
def convert_to_dialect(
|
|
23
24
|
self, statement: "Statement", to_dialect: "Optional[DialectType]" = None, pretty: bool = _DEFAULT_PRETTY
|
|
@@ -35,7 +36,7 @@ class SQLTranslatorMixin:
|
|
|
35
36
|
Raises:
|
|
36
37
|
SQLConversionError: If parsing or conversion fails
|
|
37
38
|
"""
|
|
38
|
-
|
|
39
|
+
|
|
39
40
|
parsed_expression: Optional[exp.Expression] = None
|
|
40
41
|
|
|
41
42
|
if statement is not None and isinstance(statement, SQL):
|
|
@@ -47,19 +48,16 @@ class SQLTranslatorMixin:
|
|
|
47
48
|
else:
|
|
48
49
|
parsed_expression = self._parse_statement_safely(statement)
|
|
49
50
|
|
|
50
|
-
|
|
51
|
-
target_dialect = to_dialect or self.dialect # type: ignore[attr-defined]
|
|
51
|
+
target_dialect = to_dialect or self.dialect
|
|
52
52
|
|
|
53
|
-
# Generate SQL with error handling
|
|
54
53
|
return self._generate_sql_safely(parsed_expression, target_dialect, pretty)
|
|
55
54
|
|
|
56
55
|
def _parse_statement_safely(self, statement: "Statement") -> "exp.Expression":
|
|
57
56
|
"""Parse statement with copy=False optimization and proper error handling."""
|
|
58
57
|
try:
|
|
59
|
-
# Convert statement to string if needed
|
|
60
58
|
sql_string = str(statement)
|
|
61
|
-
|
|
62
|
-
return parse_one(sql_string, dialect=self.dialect, copy=False)
|
|
59
|
+
|
|
60
|
+
return parse_one(sql_string, dialect=self.dialect, copy=False)
|
|
63
61
|
except Exception as e:
|
|
64
62
|
self._raise_parse_error(e)
|
|
65
63
|
|
sqlspec/exceptions.py
CHANGED
|
@@ -181,9 +181,8 @@ def wrap_exceptions(
|
|
|
181
181
|
(isinstance(suppress, type) and isinstance(exc, suppress))
|
|
182
182
|
or (isinstance(suppress, tuple) and isinstance(exc, suppress))
|
|
183
183
|
):
|
|
184
|
-
return
|
|
184
|
+
return
|
|
185
185
|
|
|
186
|
-
# If it's already a SQLSpec exception, don't wrap it
|
|
187
186
|
if isinstance(exc, SQLSpecError):
|
|
188
187
|
raise
|
|
189
188
|
|
sqlspec/loader.py
CHANGED
|
@@ -1,17 +1,15 @@
|
|
|
1
|
-
"""SQL file loader
|
|
1
|
+
"""SQL file loader for managing SQL statements from files.
|
|
2
2
|
|
|
3
|
-
|
|
3
|
+
Provides functionality to load, cache, and manage SQL statements
|
|
4
4
|
from files using aiosql-style named queries.
|
|
5
5
|
"""
|
|
6
6
|
|
|
7
7
|
import hashlib
|
|
8
8
|
import re
|
|
9
9
|
import time
|
|
10
|
-
from dataclasses import dataclass, field
|
|
11
10
|
from datetime import datetime, timezone
|
|
12
|
-
from difflib import get_close_matches
|
|
13
11
|
from pathlib import Path
|
|
14
|
-
from typing import Any, Optional, Union
|
|
12
|
+
from typing import TYPE_CHECKING, Any, Final, Optional, Union
|
|
15
13
|
|
|
16
14
|
from sqlspec.core.cache import CacheKey, get_cache_config, get_default_cache
|
|
17
15
|
from sqlspec.core.statement import SQL
|
|
@@ -21,11 +19,13 @@ from sqlspec.exceptions import (
|
|
|
21
19
|
SQLFileParseError,
|
|
22
20
|
StorageOperationFailedError,
|
|
23
21
|
)
|
|
24
|
-
from sqlspec.storage import storage_registry
|
|
25
|
-
from sqlspec.storage.registry import StorageRegistry
|
|
22
|
+
from sqlspec.storage.registry import storage_registry as default_storage_registry
|
|
26
23
|
from sqlspec.utils.correlation import CorrelationContext
|
|
27
24
|
from sqlspec.utils.logging import get_logger
|
|
28
25
|
|
|
26
|
+
if TYPE_CHECKING:
|
|
27
|
+
from sqlspec.storage.registry import StorageRegistry
|
|
28
|
+
|
|
29
29
|
__all__ = ("CachedSQLFile", "NamedStatement", "SQLFile", "SQLFileLoader")
|
|
30
30
|
|
|
31
31
|
logger = get_logger("loader")
|
|
@@ -38,48 +38,8 @@ TRIM_SPECIAL_CHARS = re.compile(r"[^\w.-]")
|
|
|
38
38
|
# Matches: -- dialect: dialect_name (optional dialect specification)
|
|
39
39
|
DIALECT_PATTERN = re.compile(r"^\s*--\s*dialect\s*:\s*(?P<dialect>[a-zA-Z0-9_]+)\s*$", re.IGNORECASE | re.MULTILINE)
|
|
40
40
|
|
|
41
|
-
# Supported SQL dialects (based on SQLGlot's available dialects)
|
|
42
|
-
SUPPORTED_DIALECTS = {
|
|
43
|
-
# Core databases
|
|
44
|
-
"sqlite",
|
|
45
|
-
"postgresql",
|
|
46
|
-
"postgres",
|
|
47
|
-
"mysql",
|
|
48
|
-
"oracle",
|
|
49
|
-
"mssql",
|
|
50
|
-
"tsql",
|
|
51
|
-
# Cloud platforms
|
|
52
|
-
"bigquery",
|
|
53
|
-
"snowflake",
|
|
54
|
-
"redshift",
|
|
55
|
-
"athena",
|
|
56
|
-
"fabric",
|
|
57
|
-
# Analytics engines
|
|
58
|
-
"clickhouse",
|
|
59
|
-
"duckdb",
|
|
60
|
-
"databricks",
|
|
61
|
-
"spark",
|
|
62
|
-
"spark2",
|
|
63
|
-
"trino",
|
|
64
|
-
"presto",
|
|
65
|
-
# Specialized
|
|
66
|
-
"hive",
|
|
67
|
-
"drill",
|
|
68
|
-
"druid",
|
|
69
|
-
"materialize",
|
|
70
|
-
"teradata",
|
|
71
|
-
"dremio",
|
|
72
|
-
"doris",
|
|
73
|
-
"risingwave",
|
|
74
|
-
"singlestore",
|
|
75
|
-
"starrocks",
|
|
76
|
-
"tableau",
|
|
77
|
-
"exasol",
|
|
78
|
-
"dune",
|
|
79
|
-
}
|
|
80
41
|
|
|
81
|
-
|
|
82
|
-
DIALECT_ALIASES = {
|
|
42
|
+
DIALECT_ALIASES: Final = {
|
|
83
43
|
"postgresql": "postgres",
|
|
84
44
|
"pg": "postgres",
|
|
85
45
|
"pgplsql": "postgres",
|
|
@@ -88,7 +48,7 @@ DIALECT_ALIASES = {
|
|
|
88
48
|
"tsql": "mssql",
|
|
89
49
|
}
|
|
90
50
|
|
|
91
|
-
MIN_QUERY_PARTS = 3
|
|
51
|
+
MIN_QUERY_PARTS: Final = 3
|
|
92
52
|
|
|
93
53
|
|
|
94
54
|
def _normalize_query_name(name: str) -> str:
|
|
@@ -129,19 +89,6 @@ def _normalize_dialect_for_sqlglot(dialect: str) -> str:
|
|
|
129
89
|
return DIALECT_ALIASES.get(normalized, normalized)
|
|
130
90
|
|
|
131
91
|
|
|
132
|
-
def _get_dialect_suggestions(invalid_dialect: str) -> "list[str]":
|
|
133
|
-
"""Get dialect suggestions using fuzzy matching.
|
|
134
|
-
|
|
135
|
-
Args:
|
|
136
|
-
invalid_dialect: Invalid dialect name that was provided
|
|
137
|
-
|
|
138
|
-
Returns:
|
|
139
|
-
List of suggested dialect names (up to 3 suggestions)
|
|
140
|
-
"""
|
|
141
|
-
|
|
142
|
-
return get_close_matches(invalid_dialect, SUPPORTED_DIALECTS, n=3, cutoff=0.6)
|
|
143
|
-
|
|
144
|
-
|
|
145
92
|
class NamedStatement:
|
|
146
93
|
"""Represents a parsed SQL statement with metadata.
|
|
147
94
|
|
|
@@ -159,7 +106,6 @@ class NamedStatement:
|
|
|
159
106
|
self.start_line = start_line
|
|
160
107
|
|
|
161
108
|
|
|
162
|
-
@dataclass
|
|
163
109
|
class SQLFile:
|
|
164
110
|
"""Represents a loaded SQL file with metadata.
|
|
165
111
|
|
|
@@ -167,28 +113,32 @@ class SQLFile:
|
|
|
167
113
|
timestamps, and content hash.
|
|
168
114
|
"""
|
|
169
115
|
|
|
170
|
-
content
|
|
171
|
-
"""The raw SQL content from the file."""
|
|
116
|
+
__slots__ = ("checksum", "content", "loaded_at", "metadata", "path")
|
|
172
117
|
|
|
173
|
-
|
|
174
|
-
|
|
118
|
+
def __init__(
|
|
119
|
+
self,
|
|
120
|
+
content: str,
|
|
121
|
+
path: str,
|
|
122
|
+
metadata: "Optional[dict[str, Any]]" = None,
|
|
123
|
+
loaded_at: "Optional[datetime]" = None,
|
|
124
|
+
) -> None:
|
|
125
|
+
"""Initialize SQLFile.
|
|
175
126
|
|
|
176
|
-
|
|
177
|
-
|
|
178
|
-
|
|
179
|
-
|
|
180
|
-
|
|
181
|
-
|
|
182
|
-
|
|
183
|
-
|
|
184
|
-
|
|
185
|
-
|
|
186
|
-
"""Calculate checksum after initialization."""
|
|
127
|
+
Args:
|
|
128
|
+
content: The raw SQL content from the file.
|
|
129
|
+
path: Path where the SQL file was loaded from.
|
|
130
|
+
metadata: Optional metadata associated with the SQL file.
|
|
131
|
+
loaded_at: Timestamp when the file was loaded.
|
|
132
|
+
"""
|
|
133
|
+
self.content = content
|
|
134
|
+
self.path = path
|
|
135
|
+
self.metadata = metadata or {}
|
|
136
|
+
self.loaded_at = loaded_at or datetime.now(timezone.utc)
|
|
187
137
|
self.checksum = hashlib.md5(self.content.encode(), usedforsecurity=False).hexdigest()
|
|
188
138
|
|
|
189
139
|
|
|
190
140
|
class CachedSQLFile:
|
|
191
|
-
"""Cached SQL file with parsed statements
|
|
141
|
+
"""Cached SQL file with parsed statements.
|
|
192
142
|
|
|
193
143
|
Stored in the file cache to avoid re-parsing SQL files when their
|
|
194
144
|
content hasn't changed.
|
|
@@ -205,17 +155,19 @@ class CachedSQLFile:
|
|
|
205
155
|
"""
|
|
206
156
|
self.sql_file = sql_file
|
|
207
157
|
self.parsed_statements = parsed_statements
|
|
208
|
-
self.statement_names =
|
|
158
|
+
self.statement_names = tuple(parsed_statements.keys())
|
|
209
159
|
|
|
210
160
|
|
|
211
161
|
class SQLFileLoader:
|
|
212
162
|
"""Loads and parses SQL files with aiosql-style named queries.
|
|
213
163
|
|
|
214
|
-
|
|
215
|
-
|
|
164
|
+
Loads SQL files containing named queries (using -- name: syntax)
|
|
165
|
+
and retrieves them by name.
|
|
216
166
|
"""
|
|
217
167
|
|
|
218
|
-
|
|
168
|
+
__slots__ = ("_files", "_queries", "_query_to_file", "encoding", "storage_registry")
|
|
169
|
+
|
|
170
|
+
def __init__(self, *, encoding: str = "utf-8", storage_registry: "Optional[StorageRegistry]" = None) -> None:
|
|
219
171
|
"""Initialize the SQL file loader.
|
|
220
172
|
|
|
221
173
|
Args:
|
|
@@ -223,7 +175,8 @@ class SQLFileLoader:
|
|
|
223
175
|
storage_registry: Storage registry for handling file URIs.
|
|
224
176
|
"""
|
|
225
177
|
self.encoding = encoding
|
|
226
|
-
|
|
178
|
+
|
|
179
|
+
self.storage_registry = storage_registry or default_storage_registry
|
|
227
180
|
self._queries: dict[str, NamedStatement] = {}
|
|
228
181
|
self._files: dict[str, SQLFile] = {}
|
|
229
182
|
self._query_to_file: dict[str, str] = {}
|
|
@@ -309,7 +262,6 @@ class SQLFileLoader:
|
|
|
309
262
|
except KeyError as e:
|
|
310
263
|
raise SQLFileNotFoundError(path_str) from e
|
|
311
264
|
except MissingDependencyError:
|
|
312
|
-
# Fall back to standard file reading when no storage backend is available
|
|
313
265
|
try:
|
|
314
266
|
return path.read_text(encoding=self.encoding) # type: ignore[union-attr]
|
|
315
267
|
except FileNotFoundError as e:
|
|
@@ -350,7 +302,6 @@ class SQLFileLoader:
|
|
|
350
302
|
or invalid dialect names are specified
|
|
351
303
|
"""
|
|
352
304
|
statements: dict[str, NamedStatement] = {}
|
|
353
|
-
content.splitlines()
|
|
354
305
|
|
|
355
306
|
name_matches = list(QUERY_NAME_PATTERN.finditer(content))
|
|
356
307
|
if not name_matches:
|
|
@@ -379,20 +330,7 @@ class SQLFileLoader:
|
|
|
379
330
|
if dialect_match:
|
|
380
331
|
declared_dialect = dialect_match.group("dialect").lower()
|
|
381
332
|
|
|
382
|
-
|
|
383
|
-
|
|
384
|
-
if normalized_dialect not in SUPPORTED_DIALECTS:
|
|
385
|
-
suggestions = _get_dialect_suggestions(normalized_dialect)
|
|
386
|
-
warning_msg = f"Unknown dialect '{declared_dialect}' at line {statement_start_line + 1}"
|
|
387
|
-
if suggestions:
|
|
388
|
-
warning_msg += f". Did you mean: {', '.join(suggestions)}?"
|
|
389
|
-
warning_msg += (
|
|
390
|
-
f". Supported dialects: {', '.join(sorted(SUPPORTED_DIALECTS))}. Using dialect as-is."
|
|
391
|
-
)
|
|
392
|
-
logger.warning(warning_msg)
|
|
393
|
-
dialect = declared_dialect.lower()
|
|
394
|
-
else:
|
|
395
|
-
dialect = normalized_dialect
|
|
333
|
+
dialect = _normalize_dialect(declared_dialect)
|
|
396
334
|
remaining_lines = section_lines[1:]
|
|
397
335
|
statement_sql = "\n".join(remaining_lines)
|
|
398
336
|
|
|
@@ -473,7 +411,7 @@ class SQLFileLoader:
|
|
|
473
411
|
raise
|
|
474
412
|
|
|
475
413
|
def _load_directory(self, dir_path: Path) -> int:
|
|
476
|
-
"""Load all SQL files from a directory
|
|
414
|
+
"""Load all SQL files from a directory."""
|
|
477
415
|
sql_files = list(dir_path.rglob("*.sql"))
|
|
478
416
|
if not sql_files:
|
|
479
417
|
return 0
|
|
@@ -486,7 +424,7 @@ class SQLFileLoader:
|
|
|
486
424
|
return len(sql_files)
|
|
487
425
|
|
|
488
426
|
def _load_single_file(self, file_path: Union[str, Path], namespace: Optional[str]) -> None:
|
|
489
|
-
"""Load a single SQL file with optional namespace
|
|
427
|
+
"""Load a single SQL file with optional namespace.
|
|
490
428
|
|
|
491
429
|
Args:
|
|
492
430
|
file_path: Path to the SQL file.
|
|
@@ -543,7 +481,7 @@ class SQLFileLoader:
|
|
|
543
481
|
unified_cache.put(cache_key, cached_file_data)
|
|
544
482
|
|
|
545
483
|
def _load_file_without_cache(self, file_path: Union[str, Path], namespace: Optional[str]) -> None:
|
|
546
|
-
"""Load a single SQL file without
|
|
484
|
+
"""Load a single SQL file without using cache.
|
|
547
485
|
|
|
548
486
|
Args:
|
|
549
487
|
file_path: Path to the SQL file.
|
|
@@ -580,7 +518,7 @@ class SQLFileLoader:
|
|
|
580
518
|
Raises:
|
|
581
519
|
ValueError: If query name already exists.
|
|
582
520
|
"""
|
|
583
|
-
|
|
521
|
+
|
|
584
522
|
normalized_name = _normalize_query_name(name)
|
|
585
523
|
|
|
586
524
|
if normalized_name in self._queries:
|
|
@@ -589,17 +527,7 @@ class SQLFileLoader:
|
|
|
589
527
|
raise ValueError(msg)
|
|
590
528
|
|
|
591
529
|
if dialect is not None:
|
|
592
|
-
|
|
593
|
-
if normalized_dialect not in SUPPORTED_DIALECTS:
|
|
594
|
-
suggestions = _get_dialect_suggestions(normalized_dialect)
|
|
595
|
-
warning_msg = f"Unknown dialect '{dialect}'"
|
|
596
|
-
if suggestions:
|
|
597
|
-
warning_msg += f". Did you mean: {', '.join(suggestions)}?"
|
|
598
|
-
warning_msg += f". Supported dialects: {', '.join(sorted(SUPPORTED_DIALECTS))}. Using dialect as-is."
|
|
599
|
-
logger.warning(warning_msg)
|
|
600
|
-
dialect = dialect.lower()
|
|
601
|
-
else:
|
|
602
|
-
dialect = normalized_dialect
|
|
530
|
+
dialect = _normalize_dialect(dialect)
|
|
603
531
|
|
|
604
532
|
statement = NamedStatement(name=normalized_name, sql=sql.strip(), dialect=dialect, start_line=0)
|
|
605
533
|
self._queries[normalized_name] = statement
|