sqlspec 0.11.1__py3-none-any.whl → 0.12.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of sqlspec might be problematic. Click here for more details.
- sqlspec/__init__.py +16 -3
- sqlspec/_serialization.py +3 -10
- sqlspec/_sql.py +1147 -0
- sqlspec/_typing.py +343 -41
- sqlspec/adapters/adbc/__init__.py +2 -6
- sqlspec/adapters/adbc/config.py +474 -149
- sqlspec/adapters/adbc/driver.py +330 -621
- sqlspec/adapters/aiosqlite/__init__.py +2 -6
- sqlspec/adapters/aiosqlite/config.py +143 -57
- sqlspec/adapters/aiosqlite/driver.py +269 -431
- sqlspec/adapters/asyncmy/__init__.py +3 -8
- sqlspec/adapters/asyncmy/config.py +247 -202
- sqlspec/adapters/asyncmy/driver.py +218 -436
- sqlspec/adapters/asyncpg/__init__.py +4 -7
- sqlspec/adapters/asyncpg/config.py +329 -176
- sqlspec/adapters/asyncpg/driver.py +417 -487
- sqlspec/adapters/bigquery/__init__.py +2 -2
- sqlspec/adapters/bigquery/config.py +407 -0
- sqlspec/adapters/bigquery/driver.py +600 -553
- sqlspec/adapters/duckdb/__init__.py +4 -1
- sqlspec/adapters/duckdb/config.py +432 -321
- sqlspec/adapters/duckdb/driver.py +392 -406
- sqlspec/adapters/oracledb/__init__.py +3 -8
- sqlspec/adapters/oracledb/config.py +625 -0
- sqlspec/adapters/oracledb/driver.py +548 -921
- sqlspec/adapters/psqlpy/__init__.py +4 -7
- sqlspec/adapters/psqlpy/config.py +372 -203
- sqlspec/adapters/psqlpy/driver.py +197 -533
- sqlspec/adapters/psycopg/__init__.py +3 -8
- sqlspec/adapters/psycopg/config.py +741 -0
- sqlspec/adapters/psycopg/driver.py +734 -694
- sqlspec/adapters/sqlite/__init__.py +2 -6
- sqlspec/adapters/sqlite/config.py +146 -81
- sqlspec/adapters/sqlite/driver.py +242 -405
- sqlspec/base.py +220 -784
- sqlspec/config.py +354 -0
- sqlspec/driver/__init__.py +22 -0
- sqlspec/driver/_async.py +252 -0
- sqlspec/driver/_common.py +338 -0
- sqlspec/driver/_sync.py +261 -0
- sqlspec/driver/mixins/__init__.py +17 -0
- sqlspec/driver/mixins/_pipeline.py +523 -0
- sqlspec/driver/mixins/_result_utils.py +122 -0
- sqlspec/driver/mixins/_sql_translator.py +35 -0
- sqlspec/driver/mixins/_storage.py +993 -0
- sqlspec/driver/mixins/_type_coercion.py +131 -0
- sqlspec/exceptions.py +299 -7
- sqlspec/extensions/aiosql/__init__.py +10 -0
- sqlspec/extensions/aiosql/adapter.py +474 -0
- sqlspec/extensions/litestar/__init__.py +1 -6
- sqlspec/extensions/litestar/_utils.py +1 -5
- sqlspec/extensions/litestar/config.py +5 -6
- sqlspec/extensions/litestar/handlers.py +13 -12
- sqlspec/extensions/litestar/plugin.py +22 -24
- sqlspec/extensions/litestar/providers.py +37 -55
- sqlspec/loader.py +528 -0
- sqlspec/service/__init__.py +3 -0
- sqlspec/service/base.py +24 -0
- sqlspec/service/pagination.py +26 -0
- sqlspec/statement/__init__.py +21 -0
- sqlspec/statement/builder/__init__.py +54 -0
- sqlspec/statement/builder/_ddl_utils.py +119 -0
- sqlspec/statement/builder/_parsing_utils.py +135 -0
- sqlspec/statement/builder/base.py +328 -0
- sqlspec/statement/builder/ddl.py +1379 -0
- sqlspec/statement/builder/delete.py +80 -0
- sqlspec/statement/builder/insert.py +274 -0
- sqlspec/statement/builder/merge.py +95 -0
- sqlspec/statement/builder/mixins/__init__.py +65 -0
- sqlspec/statement/builder/mixins/_aggregate_functions.py +151 -0
- sqlspec/statement/builder/mixins/_case_builder.py +91 -0
- sqlspec/statement/builder/mixins/_common_table_expr.py +91 -0
- sqlspec/statement/builder/mixins/_delete_from.py +34 -0
- sqlspec/statement/builder/mixins/_from.py +61 -0
- sqlspec/statement/builder/mixins/_group_by.py +119 -0
- sqlspec/statement/builder/mixins/_having.py +35 -0
- sqlspec/statement/builder/mixins/_insert_from_select.py +48 -0
- sqlspec/statement/builder/mixins/_insert_into.py +36 -0
- sqlspec/statement/builder/mixins/_insert_values.py +69 -0
- sqlspec/statement/builder/mixins/_join.py +110 -0
- sqlspec/statement/builder/mixins/_limit_offset.py +53 -0
- sqlspec/statement/builder/mixins/_merge_clauses.py +405 -0
- sqlspec/statement/builder/mixins/_order_by.py +46 -0
- sqlspec/statement/builder/mixins/_pivot.py +82 -0
- sqlspec/statement/builder/mixins/_returning.py +37 -0
- sqlspec/statement/builder/mixins/_select_columns.py +60 -0
- sqlspec/statement/builder/mixins/_set_ops.py +122 -0
- sqlspec/statement/builder/mixins/_unpivot.py +80 -0
- sqlspec/statement/builder/mixins/_update_from.py +54 -0
- sqlspec/statement/builder/mixins/_update_set.py +91 -0
- sqlspec/statement/builder/mixins/_update_table.py +29 -0
- sqlspec/statement/builder/mixins/_where.py +374 -0
- sqlspec/statement/builder/mixins/_window_functions.py +86 -0
- sqlspec/statement/builder/protocols.py +20 -0
- sqlspec/statement/builder/select.py +206 -0
- sqlspec/statement/builder/update.py +178 -0
- sqlspec/statement/filters.py +571 -0
- sqlspec/statement/parameters.py +736 -0
- sqlspec/statement/pipelines/__init__.py +67 -0
- sqlspec/statement/pipelines/analyzers/__init__.py +9 -0
- sqlspec/statement/pipelines/analyzers/_analyzer.py +649 -0
- sqlspec/statement/pipelines/base.py +315 -0
- sqlspec/statement/pipelines/context.py +119 -0
- sqlspec/statement/pipelines/result_types.py +41 -0
- sqlspec/statement/pipelines/transformers/__init__.py +8 -0
- sqlspec/statement/pipelines/transformers/_expression_simplifier.py +256 -0
- sqlspec/statement/pipelines/transformers/_literal_parameterizer.py +623 -0
- sqlspec/statement/pipelines/transformers/_remove_comments.py +66 -0
- sqlspec/statement/pipelines/transformers/_remove_hints.py +81 -0
- sqlspec/statement/pipelines/validators/__init__.py +23 -0
- sqlspec/statement/pipelines/validators/_dml_safety.py +275 -0
- sqlspec/statement/pipelines/validators/_parameter_style.py +297 -0
- sqlspec/statement/pipelines/validators/_performance.py +703 -0
- sqlspec/statement/pipelines/validators/_security.py +990 -0
- sqlspec/statement/pipelines/validators/base.py +67 -0
- sqlspec/statement/result.py +527 -0
- sqlspec/statement/splitter.py +701 -0
- sqlspec/statement/sql.py +1198 -0
- sqlspec/storage/__init__.py +15 -0
- sqlspec/storage/backends/__init__.py +0 -0
- sqlspec/storage/backends/base.py +166 -0
- sqlspec/storage/backends/fsspec.py +315 -0
- sqlspec/storage/backends/obstore.py +464 -0
- sqlspec/storage/protocol.py +170 -0
- sqlspec/storage/registry.py +315 -0
- sqlspec/typing.py +157 -36
- sqlspec/utils/correlation.py +155 -0
- sqlspec/utils/deprecation.py +3 -6
- sqlspec/utils/fixtures.py +6 -11
- sqlspec/utils/logging.py +135 -0
- sqlspec/utils/module_loader.py +45 -43
- sqlspec/utils/serializers.py +4 -0
- sqlspec/utils/singleton.py +6 -8
- sqlspec/utils/sync_tools.py +15 -27
- sqlspec/utils/text.py +58 -26
- {sqlspec-0.11.1.dist-info → sqlspec-0.12.0.dist-info}/METADATA +97 -26
- sqlspec-0.12.0.dist-info/RECORD +145 -0
- sqlspec/adapters/bigquery/config/__init__.py +0 -3
- sqlspec/adapters/bigquery/config/_common.py +0 -40
- sqlspec/adapters/bigquery/config/_sync.py +0 -87
- sqlspec/adapters/oracledb/config/__init__.py +0 -9
- sqlspec/adapters/oracledb/config/_asyncio.py +0 -186
- sqlspec/adapters/oracledb/config/_common.py +0 -131
- sqlspec/adapters/oracledb/config/_sync.py +0 -186
- sqlspec/adapters/psycopg/config/__init__.py +0 -19
- sqlspec/adapters/psycopg/config/_async.py +0 -169
- sqlspec/adapters/psycopg/config/_common.py +0 -56
- sqlspec/adapters/psycopg/config/_sync.py +0 -168
- sqlspec/filters.py +0 -331
- sqlspec/mixins.py +0 -305
- sqlspec/statement.py +0 -378
- sqlspec-0.11.1.dist-info/RECORD +0 -69
- {sqlspec-0.11.1.dist-info → sqlspec-0.12.0.dist-info}/WHEEL +0 -0
- {sqlspec-0.11.1.dist-info → sqlspec-0.12.0.dist-info}/licenses/LICENSE +0 -0
- {sqlspec-0.11.1.dist-info → sqlspec-0.12.0.dist-info}/licenses/NOTICE +0 -0
|
@@ -0,0 +1,338 @@
|
|
|
1
|
+
"""Common driver attributes and utilities."""
|
|
2
|
+
|
|
3
|
+
import re
|
|
4
|
+
from abc import ABC
|
|
5
|
+
from collections.abc import Mapping, Sequence
|
|
6
|
+
from typing import TYPE_CHECKING, Any, ClassVar, Generic, Optional
|
|
7
|
+
|
|
8
|
+
import sqlglot
|
|
9
|
+
from sqlglot import exp
|
|
10
|
+
from sqlglot.tokens import TokenType
|
|
11
|
+
|
|
12
|
+
from sqlspec.exceptions import NotFoundError
|
|
13
|
+
from sqlspec.statement import SQLConfig
|
|
14
|
+
from sqlspec.statement.parameters import ParameterStyle, ParameterValidator
|
|
15
|
+
from sqlspec.statement.splitter import split_sql_script
|
|
16
|
+
from sqlspec.typing import ConnectionT, DictRow, RowT, T
|
|
17
|
+
from sqlspec.utils.logging import get_logger
|
|
18
|
+
|
|
19
|
+
if TYPE_CHECKING:
|
|
20
|
+
from sqlglot.dialects.dialect import DialectType
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
__all__ = ("CommonDriverAttributesMixin",)
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
logger = get_logger("driver")
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
class CommonDriverAttributesMixin(ABC, Generic[ConnectionT, RowT]):
|
|
30
|
+
"""Common attributes and methods for driver adapters."""
|
|
31
|
+
|
|
32
|
+
__slots__ = ("config", "connection", "default_row_type")
|
|
33
|
+
|
|
34
|
+
dialect: "DialectType"
|
|
35
|
+
"""The SQL dialect supported by the underlying database driver."""
|
|
36
|
+
supported_parameter_styles: "tuple[ParameterStyle, ...]"
|
|
37
|
+
"""The parameter styles supported by this driver."""
|
|
38
|
+
default_parameter_style: "ParameterStyle"
|
|
39
|
+
"""The default parameter style to convert to when unsupported style is detected."""
|
|
40
|
+
supports_native_parquet_export: "ClassVar[bool]" = False
|
|
41
|
+
"""Indicates if the driver supports native Parquet export operations."""
|
|
42
|
+
supports_native_parquet_import: "ClassVar[bool]" = False
|
|
43
|
+
"""Indicates if the driver supports native Parquet import operations."""
|
|
44
|
+
supports_native_arrow_export: "ClassVar[bool]" = False
|
|
45
|
+
"""Indicates if the driver supports native Arrow export operations."""
|
|
46
|
+
supports_native_arrow_import: "ClassVar[bool]" = False
|
|
47
|
+
"""Indicates if the driver supports native Arrow import operations."""
|
|
48
|
+
|
|
49
|
+
def __init__(
|
|
50
|
+
self,
|
|
51
|
+
connection: "ConnectionT",
|
|
52
|
+
config: "Optional[SQLConfig]" = None,
|
|
53
|
+
default_row_type: "type[DictRow]" = dict[str, Any],
|
|
54
|
+
) -> None:
|
|
55
|
+
"""Initialize with connection, config, and default_row_type.
|
|
56
|
+
|
|
57
|
+
Args:
|
|
58
|
+
connection: The database connection
|
|
59
|
+
config: SQL statement configuration
|
|
60
|
+
default_row_type: Default row type for results (DictRow, TupleRow, etc.)
|
|
61
|
+
"""
|
|
62
|
+
self.connection = connection
|
|
63
|
+
self.config = config or SQLConfig()
|
|
64
|
+
self.default_row_type = default_row_type or dict[str, Any]
|
|
65
|
+
|
|
66
|
+
def _connection(self, connection: "Optional[ConnectionT]" = None) -> "ConnectionT":
|
|
67
|
+
return connection or self.connection
|
|
68
|
+
|
|
69
|
+
def returns_rows(self, expression: "Optional[exp.Expression]") -> bool:
|
|
70
|
+
"""Check if the SQL expression is expected to return rows.
|
|
71
|
+
|
|
72
|
+
Args:
|
|
73
|
+
expression: The SQL expression.
|
|
74
|
+
|
|
75
|
+
Returns:
|
|
76
|
+
True if the expression is a SELECT, VALUES, or WITH statement
|
|
77
|
+
that is not a CTE definition.
|
|
78
|
+
"""
|
|
79
|
+
if expression is None:
|
|
80
|
+
return False
|
|
81
|
+
if isinstance(expression, (exp.Select, exp.Values, exp.Table, exp.Show, exp.Describe, exp.Pragma, exp.Command)):
|
|
82
|
+
return True
|
|
83
|
+
if isinstance(expression, exp.With) and expression.expressions:
|
|
84
|
+
return self.returns_rows(expression.expressions[-1])
|
|
85
|
+
if isinstance(expression, (exp.Insert, exp.Update, exp.Delete)):
|
|
86
|
+
return bool(expression.find(exp.Returning))
|
|
87
|
+
# Handle Anonymous expressions (failed to parse) using a robust approach
|
|
88
|
+
if isinstance(expression, exp.Anonymous):
|
|
89
|
+
return self._check_anonymous_returns_rows(expression)
|
|
90
|
+
return False
|
|
91
|
+
|
|
92
|
+
def _check_anonymous_returns_rows(self, expression: "exp.Anonymous") -> bool:
|
|
93
|
+
"""Check if an Anonymous expression returns rows using robust methods.
|
|
94
|
+
|
|
95
|
+
This method handles SQL that failed to parse (often due to database-specific
|
|
96
|
+
placeholders) by:
|
|
97
|
+
1. Attempting to re-parse with placeholders sanitized
|
|
98
|
+
2. Using the tokenizer as a fallback for keyword detection
|
|
99
|
+
|
|
100
|
+
Args:
|
|
101
|
+
expression: The Anonymous expression to check
|
|
102
|
+
|
|
103
|
+
Returns:
|
|
104
|
+
True if the expression likely returns rows
|
|
105
|
+
"""
|
|
106
|
+
|
|
107
|
+
sql_text = str(expression.this) if expression.this else ""
|
|
108
|
+
if not sql_text.strip():
|
|
109
|
+
return False
|
|
110
|
+
|
|
111
|
+
# Regex to find common SQL placeholders: ?, %s, $1, $2, :name, etc.
|
|
112
|
+
placeholder_regex = re.compile(r"(\?|%s|\$\d+|:\w+|%\(\w+\)s)")
|
|
113
|
+
|
|
114
|
+
# Approach 1: Try to re-parse with placeholders replaced
|
|
115
|
+
try:
|
|
116
|
+
# Replace placeholders with a dummy literal that sqlglot can parse
|
|
117
|
+
sanitized_sql = placeholder_regex.sub("1", sql_text)
|
|
118
|
+
|
|
119
|
+
# If we replaced any placeholders, try parsing again
|
|
120
|
+
if sanitized_sql != sql_text:
|
|
121
|
+
parsed = sqlglot.parse_one(sanitized_sql, read=None)
|
|
122
|
+
# Check if it's a query type that returns rows
|
|
123
|
+
if isinstance(
|
|
124
|
+
parsed, (exp.Select, exp.Values, exp.Table, exp.Show, exp.Describe, exp.Pragma, exp.Command)
|
|
125
|
+
):
|
|
126
|
+
return True
|
|
127
|
+
if isinstance(parsed, exp.With) and parsed.expressions:
|
|
128
|
+
return self.returns_rows(parsed.expressions[-1])
|
|
129
|
+
if isinstance(parsed, (exp.Insert, exp.Update, exp.Delete)):
|
|
130
|
+
return bool(parsed.find(exp.Returning))
|
|
131
|
+
if not isinstance(parsed, exp.Anonymous):
|
|
132
|
+
return False
|
|
133
|
+
except Exception:
|
|
134
|
+
logger.debug("Could not parse using placeholders. Using tokenizer. %s", sql_text)
|
|
135
|
+
|
|
136
|
+
# Approach 2: Use tokenizer for robust keyword detection
|
|
137
|
+
try:
|
|
138
|
+
tokens = list(sqlglot.tokenize(sql_text, read=None))
|
|
139
|
+
row_returning_tokens = {
|
|
140
|
+
TokenType.SELECT,
|
|
141
|
+
TokenType.WITH,
|
|
142
|
+
TokenType.VALUES,
|
|
143
|
+
TokenType.TABLE,
|
|
144
|
+
TokenType.SHOW,
|
|
145
|
+
TokenType.DESCRIBE,
|
|
146
|
+
TokenType.PRAGMA,
|
|
147
|
+
}
|
|
148
|
+
for token in tokens:
|
|
149
|
+
if token.token_type in {TokenType.COMMENT, TokenType.SEMICOLON}:
|
|
150
|
+
continue
|
|
151
|
+
return token.token_type in row_returning_tokens
|
|
152
|
+
|
|
153
|
+
except Exception:
|
|
154
|
+
return False
|
|
155
|
+
|
|
156
|
+
return False
|
|
157
|
+
|
|
158
|
+
@staticmethod
|
|
159
|
+
def check_not_found(item_or_none: "Optional[T]" = None) -> "T":
|
|
160
|
+
"""Raise :exc:`sqlspec.exceptions.NotFoundError` if ``item_or_none`` is ``None``.
|
|
161
|
+
|
|
162
|
+
Args:
|
|
163
|
+
item_or_none: Item to be tested for existence.
|
|
164
|
+
|
|
165
|
+
Raises:
|
|
166
|
+
NotFoundError: If ``item_or_none`` is ``None``
|
|
167
|
+
|
|
168
|
+
Returns:
|
|
169
|
+
The item, if it exists.
|
|
170
|
+
"""
|
|
171
|
+
if item_or_none is None:
|
|
172
|
+
msg = "No result found when one was expected"
|
|
173
|
+
raise NotFoundError(msg)
|
|
174
|
+
return item_or_none
|
|
175
|
+
|
|
176
|
+
def _convert_parameters_to_driver_format( # noqa: C901
|
|
177
|
+
self, sql: str, parameters: Any, target_style: "Optional[ParameterStyle]" = None
|
|
178
|
+
) -> Any:
|
|
179
|
+
"""Convert parameters to the format expected by the driver, but only when necessary.
|
|
180
|
+
|
|
181
|
+
This method analyzes the SQL to understand what parameter style is used
|
|
182
|
+
and only converts when there's a mismatch between provided parameters
|
|
183
|
+
and what the driver expects.
|
|
184
|
+
|
|
185
|
+
Args:
|
|
186
|
+
sql: The SQL string with placeholders
|
|
187
|
+
parameters: The parameters in any format (dict, list, tuple, scalar)
|
|
188
|
+
target_style: Optional override for the target parameter style
|
|
189
|
+
|
|
190
|
+
Returns:
|
|
191
|
+
Parameters in the format expected by the database driver
|
|
192
|
+
"""
|
|
193
|
+
if parameters is None:
|
|
194
|
+
return None
|
|
195
|
+
|
|
196
|
+
# Extract parameter info from the SQL
|
|
197
|
+
validator = ParameterValidator()
|
|
198
|
+
param_info_list = validator.extract_parameters(sql)
|
|
199
|
+
|
|
200
|
+
if not param_info_list:
|
|
201
|
+
# No parameters in SQL, return None
|
|
202
|
+
return None
|
|
203
|
+
|
|
204
|
+
# Determine the target style from the SQL if not provided
|
|
205
|
+
if target_style is None:
|
|
206
|
+
target_style = self.default_parameter_style
|
|
207
|
+
|
|
208
|
+
actual_styles = {p.style for p in param_info_list if p.style}
|
|
209
|
+
if len(actual_styles) == 1:
|
|
210
|
+
detected_style = actual_styles.pop()
|
|
211
|
+
if detected_style != target_style:
|
|
212
|
+
target_style = detected_style
|
|
213
|
+
|
|
214
|
+
# Analyze what format the driver expects based on the placeholder style
|
|
215
|
+
driver_expects_dict = target_style in {
|
|
216
|
+
ParameterStyle.NAMED_COLON,
|
|
217
|
+
ParameterStyle.POSITIONAL_COLON,
|
|
218
|
+
ParameterStyle.NAMED_AT,
|
|
219
|
+
ParameterStyle.NAMED_DOLLAR,
|
|
220
|
+
ParameterStyle.NAMED_PYFORMAT,
|
|
221
|
+
}
|
|
222
|
+
|
|
223
|
+
# Check if parameters are already in the correct format
|
|
224
|
+
params_are_dict = isinstance(parameters, (dict, Mapping))
|
|
225
|
+
params_are_sequence = isinstance(parameters, (list, tuple, Sequence)) and not isinstance(
|
|
226
|
+
parameters, (str, bytes)
|
|
227
|
+
)
|
|
228
|
+
|
|
229
|
+
# Single scalar parameter
|
|
230
|
+
if len(param_info_list) == 1 and not params_are_dict and not params_are_sequence:
|
|
231
|
+
if driver_expects_dict:
|
|
232
|
+
# Convert scalar to dict
|
|
233
|
+
param_info = param_info_list[0]
|
|
234
|
+
if param_info.name:
|
|
235
|
+
return {param_info.name: parameters}
|
|
236
|
+
return {f"param_{param_info.ordinal}": parameters}
|
|
237
|
+
return [parameters]
|
|
238
|
+
|
|
239
|
+
if driver_expects_dict and params_are_dict:
|
|
240
|
+
if target_style == ParameterStyle.POSITIONAL_COLON and all(
|
|
241
|
+
p.name and p.name.isdigit() for p in param_info_list
|
|
242
|
+
):
|
|
243
|
+
# If all parameters are numeric but named, convert to dict
|
|
244
|
+
# SQL has numeric placeholders but params might have named keys
|
|
245
|
+
# Only convert if keys don't match
|
|
246
|
+
numeric_keys_expected = {p.name for p in param_info_list if p.name}
|
|
247
|
+
if not numeric_keys_expected.issubset(parameters.keys()):
|
|
248
|
+
# Need to convert named keys to numeric positions
|
|
249
|
+
numeric_result: dict[str, Any] = {}
|
|
250
|
+
param_values = list(parameters.values())
|
|
251
|
+
for param_info in param_info_list:
|
|
252
|
+
if param_info.name and param_info.ordinal < len(param_values):
|
|
253
|
+
numeric_result[param_info.name] = param_values[param_info.ordinal]
|
|
254
|
+
return numeric_result
|
|
255
|
+
|
|
256
|
+
# Special case: Auto-generated param_N style when SQL expects specific names
|
|
257
|
+
if all(key.startswith("param_") and key[6:].isdigit() for key in parameters):
|
|
258
|
+
# Check if SQL has different parameter names
|
|
259
|
+
sql_param_names = {p.name for p in param_info_list if p.name}
|
|
260
|
+
if sql_param_names and not any(name.startswith("param_") for name in sql_param_names):
|
|
261
|
+
# SQL has specific names, not param_N style - don't use these params as-is
|
|
262
|
+
# This likely indicates a mismatch in parameter generation
|
|
263
|
+
# For now, pass through and let validation catch it
|
|
264
|
+
pass
|
|
265
|
+
|
|
266
|
+
# Otherwise, dict format matches - return as-is
|
|
267
|
+
return parameters
|
|
268
|
+
|
|
269
|
+
if not driver_expects_dict and params_are_sequence:
|
|
270
|
+
# Formats match - return as-is
|
|
271
|
+
return parameters
|
|
272
|
+
|
|
273
|
+
# Formats don't match - need conversion
|
|
274
|
+
if driver_expects_dict and params_are_sequence:
|
|
275
|
+
# Convert positional to dict
|
|
276
|
+
dict_result: dict[str, Any] = {}
|
|
277
|
+
for i, (param_info, value) in enumerate(zip(param_info_list, parameters)):
|
|
278
|
+
if param_info.name:
|
|
279
|
+
# Use the name from SQL
|
|
280
|
+
if param_info.style == ParameterStyle.POSITIONAL_COLON and param_info.name.isdigit():
|
|
281
|
+
# Oracle uses string keys even for numeric placeholders
|
|
282
|
+
dict_result[param_info.name] = value
|
|
283
|
+
else:
|
|
284
|
+
dict_result[param_info.name] = value
|
|
285
|
+
else:
|
|
286
|
+
# Use param_N format for unnamed placeholders
|
|
287
|
+
dict_result[f"param_{i}"] = value
|
|
288
|
+
return dict_result
|
|
289
|
+
|
|
290
|
+
if not driver_expects_dict and params_are_dict:
|
|
291
|
+
# Convert dict to positional
|
|
292
|
+
# First check if it's already in param_N format
|
|
293
|
+
if all(key.startswith("param_") and key[6:].isdigit() for key in parameters):
|
|
294
|
+
# Extract values in order
|
|
295
|
+
positional_result: list[Any] = []
|
|
296
|
+
for i in range(len(param_info_list)):
|
|
297
|
+
key = f"param_{i}"
|
|
298
|
+
if key in parameters:
|
|
299
|
+
positional_result.append(parameters[key])
|
|
300
|
+
return positional_result
|
|
301
|
+
|
|
302
|
+
# Convert named dict to positional based on parameter order in SQL
|
|
303
|
+
positional_params: list[Any] = []
|
|
304
|
+
for param_info in param_info_list:
|
|
305
|
+
if param_info.name and param_info.name in parameters:
|
|
306
|
+
positional_params.append(parameters[param_info.name])
|
|
307
|
+
elif f"param_{param_info.ordinal}" in parameters:
|
|
308
|
+
positional_params.append(parameters[f"param_{param_info.ordinal}"])
|
|
309
|
+
else:
|
|
310
|
+
# Try to match by position if we have a simple dict
|
|
311
|
+
param_values = list(parameters.values())
|
|
312
|
+
if param_info.ordinal < len(param_values):
|
|
313
|
+
positional_params.append(param_values[param_info.ordinal])
|
|
314
|
+
return positional_params or list(parameters.values())
|
|
315
|
+
|
|
316
|
+
# This shouldn't happen, but return as-is
|
|
317
|
+
return parameters
|
|
318
|
+
|
|
319
|
+
def _split_script_statements(self, script: str, strip_trailing_semicolon: bool = False) -> list[str]:
|
|
320
|
+
"""Split a SQL script into individual statements.
|
|
321
|
+
|
|
322
|
+
This method uses a robust lexer-driven state machine to handle
|
|
323
|
+
multi-statement scripts, including complex constructs like
|
|
324
|
+
PL/SQL blocks, T-SQL batches, and nested blocks.
|
|
325
|
+
|
|
326
|
+
Args:
|
|
327
|
+
script: The SQL script to split
|
|
328
|
+
strip_trailing_semicolon: If True, remove trailing semicolons from statements
|
|
329
|
+
|
|
330
|
+
Returns:
|
|
331
|
+
A list of individual SQL statements
|
|
332
|
+
|
|
333
|
+
Note:
|
|
334
|
+
This is particularly useful for databases that don't natively
|
|
335
|
+
support multi-statement execution (e.g., Oracle, some async drivers).
|
|
336
|
+
"""
|
|
337
|
+
# The split_sql_script function already handles dialect mapping and fallback
|
|
338
|
+
return split_sql_script(script, dialect=str(self.dialect), strip_trailing_semicolon=strip_trailing_semicolon)
|
sqlspec/driver/_sync.py
ADDED
|
@@ -0,0 +1,261 @@
|
|
|
1
|
+
"""Synchronous driver protocol implementation."""
|
|
2
|
+
|
|
3
|
+
from abc import ABC, abstractmethod
|
|
4
|
+
from typing import TYPE_CHECKING, Any, Optional, Union, cast, overload
|
|
5
|
+
|
|
6
|
+
from sqlspec.driver._common import CommonDriverAttributesMixin
|
|
7
|
+
from sqlspec.statement.builder import DeleteBuilder, InsertBuilder, QueryBuilder, SelectBuilder, UpdateBuilder
|
|
8
|
+
from sqlspec.statement.filters import StatementFilter
|
|
9
|
+
from sqlspec.statement.result import SQLResult
|
|
10
|
+
from sqlspec.statement.sql import SQL, SQLConfig, Statement
|
|
11
|
+
from sqlspec.typing import ConnectionT, DictRow, ModelDTOT, RowT, StatementParameters
|
|
12
|
+
from sqlspec.utils.logging import get_logger
|
|
13
|
+
|
|
14
|
+
logger = get_logger("sqlspec")
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
if TYPE_CHECKING:
|
|
18
|
+
from sqlspec.statement.result import DMLResultDict, ScriptResultDict, SelectResultDict
|
|
19
|
+
|
|
20
|
+
__all__ = ("SyncDriverAdapterProtocol",)
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
EMPTY_FILTERS: "list[StatementFilter]" = []
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
class SyncDriverAdapterProtocol(CommonDriverAttributesMixin[ConnectionT, RowT], ABC):
|
|
27
|
+
__slots__ = ()
|
|
28
|
+
|
|
29
|
+
def __init__(
|
|
30
|
+
self,
|
|
31
|
+
connection: "ConnectionT",
|
|
32
|
+
config: "Optional[SQLConfig]" = None,
|
|
33
|
+
default_row_type: "type[DictRow]" = DictRow,
|
|
34
|
+
) -> None:
|
|
35
|
+
"""Initialize sync driver adapter.
|
|
36
|
+
|
|
37
|
+
Args:
|
|
38
|
+
connection: The database connection
|
|
39
|
+
config: SQL statement configuration
|
|
40
|
+
default_row_type: Default row type for results (DictRow, TupleRow, etc.)
|
|
41
|
+
"""
|
|
42
|
+
# Initialize CommonDriverAttributes part
|
|
43
|
+
super().__init__(connection=connection, config=config, default_row_type=default_row_type)
|
|
44
|
+
|
|
45
|
+
def _build_statement(
|
|
46
|
+
self,
|
|
47
|
+
statement: "Union[Statement, QueryBuilder[Any]]",
|
|
48
|
+
*parameters: "Union[StatementParameters, StatementFilter]",
|
|
49
|
+
_config: "Optional[SQLConfig]" = None,
|
|
50
|
+
**kwargs: Any,
|
|
51
|
+
) -> "SQL":
|
|
52
|
+
# Use driver's config if none provided
|
|
53
|
+
_config = _config or self.config
|
|
54
|
+
|
|
55
|
+
if isinstance(statement, QueryBuilder):
|
|
56
|
+
return statement.to_statement(config=_config)
|
|
57
|
+
# If statement is already a SQL object, handle additional parameters
|
|
58
|
+
if isinstance(statement, SQL):
|
|
59
|
+
if parameters or kwargs:
|
|
60
|
+
# Create a new SQL object with the same SQL but additional parameters
|
|
61
|
+
return SQL(statement._sql, *parameters, _dialect=self.dialect, _config=_config, **kwargs)
|
|
62
|
+
return statement
|
|
63
|
+
return SQL(statement, *parameters, _dialect=self.dialect, _config=_config, **kwargs)
|
|
64
|
+
|
|
65
|
+
@abstractmethod
|
|
66
|
+
def _execute_statement(
|
|
67
|
+
self, statement: "SQL", connection: "Optional[ConnectionT]" = None, **kwargs: Any
|
|
68
|
+
) -> "Union[SelectResultDict, DMLResultDict, ScriptResultDict]":
|
|
69
|
+
"""Actual execution implementation by concrete drivers, using the raw connection.
|
|
70
|
+
|
|
71
|
+
Returns one of the standardized result dictionaries based on the statement type.
|
|
72
|
+
"""
|
|
73
|
+
raise NotImplementedError
|
|
74
|
+
|
|
75
|
+
@abstractmethod
|
|
76
|
+
def _wrap_select_result(
|
|
77
|
+
self,
|
|
78
|
+
statement: "SQL",
|
|
79
|
+
result: "SelectResultDict",
|
|
80
|
+
schema_type: "Optional[type[ModelDTOT]]" = None,
|
|
81
|
+
**kwargs: Any,
|
|
82
|
+
) -> "Union[SQLResult[ModelDTOT], SQLResult[RowT]]":
|
|
83
|
+
raise NotImplementedError
|
|
84
|
+
|
|
85
|
+
@abstractmethod
|
|
86
|
+
def _wrap_execute_result(
|
|
87
|
+
self, statement: "SQL", result: "Union[DMLResultDict, ScriptResultDict]", **kwargs: Any
|
|
88
|
+
) -> "SQLResult[RowT]":
|
|
89
|
+
raise NotImplementedError
|
|
90
|
+
|
|
91
|
+
@overload
|
|
92
|
+
def execute(
|
|
93
|
+
self,
|
|
94
|
+
statement: "SelectBuilder",
|
|
95
|
+
/,
|
|
96
|
+
*parameters: "Union[StatementParameters, StatementFilter]",
|
|
97
|
+
schema_type: "type[ModelDTOT]",
|
|
98
|
+
_connection: "Optional[ConnectionT]" = None,
|
|
99
|
+
_config: "Optional[SQLConfig]" = None,
|
|
100
|
+
**kwargs: Any,
|
|
101
|
+
) -> "SQLResult[ModelDTOT]": ...
|
|
102
|
+
|
|
103
|
+
@overload
|
|
104
|
+
def execute(
|
|
105
|
+
self,
|
|
106
|
+
statement: "SelectBuilder",
|
|
107
|
+
/,
|
|
108
|
+
*parameters: "Union[StatementParameters, StatementFilter]",
|
|
109
|
+
schema_type: None = None,
|
|
110
|
+
_connection: "Optional[ConnectionT]" = None,
|
|
111
|
+
_config: "Optional[SQLConfig]" = None,
|
|
112
|
+
**kwargs: Any,
|
|
113
|
+
) -> "SQLResult[RowT]": ...
|
|
114
|
+
|
|
115
|
+
@overload
|
|
116
|
+
def execute(
|
|
117
|
+
self,
|
|
118
|
+
statement: "Union[InsertBuilder, UpdateBuilder, DeleteBuilder]",
|
|
119
|
+
/,
|
|
120
|
+
*parameters: "Union[StatementParameters, StatementFilter]",
|
|
121
|
+
_connection: "Optional[ConnectionT]" = None,
|
|
122
|
+
_config: "Optional[SQLConfig]" = None,
|
|
123
|
+
**kwargs: Any,
|
|
124
|
+
) -> "SQLResult[RowT]": ...
|
|
125
|
+
|
|
126
|
+
@overload
|
|
127
|
+
def execute(
|
|
128
|
+
self,
|
|
129
|
+
statement: "Statement",
|
|
130
|
+
/,
|
|
131
|
+
*parameters: "Union[StatementParameters, StatementFilter]",
|
|
132
|
+
schema_type: "type[ModelDTOT]",
|
|
133
|
+
_connection: "Optional[ConnectionT]" = None,
|
|
134
|
+
_config: "Optional[SQLConfig]" = None,
|
|
135
|
+
**kwargs: Any,
|
|
136
|
+
) -> "SQLResult[ModelDTOT]": ...
|
|
137
|
+
|
|
138
|
+
@overload
|
|
139
|
+
def execute(
|
|
140
|
+
self,
|
|
141
|
+
statement: "Union[str, SQL]",
|
|
142
|
+
/,
|
|
143
|
+
*parameters: "Union[StatementParameters, StatementFilter]",
|
|
144
|
+
schema_type: None = None,
|
|
145
|
+
_connection: "Optional[ConnectionT]" = None,
|
|
146
|
+
_config: "Optional[SQLConfig]" = None,
|
|
147
|
+
**kwargs: Any,
|
|
148
|
+
) -> "SQLResult[RowT]": ...
|
|
149
|
+
|
|
150
|
+
def execute(
|
|
151
|
+
self,
|
|
152
|
+
statement: "Union[SQL, Statement, QueryBuilder[Any]]",
|
|
153
|
+
/,
|
|
154
|
+
*parameters: "Union[StatementParameters, StatementFilter]",
|
|
155
|
+
schema_type: "Optional[type[ModelDTOT]]" = None,
|
|
156
|
+
_connection: "Optional[ConnectionT]" = None,
|
|
157
|
+
_config: "Optional[SQLConfig]" = None,
|
|
158
|
+
**kwargs: Any,
|
|
159
|
+
) -> "Union[SQLResult[ModelDTOT], SQLResult[RowT]]":
|
|
160
|
+
sql_statement = self._build_statement(statement, *parameters, _config=_config or self.config, **kwargs)
|
|
161
|
+
result = self._execute_statement(statement=sql_statement, connection=self._connection(_connection), **kwargs)
|
|
162
|
+
|
|
163
|
+
if self.returns_rows(sql_statement.expression):
|
|
164
|
+
return self._wrap_select_result(
|
|
165
|
+
sql_statement, cast("SelectResultDict", result), schema_type=schema_type, **kwargs
|
|
166
|
+
)
|
|
167
|
+
return self._wrap_execute_result(
|
|
168
|
+
sql_statement, cast("Union[DMLResultDict, ScriptResultDict]", result), **kwargs
|
|
169
|
+
)
|
|
170
|
+
|
|
171
|
+
def execute_many(
|
|
172
|
+
self,
|
|
173
|
+
statement: "Union[SQL, Statement, QueryBuilder[Any]]",
|
|
174
|
+
/,
|
|
175
|
+
*parameters: "Union[StatementParameters, StatementFilter]",
|
|
176
|
+
_connection: "Optional[ConnectionT]" = None,
|
|
177
|
+
_config: "Optional[SQLConfig]" = None,
|
|
178
|
+
**kwargs: Any,
|
|
179
|
+
) -> "SQLResult[RowT]":
|
|
180
|
+
# Separate parameters from filters
|
|
181
|
+
param_sequences = []
|
|
182
|
+
filters = []
|
|
183
|
+
for param in parameters:
|
|
184
|
+
if isinstance(param, StatementFilter):
|
|
185
|
+
filters.append(param)
|
|
186
|
+
else:
|
|
187
|
+
param_sequences.append(param)
|
|
188
|
+
|
|
189
|
+
# Use first parameter as the sequence for execute_many
|
|
190
|
+
param_sequence = param_sequences[0] if param_sequences else None
|
|
191
|
+
# Convert tuple to list if needed
|
|
192
|
+
if isinstance(param_sequence, tuple):
|
|
193
|
+
param_sequence = list(param_sequence)
|
|
194
|
+
# Ensure param_sequence is a list or None
|
|
195
|
+
if param_sequence is not None and not isinstance(param_sequence, list):
|
|
196
|
+
param_sequence = list(param_sequence) if hasattr(param_sequence, "__iter__") else None
|
|
197
|
+
sql_statement = self._build_statement(statement, _config=_config or self.config, **kwargs).as_many(
|
|
198
|
+
param_sequence
|
|
199
|
+
)
|
|
200
|
+
|
|
201
|
+
result = self._execute_statement(
|
|
202
|
+
statement=sql_statement,
|
|
203
|
+
connection=self._connection(_connection),
|
|
204
|
+
parameters=param_sequence,
|
|
205
|
+
is_many=True,
|
|
206
|
+
**kwargs,
|
|
207
|
+
)
|
|
208
|
+
return self._wrap_execute_result(
|
|
209
|
+
sql_statement, cast("Union[DMLResultDict, ScriptResultDict]", result), **kwargs
|
|
210
|
+
)
|
|
211
|
+
|
|
212
|
+
def execute_script(
|
|
213
|
+
self,
|
|
214
|
+
statement: "Union[str, SQL]",
|
|
215
|
+
/,
|
|
216
|
+
*parameters: "Union[StatementParameters, StatementFilter]",
|
|
217
|
+
_connection: "Optional[ConnectionT]" = None,
|
|
218
|
+
_config: "Optional[SQLConfig]" = None,
|
|
219
|
+
**kwargs: Any,
|
|
220
|
+
) -> "SQLResult[RowT]":
|
|
221
|
+
# Separate parameters from filters
|
|
222
|
+
param_values = []
|
|
223
|
+
filters = []
|
|
224
|
+
for param in parameters:
|
|
225
|
+
if isinstance(param, StatementFilter):
|
|
226
|
+
filters.append(param)
|
|
227
|
+
else:
|
|
228
|
+
param_values.append(param)
|
|
229
|
+
|
|
230
|
+
# Use first parameter as the primary parameter value, or None if no parameters
|
|
231
|
+
primary_params = param_values[0] if param_values else None
|
|
232
|
+
|
|
233
|
+
script_config = _config or self.config
|
|
234
|
+
if script_config.enable_validation:
|
|
235
|
+
script_config = SQLConfig(
|
|
236
|
+
enable_parsing=script_config.enable_parsing,
|
|
237
|
+
enable_validation=False,
|
|
238
|
+
enable_transformations=script_config.enable_transformations,
|
|
239
|
+
enable_analysis=script_config.enable_analysis,
|
|
240
|
+
strict_mode=False,
|
|
241
|
+
cache_parsed_expression=script_config.cache_parsed_expression,
|
|
242
|
+
parameter_converter=script_config.parameter_converter,
|
|
243
|
+
parameter_validator=script_config.parameter_validator,
|
|
244
|
+
analysis_cache_size=script_config.analysis_cache_size,
|
|
245
|
+
allowed_parameter_styles=script_config.allowed_parameter_styles,
|
|
246
|
+
target_parameter_style=script_config.target_parameter_style,
|
|
247
|
+
allow_mixed_parameter_styles=script_config.allow_mixed_parameter_styles,
|
|
248
|
+
)
|
|
249
|
+
|
|
250
|
+
sql_statement = SQL(statement, primary_params, *filters, _dialect=self.dialect, _config=script_config, **kwargs)
|
|
251
|
+
sql_statement = sql_statement.as_script()
|
|
252
|
+
script_output = self._execute_statement(
|
|
253
|
+
statement=sql_statement, connection=self._connection(_connection), is_script=True, **kwargs
|
|
254
|
+
)
|
|
255
|
+
if isinstance(script_output, str):
|
|
256
|
+
result = SQLResult[RowT](statement=sql_statement, data=[], operation_type="SCRIPT")
|
|
257
|
+
result.total_statements = 1
|
|
258
|
+
result.successful_statements = 1
|
|
259
|
+
return result
|
|
260
|
+
# Wrap the ScriptResultDict using the driver's wrapper
|
|
261
|
+
return self._wrap_execute_result(sql_statement, cast("ScriptResultDict", script_output), **kwargs)
|
|
@@ -0,0 +1,17 @@
|
|
|
1
|
+
"""Driver mixins for instrumentation, storage, and utilities."""
|
|
2
|
+
|
|
3
|
+
from sqlspec.driver.mixins._pipeline import AsyncPipelinedExecutionMixin, SyncPipelinedExecutionMixin
|
|
4
|
+
from sqlspec.driver.mixins._result_utils import ToSchemaMixin
|
|
5
|
+
from sqlspec.driver.mixins._sql_translator import SQLTranslatorMixin
|
|
6
|
+
from sqlspec.driver.mixins._storage import AsyncStorageMixin, SyncStorageMixin
|
|
7
|
+
from sqlspec.driver.mixins._type_coercion import TypeCoercionMixin
|
|
8
|
+
|
|
9
|
+
__all__ = (
|
|
10
|
+
"AsyncPipelinedExecutionMixin",
|
|
11
|
+
"AsyncStorageMixin",
|
|
12
|
+
"SQLTranslatorMixin",
|
|
13
|
+
"SyncPipelinedExecutionMixin",
|
|
14
|
+
"SyncStorageMixin",
|
|
15
|
+
"ToSchemaMixin",
|
|
16
|
+
"TypeCoercionMixin",
|
|
17
|
+
)
|