sqlspec 0.12.2__py3-none-any.whl → 0.13.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of sqlspec might be problematic. Click here for more details.

Files changed (113) hide show
  1. sqlspec/_sql.py +21 -180
  2. sqlspec/adapters/adbc/config.py +10 -12
  3. sqlspec/adapters/adbc/driver.py +120 -118
  4. sqlspec/adapters/aiosqlite/config.py +3 -3
  5. sqlspec/adapters/aiosqlite/driver.py +100 -130
  6. sqlspec/adapters/asyncmy/config.py +3 -4
  7. sqlspec/adapters/asyncmy/driver.py +123 -135
  8. sqlspec/adapters/asyncpg/config.py +3 -7
  9. sqlspec/adapters/asyncpg/driver.py +98 -140
  10. sqlspec/adapters/bigquery/config.py +4 -5
  11. sqlspec/adapters/bigquery/driver.py +125 -167
  12. sqlspec/adapters/duckdb/config.py +3 -6
  13. sqlspec/adapters/duckdb/driver.py +114 -111
  14. sqlspec/adapters/oracledb/config.py +6 -5
  15. sqlspec/adapters/oracledb/driver.py +242 -259
  16. sqlspec/adapters/psqlpy/config.py +3 -7
  17. sqlspec/adapters/psqlpy/driver.py +118 -93
  18. sqlspec/adapters/psycopg/config.py +18 -31
  19. sqlspec/adapters/psycopg/driver.py +283 -236
  20. sqlspec/adapters/sqlite/config.py +3 -3
  21. sqlspec/adapters/sqlite/driver.py +103 -97
  22. sqlspec/config.py +0 -4
  23. sqlspec/driver/_async.py +89 -98
  24. sqlspec/driver/_common.py +52 -17
  25. sqlspec/driver/_sync.py +81 -105
  26. sqlspec/driver/connection.py +207 -0
  27. sqlspec/driver/mixins/_csv_writer.py +91 -0
  28. sqlspec/driver/mixins/_pipeline.py +38 -49
  29. sqlspec/driver/mixins/_result_utils.py +27 -9
  30. sqlspec/driver/mixins/_storage.py +67 -181
  31. sqlspec/driver/mixins/_type_coercion.py +3 -4
  32. sqlspec/driver/parameters.py +138 -0
  33. sqlspec/exceptions.py +10 -2
  34. sqlspec/extensions/aiosql/adapter.py +0 -10
  35. sqlspec/extensions/litestar/handlers.py +0 -1
  36. sqlspec/extensions/litestar/plugin.py +0 -3
  37. sqlspec/extensions/litestar/providers.py +0 -14
  38. sqlspec/loader.py +25 -90
  39. sqlspec/protocols.py +542 -0
  40. sqlspec/service/__init__.py +3 -2
  41. sqlspec/service/_util.py +147 -0
  42. sqlspec/service/base.py +1116 -9
  43. sqlspec/statement/builder/__init__.py +42 -32
  44. sqlspec/statement/builder/_ddl_utils.py +0 -10
  45. sqlspec/statement/builder/_parsing_utils.py +10 -4
  46. sqlspec/statement/builder/base.py +67 -22
  47. sqlspec/statement/builder/column.py +283 -0
  48. sqlspec/statement/builder/ddl.py +91 -67
  49. sqlspec/statement/builder/delete.py +23 -7
  50. sqlspec/statement/builder/insert.py +29 -15
  51. sqlspec/statement/builder/merge.py +4 -4
  52. sqlspec/statement/builder/mixins/_aggregate_functions.py +113 -14
  53. sqlspec/statement/builder/mixins/_common_table_expr.py +0 -1
  54. sqlspec/statement/builder/mixins/_delete_from.py +1 -1
  55. sqlspec/statement/builder/mixins/_from.py +10 -8
  56. sqlspec/statement/builder/mixins/_group_by.py +0 -1
  57. sqlspec/statement/builder/mixins/_insert_from_select.py +0 -1
  58. sqlspec/statement/builder/mixins/_insert_values.py +0 -2
  59. sqlspec/statement/builder/mixins/_join.py +20 -13
  60. sqlspec/statement/builder/mixins/_limit_offset.py +3 -3
  61. sqlspec/statement/builder/mixins/_merge_clauses.py +3 -4
  62. sqlspec/statement/builder/mixins/_order_by.py +2 -2
  63. sqlspec/statement/builder/mixins/_pivot.py +4 -7
  64. sqlspec/statement/builder/mixins/_select_columns.py +6 -5
  65. sqlspec/statement/builder/mixins/_unpivot.py +6 -9
  66. sqlspec/statement/builder/mixins/_update_from.py +2 -1
  67. sqlspec/statement/builder/mixins/_update_set.py +11 -8
  68. sqlspec/statement/builder/mixins/_where.py +61 -34
  69. sqlspec/statement/builder/select.py +32 -17
  70. sqlspec/statement/builder/update.py +25 -11
  71. sqlspec/statement/filters.py +39 -14
  72. sqlspec/statement/parameter_manager.py +220 -0
  73. sqlspec/statement/parameters.py +210 -79
  74. sqlspec/statement/pipelines/__init__.py +166 -23
  75. sqlspec/statement/pipelines/analyzers/_analyzer.py +21 -20
  76. sqlspec/statement/pipelines/context.py +35 -39
  77. sqlspec/statement/pipelines/transformers/__init__.py +2 -3
  78. sqlspec/statement/pipelines/transformers/_expression_simplifier.py +19 -187
  79. sqlspec/statement/pipelines/transformers/_literal_parameterizer.py +628 -58
  80. sqlspec/statement/pipelines/transformers/_remove_comments_and_hints.py +76 -0
  81. sqlspec/statement/pipelines/validators/_dml_safety.py +33 -18
  82. sqlspec/statement/pipelines/validators/_parameter_style.py +87 -14
  83. sqlspec/statement/pipelines/validators/_performance.py +38 -23
  84. sqlspec/statement/pipelines/validators/_security.py +39 -62
  85. sqlspec/statement/result.py +37 -129
  86. sqlspec/statement/splitter.py +0 -12
  87. sqlspec/statement/sql.py +863 -391
  88. sqlspec/statement/sql_compiler.py +140 -0
  89. sqlspec/storage/__init__.py +10 -2
  90. sqlspec/storage/backends/fsspec.py +53 -8
  91. sqlspec/storage/backends/obstore.py +15 -19
  92. sqlspec/storage/capabilities.py +101 -0
  93. sqlspec/storage/registry.py +56 -83
  94. sqlspec/typing.py +6 -434
  95. sqlspec/utils/cached_property.py +25 -0
  96. sqlspec/utils/correlation.py +0 -2
  97. sqlspec/utils/logging.py +0 -6
  98. sqlspec/utils/sync_tools.py +0 -4
  99. sqlspec/utils/text.py +0 -5
  100. sqlspec/utils/type_guards.py +892 -0
  101. {sqlspec-0.12.2.dist-info → sqlspec-0.13.0.dist-info}/METADATA +1 -1
  102. sqlspec-0.13.0.dist-info/RECORD +150 -0
  103. sqlspec/statement/builder/protocols.py +0 -20
  104. sqlspec/statement/pipelines/base.py +0 -315
  105. sqlspec/statement/pipelines/result_types.py +0 -41
  106. sqlspec/statement/pipelines/transformers/_remove_comments.py +0 -66
  107. sqlspec/statement/pipelines/transformers/_remove_hints.py +0 -81
  108. sqlspec/statement/pipelines/validators/base.py +0 -67
  109. sqlspec/storage/protocol.py +0 -173
  110. sqlspec-0.12.2.dist-info/RECORD +0 -145
  111. {sqlspec-0.12.2.dist-info → sqlspec-0.13.0.dist-info}/WHEEL +0 -0
  112. {sqlspec-0.12.2.dist-info → sqlspec-0.13.0.dist-info}/licenses/LICENSE +0 -0
  113. {sqlspec-0.12.2.dist-info → sqlspec-0.13.0.dist-info}/licenses/NOTICE +0 -0
sqlspec/driver/_common.py CHANGED
@@ -9,9 +9,10 @@ import sqlglot
9
9
  from sqlglot import exp
10
10
  from sqlglot.tokens import TokenType
11
11
 
12
+ from sqlspec.driver.parameters import normalize_parameter_sequence
12
13
  from sqlspec.exceptions import NotFoundError
13
14
  from sqlspec.statement import SQLConfig
14
- from sqlspec.statement.parameters import ParameterStyle, ParameterValidator
15
+ from sqlspec.statement.parameters import ParameterStyle, ParameterValidator, TypedParameter
15
16
  from sqlspec.statement.splitter import split_sql_script
16
17
  from sqlspec.typing import ConnectionT, DictRow, RowT, T
17
18
  from sqlspec.utils.logging import get_logger
@@ -84,7 +85,6 @@ class CommonDriverAttributesMixin(ABC, Generic[ConnectionT, RowT]):
84
85
  return self.returns_rows(expression.expressions[-1])
85
86
  if isinstance(expression, (exp.Insert, exp.Update, exp.Delete)):
86
87
  return bool(expression.find(exp.Returning))
87
- # Handle Anonymous expressions (failed to parse) using a robust approach
88
88
  if isinstance(expression, exp.Anonymous):
89
89
  return self._check_anonymous_returns_rows(expression)
90
90
  return False
@@ -113,13 +113,11 @@ class CommonDriverAttributesMixin(ABC, Generic[ConnectionT, RowT]):
113
113
 
114
114
  # Approach 1: Try to re-parse with placeholders replaced
115
115
  try:
116
- # Replace placeholders with a dummy literal that sqlglot can parse
117
116
  sanitized_sql = placeholder_regex.sub("1", sql_text)
118
117
 
119
118
  # If we replaced any placeholders, try parsing again
120
119
  if sanitized_sql != sql_text:
121
120
  parsed = sqlglot.parse_one(sanitized_sql, read=None)
122
- # Check if it's a query type that returns rows
123
121
  if isinstance(
124
122
  parsed, (exp.Select, exp.Values, exp.Table, exp.Show, exp.Describe, exp.Pragma, exp.Command)
125
123
  ):
@@ -193,15 +191,12 @@ class CommonDriverAttributesMixin(ABC, Generic[ConnectionT, RowT]):
193
191
  if parameters is None:
194
192
  return None
195
193
 
196
- # Extract parameter info from the SQL
197
194
  validator = ParameterValidator()
198
195
  param_info_list = validator.extract_parameters(sql)
199
196
 
200
197
  if not param_info_list:
201
- # No parameters in SQL, return None
202
198
  return None
203
199
 
204
- # Determine the target style from the SQL if not provided
205
200
  if target_style is None:
206
201
  target_style = self.default_parameter_style
207
202
 
@@ -220,7 +215,6 @@ class CommonDriverAttributesMixin(ABC, Generic[ConnectionT, RowT]):
220
215
  ParameterStyle.NAMED_PYFORMAT,
221
216
  }
222
217
 
223
- # Check if parameters are already in the correct format
224
218
  params_are_dict = isinstance(parameters, (dict, Mapping))
225
219
  params_are_sequence = isinstance(parameters, (list, tuple, Sequence)) and not isinstance(
226
220
  parameters, (str, bytes)
@@ -229,7 +223,6 @@ class CommonDriverAttributesMixin(ABC, Generic[ConnectionT, RowT]):
229
223
  # Single scalar parameter
230
224
  if len(param_info_list) == 1 and not params_are_dict and not params_are_sequence:
231
225
  if driver_expects_dict:
232
- # Convert scalar to dict
233
226
  param_info = param_info_list[0]
234
227
  if param_info.name:
235
228
  return {param_info.name: parameters}
@@ -242,7 +235,6 @@ class CommonDriverAttributesMixin(ABC, Generic[ConnectionT, RowT]):
242
235
  ):
243
236
  # If all parameters are numeric but named, convert to dict
244
237
  # SQL has numeric placeholders but params might have named keys
245
- # Only convert if keys don't match
246
238
  numeric_keys_expected = {p.name for p in param_info_list if p.name}
247
239
  if not numeric_keys_expected.issubset(parameters.keys()):
248
240
  # Need to convert named keys to numeric positions
@@ -255,7 +247,6 @@ class CommonDriverAttributesMixin(ABC, Generic[ConnectionT, RowT]):
255
247
 
256
248
  # Special case: Auto-generated param_N style when SQL expects specific names
257
249
  if all(key.startswith("param_") and key[6:].isdigit() for key in parameters):
258
- # Check if SQL has different parameter names
259
250
  sql_param_names = {p.name for p in param_info_list if p.name}
260
251
  if sql_param_names and not any(name.startswith("param_") for name in sql_param_names):
261
252
  # SQL has specific names, not param_N style - don't use these params as-is
@@ -263,7 +254,6 @@ class CommonDriverAttributesMixin(ABC, Generic[ConnectionT, RowT]):
263
254
  # For now, pass through and let validation catch it
264
255
  pass
265
256
 
266
- # Otherwise, dict format matches - return as-is
267
257
  return parameters
268
258
 
269
259
  if not driver_expects_dict and params_are_sequence:
@@ -272,11 +262,9 @@ class CommonDriverAttributesMixin(ABC, Generic[ConnectionT, RowT]):
272
262
 
273
263
  # Formats don't match - need conversion
274
264
  if driver_expects_dict and params_are_sequence:
275
- # Convert positional to dict
276
265
  dict_result: dict[str, Any] = {}
277
266
  for i, (param_info, value) in enumerate(zip(param_info_list, parameters)):
278
267
  if param_info.name:
279
- # Use the name from SQL
280
268
  if param_info.style == ParameterStyle.POSITIONAL_COLON and param_info.name.isdigit():
281
269
  # Oracle uses string keys even for numeric placeholders
282
270
  dict_result[param_info.name] = value
@@ -288,10 +276,8 @@ class CommonDriverAttributesMixin(ABC, Generic[ConnectionT, RowT]):
288
276
  return dict_result
289
277
 
290
278
  if not driver_expects_dict and params_are_dict:
291
- # Convert dict to positional
292
279
  # First check if it's already in param_N format
293
280
  if all(key.startswith("param_") and key[6:].isdigit() for key in parameters):
294
- # Extract values in order
295
281
  positional_result: list[Any] = []
296
282
  for i in range(len(param_info_list)):
297
283
  key = f"param_{i}"
@@ -299,7 +285,6 @@ class CommonDriverAttributesMixin(ABC, Generic[ConnectionT, RowT]):
299
285
  positional_result.append(parameters[key])
300
286
  return positional_result
301
287
 
302
- # Convert named dict to positional based on parameter order in SQL
303
288
  positional_params: list[Any] = []
304
289
  for param_info in param_info_list:
305
290
  if param_info.name and param_info.name in parameters:
@@ -336,3 +321,53 @@ class CommonDriverAttributesMixin(ABC, Generic[ConnectionT, RowT]):
336
321
  """
337
322
  # The split_sql_script function already handles dialect mapping and fallback
338
323
  return split_sql_script(script, dialect=str(self.dialect), strip_trailing_semicolon=strip_trailing_semicolon)
324
+
325
+ def _prepare_driver_parameters(self, parameters: Any) -> Any:
326
+ """Prepare parameters for database driver consumption by unwrapping TypedParameter objects.
327
+
328
+ This method normalizes parameter structure and unwraps TypedParameter objects
329
+ to their underlying values, which database drivers expect.
330
+
331
+ Args:
332
+ parameters: Parameters in any format (dict, list, tuple, scalar, TypedParameter)
333
+
334
+ Returns:
335
+ Parameters with TypedParameter objects unwrapped to primitive values
336
+ """
337
+
338
+ normalized = normalize_parameter_sequence(parameters)
339
+ if not normalized:
340
+ return []
341
+
342
+ return [self._coerce_parameter(p) if isinstance(p, TypedParameter) else p for p in normalized]
343
+
344
+ def _prepare_driver_parameters_many(self, parameters: Any) -> "list[Any]":
345
+ """Prepare parameter sequences for executemany operations.
346
+
347
+ This method handles sequences of parameter sets, unwrapping TypedParameter
348
+ objects in each set for database driver consumption.
349
+
350
+ Args:
351
+ parameters: Sequence of parameter sets for executemany
352
+
353
+ Returns:
354
+ List of parameter sets with TypedParameter objects unwrapped
355
+ """
356
+ if not parameters:
357
+ return []
358
+ return [self._prepare_driver_parameters(param_set) for param_set in parameters]
359
+
360
+ def _coerce_parameter(self, param: "TypedParameter") -> Any:
361
+ """Coerce TypedParameter to driver-safe value.
362
+
363
+ This method extracts the underlying value from a TypedParameter object.
364
+ Individual drivers can override this method to perform driver-specific
365
+ type coercion using the rich type information available in TypedParameter.
366
+
367
+ Args:
368
+ param: TypedParameter object with value and type information
369
+
370
+ Returns:
371
+ The underlying parameter value suitable for the database driver
372
+ """
373
+ return param.value
sqlspec/driver/_sync.py CHANGED
@@ -1,21 +1,22 @@
1
1
  """Synchronous driver protocol implementation."""
2
2
 
3
3
  from abc import ABC, abstractmethod
4
- from typing import TYPE_CHECKING, Any, Optional, Union, cast, overload
4
+ from dataclasses import replace
5
+ from typing import TYPE_CHECKING, Any, Optional, Union, overload
5
6
 
6
7
  from sqlspec.driver._common import CommonDriverAttributesMixin
7
- from sqlspec.statement.builder import DeleteBuilder, InsertBuilder, QueryBuilder, SelectBuilder, UpdateBuilder
8
- from sqlspec.statement.filters import StatementFilter
8
+ from sqlspec.driver.parameters import process_execute_many_parameters
9
+ from sqlspec.statement.builder import Delete, Insert, QueryBuilder, Select, Update
9
10
  from sqlspec.statement.result import SQLResult
10
11
  from sqlspec.statement.sql import SQL, SQLConfig, Statement
11
12
  from sqlspec.typing import ConnectionT, DictRow, ModelDTOT, RowT, StatementParameters
12
13
  from sqlspec.utils.logging import get_logger
13
-
14
- logger = get_logger("sqlspec")
15
-
14
+ from sqlspec.utils.type_guards import can_convert_to_schema
16
15
 
17
16
  if TYPE_CHECKING:
18
- from sqlspec.statement.result import DMLResultDict, ScriptResultDict, SelectResultDict
17
+ from sqlspec.statement.filters import StatementFilter
18
+
19
+ logger = get_logger("sqlspec")
19
20
 
20
21
  __all__ = ("SyncDriverAdapterProtocol",)
21
22
 
@@ -39,7 +40,6 @@ class SyncDriverAdapterProtocol(CommonDriverAttributesMixin[ConnectionT, RowT],
39
40
  config: SQL statement configuration
40
41
  default_row_type: Default row type for results (DictRow, TupleRow, etc.)
41
42
  """
42
- # Initialize CommonDriverAttributes part
43
43
  super().__init__(connection=connection, config=config, default_row_type=default_row_type)
44
44
 
45
45
  def _build_statement(
@@ -57,41 +57,61 @@ class SyncDriverAdapterProtocol(CommonDriverAttributesMixin[ConnectionT, RowT],
57
57
  # If statement is already a SQL object, handle additional parameters
58
58
  if isinstance(statement, SQL):
59
59
  if parameters or kwargs:
60
- # Create a new SQL object with the same SQL but additional parameters
61
- return SQL(statement._sql, *parameters, _dialect=self.dialect, _config=_config, **kwargs)
60
+ new_config = _config
61
+ if self.dialect and not new_config.dialect:
62
+ new_config = replace(new_config, dialect=self.dialect)
63
+ # Use raw SQL if available to ensure proper parsing with dialect
64
+ sql_source = statement._raw_sql or statement._statement
65
+ # Preserve filters and state when creating new SQL object
66
+ existing_state = {
67
+ "is_many": statement._is_many,
68
+ "is_script": statement._is_script,
69
+ "original_parameters": statement._original_parameters,
70
+ "filters": statement._filters,
71
+ "positional_params": statement._positional_params,
72
+ "named_params": statement._named_params,
73
+ }
74
+ return SQL(sql_source, *parameters, config=new_config, _existing_state=existing_state, **kwargs)
75
+ # Even without additional parameters, ensure dialect is set
76
+ if self.dialect and (not statement._config.dialect or statement._config.dialect != self.dialect):
77
+ new_config = replace(statement._config, dialect=self.dialect)
78
+ # Use raw SQL if available to ensure proper parsing with dialect
79
+ sql_source = statement._raw_sql or statement._statement
80
+ # Preserve parameters and state when creating new SQL object
81
+ # Use the public parameters property which always has the right value
82
+ existing_state = {
83
+ "is_many": statement._is_many,
84
+ "is_script": statement._is_script,
85
+ "original_parameters": statement._original_parameters,
86
+ "filters": statement._filters,
87
+ "positional_params": statement._positional_params,
88
+ "named_params": statement._named_params,
89
+ }
90
+ if statement.parameters:
91
+ return SQL(
92
+ sql_source, parameters=statement.parameters, config=new_config, _existing_state=existing_state
93
+ )
94
+ return SQL(sql_source, config=new_config, _existing_state=existing_state)
62
95
  return statement
63
- return SQL(statement, *parameters, _dialect=self.dialect, _config=_config, **kwargs)
96
+ new_config = _config
97
+ if self.dialect and not new_config.dialect:
98
+ new_config = replace(new_config, dialect=self.dialect)
99
+ return SQL(statement, *parameters, config=new_config, **kwargs)
64
100
 
65
101
  @abstractmethod
66
102
  def _execute_statement(
67
103
  self, statement: "SQL", connection: "Optional[ConnectionT]" = None, **kwargs: Any
68
- ) -> "Union[SelectResultDict, DMLResultDict, ScriptResultDict]":
104
+ ) -> "SQLResult[RowT]":
69
105
  """Actual execution implementation by concrete drivers, using the raw connection.
70
106
 
71
- Returns one of the standardized result dictionaries based on the statement type.
107
+ Returns SQLResult directly based on the statement type.
72
108
  """
73
109
  raise NotImplementedError
74
110
 
75
- @abstractmethod
76
- def _wrap_select_result(
77
- self,
78
- statement: "SQL",
79
- result: "SelectResultDict",
80
- schema_type: "Optional[type[ModelDTOT]]" = None,
81
- **kwargs: Any,
82
- ) -> "Union[SQLResult[ModelDTOT], SQLResult[RowT]]":
83
- raise NotImplementedError
84
-
85
- @abstractmethod
86
- def _wrap_execute_result(
87
- self, statement: "SQL", result: "Union[DMLResultDict, ScriptResultDict]", **kwargs: Any
88
- ) -> "SQLResult[RowT]":
89
- raise NotImplementedError
90
-
91
111
  @overload
92
112
  def execute(
93
113
  self,
94
- statement: "SelectBuilder",
114
+ statement: "Select",
95
115
  /,
96
116
  *parameters: "Union[StatementParameters, StatementFilter]",
97
117
  schema_type: "type[ModelDTOT]",
@@ -103,7 +123,7 @@ class SyncDriverAdapterProtocol(CommonDriverAttributesMixin[ConnectionT, RowT],
103
123
  @overload
104
124
  def execute(
105
125
  self,
106
- statement: "SelectBuilder",
126
+ statement: "Select",
107
127
  /,
108
128
  *parameters: "Union[StatementParameters, StatementFilter]",
109
129
  schema_type: None = None,
@@ -115,7 +135,7 @@ class SyncDriverAdapterProtocol(CommonDriverAttributesMixin[ConnectionT, RowT],
115
135
  @overload
116
136
  def execute(
117
137
  self,
118
- statement: "Union[InsertBuilder, UpdateBuilder, DeleteBuilder]",
138
+ statement: "Union[Insert, Update, Delete]",
119
139
  /,
120
140
  *parameters: "Union[StatementParameters, StatementFilter]",
121
141
  _connection: "Optional[ConnectionT]" = None,
@@ -126,7 +146,7 @@ class SyncDriverAdapterProtocol(CommonDriverAttributesMixin[ConnectionT, RowT],
126
146
  @overload
127
147
  def execute(
128
148
  self,
129
- statement: "Statement",
149
+ statement: "Union[str, SQL]",
130
150
  /,
131
151
  *parameters: "Union[StatementParameters, StatementFilter]",
132
152
  schema_type: "type[ModelDTOT]",
@@ -160,13 +180,21 @@ class SyncDriverAdapterProtocol(CommonDriverAttributesMixin[ConnectionT, RowT],
160
180
  sql_statement = self._build_statement(statement, *parameters, _config=_config or self.config, **kwargs)
161
181
  result = self._execute_statement(statement=sql_statement, connection=self._connection(_connection), **kwargs)
162
182
 
163
- if self.returns_rows(sql_statement.expression):
164
- return self._wrap_select_result(
165
- sql_statement, cast("SelectResultDict", result), schema_type=schema_type, **kwargs
183
+ # If schema_type is provided and we have data, convert it
184
+ if schema_type and result.data and can_convert_to_schema(self):
185
+ converted_data = list(self.to_schema(data=result.data, schema_type=schema_type))
186
+ return SQLResult[ModelDTOT](
187
+ statement=result.statement,
188
+ data=converted_data,
189
+ column_names=result.column_names,
190
+ rows_affected=result.rows_affected,
191
+ operation_type=result.operation_type,
192
+ last_inserted_id=result.last_inserted_id,
193
+ execution_time=result.execution_time,
194
+ metadata=result.metadata,
166
195
  )
167
- return self._wrap_execute_result(
168
- sql_statement, cast("Union[DMLResultDict, ScriptResultDict]", result), **kwargs
169
- )
196
+
197
+ return result
170
198
 
171
199
  def execute_many(
172
200
  self,
@@ -177,37 +205,19 @@ class SyncDriverAdapterProtocol(CommonDriverAttributesMixin[ConnectionT, RowT],
177
205
  _config: "Optional[SQLConfig]" = None,
178
206
  **kwargs: Any,
179
207
  ) -> "SQLResult[RowT]":
180
- # Separate parameters from filters
181
- param_sequences = []
182
- filters = []
183
- for param in parameters:
184
- if isinstance(param, StatementFilter):
185
- filters.append(param)
186
- else:
187
- param_sequences.append(param)
208
+ _filters, param_sequence = process_execute_many_parameters(parameters)
209
+
210
+ # For execute_many, disable transformations to prevent literal extraction
211
+ # since the SQL already has placeholders for bulk operations
212
+ many_config = _config or self.config
213
+ if many_config.enable_transformations:
214
+ from dataclasses import replace
215
+
216
+ many_config = replace(many_config, enable_transformations=False)
188
217
 
189
- # Use first parameter as the sequence for execute_many
190
- param_sequence = param_sequences[0] if param_sequences else None
191
- # Convert tuple to list if needed
192
- if isinstance(param_sequence, tuple):
193
- param_sequence = list(param_sequence)
194
- # Ensure param_sequence is a list or None
195
- if param_sequence is not None and not isinstance(param_sequence, list):
196
- param_sequence = list(param_sequence) if hasattr(param_sequence, "__iter__") else None
197
- sql_statement = self._build_statement(statement, _config=_config or self.config, **kwargs).as_many(
198
- param_sequence
199
- )
218
+ sql_statement = self._build_statement(statement, _config=many_config, **kwargs).as_many(param_sequence)
200
219
 
201
- result = self._execute_statement(
202
- statement=sql_statement,
203
- connection=self._connection(_connection),
204
- parameters=param_sequence,
205
- is_many=True,
206
- **kwargs,
207
- )
208
- return self._wrap_execute_result(
209
- sql_statement, cast("Union[DMLResultDict, ScriptResultDict]", result), **kwargs
210
- )
220
+ return self._execute_statement(statement=sql_statement, connection=self._connection(_connection), **kwargs)
211
221
 
212
222
  def execute_script(
213
223
  self,
@@ -218,44 +228,10 @@ class SyncDriverAdapterProtocol(CommonDriverAttributesMixin[ConnectionT, RowT],
218
228
  _config: "Optional[SQLConfig]" = None,
219
229
  **kwargs: Any,
220
230
  ) -> "SQLResult[RowT]":
221
- # Separate parameters from filters
222
- param_values = []
223
- filters = []
224
- for param in parameters:
225
- if isinstance(param, StatementFilter):
226
- filters.append(param)
227
- else:
228
- param_values.append(param)
229
-
230
- # Use first parameter as the primary parameter value, or None if no parameters
231
- primary_params = param_values[0] if param_values else None
232
-
233
231
  script_config = _config or self.config
234
232
  if script_config.enable_validation:
235
- script_config = SQLConfig(
236
- enable_parsing=script_config.enable_parsing,
237
- enable_validation=False,
238
- enable_transformations=script_config.enable_transformations,
239
- enable_analysis=script_config.enable_analysis,
240
- strict_mode=False,
241
- cache_parsed_expression=script_config.cache_parsed_expression,
242
- parameter_converter=script_config.parameter_converter,
243
- parameter_validator=script_config.parameter_validator,
244
- analysis_cache_size=script_config.analysis_cache_size,
245
- allowed_parameter_styles=script_config.allowed_parameter_styles,
246
- target_parameter_style=script_config.target_parameter_style,
247
- allow_mixed_parameter_styles=script_config.allow_mixed_parameter_styles,
248
- )
233
+ script_config = replace(script_config, enable_validation=False, strict_mode=False)
249
234
 
250
- sql_statement = SQL(statement, primary_params, *filters, _dialect=self.dialect, _config=script_config, **kwargs)
235
+ sql_statement = self._build_statement(statement, *parameters, _config=script_config, **kwargs)
251
236
  sql_statement = sql_statement.as_script()
252
- script_output = self._execute_statement(
253
- statement=sql_statement, connection=self._connection(_connection), is_script=True, **kwargs
254
- )
255
- if isinstance(script_output, str):
256
- result = SQLResult[RowT](statement=sql_statement, data=[], operation_type="SCRIPT")
257
- result.total_statements = 1
258
- result.successful_statements = 1
259
- return result
260
- # Wrap the ScriptResultDict using the driver's wrapper
261
- return self._wrap_execute_result(sql_statement, cast("ScriptResultDict", script_output), **kwargs)
237
+ return self._execute_statement(statement=sql_statement, connection=self._connection(_connection), **kwargs)
@@ -0,0 +1,207 @@
1
+ """Consolidated connection management utilities for database drivers.
2
+
3
+ This module provides centralized connection handling to avoid duplication
4
+ across database adapter implementations.
5
+ """
6
+
7
+ from contextlib import asynccontextmanager, contextmanager
8
+ from typing import TYPE_CHECKING, Any, Optional, TypeVar, cast
9
+
10
+ if TYPE_CHECKING:
11
+ from collections.abc import AsyncIterator, Iterator
12
+
13
+ from sqlspec.utils.type_guards import is_async_transaction_capable, is_sync_transaction_capable
14
+
15
+ __all__ = (
16
+ "get_connection_info",
17
+ "managed_connection_async",
18
+ "managed_connection_sync",
19
+ "managed_transaction_async",
20
+ "managed_transaction_sync",
21
+ "validate_pool_config",
22
+ )
23
+
24
+
25
+ ConnectionT = TypeVar("ConnectionT")
26
+ PoolT = TypeVar("PoolT")
27
+
28
+
29
+ @contextmanager
30
+ def managed_connection_sync(config: Any, provided_connection: Optional[ConnectionT] = None) -> "Iterator[ConnectionT]":
31
+ """Context manager for database connections.
32
+
33
+ Args:
34
+ config: Database configuration with provide_connection method
35
+ provided_connection: Optional existing connection to use
36
+
37
+ Yields:
38
+ Database connection
39
+ """
40
+ if provided_connection is not None:
41
+ yield provided_connection
42
+ return
43
+
44
+ # Get connection from config
45
+ with config.provide_connection() as connection:
46
+ yield connection
47
+
48
+
49
+ @contextmanager
50
+ def managed_transaction_sync(connection: ConnectionT, auto_commit: bool = True) -> "Iterator[ConnectionT]":
51
+ """Context manager for database transactions.
52
+
53
+ Args:
54
+ connection: Database connection
55
+ auto_commit: Whether to auto-commit on success
56
+
57
+ Yields:
58
+ Database connection
59
+ """
60
+ # Check if connection already has autocommit enabled
61
+ has_autocommit = getattr(connection, "autocommit", False)
62
+
63
+ if not auto_commit or not is_sync_transaction_capable(connection) or has_autocommit:
64
+ yield connection
65
+ return
66
+
67
+ try:
68
+ yield cast("ConnectionT", connection)
69
+ cast("Any", connection).commit()
70
+ except Exception:
71
+ # Some databases (like DuckDB) throw an error if rollback is called
72
+ # when no transaction is active. Catch and ignore these specific errors.
73
+ try:
74
+ cast("Any", connection).rollback()
75
+ except Exception as rollback_error:
76
+ # Check if this is a "no transaction active" type error
77
+ error_msg = str(rollback_error).lower()
78
+ if "no transaction" in error_msg or "transaction context error" in error_msg:
79
+ # Ignore rollback errors when no transaction is active
80
+ pass
81
+ else:
82
+ # Re-raise other rollback errors
83
+ raise
84
+ raise
85
+
86
+
87
+ @asynccontextmanager
88
+ async def managed_connection_async(
89
+ config: Any, provided_connection: Optional[ConnectionT] = None
90
+ ) -> "AsyncIterator[ConnectionT]":
91
+ """Async context manager for database connections.
92
+
93
+ Args:
94
+ config: Database configuration with provide_connection method
95
+ provided_connection: Optional existing connection to use
96
+
97
+ Yields:
98
+ Database connection
99
+ """
100
+ if provided_connection is not None:
101
+ yield provided_connection
102
+ return
103
+
104
+ # Get connection from config
105
+ async with config.provide_connection() as connection:
106
+ yield connection
107
+
108
+
109
+ @asynccontextmanager
110
+ async def managed_transaction_async(connection: ConnectionT, auto_commit: bool = True) -> "AsyncIterator[ConnectionT]":
111
+ """Async context manager for database transactions.
112
+
113
+ Args:
114
+ connection: Database connection
115
+ auto_commit: Whether to auto-commit on success
116
+
117
+ Yields:
118
+ Database connection
119
+ """
120
+ # Check if connection already has autocommit enabled
121
+ has_autocommit = getattr(connection, "autocommit", False)
122
+
123
+ if not auto_commit or not is_async_transaction_capable(connection) or has_autocommit:
124
+ yield connection
125
+ return
126
+
127
+ try:
128
+ yield cast("ConnectionT", connection)
129
+ await cast("Any", connection).commit()
130
+ except Exception:
131
+ # Some databases (like DuckDB) throw an error if rollback is called
132
+ # when no transaction is active. Catch and ignore these specific errors.
133
+ try:
134
+ await cast("Any", connection).rollback()
135
+ except Exception as rollback_error:
136
+ # Check if this is a "no transaction active" type error
137
+ error_msg = str(rollback_error).lower()
138
+ if "no transaction" in error_msg or "transaction context error" in error_msg:
139
+ # Ignore rollback errors when no transaction is active
140
+ pass
141
+ else:
142
+ # Re-raise other rollback errors
143
+ raise
144
+ raise
145
+
146
+
147
+ def get_connection_info(connection: Any) -> dict[str, Any]:
148
+ """Extract connection information for logging/debugging.
149
+
150
+ Args:
151
+ connection: Database connection object
152
+
153
+ Returns:
154
+ Dictionary of connection information
155
+ """
156
+ info = {"type": type(connection).__name__, "module": type(connection).__module__}
157
+
158
+ # Try to get database name
159
+ for attr in ("database", "dbname", "db", "catalog"):
160
+ value = getattr(connection, attr, None)
161
+ if value is not None:
162
+ info["database"] = value
163
+ break
164
+
165
+ # Try to get host information
166
+ for attr in ("host", "hostname", "server"):
167
+ value = getattr(connection, attr, None)
168
+ if value is not None:
169
+ info["host"] = value
170
+ break
171
+
172
+ return info
173
+
174
+
175
+ def validate_pool_config(
176
+ min_size: int, max_size: int, max_idle_time: Optional[int] = None, max_lifetime: Optional[int] = None
177
+ ) -> None:
178
+ """Validate connection pool configuration.
179
+
180
+ Args:
181
+ min_size: Minimum pool size
182
+ max_size: Maximum pool size
183
+ max_idle_time: Maximum idle time in seconds
184
+ max_lifetime: Maximum connection lifetime in seconds
185
+
186
+ Raises:
187
+ ValueError: If configuration is invalid
188
+ """
189
+ if min_size < 0:
190
+ msg = f"min_size must be >= 0, got {min_size}"
191
+ raise ValueError(msg)
192
+
193
+ if max_size < 1:
194
+ msg = f"max_size must be >= 1, got {max_size}"
195
+ raise ValueError(msg)
196
+
197
+ if min_size > max_size:
198
+ msg = f"min_size ({min_size}) cannot be greater than max_size ({max_size})"
199
+ raise ValueError(msg)
200
+
201
+ if max_idle_time is not None and max_idle_time < 0:
202
+ msg = f"max_idle_time must be >= 0, got {max_idle_time}"
203
+ raise ValueError(msg)
204
+
205
+ if max_lifetime is not None and max_lifetime < 0:
206
+ msg = f"max_lifetime must be >= 0, got {max_lifetime}"
207
+ raise ValueError(msg)