sqlspec 0.12.2__py3-none-any.whl → 0.13.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of sqlspec might be problematic. Click here for more details.

Files changed (113) hide show
  1. sqlspec/_sql.py +21 -180
  2. sqlspec/adapters/adbc/config.py +10 -12
  3. sqlspec/adapters/adbc/driver.py +120 -118
  4. sqlspec/adapters/aiosqlite/config.py +3 -3
  5. sqlspec/adapters/aiosqlite/driver.py +100 -130
  6. sqlspec/adapters/asyncmy/config.py +3 -4
  7. sqlspec/adapters/asyncmy/driver.py +123 -135
  8. sqlspec/adapters/asyncpg/config.py +3 -7
  9. sqlspec/adapters/asyncpg/driver.py +98 -140
  10. sqlspec/adapters/bigquery/config.py +4 -5
  11. sqlspec/adapters/bigquery/driver.py +125 -167
  12. sqlspec/adapters/duckdb/config.py +3 -6
  13. sqlspec/adapters/duckdb/driver.py +114 -111
  14. sqlspec/adapters/oracledb/config.py +6 -5
  15. sqlspec/adapters/oracledb/driver.py +242 -259
  16. sqlspec/adapters/psqlpy/config.py +3 -7
  17. sqlspec/adapters/psqlpy/driver.py +118 -93
  18. sqlspec/adapters/psycopg/config.py +18 -31
  19. sqlspec/adapters/psycopg/driver.py +283 -236
  20. sqlspec/adapters/sqlite/config.py +3 -3
  21. sqlspec/adapters/sqlite/driver.py +103 -97
  22. sqlspec/config.py +0 -4
  23. sqlspec/driver/_async.py +89 -98
  24. sqlspec/driver/_common.py +52 -17
  25. sqlspec/driver/_sync.py +81 -105
  26. sqlspec/driver/connection.py +207 -0
  27. sqlspec/driver/mixins/_csv_writer.py +91 -0
  28. sqlspec/driver/mixins/_pipeline.py +38 -49
  29. sqlspec/driver/mixins/_result_utils.py +27 -9
  30. sqlspec/driver/mixins/_storage.py +67 -181
  31. sqlspec/driver/mixins/_type_coercion.py +3 -4
  32. sqlspec/driver/parameters.py +138 -0
  33. sqlspec/exceptions.py +10 -2
  34. sqlspec/extensions/aiosql/adapter.py +0 -10
  35. sqlspec/extensions/litestar/handlers.py +0 -1
  36. sqlspec/extensions/litestar/plugin.py +0 -3
  37. sqlspec/extensions/litestar/providers.py +0 -14
  38. sqlspec/loader.py +25 -90
  39. sqlspec/protocols.py +542 -0
  40. sqlspec/service/__init__.py +3 -2
  41. sqlspec/service/_util.py +147 -0
  42. sqlspec/service/base.py +1116 -9
  43. sqlspec/statement/builder/__init__.py +42 -32
  44. sqlspec/statement/builder/_ddl_utils.py +0 -10
  45. sqlspec/statement/builder/_parsing_utils.py +10 -4
  46. sqlspec/statement/builder/base.py +67 -22
  47. sqlspec/statement/builder/column.py +283 -0
  48. sqlspec/statement/builder/ddl.py +91 -67
  49. sqlspec/statement/builder/delete.py +23 -7
  50. sqlspec/statement/builder/insert.py +29 -15
  51. sqlspec/statement/builder/merge.py +4 -4
  52. sqlspec/statement/builder/mixins/_aggregate_functions.py +113 -14
  53. sqlspec/statement/builder/mixins/_common_table_expr.py +0 -1
  54. sqlspec/statement/builder/mixins/_delete_from.py +1 -1
  55. sqlspec/statement/builder/mixins/_from.py +10 -8
  56. sqlspec/statement/builder/mixins/_group_by.py +0 -1
  57. sqlspec/statement/builder/mixins/_insert_from_select.py +0 -1
  58. sqlspec/statement/builder/mixins/_insert_values.py +0 -2
  59. sqlspec/statement/builder/mixins/_join.py +20 -13
  60. sqlspec/statement/builder/mixins/_limit_offset.py +3 -3
  61. sqlspec/statement/builder/mixins/_merge_clauses.py +3 -4
  62. sqlspec/statement/builder/mixins/_order_by.py +2 -2
  63. sqlspec/statement/builder/mixins/_pivot.py +4 -7
  64. sqlspec/statement/builder/mixins/_select_columns.py +6 -5
  65. sqlspec/statement/builder/mixins/_unpivot.py +6 -9
  66. sqlspec/statement/builder/mixins/_update_from.py +2 -1
  67. sqlspec/statement/builder/mixins/_update_set.py +11 -8
  68. sqlspec/statement/builder/mixins/_where.py +61 -34
  69. sqlspec/statement/builder/select.py +32 -17
  70. sqlspec/statement/builder/update.py +25 -11
  71. sqlspec/statement/filters.py +39 -14
  72. sqlspec/statement/parameter_manager.py +220 -0
  73. sqlspec/statement/parameters.py +210 -79
  74. sqlspec/statement/pipelines/__init__.py +166 -23
  75. sqlspec/statement/pipelines/analyzers/_analyzer.py +21 -20
  76. sqlspec/statement/pipelines/context.py +35 -39
  77. sqlspec/statement/pipelines/transformers/__init__.py +2 -3
  78. sqlspec/statement/pipelines/transformers/_expression_simplifier.py +19 -187
  79. sqlspec/statement/pipelines/transformers/_literal_parameterizer.py +628 -58
  80. sqlspec/statement/pipelines/transformers/_remove_comments_and_hints.py +76 -0
  81. sqlspec/statement/pipelines/validators/_dml_safety.py +33 -18
  82. sqlspec/statement/pipelines/validators/_parameter_style.py +87 -14
  83. sqlspec/statement/pipelines/validators/_performance.py +38 -23
  84. sqlspec/statement/pipelines/validators/_security.py +39 -62
  85. sqlspec/statement/result.py +37 -129
  86. sqlspec/statement/splitter.py +0 -12
  87. sqlspec/statement/sql.py +863 -391
  88. sqlspec/statement/sql_compiler.py +140 -0
  89. sqlspec/storage/__init__.py +10 -2
  90. sqlspec/storage/backends/fsspec.py +53 -8
  91. sqlspec/storage/backends/obstore.py +15 -19
  92. sqlspec/storage/capabilities.py +101 -0
  93. sqlspec/storage/registry.py +56 -83
  94. sqlspec/typing.py +6 -434
  95. sqlspec/utils/cached_property.py +25 -0
  96. sqlspec/utils/correlation.py +0 -2
  97. sqlspec/utils/logging.py +0 -6
  98. sqlspec/utils/sync_tools.py +0 -4
  99. sqlspec/utils/text.py +0 -5
  100. sqlspec/utils/type_guards.py +892 -0
  101. {sqlspec-0.12.2.dist-info → sqlspec-0.13.0.dist-info}/METADATA +1 -1
  102. sqlspec-0.13.0.dist-info/RECORD +150 -0
  103. sqlspec/statement/builder/protocols.py +0 -20
  104. sqlspec/statement/pipelines/base.py +0 -315
  105. sqlspec/statement/pipelines/result_types.py +0 -41
  106. sqlspec/statement/pipelines/transformers/_remove_comments.py +0 -66
  107. sqlspec/statement/pipelines/transformers/_remove_hints.py +0 -81
  108. sqlspec/statement/pipelines/validators/base.py +0 -67
  109. sqlspec/storage/protocol.py +0 -173
  110. sqlspec-0.12.2.dist-info/RECORD +0 -145
  111. {sqlspec-0.12.2.dist-info → sqlspec-0.13.0.dist-info}/WHEEL +0 -0
  112. {sqlspec-0.12.2.dist-info → sqlspec-0.13.0.dist-info}/licenses/LICENSE +0 -0
  113. {sqlspec-0.12.2.dist-info → sqlspec-0.13.0.dist-info}/licenses/NOTICE +0 -0
@@ -3,7 +3,6 @@
3
3
  import logging
4
4
  import sqlite3
5
5
  from contextlib import contextmanager
6
- from dataclasses import replace
7
6
  from typing import TYPE_CHECKING, Any, ClassVar, Optional, Union
8
7
 
9
8
  from sqlspec.adapters.sqlite.driver import SqliteConnection, SqliteDriver
@@ -88,7 +87,6 @@ class SqliteConfig(NoPoolSyncConfig[SqliteConnection, SqliteDriver]):
88
87
  uri: Whether to interpret database as URI
89
88
  **kwargs: Additional parameters (stored in extras)
90
89
  """
91
- # Validate required parameters
92
90
  if database is None:
93
91
  msg = "database parameter cannot be None"
94
92
  raise TypeError(msg)
@@ -164,11 +162,13 @@ class SqliteConfig(NoPoolSyncConfig[SqliteConnection, SqliteDriver]):
164
162
  """
165
163
  with self.provide_connection(*args, **kwargs) as connection:
166
164
  statement_config = self.statement_config
165
+ # Inject parameter style info if not already set
167
166
  if statement_config.allowed_parameter_styles is None:
167
+ from dataclasses import replace
168
+
168
169
  statement_config = replace(
169
170
  statement_config,
170
171
  allowed_parameter_styles=self.supported_parameter_styles,
171
172
  target_parameter_style=self.preferred_parameter_style,
172
173
  )
173
-
174
174
  yield self.driver_type(connection=connection, config=statement_config)
@@ -4,11 +4,12 @@ import sqlite3
4
4
  from collections.abc import Iterator
5
5
  from contextlib import contextmanager
6
6
  from pathlib import Path
7
- from typing import TYPE_CHECKING, Any, Optional, Union, cast
7
+ from typing import TYPE_CHECKING, Any, Optional, cast
8
8
 
9
9
  from typing_extensions import TypeAlias
10
10
 
11
11
  from sqlspec.driver import SyncDriverAdapterProtocol
12
+ from sqlspec.driver.connection import managed_transaction_sync
12
13
  from sqlspec.driver.mixins import (
13
14
  SQLTranslatorMixin,
14
15
  SyncPipelinedExecutionMixin,
@@ -16,10 +17,11 @@ from sqlspec.driver.mixins import (
16
17
  ToSchemaMixin,
17
18
  TypeCoercionMixin,
18
19
  )
19
- from sqlspec.statement.parameters import ParameterStyle
20
- from sqlspec.statement.result import DMLResultDict, ScriptResultDict, SelectResultDict, SQLResult
20
+ from sqlspec.driver.parameters import normalize_parameter_sequence
21
+ from sqlspec.statement.parameters import ParameterStyle, ParameterValidator
22
+ from sqlspec.statement.result import SQLResult
21
23
  from sqlspec.statement.sql import SQL, SQLConfig
22
- from sqlspec.typing import DictRow, ModelDTOT, RowT, is_dict_with_field
24
+ from sqlspec.typing import DictRow, RowT
23
25
  from sqlspec.utils.logging import get_logger
24
26
  from sqlspec.utils.serializers import to_json
25
27
 
@@ -102,19 +104,22 @@ class SqliteDriver(
102
104
 
103
105
  def _execute_statement(
104
106
  self, statement: SQL, connection: Optional[SqliteConnection] = None, **kwargs: Any
105
- ) -> Union[SelectResultDict, DMLResultDict, ScriptResultDict]:
107
+ ) -> SQLResult[RowT]:
106
108
  if statement.is_script:
107
109
  sql, _ = statement.compile(placeholder_style=ParameterStyle.STATIC)
108
- return self._execute_script(sql, connection=connection, **kwargs)
110
+ return self._execute_script(sql, connection=connection, statement=statement, **kwargs)
111
+
112
+ detected_styles = set()
113
+ sql_str = statement.to_sql(placeholder_style=None) # Get raw SQL
114
+ validator = self.config.parameter_validator if self.config else ParameterValidator()
115
+ param_infos = validator.extract_parameters(sql_str)
116
+ if param_infos:
117
+ detected_styles = {p.style for p in param_infos}
109
118
 
110
- # Determine if we need to convert parameter style
111
- detected_styles = {p.style for p in statement.parameter_info}
112
119
  target_style = self.default_parameter_style
113
120
 
114
- # Check if any detected style is not supported
115
121
  unsupported_styles = detected_styles - set(self.supported_parameter_styles)
116
122
  if unsupported_styles:
117
- # Convert to default style if we have unsupported styles
118
123
  target_style = self.default_parameter_style
119
124
  elif len(detected_styles) > 1:
120
125
  # Mixed styles detected - use default style for consistency
@@ -129,11 +134,10 @@ class SqliteDriver(
129
134
 
130
135
  if statement.is_many:
131
136
  sql, params = statement.compile(placeholder_style=target_style)
132
- return self._execute_many(sql, params, connection=connection, **kwargs)
137
+ return self._execute_many(sql, params, connection=connection, statement=statement, **kwargs)
133
138
 
134
139
  sql, params = statement.compile(placeholder_style=target_style)
135
140
 
136
- # Process parameters through type coercion
137
141
  params = self._process_parameters(params)
138
142
 
139
143
  # SQLite expects tuples for positional parameters
@@ -144,58 +148,105 @@ class SqliteDriver(
144
148
 
145
149
  def _execute(
146
150
  self, sql: str, parameters: Any, statement: SQL, connection: Optional[SqliteConnection] = None, **kwargs: Any
147
- ) -> Union[SelectResultDict, DMLResultDict]:
151
+ ) -> SQLResult[RowT]:
148
152
  """Execute a single statement with parameters."""
149
- conn = self._connection(connection)
150
- with self._get_cursor(conn) as cursor:
151
- # SQLite expects tuple or dict parameters
152
- if parameters is not None and not isinstance(parameters, (tuple, list, dict)):
153
- # Convert scalar to tuple
154
- parameters = (parameters,)
155
- cursor.execute(sql, parameters or ())
153
+ # Use provided connection or driver's default connection
154
+ conn = connection if connection is not None else self._connection(None)
155
+ with managed_transaction_sync(conn, auto_commit=True) as txn_conn, self._get_cursor(txn_conn) as cursor:
156
+ # Normalize parameters using consolidated utility
157
+ normalized_params_list = normalize_parameter_sequence(parameters)
158
+ params_for_execute: Any
159
+ if normalized_params_list and len(normalized_params_list) == 1:
160
+ # Single parameter should be tuple for SQLite
161
+ if not isinstance(normalized_params_list[0], (tuple, list, dict)):
162
+ params_for_execute = (normalized_params_list[0],)
163
+ else:
164
+ params_for_execute = normalized_params_list[0]
165
+ else:
166
+ # Multiple parameters
167
+ params_for_execute = tuple(normalized_params_list) if normalized_params_list else ()
168
+
169
+ cursor.execute(sql, params_for_execute)
156
170
  if self.returns_rows(statement.expression):
157
171
  fetched_data: list[sqlite3.Row] = cursor.fetchall()
158
- return {
159
- "data": fetched_data,
160
- "column_names": [col[0] for col in cursor.description or []],
161
- "rows_affected": len(fetched_data),
162
- }
163
- return {"rows_affected": cursor.rowcount, "status_message": "OK"}
172
+ return SQLResult(
173
+ statement=statement,
174
+ data=cast("list[RowT]", fetched_data),
175
+ column_names=[col[0] for col in cursor.description or []],
176
+ rows_affected=len(fetched_data),
177
+ operation_type="SELECT",
178
+ )
179
+ operation_type = self._determine_operation_type(statement)
180
+
181
+ return SQLResult(
182
+ statement=statement,
183
+ data=[],
184
+ rows_affected=cursor.rowcount,
185
+ operation_type=operation_type,
186
+ metadata={"status_message": "OK"},
187
+ )
164
188
 
165
189
  def _execute_many(
166
- self, sql: str, param_list: Any, connection: Optional[SqliteConnection] = None, **kwargs: Any
167
- ) -> DMLResultDict:
190
+ self,
191
+ sql: str,
192
+ param_list: Any,
193
+ connection: Optional[SqliteConnection] = None,
194
+ statement: Optional[SQL] = None,
195
+ **kwargs: Any,
196
+ ) -> SQLResult[RowT]:
168
197
  """Execute a statement many times with a list of parameter tuples."""
169
- conn = self._connection(connection)
170
- if param_list:
171
- param_list = self._process_parameters(param_list)
172
-
173
- # Convert parameter list to proper format for executemany
174
- formatted_params: list[tuple[Any, ...]] = []
175
- if param_list and isinstance(param_list, list):
176
- for param_set in cast("list[Union[list, tuple]]", param_list):
177
- if isinstance(param_set, (list, tuple)):
178
- formatted_params.append(tuple(param_set))
179
- elif param_set is None:
180
- formatted_params.append(())
181
- else:
182
- formatted_params.append((param_set,))
183
-
184
- with self._get_cursor(conn) as cursor:
185
- cursor.executemany(sql, formatted_params)
186
- return {"rows_affected": cursor.rowcount, "status_message": "OK"}
198
+ # Use provided connection or driver's default connection
199
+ conn = connection if connection is not None else self._connection(None)
200
+ with managed_transaction_sync(conn, auto_commit=True) as txn_conn:
201
+ # Normalize parameter list using consolidated utility
202
+ normalized_param_list = normalize_parameter_sequence(param_list)
203
+ formatted_params: list[tuple[Any, ...]] = []
204
+ if normalized_param_list:
205
+ for param_set in normalized_param_list:
206
+ if isinstance(param_set, (list, tuple)):
207
+ formatted_params.append(tuple(param_set))
208
+ elif param_set is None:
209
+ formatted_params.append(())
210
+ else:
211
+ formatted_params.append((param_set,))
212
+
213
+ with self._get_cursor(txn_conn) as cursor:
214
+ cursor.executemany(sql, formatted_params)
215
+
216
+ if statement is None:
217
+ statement = SQL(sql, _dialect=self.dialect)
218
+
219
+ return SQLResult(
220
+ statement=statement,
221
+ data=[],
222
+ rows_affected=cursor.rowcount,
223
+ operation_type="EXECUTE",
224
+ metadata={"status_message": "OK"},
225
+ )
187
226
 
188
227
  def _execute_script(
189
- self, script: str, connection: Optional[SqliteConnection] = None, **kwargs: Any
190
- ) -> ScriptResultDict:
228
+ self, script: str, connection: Optional[SqliteConnection] = None, statement: Optional[SQL] = None, **kwargs: Any
229
+ ) -> SQLResult[RowT]:
191
230
  """Execute a script on the SQLite connection."""
192
- conn = self._connection(connection)
231
+ # Use provided connection or driver's default connection
232
+ conn = connection if connection is not None else self._connection(None)
193
233
  with self._get_cursor(conn) as cursor:
194
234
  cursor.executescript(script)
195
- # executescript doesn't auto-commit in some cases
235
+ # executescript doesn't auto-commit in some cases - force commit
196
236
  conn.commit()
197
- result: ScriptResultDict = {"statements_executed": -1, "status_message": "SCRIPT EXECUTED"}
198
- return result
237
+
238
+ if statement is None:
239
+ statement = SQL(script, _dialect=self.dialect).as_script()
240
+
241
+ return SQLResult(
242
+ statement=statement,
243
+ data=[],
244
+ rows_affected=-1, # Unknown for scripts
245
+ operation_type="SCRIPT",
246
+ total_statements=-1, # SQLite doesn't provide this info
247
+ successful_statements=-1,
248
+ metadata={"status_message": "SCRIPT EXECUTED"},
249
+ )
199
250
 
200
251
  def _ingest_arrow_table(self, table: Any, table_name: str, mode: str = "create", **options: Any) -> int:
201
252
  """SQLite-specific Arrow table ingestion using CSV conversion.
@@ -208,12 +259,10 @@ class SqliteDriver(
208
259
 
209
260
  import pyarrow.csv as pa_csv
210
261
 
211
- # Convert Arrow table to CSV in memory
212
262
  csv_buffer = io.BytesIO()
213
263
  pa_csv.write_csv(table, csv_buffer)
214
264
  csv_content = csv_buffer.getvalue()
215
265
 
216
- # Create a temporary file path
217
266
  temp_filename = f"sqlspec_temp_{table_name}_{id(self)}.csv"
218
267
  temp_path = Path(tempfile.gettempdir()) / temp_filename
219
268
 
@@ -258,46 +307,3 @@ class SqliteDriver(
258
307
  data_iter = list(reader) # Read all data into memory
259
308
  cursor.executemany(sql, data_iter)
260
309
  return cursor.rowcount
261
-
262
- def _wrap_select_result(
263
- self, statement: SQL, result: SelectResultDict, schema_type: Optional[type[ModelDTOT]] = None, **kwargs: Any
264
- ) -> Union[SQLResult[ModelDTOT], SQLResult[RowT]]:
265
- rows_as_dicts = [dict(row) for row in result["data"]]
266
- if schema_type:
267
- return SQLResult[ModelDTOT](
268
- statement=statement,
269
- data=list(self.to_schema(data=rows_as_dicts, schema_type=schema_type)),
270
- column_names=result["column_names"],
271
- rows_affected=result["rows_affected"],
272
- operation_type="SELECT",
273
- )
274
-
275
- return SQLResult[RowT](
276
- statement=statement,
277
- data=rows_as_dicts,
278
- column_names=result["column_names"],
279
- rows_affected=result["rows_affected"],
280
- operation_type="SELECT",
281
- )
282
-
283
- def _wrap_execute_result(
284
- self, statement: SQL, result: Union[DMLResultDict, ScriptResultDict], **kwargs: Any
285
- ) -> SQLResult[RowT]:
286
- if is_dict_with_field(result, "statements_executed"):
287
- return SQLResult[RowT](
288
- statement=statement,
289
- data=[],
290
- rows_affected=0,
291
- operation_type="SCRIPT",
292
- metadata={
293
- "status_message": result.get("status_message", ""),
294
- "statements_executed": result.get("statements_executed", -1),
295
- },
296
- )
297
- return SQLResult[RowT](
298
- statement=statement,
299
- data=[],
300
- rows_affected=cast("int", result.get("rows_affected", -1)),
301
- operation_type=statement.expression.key.upper() if statement.expression else "UNKNOWN",
302
- metadata={"status_message": result.get("status_message", "")},
303
- )
sqlspec/config.py CHANGED
@@ -97,7 +97,6 @@ class DatabaseConfigProtocol(ABC, Generic[ConnectionT, PoolT, DriverT]):
97
97
  Returns:
98
98
  The SQL dialect type.
99
99
  """
100
- # Get dialect from driver_class (all drivers must have a dialect attribute)
101
100
  return self.driver_type.dialect
102
101
 
103
102
  @abstractmethod
@@ -154,17 +153,14 @@ class DatabaseConfigProtocol(ABC, Generic[ConnectionT, PoolT, DriverT]):
154
153
  """
155
154
  namespace: dict[str, type[Any]] = {}
156
155
 
157
- # Add the driver and config types
158
156
  if hasattr(self, "driver_type") and self.driver_type:
159
157
  namespace[self.driver_type.__name__] = self.driver_type
160
158
 
161
159
  namespace[self.__class__.__name__] = self.__class__
162
160
 
163
- # Add connection type(s)
164
161
  if hasattr(self, "connection_type") and self.connection_type:
165
162
  connection_type = self.connection_type
166
163
 
167
- # Handle Union types (like AsyncPG's Union[Connection, PoolConnectionProxy])
168
164
  if hasattr(connection_type, "__args__"):
169
165
  # It's a generic type, extract the actual types
170
166
  for arg_type in connection_type.__args__: # type: ignore[attr-defined]
sqlspec/driver/_async.py CHANGED
@@ -1,17 +1,22 @@
1
1
  """Asynchronous driver protocol implementation."""
2
2
 
3
3
  from abc import ABC, abstractmethod
4
- from typing import TYPE_CHECKING, Any, Optional, Union, cast, overload
4
+ from dataclasses import replace
5
+ from typing import TYPE_CHECKING, Any, Optional, Union, overload
5
6
 
6
7
  from sqlspec.driver._common import CommonDriverAttributesMixin
7
- from sqlspec.statement.builder import DeleteBuilder, InsertBuilder, QueryBuilder, SelectBuilder, UpdateBuilder
8
- from sqlspec.statement.filters import StatementFilter
8
+ from sqlspec.driver.parameters import process_execute_many_parameters
9
+ from sqlspec.statement.builder import Delete, Insert, QueryBuilder, Select, Update
9
10
  from sqlspec.statement.result import SQLResult
10
11
  from sqlspec.statement.sql import SQL, SQLConfig, Statement
11
12
  from sqlspec.typing import ConnectionT, DictRow, ModelDTOT, RowT, StatementParameters
13
+ from sqlspec.utils.logging import get_logger
14
+ from sqlspec.utils.type_guards import can_convert_to_schema
12
15
 
13
16
  if TYPE_CHECKING:
14
- from sqlspec.statement.result import DMLResultDict, ScriptResultDict, SelectResultDict
17
+ from sqlspec.statement.filters import StatementFilter
18
+
19
+ logger = get_logger("sqlspec")
15
20
 
16
21
  __all__ = ("AsyncDriverAdapterProtocol",)
17
22
 
@@ -49,42 +54,64 @@ class AsyncDriverAdapterProtocol(CommonDriverAttributesMixin[ConnectionT, RowT],
49
54
 
50
55
  if isinstance(statement, QueryBuilder):
51
56
  return statement.to_statement(config=_config)
52
- # If statement is already a SQL object, return it as-is
57
+ # If statement is already a SQL object, handle additional parameters
53
58
  if isinstance(statement, SQL):
59
+ if parameters or kwargs:
60
+ new_config = _config
61
+ if self.dialect and not new_config.dialect:
62
+ new_config = replace(new_config, dialect=self.dialect)
63
+ # Use raw SQL if available to ensure proper parsing with dialect
64
+ sql_source = statement._raw_sql or statement._statement
65
+ # Preserve filters and state when creating new SQL object
66
+ existing_state = {
67
+ "is_many": statement._is_many,
68
+ "is_script": statement._is_script,
69
+ "original_parameters": statement._original_parameters,
70
+ "filters": statement._filters,
71
+ "positional_params": statement._positional_params,
72
+ "named_params": statement._named_params,
73
+ }
74
+ return SQL(sql_source, *parameters, config=new_config, _existing_state=existing_state, **kwargs)
75
+ # Even without additional parameters, ensure dialect is set
76
+ if self.dialect and (not statement._config.dialect or statement._config.dialect != self.dialect):
77
+ new_config = replace(statement._config, dialect=self.dialect)
78
+ # Use raw SQL if available to ensure proper parsing with dialect
79
+ sql_source = statement._raw_sql or statement._statement
80
+ # Preserve parameters and state when creating new SQL object
81
+ # Use the public parameters property which always has the right value
82
+ existing_state = {
83
+ "is_many": statement._is_many,
84
+ "is_script": statement._is_script,
85
+ "original_parameters": statement._original_parameters,
86
+ "filters": statement._filters,
87
+ "positional_params": statement._positional_params,
88
+ "named_params": statement._named_params,
89
+ }
90
+ if statement.parameters:
91
+ return SQL(
92
+ sql_source, parameters=statement.parameters, config=new_config, _existing_state=existing_state
93
+ )
94
+ return SQL(sql_source, config=new_config, _existing_state=existing_state)
54
95
  return statement
55
- return SQL(statement, *parameters, _dialect=self.dialect, _config=_config, **kwargs)
96
+ new_config = _config
97
+ if self.dialect and not new_config.dialect:
98
+ new_config = replace(new_config, dialect=self.dialect)
99
+ return SQL(statement, *parameters, config=new_config, **kwargs)
56
100
 
57
101
  @abstractmethod
58
102
  async def _execute_statement(
59
103
  self, statement: "SQL", connection: "Optional[ConnectionT]" = None, **kwargs: Any
60
- ) -> "Union[SelectResultDict, DMLResultDict, ScriptResultDict]":
104
+ ) -> "SQLResult[RowT]":
61
105
  """Actual execution implementation by concrete drivers, using the raw connection.
62
106
 
63
- Returns one of the standardized result dictionaries based on the statement type.
107
+ Returns SQLResult directly based on the statement type.
64
108
  """
65
109
  raise NotImplementedError
66
110
 
67
- @abstractmethod
68
- async def _wrap_select_result(
69
- self,
70
- statement: "SQL",
71
- result: "SelectResultDict",
72
- schema_type: "Optional[type[ModelDTOT]]" = None,
73
- **kwargs: Any,
74
- ) -> "Union[SQLResult[ModelDTOT], SQLResult[RowT]]":
75
- raise NotImplementedError
76
-
77
- @abstractmethod
78
- async def _wrap_execute_result(
79
- self, statement: "SQL", result: "Union[DMLResultDict, ScriptResultDict]", **kwargs: Any
80
- ) -> "SQLResult[RowT]":
81
- raise NotImplementedError
82
-
83
- # Type-safe overloads based on the refactor plan pattern
84
111
  @overload
85
112
  async def execute(
86
113
  self,
87
- statement: "SelectBuilder",
114
+ statement: "Select",
88
115
  /,
89
116
  *parameters: "Union[StatementParameters, StatementFilter]",
90
117
  schema_type: "type[ModelDTOT]",
@@ -96,7 +123,7 @@ class AsyncDriverAdapterProtocol(CommonDriverAttributesMixin[ConnectionT, RowT],
96
123
  @overload
97
124
  async def execute(
98
125
  self,
99
- statement: "SelectBuilder",
126
+ statement: "Select",
100
127
  /,
101
128
  *parameters: "Union[StatementParameters, StatementFilter]",
102
129
  schema_type: None = None,
@@ -108,7 +135,7 @@ class AsyncDriverAdapterProtocol(CommonDriverAttributesMixin[ConnectionT, RowT],
108
135
  @overload
109
136
  async def execute(
110
137
  self,
111
- statement: "Union[InsertBuilder, UpdateBuilder, DeleteBuilder]",
138
+ statement: "Union[Insert, Update, Delete]",
112
139
  /,
113
140
  *parameters: "Union[StatementParameters, StatementFilter]",
114
141
  _connection: "Optional[ConnectionT]" = None,
@@ -155,51 +182,45 @@ class AsyncDriverAdapterProtocol(CommonDriverAttributesMixin[ConnectionT, RowT],
155
182
  statement=sql_statement, connection=self._connection(_connection), **kwargs
156
183
  )
157
184
 
158
- if self.returns_rows(sql_statement.expression):
159
- return await self._wrap_select_result(
160
- sql_statement, cast("SelectResultDict", result), schema_type=schema_type, **kwargs
185
+ # If schema_type is provided and we have data, convert it
186
+ if schema_type and result.data and can_convert_to_schema(self):
187
+ converted_data = list(self.to_schema(data=result.data, schema_type=schema_type))
188
+ return SQLResult[ModelDTOT](
189
+ statement=result.statement,
190
+ data=converted_data,
191
+ column_names=result.column_names,
192
+ rows_affected=result.rows_affected,
193
+ operation_type=result.operation_type,
194
+ last_inserted_id=result.last_inserted_id,
195
+ execution_time=result.execution_time,
196
+ metadata=result.metadata,
161
197
  )
162
- return await self._wrap_execute_result(
163
- sql_statement, cast("Union[DMLResultDict, ScriptResultDict]", result), **kwargs
164
- )
198
+
199
+ return result
165
200
 
166
201
  async def execute_many(
167
202
  self,
168
- statement: "Union[SQL, Statement, QueryBuilder[Any]]", # QueryBuilder for DMLs will likely not return rows.
203
+ statement: "Union[SQL, Statement, QueryBuilder[Any]]",
169
204
  /,
170
205
  *parameters: "Union[StatementParameters, StatementFilter]",
171
206
  _connection: "Optional[ConnectionT]" = None,
172
207
  _config: "Optional[SQLConfig]" = None,
173
208
  **kwargs: Any,
174
209
  ) -> "SQLResult[RowT]":
175
- # Separate parameters from filters
176
- param_sequences = []
177
- filters = []
178
- for param in parameters:
179
- if isinstance(param, StatementFilter):
180
- filters.append(param)
181
- else:
182
- param_sequences.append(param)
183
-
184
- # Use first parameter as the sequence for execute_many
185
- param_sequence = param_sequences[0] if param_sequences else None
186
- # Convert tuple to list if needed
187
- if isinstance(param_sequence, tuple):
188
- param_sequence = list(param_sequence)
189
- # Ensure param_sequence is a list or None
190
- if param_sequence is not None and not isinstance(param_sequence, list):
191
- param_sequence = list(param_sequence) if hasattr(param_sequence, "__iter__") else None
192
- sql_statement = self._build_statement(statement, _config=_config or self.config, **kwargs)
193
- sql_statement = sql_statement.as_many(param_sequence)
194
- result = await self._execute_statement(
195
- statement=sql_statement,
196
- connection=self._connection(_connection),
197
- parameters=param_sequence,
198
- is_many=True,
199
- **kwargs,
200
- )
201
- return await self._wrap_execute_result(
202
- sql_statement, cast("Union[DMLResultDict, ScriptResultDict]", result), **kwargs
210
+ _filters, param_sequence = process_execute_many_parameters(parameters)
211
+
212
+ # For execute_many, disable transformations to prevent literal extraction
213
+ # since the SQL already has placeholders for bulk operations
214
+ many_config = _config or self.config
215
+ if many_config.enable_transformations:
216
+ from dataclasses import replace
217
+
218
+ many_config = replace(many_config, enable_transformations=False)
219
+
220
+ sql_statement = self._build_statement(statement, _config=many_config, **kwargs).as_many(param_sequence)
221
+
222
+ return await self._execute_statement(
223
+ statement=sql_statement, connection=self._connection(_connection), **kwargs
203
224
  )
204
225
 
205
226
  async def execute_script(
@@ -211,42 +232,12 @@ class AsyncDriverAdapterProtocol(CommonDriverAttributesMixin[ConnectionT, RowT],
211
232
  _config: "Optional[SQLConfig]" = None,
212
233
  **kwargs: Any,
213
234
  ) -> "SQLResult[RowT]":
214
- param_values = []
215
- filters = []
216
- for param in parameters:
217
- if isinstance(param, StatementFilter):
218
- filters.append(param)
219
- else:
220
- param_values.append(param)
221
-
222
- # Use first parameter as the primary parameter value, or None if no parameters
223
- primary_params = param_values[0] if param_values else None
224
-
225
235
  script_config = _config or self.config
226
236
  if script_config.enable_validation:
227
- script_config = SQLConfig(
228
- enable_parsing=script_config.enable_parsing,
229
- enable_validation=False,
230
- enable_transformations=script_config.enable_transformations,
231
- enable_analysis=script_config.enable_analysis,
232
- strict_mode=False,
233
- cache_parsed_expression=script_config.cache_parsed_expression,
234
- parameter_converter=script_config.parameter_converter,
235
- parameter_validator=script_config.parameter_validator,
236
- analysis_cache_size=script_config.analysis_cache_size,
237
- allowed_parameter_styles=script_config.allowed_parameter_styles,
238
- target_parameter_style=script_config.target_parameter_style,
239
- allow_mixed_parameter_styles=script_config.allow_mixed_parameter_styles,
240
- )
241
- sql_statement = SQL(statement, primary_params, *filters, _dialect=self.dialect, _config=script_config, **kwargs)
237
+ script_config = replace(script_config, enable_validation=False, strict_mode=False)
238
+
239
+ sql_statement = self._build_statement(statement, *parameters, _config=script_config, **kwargs)
242
240
  sql_statement = sql_statement.as_script()
243
- script_output = await self._execute_statement(
244
- statement=sql_statement, connection=self._connection(_connection), is_script=True, **kwargs
241
+ return await self._execute_statement(
242
+ statement=sql_statement, connection=self._connection(_connection), **kwargs
245
243
  )
246
- if isinstance(script_output, str):
247
- result = SQLResult[RowT](statement=sql_statement, data=[], operation_type="SCRIPT")
248
- result.total_statements = 1
249
- result.successful_statements = 1
250
- return result
251
- # Wrap the ScriptResultDict using the driver's wrapper
252
- return await self._wrap_execute_result(sql_statement, cast("ScriptResultDict", script_output), **kwargs)