sqlspec 0.12.1__py3-none-any.whl → 0.13.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of sqlspec might be problematic. Click here for more details.
- sqlspec/_sql.py +21 -180
- sqlspec/adapters/adbc/config.py +10 -12
- sqlspec/adapters/adbc/driver.py +120 -118
- sqlspec/adapters/aiosqlite/config.py +3 -3
- sqlspec/adapters/aiosqlite/driver.py +116 -141
- sqlspec/adapters/asyncmy/config.py +3 -4
- sqlspec/adapters/asyncmy/driver.py +123 -135
- sqlspec/adapters/asyncpg/config.py +3 -7
- sqlspec/adapters/asyncpg/driver.py +98 -140
- sqlspec/adapters/bigquery/config.py +4 -5
- sqlspec/adapters/bigquery/driver.py +231 -181
- sqlspec/adapters/duckdb/config.py +3 -6
- sqlspec/adapters/duckdb/driver.py +132 -124
- sqlspec/adapters/oracledb/config.py +6 -5
- sqlspec/adapters/oracledb/driver.py +242 -259
- sqlspec/adapters/psqlpy/config.py +3 -7
- sqlspec/adapters/psqlpy/driver.py +118 -93
- sqlspec/adapters/psycopg/config.py +34 -30
- sqlspec/adapters/psycopg/driver.py +342 -214
- sqlspec/adapters/sqlite/config.py +3 -3
- sqlspec/adapters/sqlite/driver.py +150 -104
- sqlspec/config.py +0 -4
- sqlspec/driver/_async.py +89 -98
- sqlspec/driver/_common.py +52 -17
- sqlspec/driver/_sync.py +81 -105
- sqlspec/driver/connection.py +207 -0
- sqlspec/driver/mixins/_csv_writer.py +91 -0
- sqlspec/driver/mixins/_pipeline.py +38 -49
- sqlspec/driver/mixins/_result_utils.py +27 -9
- sqlspec/driver/mixins/_storage.py +149 -216
- sqlspec/driver/mixins/_type_coercion.py +3 -4
- sqlspec/driver/parameters.py +138 -0
- sqlspec/exceptions.py +10 -2
- sqlspec/extensions/aiosql/adapter.py +0 -10
- sqlspec/extensions/litestar/handlers.py +0 -1
- sqlspec/extensions/litestar/plugin.py +0 -3
- sqlspec/extensions/litestar/providers.py +0 -14
- sqlspec/loader.py +31 -118
- sqlspec/protocols.py +542 -0
- sqlspec/service/__init__.py +3 -2
- sqlspec/service/_util.py +147 -0
- sqlspec/service/base.py +1116 -9
- sqlspec/statement/builder/__init__.py +42 -32
- sqlspec/statement/builder/_ddl_utils.py +0 -10
- sqlspec/statement/builder/_parsing_utils.py +10 -4
- sqlspec/statement/builder/base.py +70 -23
- sqlspec/statement/builder/column.py +283 -0
- sqlspec/statement/builder/ddl.py +102 -65
- sqlspec/statement/builder/delete.py +23 -7
- sqlspec/statement/builder/insert.py +29 -15
- sqlspec/statement/builder/merge.py +4 -4
- sqlspec/statement/builder/mixins/_aggregate_functions.py +113 -14
- sqlspec/statement/builder/mixins/_common_table_expr.py +0 -1
- sqlspec/statement/builder/mixins/_delete_from.py +1 -1
- sqlspec/statement/builder/mixins/_from.py +10 -8
- sqlspec/statement/builder/mixins/_group_by.py +0 -1
- sqlspec/statement/builder/mixins/_insert_from_select.py +0 -1
- sqlspec/statement/builder/mixins/_insert_values.py +0 -2
- sqlspec/statement/builder/mixins/_join.py +20 -13
- sqlspec/statement/builder/mixins/_limit_offset.py +3 -3
- sqlspec/statement/builder/mixins/_merge_clauses.py +3 -4
- sqlspec/statement/builder/mixins/_order_by.py +2 -2
- sqlspec/statement/builder/mixins/_pivot.py +4 -7
- sqlspec/statement/builder/mixins/_select_columns.py +6 -5
- sqlspec/statement/builder/mixins/_unpivot.py +6 -9
- sqlspec/statement/builder/mixins/_update_from.py +2 -1
- sqlspec/statement/builder/mixins/_update_set.py +11 -8
- sqlspec/statement/builder/mixins/_where.py +61 -34
- sqlspec/statement/builder/select.py +32 -17
- sqlspec/statement/builder/update.py +25 -11
- sqlspec/statement/filters.py +39 -14
- sqlspec/statement/parameter_manager.py +220 -0
- sqlspec/statement/parameters.py +210 -79
- sqlspec/statement/pipelines/__init__.py +166 -23
- sqlspec/statement/pipelines/analyzers/_analyzer.py +22 -25
- sqlspec/statement/pipelines/context.py +35 -39
- sqlspec/statement/pipelines/transformers/__init__.py +2 -3
- sqlspec/statement/pipelines/transformers/_expression_simplifier.py +19 -187
- sqlspec/statement/pipelines/transformers/_literal_parameterizer.py +667 -43
- sqlspec/statement/pipelines/transformers/_remove_comments_and_hints.py +76 -0
- sqlspec/statement/pipelines/validators/_dml_safety.py +33 -18
- sqlspec/statement/pipelines/validators/_parameter_style.py +87 -14
- sqlspec/statement/pipelines/validators/_performance.py +38 -23
- sqlspec/statement/pipelines/validators/_security.py +39 -62
- sqlspec/statement/result.py +37 -129
- sqlspec/statement/splitter.py +0 -12
- sqlspec/statement/sql.py +885 -379
- sqlspec/statement/sql_compiler.py +140 -0
- sqlspec/storage/__init__.py +10 -2
- sqlspec/storage/backends/fsspec.py +82 -35
- sqlspec/storage/backends/obstore.py +66 -49
- sqlspec/storage/capabilities.py +101 -0
- sqlspec/storage/registry.py +56 -83
- sqlspec/typing.py +6 -434
- sqlspec/utils/cached_property.py +25 -0
- sqlspec/utils/correlation.py +0 -2
- sqlspec/utils/logging.py +0 -6
- sqlspec/utils/sync_tools.py +0 -4
- sqlspec/utils/text.py +0 -5
- sqlspec/utils/type_guards.py +892 -0
- {sqlspec-0.12.1.dist-info → sqlspec-0.13.0.dist-info}/METADATA +1 -1
- sqlspec-0.13.0.dist-info/RECORD +150 -0
- sqlspec/statement/builder/protocols.py +0 -20
- sqlspec/statement/pipelines/base.py +0 -315
- sqlspec/statement/pipelines/result_types.py +0 -41
- sqlspec/statement/pipelines/transformers/_remove_comments.py +0 -66
- sqlspec/statement/pipelines/transformers/_remove_hints.py +0 -81
- sqlspec/statement/pipelines/validators/base.py +0 -67
- sqlspec/storage/protocol.py +0 -170
- sqlspec-0.12.1.dist-info/RECORD +0 -145
- {sqlspec-0.12.1.dist-info → sqlspec-0.13.0.dist-info}/WHEEL +0 -0
- {sqlspec-0.12.1.dist-info → sqlspec-0.13.0.dist-info}/licenses/LICENSE +0 -0
- {sqlspec-0.12.1.dist-info → sqlspec-0.13.0.dist-info}/licenses/NOTICE +0 -0
sqlspec/driver/_common.py
CHANGED
|
@@ -9,9 +9,10 @@ import sqlglot
|
|
|
9
9
|
from sqlglot import exp
|
|
10
10
|
from sqlglot.tokens import TokenType
|
|
11
11
|
|
|
12
|
+
from sqlspec.driver.parameters import normalize_parameter_sequence
|
|
12
13
|
from sqlspec.exceptions import NotFoundError
|
|
13
14
|
from sqlspec.statement import SQLConfig
|
|
14
|
-
from sqlspec.statement.parameters import ParameterStyle, ParameterValidator
|
|
15
|
+
from sqlspec.statement.parameters import ParameterStyle, ParameterValidator, TypedParameter
|
|
15
16
|
from sqlspec.statement.splitter import split_sql_script
|
|
16
17
|
from sqlspec.typing import ConnectionT, DictRow, RowT, T
|
|
17
18
|
from sqlspec.utils.logging import get_logger
|
|
@@ -84,7 +85,6 @@ class CommonDriverAttributesMixin(ABC, Generic[ConnectionT, RowT]):
|
|
|
84
85
|
return self.returns_rows(expression.expressions[-1])
|
|
85
86
|
if isinstance(expression, (exp.Insert, exp.Update, exp.Delete)):
|
|
86
87
|
return bool(expression.find(exp.Returning))
|
|
87
|
-
# Handle Anonymous expressions (failed to parse) using a robust approach
|
|
88
88
|
if isinstance(expression, exp.Anonymous):
|
|
89
89
|
return self._check_anonymous_returns_rows(expression)
|
|
90
90
|
return False
|
|
@@ -113,13 +113,11 @@ class CommonDriverAttributesMixin(ABC, Generic[ConnectionT, RowT]):
|
|
|
113
113
|
|
|
114
114
|
# Approach 1: Try to re-parse with placeholders replaced
|
|
115
115
|
try:
|
|
116
|
-
# Replace placeholders with a dummy literal that sqlglot can parse
|
|
117
116
|
sanitized_sql = placeholder_regex.sub("1", sql_text)
|
|
118
117
|
|
|
119
118
|
# If we replaced any placeholders, try parsing again
|
|
120
119
|
if sanitized_sql != sql_text:
|
|
121
120
|
parsed = sqlglot.parse_one(sanitized_sql, read=None)
|
|
122
|
-
# Check if it's a query type that returns rows
|
|
123
121
|
if isinstance(
|
|
124
122
|
parsed, (exp.Select, exp.Values, exp.Table, exp.Show, exp.Describe, exp.Pragma, exp.Command)
|
|
125
123
|
):
|
|
@@ -193,15 +191,12 @@ class CommonDriverAttributesMixin(ABC, Generic[ConnectionT, RowT]):
|
|
|
193
191
|
if parameters is None:
|
|
194
192
|
return None
|
|
195
193
|
|
|
196
|
-
# Extract parameter info from the SQL
|
|
197
194
|
validator = ParameterValidator()
|
|
198
195
|
param_info_list = validator.extract_parameters(sql)
|
|
199
196
|
|
|
200
197
|
if not param_info_list:
|
|
201
|
-
# No parameters in SQL, return None
|
|
202
198
|
return None
|
|
203
199
|
|
|
204
|
-
# Determine the target style from the SQL if not provided
|
|
205
200
|
if target_style is None:
|
|
206
201
|
target_style = self.default_parameter_style
|
|
207
202
|
|
|
@@ -220,7 +215,6 @@ class CommonDriverAttributesMixin(ABC, Generic[ConnectionT, RowT]):
|
|
|
220
215
|
ParameterStyle.NAMED_PYFORMAT,
|
|
221
216
|
}
|
|
222
217
|
|
|
223
|
-
# Check if parameters are already in the correct format
|
|
224
218
|
params_are_dict = isinstance(parameters, (dict, Mapping))
|
|
225
219
|
params_are_sequence = isinstance(parameters, (list, tuple, Sequence)) and not isinstance(
|
|
226
220
|
parameters, (str, bytes)
|
|
@@ -229,7 +223,6 @@ class CommonDriverAttributesMixin(ABC, Generic[ConnectionT, RowT]):
|
|
|
229
223
|
# Single scalar parameter
|
|
230
224
|
if len(param_info_list) == 1 and not params_are_dict and not params_are_sequence:
|
|
231
225
|
if driver_expects_dict:
|
|
232
|
-
# Convert scalar to dict
|
|
233
226
|
param_info = param_info_list[0]
|
|
234
227
|
if param_info.name:
|
|
235
228
|
return {param_info.name: parameters}
|
|
@@ -242,7 +235,6 @@ class CommonDriverAttributesMixin(ABC, Generic[ConnectionT, RowT]):
|
|
|
242
235
|
):
|
|
243
236
|
# If all parameters are numeric but named, convert to dict
|
|
244
237
|
# SQL has numeric placeholders but params might have named keys
|
|
245
|
-
# Only convert if keys don't match
|
|
246
238
|
numeric_keys_expected = {p.name for p in param_info_list if p.name}
|
|
247
239
|
if not numeric_keys_expected.issubset(parameters.keys()):
|
|
248
240
|
# Need to convert named keys to numeric positions
|
|
@@ -255,7 +247,6 @@ class CommonDriverAttributesMixin(ABC, Generic[ConnectionT, RowT]):
|
|
|
255
247
|
|
|
256
248
|
# Special case: Auto-generated param_N style when SQL expects specific names
|
|
257
249
|
if all(key.startswith("param_") and key[6:].isdigit() for key in parameters):
|
|
258
|
-
# Check if SQL has different parameter names
|
|
259
250
|
sql_param_names = {p.name for p in param_info_list if p.name}
|
|
260
251
|
if sql_param_names and not any(name.startswith("param_") for name in sql_param_names):
|
|
261
252
|
# SQL has specific names, not param_N style - don't use these params as-is
|
|
@@ -263,7 +254,6 @@ class CommonDriverAttributesMixin(ABC, Generic[ConnectionT, RowT]):
|
|
|
263
254
|
# For now, pass through and let validation catch it
|
|
264
255
|
pass
|
|
265
256
|
|
|
266
|
-
# Otherwise, dict format matches - return as-is
|
|
267
257
|
return parameters
|
|
268
258
|
|
|
269
259
|
if not driver_expects_dict and params_are_sequence:
|
|
@@ -272,11 +262,9 @@ class CommonDriverAttributesMixin(ABC, Generic[ConnectionT, RowT]):
|
|
|
272
262
|
|
|
273
263
|
# Formats don't match - need conversion
|
|
274
264
|
if driver_expects_dict and params_are_sequence:
|
|
275
|
-
# Convert positional to dict
|
|
276
265
|
dict_result: dict[str, Any] = {}
|
|
277
266
|
for i, (param_info, value) in enumerate(zip(param_info_list, parameters)):
|
|
278
267
|
if param_info.name:
|
|
279
|
-
# Use the name from SQL
|
|
280
268
|
if param_info.style == ParameterStyle.POSITIONAL_COLON and param_info.name.isdigit():
|
|
281
269
|
# Oracle uses string keys even for numeric placeholders
|
|
282
270
|
dict_result[param_info.name] = value
|
|
@@ -288,10 +276,8 @@ class CommonDriverAttributesMixin(ABC, Generic[ConnectionT, RowT]):
|
|
|
288
276
|
return dict_result
|
|
289
277
|
|
|
290
278
|
if not driver_expects_dict and params_are_dict:
|
|
291
|
-
# Convert dict to positional
|
|
292
279
|
# First check if it's already in param_N format
|
|
293
280
|
if all(key.startswith("param_") and key[6:].isdigit() for key in parameters):
|
|
294
|
-
# Extract values in order
|
|
295
281
|
positional_result: list[Any] = []
|
|
296
282
|
for i in range(len(param_info_list)):
|
|
297
283
|
key = f"param_{i}"
|
|
@@ -299,7 +285,6 @@ class CommonDriverAttributesMixin(ABC, Generic[ConnectionT, RowT]):
|
|
|
299
285
|
positional_result.append(parameters[key])
|
|
300
286
|
return positional_result
|
|
301
287
|
|
|
302
|
-
# Convert named dict to positional based on parameter order in SQL
|
|
303
288
|
positional_params: list[Any] = []
|
|
304
289
|
for param_info in param_info_list:
|
|
305
290
|
if param_info.name and param_info.name in parameters:
|
|
@@ -336,3 +321,53 @@ class CommonDriverAttributesMixin(ABC, Generic[ConnectionT, RowT]):
|
|
|
336
321
|
"""
|
|
337
322
|
# The split_sql_script function already handles dialect mapping and fallback
|
|
338
323
|
return split_sql_script(script, dialect=str(self.dialect), strip_trailing_semicolon=strip_trailing_semicolon)
|
|
324
|
+
|
|
325
|
+
def _prepare_driver_parameters(self, parameters: Any) -> Any:
|
|
326
|
+
"""Prepare parameters for database driver consumption by unwrapping TypedParameter objects.
|
|
327
|
+
|
|
328
|
+
This method normalizes parameter structure and unwraps TypedParameter objects
|
|
329
|
+
to their underlying values, which database drivers expect.
|
|
330
|
+
|
|
331
|
+
Args:
|
|
332
|
+
parameters: Parameters in any format (dict, list, tuple, scalar, TypedParameter)
|
|
333
|
+
|
|
334
|
+
Returns:
|
|
335
|
+
Parameters with TypedParameter objects unwrapped to primitive values
|
|
336
|
+
"""
|
|
337
|
+
|
|
338
|
+
normalized = normalize_parameter_sequence(parameters)
|
|
339
|
+
if not normalized:
|
|
340
|
+
return []
|
|
341
|
+
|
|
342
|
+
return [self._coerce_parameter(p) if isinstance(p, TypedParameter) else p for p in normalized]
|
|
343
|
+
|
|
344
|
+
def _prepare_driver_parameters_many(self, parameters: Any) -> "list[Any]":
|
|
345
|
+
"""Prepare parameter sequences for executemany operations.
|
|
346
|
+
|
|
347
|
+
This method handles sequences of parameter sets, unwrapping TypedParameter
|
|
348
|
+
objects in each set for database driver consumption.
|
|
349
|
+
|
|
350
|
+
Args:
|
|
351
|
+
parameters: Sequence of parameter sets for executemany
|
|
352
|
+
|
|
353
|
+
Returns:
|
|
354
|
+
List of parameter sets with TypedParameter objects unwrapped
|
|
355
|
+
"""
|
|
356
|
+
if not parameters:
|
|
357
|
+
return []
|
|
358
|
+
return [self._prepare_driver_parameters(param_set) for param_set in parameters]
|
|
359
|
+
|
|
360
|
+
def _coerce_parameter(self, param: "TypedParameter") -> Any:
|
|
361
|
+
"""Coerce TypedParameter to driver-safe value.
|
|
362
|
+
|
|
363
|
+
This method extracts the underlying value from a TypedParameter object.
|
|
364
|
+
Individual drivers can override this method to perform driver-specific
|
|
365
|
+
type coercion using the rich type information available in TypedParameter.
|
|
366
|
+
|
|
367
|
+
Args:
|
|
368
|
+
param: TypedParameter object with value and type information
|
|
369
|
+
|
|
370
|
+
Returns:
|
|
371
|
+
The underlying parameter value suitable for the database driver
|
|
372
|
+
"""
|
|
373
|
+
return param.value
|
sqlspec/driver/_sync.py
CHANGED
|
@@ -1,21 +1,22 @@
|
|
|
1
1
|
"""Synchronous driver protocol implementation."""
|
|
2
2
|
|
|
3
3
|
from abc import ABC, abstractmethod
|
|
4
|
-
from
|
|
4
|
+
from dataclasses import replace
|
|
5
|
+
from typing import TYPE_CHECKING, Any, Optional, Union, overload
|
|
5
6
|
|
|
6
7
|
from sqlspec.driver._common import CommonDriverAttributesMixin
|
|
7
|
-
from sqlspec.
|
|
8
|
-
from sqlspec.statement.
|
|
8
|
+
from sqlspec.driver.parameters import process_execute_many_parameters
|
|
9
|
+
from sqlspec.statement.builder import Delete, Insert, QueryBuilder, Select, Update
|
|
9
10
|
from sqlspec.statement.result import SQLResult
|
|
10
11
|
from sqlspec.statement.sql import SQL, SQLConfig, Statement
|
|
11
12
|
from sqlspec.typing import ConnectionT, DictRow, ModelDTOT, RowT, StatementParameters
|
|
12
13
|
from sqlspec.utils.logging import get_logger
|
|
13
|
-
|
|
14
|
-
logger = get_logger("sqlspec")
|
|
15
|
-
|
|
14
|
+
from sqlspec.utils.type_guards import can_convert_to_schema
|
|
16
15
|
|
|
17
16
|
if TYPE_CHECKING:
|
|
18
|
-
from sqlspec.statement.
|
|
17
|
+
from sqlspec.statement.filters import StatementFilter
|
|
18
|
+
|
|
19
|
+
logger = get_logger("sqlspec")
|
|
19
20
|
|
|
20
21
|
__all__ = ("SyncDriverAdapterProtocol",)
|
|
21
22
|
|
|
@@ -39,7 +40,6 @@ class SyncDriverAdapterProtocol(CommonDriverAttributesMixin[ConnectionT, RowT],
|
|
|
39
40
|
config: SQL statement configuration
|
|
40
41
|
default_row_type: Default row type for results (DictRow, TupleRow, etc.)
|
|
41
42
|
"""
|
|
42
|
-
# Initialize CommonDriverAttributes part
|
|
43
43
|
super().__init__(connection=connection, config=config, default_row_type=default_row_type)
|
|
44
44
|
|
|
45
45
|
def _build_statement(
|
|
@@ -57,41 +57,61 @@ class SyncDriverAdapterProtocol(CommonDriverAttributesMixin[ConnectionT, RowT],
|
|
|
57
57
|
# If statement is already a SQL object, handle additional parameters
|
|
58
58
|
if isinstance(statement, SQL):
|
|
59
59
|
if parameters or kwargs:
|
|
60
|
-
|
|
61
|
-
|
|
60
|
+
new_config = _config
|
|
61
|
+
if self.dialect and not new_config.dialect:
|
|
62
|
+
new_config = replace(new_config, dialect=self.dialect)
|
|
63
|
+
# Use raw SQL if available to ensure proper parsing with dialect
|
|
64
|
+
sql_source = statement._raw_sql or statement._statement
|
|
65
|
+
# Preserve filters and state when creating new SQL object
|
|
66
|
+
existing_state = {
|
|
67
|
+
"is_many": statement._is_many,
|
|
68
|
+
"is_script": statement._is_script,
|
|
69
|
+
"original_parameters": statement._original_parameters,
|
|
70
|
+
"filters": statement._filters,
|
|
71
|
+
"positional_params": statement._positional_params,
|
|
72
|
+
"named_params": statement._named_params,
|
|
73
|
+
}
|
|
74
|
+
return SQL(sql_source, *parameters, config=new_config, _existing_state=existing_state, **kwargs)
|
|
75
|
+
# Even without additional parameters, ensure dialect is set
|
|
76
|
+
if self.dialect and (not statement._config.dialect or statement._config.dialect != self.dialect):
|
|
77
|
+
new_config = replace(statement._config, dialect=self.dialect)
|
|
78
|
+
# Use raw SQL if available to ensure proper parsing with dialect
|
|
79
|
+
sql_source = statement._raw_sql or statement._statement
|
|
80
|
+
# Preserve parameters and state when creating new SQL object
|
|
81
|
+
# Use the public parameters property which always has the right value
|
|
82
|
+
existing_state = {
|
|
83
|
+
"is_many": statement._is_many,
|
|
84
|
+
"is_script": statement._is_script,
|
|
85
|
+
"original_parameters": statement._original_parameters,
|
|
86
|
+
"filters": statement._filters,
|
|
87
|
+
"positional_params": statement._positional_params,
|
|
88
|
+
"named_params": statement._named_params,
|
|
89
|
+
}
|
|
90
|
+
if statement.parameters:
|
|
91
|
+
return SQL(
|
|
92
|
+
sql_source, parameters=statement.parameters, config=new_config, _existing_state=existing_state
|
|
93
|
+
)
|
|
94
|
+
return SQL(sql_source, config=new_config, _existing_state=existing_state)
|
|
62
95
|
return statement
|
|
63
|
-
|
|
96
|
+
new_config = _config
|
|
97
|
+
if self.dialect and not new_config.dialect:
|
|
98
|
+
new_config = replace(new_config, dialect=self.dialect)
|
|
99
|
+
return SQL(statement, *parameters, config=new_config, **kwargs)
|
|
64
100
|
|
|
65
101
|
@abstractmethod
|
|
66
102
|
def _execute_statement(
|
|
67
103
|
self, statement: "SQL", connection: "Optional[ConnectionT]" = None, **kwargs: Any
|
|
68
|
-
) -> "
|
|
104
|
+
) -> "SQLResult[RowT]":
|
|
69
105
|
"""Actual execution implementation by concrete drivers, using the raw connection.
|
|
70
106
|
|
|
71
|
-
Returns
|
|
107
|
+
Returns SQLResult directly based on the statement type.
|
|
72
108
|
"""
|
|
73
109
|
raise NotImplementedError
|
|
74
110
|
|
|
75
|
-
@abstractmethod
|
|
76
|
-
def _wrap_select_result(
|
|
77
|
-
self,
|
|
78
|
-
statement: "SQL",
|
|
79
|
-
result: "SelectResultDict",
|
|
80
|
-
schema_type: "Optional[type[ModelDTOT]]" = None,
|
|
81
|
-
**kwargs: Any,
|
|
82
|
-
) -> "Union[SQLResult[ModelDTOT], SQLResult[RowT]]":
|
|
83
|
-
raise NotImplementedError
|
|
84
|
-
|
|
85
|
-
@abstractmethod
|
|
86
|
-
def _wrap_execute_result(
|
|
87
|
-
self, statement: "SQL", result: "Union[DMLResultDict, ScriptResultDict]", **kwargs: Any
|
|
88
|
-
) -> "SQLResult[RowT]":
|
|
89
|
-
raise NotImplementedError
|
|
90
|
-
|
|
91
111
|
@overload
|
|
92
112
|
def execute(
|
|
93
113
|
self,
|
|
94
|
-
statement: "
|
|
114
|
+
statement: "Select",
|
|
95
115
|
/,
|
|
96
116
|
*parameters: "Union[StatementParameters, StatementFilter]",
|
|
97
117
|
schema_type: "type[ModelDTOT]",
|
|
@@ -103,7 +123,7 @@ class SyncDriverAdapterProtocol(CommonDriverAttributesMixin[ConnectionT, RowT],
|
|
|
103
123
|
@overload
|
|
104
124
|
def execute(
|
|
105
125
|
self,
|
|
106
|
-
statement: "
|
|
126
|
+
statement: "Select",
|
|
107
127
|
/,
|
|
108
128
|
*parameters: "Union[StatementParameters, StatementFilter]",
|
|
109
129
|
schema_type: None = None,
|
|
@@ -115,7 +135,7 @@ class SyncDriverAdapterProtocol(CommonDriverAttributesMixin[ConnectionT, RowT],
|
|
|
115
135
|
@overload
|
|
116
136
|
def execute(
|
|
117
137
|
self,
|
|
118
|
-
statement: "Union[
|
|
138
|
+
statement: "Union[Insert, Update, Delete]",
|
|
119
139
|
/,
|
|
120
140
|
*parameters: "Union[StatementParameters, StatementFilter]",
|
|
121
141
|
_connection: "Optional[ConnectionT]" = None,
|
|
@@ -126,7 +146,7 @@ class SyncDriverAdapterProtocol(CommonDriverAttributesMixin[ConnectionT, RowT],
|
|
|
126
146
|
@overload
|
|
127
147
|
def execute(
|
|
128
148
|
self,
|
|
129
|
-
statement: "
|
|
149
|
+
statement: "Union[str, SQL]",
|
|
130
150
|
/,
|
|
131
151
|
*parameters: "Union[StatementParameters, StatementFilter]",
|
|
132
152
|
schema_type: "type[ModelDTOT]",
|
|
@@ -160,13 +180,21 @@ class SyncDriverAdapterProtocol(CommonDriverAttributesMixin[ConnectionT, RowT],
|
|
|
160
180
|
sql_statement = self._build_statement(statement, *parameters, _config=_config or self.config, **kwargs)
|
|
161
181
|
result = self._execute_statement(statement=sql_statement, connection=self._connection(_connection), **kwargs)
|
|
162
182
|
|
|
163
|
-
|
|
164
|
-
|
|
165
|
-
|
|
183
|
+
# If schema_type is provided and we have data, convert it
|
|
184
|
+
if schema_type and result.data and can_convert_to_schema(self):
|
|
185
|
+
converted_data = list(self.to_schema(data=result.data, schema_type=schema_type))
|
|
186
|
+
return SQLResult[ModelDTOT](
|
|
187
|
+
statement=result.statement,
|
|
188
|
+
data=converted_data,
|
|
189
|
+
column_names=result.column_names,
|
|
190
|
+
rows_affected=result.rows_affected,
|
|
191
|
+
operation_type=result.operation_type,
|
|
192
|
+
last_inserted_id=result.last_inserted_id,
|
|
193
|
+
execution_time=result.execution_time,
|
|
194
|
+
metadata=result.metadata,
|
|
166
195
|
)
|
|
167
|
-
|
|
168
|
-
|
|
169
|
-
)
|
|
196
|
+
|
|
197
|
+
return result
|
|
170
198
|
|
|
171
199
|
def execute_many(
|
|
172
200
|
self,
|
|
@@ -177,37 +205,19 @@ class SyncDriverAdapterProtocol(CommonDriverAttributesMixin[ConnectionT, RowT],
|
|
|
177
205
|
_config: "Optional[SQLConfig]" = None,
|
|
178
206
|
**kwargs: Any,
|
|
179
207
|
) -> "SQLResult[RowT]":
|
|
180
|
-
|
|
181
|
-
|
|
182
|
-
|
|
183
|
-
for
|
|
184
|
-
|
|
185
|
-
|
|
186
|
-
|
|
187
|
-
|
|
208
|
+
_filters, param_sequence = process_execute_many_parameters(parameters)
|
|
209
|
+
|
|
210
|
+
# For execute_many, disable transformations to prevent literal extraction
|
|
211
|
+
# since the SQL already has placeholders for bulk operations
|
|
212
|
+
many_config = _config or self.config
|
|
213
|
+
if many_config.enable_transformations:
|
|
214
|
+
from dataclasses import replace
|
|
215
|
+
|
|
216
|
+
many_config = replace(many_config, enable_transformations=False)
|
|
188
217
|
|
|
189
|
-
|
|
190
|
-
param_sequence = param_sequences[0] if param_sequences else None
|
|
191
|
-
# Convert tuple to list if needed
|
|
192
|
-
if isinstance(param_sequence, tuple):
|
|
193
|
-
param_sequence = list(param_sequence)
|
|
194
|
-
# Ensure param_sequence is a list or None
|
|
195
|
-
if param_sequence is not None and not isinstance(param_sequence, list):
|
|
196
|
-
param_sequence = list(param_sequence) if hasattr(param_sequence, "__iter__") else None
|
|
197
|
-
sql_statement = self._build_statement(statement, _config=_config or self.config, **kwargs).as_many(
|
|
198
|
-
param_sequence
|
|
199
|
-
)
|
|
218
|
+
sql_statement = self._build_statement(statement, _config=many_config, **kwargs).as_many(param_sequence)
|
|
200
219
|
|
|
201
|
-
|
|
202
|
-
statement=sql_statement,
|
|
203
|
-
connection=self._connection(_connection),
|
|
204
|
-
parameters=param_sequence,
|
|
205
|
-
is_many=True,
|
|
206
|
-
**kwargs,
|
|
207
|
-
)
|
|
208
|
-
return self._wrap_execute_result(
|
|
209
|
-
sql_statement, cast("Union[DMLResultDict, ScriptResultDict]", result), **kwargs
|
|
210
|
-
)
|
|
220
|
+
return self._execute_statement(statement=sql_statement, connection=self._connection(_connection), **kwargs)
|
|
211
221
|
|
|
212
222
|
def execute_script(
|
|
213
223
|
self,
|
|
@@ -218,44 +228,10 @@ class SyncDriverAdapterProtocol(CommonDriverAttributesMixin[ConnectionT, RowT],
|
|
|
218
228
|
_config: "Optional[SQLConfig]" = None,
|
|
219
229
|
**kwargs: Any,
|
|
220
230
|
) -> "SQLResult[RowT]":
|
|
221
|
-
# Separate parameters from filters
|
|
222
|
-
param_values = []
|
|
223
|
-
filters = []
|
|
224
|
-
for param in parameters:
|
|
225
|
-
if isinstance(param, StatementFilter):
|
|
226
|
-
filters.append(param)
|
|
227
|
-
else:
|
|
228
|
-
param_values.append(param)
|
|
229
|
-
|
|
230
|
-
# Use first parameter as the primary parameter value, or None if no parameters
|
|
231
|
-
primary_params = param_values[0] if param_values else None
|
|
232
|
-
|
|
233
231
|
script_config = _config or self.config
|
|
234
232
|
if script_config.enable_validation:
|
|
235
|
-
script_config =
|
|
236
|
-
enable_parsing=script_config.enable_parsing,
|
|
237
|
-
enable_validation=False,
|
|
238
|
-
enable_transformations=script_config.enable_transformations,
|
|
239
|
-
enable_analysis=script_config.enable_analysis,
|
|
240
|
-
strict_mode=False,
|
|
241
|
-
cache_parsed_expression=script_config.cache_parsed_expression,
|
|
242
|
-
parameter_converter=script_config.parameter_converter,
|
|
243
|
-
parameter_validator=script_config.parameter_validator,
|
|
244
|
-
analysis_cache_size=script_config.analysis_cache_size,
|
|
245
|
-
allowed_parameter_styles=script_config.allowed_parameter_styles,
|
|
246
|
-
target_parameter_style=script_config.target_parameter_style,
|
|
247
|
-
allow_mixed_parameter_styles=script_config.allow_mixed_parameter_styles,
|
|
248
|
-
)
|
|
233
|
+
script_config = replace(script_config, enable_validation=False, strict_mode=False)
|
|
249
234
|
|
|
250
|
-
sql_statement =
|
|
235
|
+
sql_statement = self._build_statement(statement, *parameters, _config=script_config, **kwargs)
|
|
251
236
|
sql_statement = sql_statement.as_script()
|
|
252
|
-
|
|
253
|
-
statement=sql_statement, connection=self._connection(_connection), is_script=True, **kwargs
|
|
254
|
-
)
|
|
255
|
-
if isinstance(script_output, str):
|
|
256
|
-
result = SQLResult[RowT](statement=sql_statement, data=[], operation_type="SCRIPT")
|
|
257
|
-
result.total_statements = 1
|
|
258
|
-
result.successful_statements = 1
|
|
259
|
-
return result
|
|
260
|
-
# Wrap the ScriptResultDict using the driver's wrapper
|
|
261
|
-
return self._wrap_execute_result(sql_statement, cast("ScriptResultDict", script_output), **kwargs)
|
|
237
|
+
return self._execute_statement(statement=sql_statement, connection=self._connection(_connection), **kwargs)
|
|
@@ -0,0 +1,207 @@
|
|
|
1
|
+
"""Consolidated connection management utilities for database drivers.
|
|
2
|
+
|
|
3
|
+
This module provides centralized connection handling to avoid duplication
|
|
4
|
+
across database adapter implementations.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
from contextlib import asynccontextmanager, contextmanager
|
|
8
|
+
from typing import TYPE_CHECKING, Any, Optional, TypeVar, cast
|
|
9
|
+
|
|
10
|
+
if TYPE_CHECKING:
|
|
11
|
+
from collections.abc import AsyncIterator, Iterator
|
|
12
|
+
|
|
13
|
+
from sqlspec.utils.type_guards import is_async_transaction_capable, is_sync_transaction_capable
|
|
14
|
+
|
|
15
|
+
__all__ = (
|
|
16
|
+
"get_connection_info",
|
|
17
|
+
"managed_connection_async",
|
|
18
|
+
"managed_connection_sync",
|
|
19
|
+
"managed_transaction_async",
|
|
20
|
+
"managed_transaction_sync",
|
|
21
|
+
"validate_pool_config",
|
|
22
|
+
)
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
ConnectionT = TypeVar("ConnectionT")
|
|
26
|
+
PoolT = TypeVar("PoolT")
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
@contextmanager
|
|
30
|
+
def managed_connection_sync(config: Any, provided_connection: Optional[ConnectionT] = None) -> "Iterator[ConnectionT]":
|
|
31
|
+
"""Context manager for database connections.
|
|
32
|
+
|
|
33
|
+
Args:
|
|
34
|
+
config: Database configuration with provide_connection method
|
|
35
|
+
provided_connection: Optional existing connection to use
|
|
36
|
+
|
|
37
|
+
Yields:
|
|
38
|
+
Database connection
|
|
39
|
+
"""
|
|
40
|
+
if provided_connection is not None:
|
|
41
|
+
yield provided_connection
|
|
42
|
+
return
|
|
43
|
+
|
|
44
|
+
# Get connection from config
|
|
45
|
+
with config.provide_connection() as connection:
|
|
46
|
+
yield connection
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
@contextmanager
|
|
50
|
+
def managed_transaction_sync(connection: ConnectionT, auto_commit: bool = True) -> "Iterator[ConnectionT]":
|
|
51
|
+
"""Context manager for database transactions.
|
|
52
|
+
|
|
53
|
+
Args:
|
|
54
|
+
connection: Database connection
|
|
55
|
+
auto_commit: Whether to auto-commit on success
|
|
56
|
+
|
|
57
|
+
Yields:
|
|
58
|
+
Database connection
|
|
59
|
+
"""
|
|
60
|
+
# Check if connection already has autocommit enabled
|
|
61
|
+
has_autocommit = getattr(connection, "autocommit", False)
|
|
62
|
+
|
|
63
|
+
if not auto_commit or not is_sync_transaction_capable(connection) or has_autocommit:
|
|
64
|
+
yield connection
|
|
65
|
+
return
|
|
66
|
+
|
|
67
|
+
try:
|
|
68
|
+
yield cast("ConnectionT", connection)
|
|
69
|
+
cast("Any", connection).commit()
|
|
70
|
+
except Exception:
|
|
71
|
+
# Some databases (like DuckDB) throw an error if rollback is called
|
|
72
|
+
# when no transaction is active. Catch and ignore these specific errors.
|
|
73
|
+
try:
|
|
74
|
+
cast("Any", connection).rollback()
|
|
75
|
+
except Exception as rollback_error:
|
|
76
|
+
# Check if this is a "no transaction active" type error
|
|
77
|
+
error_msg = str(rollback_error).lower()
|
|
78
|
+
if "no transaction" in error_msg or "transaction context error" in error_msg:
|
|
79
|
+
# Ignore rollback errors when no transaction is active
|
|
80
|
+
pass
|
|
81
|
+
else:
|
|
82
|
+
# Re-raise other rollback errors
|
|
83
|
+
raise
|
|
84
|
+
raise
|
|
85
|
+
|
|
86
|
+
|
|
87
|
+
@asynccontextmanager
|
|
88
|
+
async def managed_connection_async(
|
|
89
|
+
config: Any, provided_connection: Optional[ConnectionT] = None
|
|
90
|
+
) -> "AsyncIterator[ConnectionT]":
|
|
91
|
+
"""Async context manager for database connections.
|
|
92
|
+
|
|
93
|
+
Args:
|
|
94
|
+
config: Database configuration with provide_connection method
|
|
95
|
+
provided_connection: Optional existing connection to use
|
|
96
|
+
|
|
97
|
+
Yields:
|
|
98
|
+
Database connection
|
|
99
|
+
"""
|
|
100
|
+
if provided_connection is not None:
|
|
101
|
+
yield provided_connection
|
|
102
|
+
return
|
|
103
|
+
|
|
104
|
+
# Get connection from config
|
|
105
|
+
async with config.provide_connection() as connection:
|
|
106
|
+
yield connection
|
|
107
|
+
|
|
108
|
+
|
|
109
|
+
@asynccontextmanager
|
|
110
|
+
async def managed_transaction_async(connection: ConnectionT, auto_commit: bool = True) -> "AsyncIterator[ConnectionT]":
|
|
111
|
+
"""Async context manager for database transactions.
|
|
112
|
+
|
|
113
|
+
Args:
|
|
114
|
+
connection: Database connection
|
|
115
|
+
auto_commit: Whether to auto-commit on success
|
|
116
|
+
|
|
117
|
+
Yields:
|
|
118
|
+
Database connection
|
|
119
|
+
"""
|
|
120
|
+
# Check if connection already has autocommit enabled
|
|
121
|
+
has_autocommit = getattr(connection, "autocommit", False)
|
|
122
|
+
|
|
123
|
+
if not auto_commit or not is_async_transaction_capable(connection) or has_autocommit:
|
|
124
|
+
yield connection
|
|
125
|
+
return
|
|
126
|
+
|
|
127
|
+
try:
|
|
128
|
+
yield cast("ConnectionT", connection)
|
|
129
|
+
await cast("Any", connection).commit()
|
|
130
|
+
except Exception:
|
|
131
|
+
# Some databases (like DuckDB) throw an error if rollback is called
|
|
132
|
+
# when no transaction is active. Catch and ignore these specific errors.
|
|
133
|
+
try:
|
|
134
|
+
await cast("Any", connection).rollback()
|
|
135
|
+
except Exception as rollback_error:
|
|
136
|
+
# Check if this is a "no transaction active" type error
|
|
137
|
+
error_msg = str(rollback_error).lower()
|
|
138
|
+
if "no transaction" in error_msg or "transaction context error" in error_msg:
|
|
139
|
+
# Ignore rollback errors when no transaction is active
|
|
140
|
+
pass
|
|
141
|
+
else:
|
|
142
|
+
# Re-raise other rollback errors
|
|
143
|
+
raise
|
|
144
|
+
raise
|
|
145
|
+
|
|
146
|
+
|
|
147
|
+
def get_connection_info(connection: Any) -> dict[str, Any]:
|
|
148
|
+
"""Extract connection information for logging/debugging.
|
|
149
|
+
|
|
150
|
+
Args:
|
|
151
|
+
connection: Database connection object
|
|
152
|
+
|
|
153
|
+
Returns:
|
|
154
|
+
Dictionary of connection information
|
|
155
|
+
"""
|
|
156
|
+
info = {"type": type(connection).__name__, "module": type(connection).__module__}
|
|
157
|
+
|
|
158
|
+
# Try to get database name
|
|
159
|
+
for attr in ("database", "dbname", "db", "catalog"):
|
|
160
|
+
value = getattr(connection, attr, None)
|
|
161
|
+
if value is not None:
|
|
162
|
+
info["database"] = value
|
|
163
|
+
break
|
|
164
|
+
|
|
165
|
+
# Try to get host information
|
|
166
|
+
for attr in ("host", "hostname", "server"):
|
|
167
|
+
value = getattr(connection, attr, None)
|
|
168
|
+
if value is not None:
|
|
169
|
+
info["host"] = value
|
|
170
|
+
break
|
|
171
|
+
|
|
172
|
+
return info
|
|
173
|
+
|
|
174
|
+
|
|
175
|
+
def validate_pool_config(
|
|
176
|
+
min_size: int, max_size: int, max_idle_time: Optional[int] = None, max_lifetime: Optional[int] = None
|
|
177
|
+
) -> None:
|
|
178
|
+
"""Validate connection pool configuration.
|
|
179
|
+
|
|
180
|
+
Args:
|
|
181
|
+
min_size: Minimum pool size
|
|
182
|
+
max_size: Maximum pool size
|
|
183
|
+
max_idle_time: Maximum idle time in seconds
|
|
184
|
+
max_lifetime: Maximum connection lifetime in seconds
|
|
185
|
+
|
|
186
|
+
Raises:
|
|
187
|
+
ValueError: If configuration is invalid
|
|
188
|
+
"""
|
|
189
|
+
if min_size < 0:
|
|
190
|
+
msg = f"min_size must be >= 0, got {min_size}"
|
|
191
|
+
raise ValueError(msg)
|
|
192
|
+
|
|
193
|
+
if max_size < 1:
|
|
194
|
+
msg = f"max_size must be >= 1, got {max_size}"
|
|
195
|
+
raise ValueError(msg)
|
|
196
|
+
|
|
197
|
+
if min_size > max_size:
|
|
198
|
+
msg = f"min_size ({min_size}) cannot be greater than max_size ({max_size})"
|
|
199
|
+
raise ValueError(msg)
|
|
200
|
+
|
|
201
|
+
if max_idle_time is not None and max_idle_time < 0:
|
|
202
|
+
msg = f"max_idle_time must be >= 0, got {max_idle_time}"
|
|
203
|
+
raise ValueError(msg)
|
|
204
|
+
|
|
205
|
+
if max_lifetime is not None and max_lifetime < 0:
|
|
206
|
+
msg = f"max_lifetime must be >= 0, got {max_lifetime}"
|
|
207
|
+
raise ValueError(msg)
|