sqlspec 0.12.2__py3-none-any.whl → 0.13.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of sqlspec might be problematic. Click here for more details.

Files changed (113) hide show
  1. sqlspec/_sql.py +21 -180
  2. sqlspec/adapters/adbc/config.py +10 -12
  3. sqlspec/adapters/adbc/driver.py +120 -118
  4. sqlspec/adapters/aiosqlite/config.py +16 -3
  5. sqlspec/adapters/aiosqlite/driver.py +100 -130
  6. sqlspec/adapters/asyncmy/config.py +17 -4
  7. sqlspec/adapters/asyncmy/driver.py +123 -135
  8. sqlspec/adapters/asyncpg/config.py +17 -29
  9. sqlspec/adapters/asyncpg/driver.py +98 -140
  10. sqlspec/adapters/bigquery/config.py +4 -5
  11. sqlspec/adapters/bigquery/driver.py +125 -167
  12. sqlspec/adapters/duckdb/config.py +3 -6
  13. sqlspec/adapters/duckdb/driver.py +114 -111
  14. sqlspec/adapters/oracledb/config.py +32 -5
  15. sqlspec/adapters/oracledb/driver.py +242 -259
  16. sqlspec/adapters/psqlpy/config.py +18 -9
  17. sqlspec/adapters/psqlpy/driver.py +118 -93
  18. sqlspec/adapters/psycopg/config.py +44 -31
  19. sqlspec/adapters/psycopg/driver.py +283 -236
  20. sqlspec/adapters/sqlite/config.py +3 -3
  21. sqlspec/adapters/sqlite/driver.py +103 -97
  22. sqlspec/config.py +0 -4
  23. sqlspec/driver/_async.py +89 -98
  24. sqlspec/driver/_common.py +52 -17
  25. sqlspec/driver/_sync.py +81 -105
  26. sqlspec/driver/connection.py +207 -0
  27. sqlspec/driver/mixins/_csv_writer.py +91 -0
  28. sqlspec/driver/mixins/_pipeline.py +38 -49
  29. sqlspec/driver/mixins/_result_utils.py +27 -9
  30. sqlspec/driver/mixins/_storage.py +67 -181
  31. sqlspec/driver/mixins/_type_coercion.py +3 -4
  32. sqlspec/driver/parameters.py +138 -0
  33. sqlspec/exceptions.py +10 -2
  34. sqlspec/extensions/aiosql/adapter.py +0 -10
  35. sqlspec/extensions/litestar/handlers.py +0 -1
  36. sqlspec/extensions/litestar/plugin.py +0 -3
  37. sqlspec/extensions/litestar/providers.py +0 -14
  38. sqlspec/loader.py +25 -90
  39. sqlspec/protocols.py +542 -0
  40. sqlspec/service/__init__.py +3 -2
  41. sqlspec/service/_util.py +147 -0
  42. sqlspec/service/base.py +1116 -9
  43. sqlspec/statement/builder/__init__.py +42 -32
  44. sqlspec/statement/builder/_ddl_utils.py +0 -10
  45. sqlspec/statement/builder/_parsing_utils.py +10 -4
  46. sqlspec/statement/builder/base.py +67 -22
  47. sqlspec/statement/builder/column.py +283 -0
  48. sqlspec/statement/builder/ddl.py +91 -67
  49. sqlspec/statement/builder/delete.py +23 -7
  50. sqlspec/statement/builder/insert.py +29 -15
  51. sqlspec/statement/builder/merge.py +4 -4
  52. sqlspec/statement/builder/mixins/_aggregate_functions.py +113 -14
  53. sqlspec/statement/builder/mixins/_common_table_expr.py +0 -1
  54. sqlspec/statement/builder/mixins/_delete_from.py +1 -1
  55. sqlspec/statement/builder/mixins/_from.py +10 -8
  56. sqlspec/statement/builder/mixins/_group_by.py +0 -1
  57. sqlspec/statement/builder/mixins/_insert_from_select.py +0 -1
  58. sqlspec/statement/builder/mixins/_insert_values.py +0 -2
  59. sqlspec/statement/builder/mixins/_join.py +20 -13
  60. sqlspec/statement/builder/mixins/_limit_offset.py +3 -3
  61. sqlspec/statement/builder/mixins/_merge_clauses.py +3 -4
  62. sqlspec/statement/builder/mixins/_order_by.py +2 -2
  63. sqlspec/statement/builder/mixins/_pivot.py +4 -7
  64. sqlspec/statement/builder/mixins/_select_columns.py +6 -5
  65. sqlspec/statement/builder/mixins/_unpivot.py +6 -9
  66. sqlspec/statement/builder/mixins/_update_from.py +2 -1
  67. sqlspec/statement/builder/mixins/_update_set.py +11 -8
  68. sqlspec/statement/builder/mixins/_where.py +61 -34
  69. sqlspec/statement/builder/select.py +32 -17
  70. sqlspec/statement/builder/update.py +25 -11
  71. sqlspec/statement/filters.py +39 -14
  72. sqlspec/statement/parameter_manager.py +220 -0
  73. sqlspec/statement/parameters.py +210 -79
  74. sqlspec/statement/pipelines/__init__.py +166 -23
  75. sqlspec/statement/pipelines/analyzers/_analyzer.py +21 -20
  76. sqlspec/statement/pipelines/context.py +35 -39
  77. sqlspec/statement/pipelines/transformers/__init__.py +2 -3
  78. sqlspec/statement/pipelines/transformers/_expression_simplifier.py +19 -187
  79. sqlspec/statement/pipelines/transformers/_literal_parameterizer.py +628 -58
  80. sqlspec/statement/pipelines/transformers/_remove_comments_and_hints.py +76 -0
  81. sqlspec/statement/pipelines/validators/_dml_safety.py +33 -18
  82. sqlspec/statement/pipelines/validators/_parameter_style.py +87 -14
  83. sqlspec/statement/pipelines/validators/_performance.py +38 -23
  84. sqlspec/statement/pipelines/validators/_security.py +39 -62
  85. sqlspec/statement/result.py +37 -129
  86. sqlspec/statement/splitter.py +0 -12
  87. sqlspec/statement/sql.py +863 -391
  88. sqlspec/statement/sql_compiler.py +140 -0
  89. sqlspec/storage/__init__.py +10 -2
  90. sqlspec/storage/backends/fsspec.py +53 -8
  91. sqlspec/storage/backends/obstore.py +15 -19
  92. sqlspec/storage/capabilities.py +101 -0
  93. sqlspec/storage/registry.py +56 -83
  94. sqlspec/typing.py +6 -434
  95. sqlspec/utils/cached_property.py +25 -0
  96. sqlspec/utils/correlation.py +0 -2
  97. sqlspec/utils/logging.py +0 -6
  98. sqlspec/utils/sync_tools.py +0 -4
  99. sqlspec/utils/text.py +0 -5
  100. sqlspec/utils/type_guards.py +892 -0
  101. {sqlspec-0.12.2.dist-info → sqlspec-0.13.1.dist-info}/METADATA +1 -1
  102. sqlspec-0.13.1.dist-info/RECORD +150 -0
  103. sqlspec/statement/builder/protocols.py +0 -20
  104. sqlspec/statement/pipelines/base.py +0 -315
  105. sqlspec/statement/pipelines/result_types.py +0 -41
  106. sqlspec/statement/pipelines/transformers/_remove_comments.py +0 -66
  107. sqlspec/statement/pipelines/transformers/_remove_hints.py +0 -81
  108. sqlspec/statement/pipelines/validators/base.py +0 -67
  109. sqlspec/storage/protocol.py +0 -173
  110. sqlspec-0.12.2.dist-info/RECORD +0 -145
  111. {sqlspec-0.12.2.dist-info → sqlspec-0.13.1.dist-info}/WHEEL +0 -0
  112. {sqlspec-0.12.2.dist-info → sqlspec-0.13.1.dist-info}/licenses/LICENSE +0 -0
  113. {sqlspec-0.12.2.dist-info → sqlspec-0.13.1.dist-info}/licenses/NOTICE +0 -0
@@ -2,12 +2,14 @@ import contextlib
2
2
  import logging
3
3
  from collections.abc import Iterator
4
4
  from contextlib import contextmanager
5
+ from dataclasses import replace
5
6
  from decimal import Decimal
6
- from typing import TYPE_CHECKING, Any, ClassVar, Optional, Union
7
+ from typing import TYPE_CHECKING, Any, ClassVar, Optional, cast
7
8
 
8
9
  from adbc_driver_manager.dbapi import Connection, Cursor
9
10
 
10
11
  from sqlspec.driver import SyncDriverAdapterProtocol
12
+ from sqlspec.driver.connection import managed_transaction_sync
11
13
  from sqlspec.driver.mixins import (
12
14
  SQLTranslatorMixin,
13
15
  SyncPipelinedExecutionMixin,
@@ -15,11 +17,12 @@ from sqlspec.driver.mixins import (
15
17
  ToSchemaMixin,
16
18
  TypeCoercionMixin,
17
19
  )
20
+ from sqlspec.driver.parameters import normalize_parameter_sequence
18
21
  from sqlspec.exceptions import wrap_exceptions
19
22
  from sqlspec.statement.parameters import ParameterStyle
20
- from sqlspec.statement.result import ArrowResult, DMLResultDict, ScriptResultDict, SelectResultDict, SQLResult
23
+ from sqlspec.statement.result import ArrowResult, SQLResult
21
24
  from sqlspec.statement.sql import SQL, SQLConfig
22
- from sqlspec.typing import DictRow, ModelDTOT, RowT, is_dict_with_field
25
+ from sqlspec.typing import DictRow, RowT
23
26
  from sqlspec.utils.serializers import to_json
24
27
 
25
28
  if TYPE_CHECKING:
@@ -65,8 +68,15 @@ class AdbcDriver(
65
68
  config: "Optional[SQLConfig]" = None,
66
69
  default_row_type: "type[DictRow]" = DictRow,
67
70
  ) -> None:
71
+ dialect = self._get_dialect(connection)
72
+ if config and not config.dialect:
73
+ config = replace(config, dialect=dialect)
74
+ elif not config:
75
+ # Create config with dialect
76
+ config = SQLConfig(dialect=dialect)
77
+
68
78
  super().__init__(connection=connection, config=config, default_row_type=default_row_type)
69
- self.dialect: DialectType = self._get_dialect(connection)
79
+ self.dialect: DialectType = dialect
70
80
  self.default_parameter_style = self._get_parameter_style_for_dialect(self.dialect)
71
81
  # Override supported parameter styles based on actual dialect capabilities
72
82
  self.supported_parameter_styles = self._get_supported_parameter_styles_for_dialect(self.dialect)
@@ -169,13 +179,13 @@ class AdbcDriver(
169
179
 
170
180
  def _execute_statement(
171
181
  self, statement: SQL, connection: Optional["AdbcConnection"] = None, **kwargs: Any
172
- ) -> Union[SelectResultDict, DMLResultDict, ScriptResultDict]:
182
+ ) -> SQLResult[RowT]:
173
183
  if statement.is_script:
174
184
  sql, _ = statement.compile(placeholder_style=ParameterStyle.STATIC)
175
185
  return self._execute_script(sql, connection=connection, **kwargs)
176
186
 
177
- # Determine if we need to convert parameter style
178
187
  detected_styles = {p.style for p in statement.parameter_info}
188
+
179
189
  target_style = self.default_parameter_style
180
190
  unsupported_styles = detected_styles - set(self.supported_parameter_styles)
181
191
 
@@ -196,69 +206,107 @@ class AdbcDriver(
196
206
 
197
207
  def _execute(
198
208
  self, sql: str, parameters: Any, statement: SQL, connection: Optional["AdbcConnection"] = None, **kwargs: Any
199
- ) -> Union[SelectResultDict, DMLResultDict]:
200
- conn = self._connection(connection)
201
- with self._get_cursor(conn) as cursor:
202
- # ADBC expects parameters as a list for most drivers
203
- if parameters is not None and not isinstance(parameters, (list, tuple)):
204
- cursor_params = [parameters]
209
+ ) -> SQLResult[RowT]:
210
+ # Use provided connection or driver's default connection
211
+ conn = connection if connection is not None else self._connection(None)
212
+
213
+ with managed_transaction_sync(conn, auto_commit=True) as txn_conn:
214
+ normalized_params = normalize_parameter_sequence(parameters)
215
+ if normalized_params is not None and not isinstance(normalized_params, (list, tuple)):
216
+ cursor_params = [normalized_params]
205
217
  else:
206
- cursor_params = parameters # type: ignore[assignment]
207
-
208
- try:
209
- cursor.execute(sql, cursor_params or [])
210
- except Exception as e:
211
- # Rollback transaction on error for PostgreSQL to avoid "current transaction is aborted" errors
212
- if self.dialect == "postgres":
213
- with contextlib.suppress(Exception):
214
- cursor.execute("ROLLBACK")
215
- raise e from e
216
-
217
- if self.returns_rows(statement.expression):
218
- fetched_data = cursor.fetchall()
219
- column_names = [col[0] for col in cursor.description or []]
220
- result: SelectResultDict = {
221
- "data": fetched_data,
222
- "column_names": column_names,
223
- "rows_affected": len(fetched_data),
224
- }
225
- return result
226
-
227
- dml_result: DMLResultDict = {"rows_affected": cursor.rowcount, "status_message": "OK"}
228
- return dml_result
218
+ cursor_params = normalized_params
219
+
220
+ with self._get_cursor(txn_conn) as cursor:
221
+ try:
222
+ cursor.execute(sql, cursor_params or [])
223
+ except Exception as e:
224
+ # Rollback transaction on error for PostgreSQL to avoid "current transaction is aborted" errors
225
+ if self.dialect == "postgres":
226
+ with contextlib.suppress(Exception):
227
+ cursor.execute("ROLLBACK")
228
+ raise e from e
229
+
230
+ if self.returns_rows(statement.expression):
231
+ fetched_data = cursor.fetchall()
232
+ column_names = [col[0] for col in cursor.description or []]
233
+
234
+ if fetched_data and isinstance(fetched_data[0], tuple):
235
+ dict_data: list[dict[Any, Any]] = [dict(zip(column_names, row)) for row in fetched_data]
236
+ else:
237
+ dict_data = fetched_data # type: ignore[assignment]
238
+
239
+ return SQLResult(
240
+ statement=statement,
241
+ data=cast("list[RowT]", dict_data),
242
+ column_names=column_names,
243
+ rows_affected=len(dict_data),
244
+ operation_type="SELECT",
245
+ )
246
+
247
+ operation_type = self._determine_operation_type(statement)
248
+ return SQLResult(
249
+ statement=statement,
250
+ data=cast("list[RowT]", []),
251
+ rows_affected=cursor.rowcount,
252
+ operation_type=operation_type,
253
+ metadata={"status_message": "OK"},
254
+ )
229
255
 
230
256
  def _execute_many(
231
257
  self, sql: str, param_list: Any, connection: Optional["AdbcConnection"] = None, **kwargs: Any
232
- ) -> DMLResultDict:
233
- conn = self._connection(connection)
234
- with self._get_cursor(conn) as cursor:
235
- try:
236
- cursor.executemany(sql, param_list or [])
237
- except Exception as e:
238
- if self.dialect == "postgres":
239
- with contextlib.suppress(Exception):
240
- cursor.execute("ROLLBACK")
241
- # Always re-raise the original exception
242
- raise e from e
243
-
244
- result: DMLResultDict = {"rows_affected": cursor.rowcount, "status_message": "OK"}
245
- return result
258
+ ) -> SQLResult[RowT]:
259
+ # Use provided connection or driver's default connection
260
+ conn = connection if connection is not None else self._connection(None)
261
+
262
+ with managed_transaction_sync(conn, auto_commit=True) as txn_conn:
263
+ # Normalize parameter list using consolidated utility
264
+ normalized_param_list = normalize_parameter_sequence(param_list)
265
+
266
+ with self._get_cursor(txn_conn) as cursor:
267
+ try:
268
+ cursor.executemany(sql, normalized_param_list or [])
269
+ except Exception as e:
270
+ if self.dialect == "postgres":
271
+ with contextlib.suppress(Exception):
272
+ cursor.execute("ROLLBACK")
273
+ # Always re-raise the original exception
274
+ raise e from e
275
+
276
+ return SQLResult(
277
+ statement=SQL(sql, _dialect=self.dialect),
278
+ data=[],
279
+ rows_affected=cursor.rowcount,
280
+ operation_type="EXECUTE",
281
+ metadata={"status_message": "OK"},
282
+ )
246
283
 
247
284
  def _execute_script(
248
285
  self, script: str, connection: Optional["AdbcConnection"] = None, **kwargs: Any
249
- ) -> ScriptResultDict:
250
- conn = self._connection(connection)
251
- # ADBC drivers don't support multiple statements in a single execute
252
- # Use the shared implementation to split the script
253
- statements = self._split_script_statements(script)
254
-
255
- executed_count = 0
256
- with self._get_cursor(conn) as cursor:
257
- for statement in statements:
258
- executed_count += self._execute_single_script_statement(cursor, statement)
259
-
260
- result: ScriptResultDict = {"statements_executed": executed_count, "status_message": "SCRIPT EXECUTED"}
261
- return result
286
+ ) -> SQLResult[RowT]:
287
+ # Use provided connection or driver's default connection
288
+ conn = connection if connection is not None else self._connection(None)
289
+
290
+ with managed_transaction_sync(conn, auto_commit=True) as txn_conn:
291
+ # ADBC drivers don't support multiple statements in a single execute
292
+ statements = self._split_script_statements(script)
293
+
294
+ executed_count = 0
295
+ with self._get_cursor(txn_conn) as cursor:
296
+ for statement in statements:
297
+ if statement.strip():
298
+ self._execute_single_script_statement(cursor, statement)
299
+ executed_count += 1
300
+
301
+ return SQLResult(
302
+ statement=SQL(script, _dialect=self.dialect).as_script(),
303
+ data=[],
304
+ rows_affected=0,
305
+ operation_type="SCRIPT",
306
+ metadata={"status_message": "SCRIPT EXECUTED"},
307
+ total_statements=executed_count,
308
+ successful_statements=executed_count,
309
+ )
262
310
 
263
311
  def _execute_single_script_statement(self, cursor: "Cursor", statement: str) -> int:
264
312
  """Execute a single statement from a script and handle errors.
@@ -273,7 +321,7 @@ class AdbcDriver(
273
321
  try:
274
322
  cursor.execute(statement)
275
323
  except Exception as e:
276
- # Rollback transaction on error for PostgreSQL
324
+ # Rollback transaction on error for PostgreSQL to avoid "current transaction is aborted" errors
277
325
  if self.dialect == "postgres":
278
326
  with contextlib.suppress(Exception):
279
327
  cursor.execute("ROLLBACK")
@@ -281,59 +329,6 @@ class AdbcDriver(
281
329
  else:
282
330
  return 1
283
331
 
284
- def _wrap_select_result(
285
- self, statement: SQL, result: SelectResultDict, schema_type: Optional[type[ModelDTOT]] = None, **kwargs: Any
286
- ) -> Union[SQLResult[ModelDTOT], SQLResult[RowT]]:
287
- # result must be a dict with keys: data, column_names, rows_affected
288
-
289
- rows_as_dicts = [dict(zip(result["column_names"], row)) for row in result["data"]]
290
-
291
- if schema_type:
292
- return SQLResult[ModelDTOT](
293
- statement=statement,
294
- data=list(self.to_schema(data=rows_as_dicts, schema_type=schema_type)),
295
- column_names=result["column_names"],
296
- rows_affected=result["rows_affected"],
297
- operation_type="SELECT",
298
- )
299
- return SQLResult[RowT](
300
- statement=statement,
301
- data=rows_as_dicts,
302
- column_names=result["column_names"],
303
- rows_affected=result["rows_affected"],
304
- operation_type="SELECT",
305
- )
306
-
307
- def _wrap_execute_result(
308
- self, statement: SQL, result: Union[DMLResultDict, ScriptResultDict], **kwargs: Any
309
- ) -> SQLResult[RowT]:
310
- operation_type = (
311
- str(statement.expression.key).upper()
312
- if statement.expression and hasattr(statement.expression, "key")
313
- else "UNKNOWN"
314
- )
315
-
316
- # Handle TypedDict results
317
- if is_dict_with_field(result, "statements_executed"):
318
- return SQLResult[RowT](
319
- statement=statement,
320
- data=[],
321
- rows_affected=0,
322
- total_statements=result["statements_executed"],
323
- operation_type="SCRIPT", # Scripts always have operation_type SCRIPT
324
- metadata={"status_message": result["status_message"]},
325
- )
326
- if is_dict_with_field(result, "rows_affected"):
327
- return SQLResult[RowT](
328
- statement=statement,
329
- data=[],
330
- rows_affected=result["rows_affected"],
331
- operation_type=operation_type,
332
- metadata={"status_message": result["status_message"]},
333
- )
334
- msg = f"Unexpected result type: {type(result)}"
335
- raise ValueError(msg)
336
-
337
332
  def _fetch_arrow_table(self, sql: SQL, connection: "Optional[Any]" = None, **kwargs: Any) -> "ArrowResult":
338
333
  """ADBC native Arrow table fetching.
339
334
 
@@ -379,10 +374,17 @@ class AdbcDriver(
379
374
 
380
375
  conn = self._connection(None)
381
376
  with self._get_cursor(conn) as cursor:
382
- # Handle different modes
383
377
  if mode == "replace":
384
- cursor.execute(SQL(f"TRUNCATE TABLE {table_name}").to_sql(placeholder_style=ParameterStyle.STATIC))
378
+ cursor.execute(
379
+ SQL(f"TRUNCATE TABLE {table_name}", _dialect=self.dialect).to_sql(
380
+ placeholder_style=ParameterStyle.STATIC
381
+ )
382
+ )
385
383
  elif mode == "create":
386
384
  msg = "'create' mode is not supported for ADBC ingestion"
387
385
  raise NotImplementedError(msg)
388
386
  return cursor.adbc_ingest(table_name, table, mode=mode, **options) # type: ignore[arg-type]
387
+
388
+ def _connection(self, connection: Optional["AdbcConnection"] = None) -> "AdbcConnection":
389
+ """Get the connection to use for the operation."""
390
+ return connection or self.connection
@@ -3,7 +3,6 @@
3
3
  import logging
4
4
  from collections.abc import AsyncGenerator
5
5
  from contextlib import asynccontextmanager
6
- from dataclasses import replace
7
6
  from typing import TYPE_CHECKING, Any, ClassVar, Optional
8
7
 
9
8
  import aiosqlite
@@ -172,17 +171,31 @@ class AiosqliteConfig(AsyncDatabaseConfig[AiosqliteConnection, None, AiosqliteDr
172
171
  An AiosqliteDriver instance.
173
172
  """
174
173
  async with self.provide_connection(*args, **kwargs) as connection:
175
- # Create statement config with parameter style info if not already set
176
174
  statement_config = self.statement_config
175
+ # Inject parameter style info if not already set
177
176
  if statement_config.allowed_parameter_styles is None:
177
+ from dataclasses import replace
178
+
178
179
  statement_config = replace(
179
180
  statement_config,
180
181
  allowed_parameter_styles=self.supported_parameter_styles,
181
182
  target_parameter_style=self.preferred_parameter_style,
182
183
  )
183
-
184
184
  yield self.driver_type(connection=connection, config=statement_config)
185
185
 
186
186
  async def provide_pool(self, *args: Any, **kwargs: Any) -> None:
187
187
  """Aiosqlite doesn't support pooling."""
188
188
  return
189
+
190
+ def get_signature_namespace(self) -> "dict[str, type[Any]]":
191
+ """Get the signature namespace for Aiosqlite types.
192
+
193
+ This provides all Aiosqlite-specific types that Litestar needs to recognize
194
+ to avoid serialization attempts.
195
+
196
+ Returns:
197
+ Dictionary mapping type names to types.
198
+ """
199
+ namespace = super().get_signature_namespace()
200
+ namespace.update({"AiosqliteConnection": AiosqliteConnection})
201
+ return namespace
@@ -3,11 +3,12 @@ import logging
3
3
  from collections.abc import AsyncGenerator, Sequence
4
4
  from contextlib import asynccontextmanager
5
5
  from pathlib import Path
6
- from typing import TYPE_CHECKING, Any, Optional, Union, cast
6
+ from typing import TYPE_CHECKING, Any, Optional
7
7
 
8
8
  import aiosqlite
9
9
 
10
10
  from sqlspec.driver import AsyncDriverAdapterProtocol
11
+ from sqlspec.driver.connection import managed_transaction_async
11
12
  from sqlspec.driver.mixins import (
12
13
  AsyncPipelinedExecutionMixin,
13
14
  AsyncStorageMixin,
@@ -15,10 +16,11 @@ from sqlspec.driver.mixins import (
15
16
  ToSchemaMixin,
16
17
  TypeCoercionMixin,
17
18
  )
18
- from sqlspec.statement.parameters import ParameterStyle
19
- from sqlspec.statement.result import DMLResultDict, ScriptResultDict, SelectResultDict, SQLResult
19
+ from sqlspec.driver.parameters import normalize_parameter_sequence
20
+ from sqlspec.statement.parameters import ParameterStyle, ParameterValidator
21
+ from sqlspec.statement.result import SQLResult
20
22
  from sqlspec.statement.sql import SQL, SQLConfig
21
- from sqlspec.typing import DictRow, ModelDTOT, RowT
23
+ from sqlspec.typing import DictRow, RowT
22
24
  from sqlspec.utils.serializers import to_json
23
25
 
24
26
  if TYPE_CHECKING:
@@ -97,22 +99,24 @@ class AiosqliteDriver(
97
99
 
98
100
  async def _execute_statement(
99
101
  self, statement: SQL, connection: Optional[AiosqliteConnection] = None, **kwargs: Any
100
- ) -> Union[SelectResultDict, DMLResultDict, ScriptResultDict]:
102
+ ) -> SQLResult[RowT]:
101
103
  if statement.is_script:
102
104
  sql, _ = statement.compile(placeholder_style=ParameterStyle.STATIC)
103
105
  return await self._execute_script(sql, connection=connection, **kwargs)
104
106
 
105
- # Determine if we need to convert parameter style
106
- detected_styles = {p.style for p in statement.parameter_info}
107
+ detected_styles = set()
108
+ sql_str = statement.to_sql(placeholder_style=None) # Get raw SQL
109
+ validator = self.config.parameter_validator if self.config else ParameterValidator()
110
+ param_infos = validator.extract_parameters(sql_str)
111
+ if param_infos:
112
+ detected_styles = {p.style for p in param_infos}
113
+
107
114
  target_style = self.default_parameter_style
108
115
 
109
- # Check if any detected style is not supported
110
116
  unsupported_styles = detected_styles - set(self.supported_parameter_styles)
111
117
  if unsupported_styles:
112
- # Convert to default style if we have unsupported styles
113
118
  target_style = self.default_parameter_style
114
119
  elif detected_styles:
115
- # Use the first detected style if all are supported
116
120
  # Prefer the first supported style found
117
121
  for style in detected_styles:
118
122
  if style in self.supported_parameter_styles:
@@ -122,85 +126,111 @@ class AiosqliteDriver(
122
126
  if statement.is_many:
123
127
  sql, params = statement.compile(placeholder_style=target_style)
124
128
 
125
- # Process parameter list through type coercion
126
129
  params = self._process_parameters(params)
127
130
 
128
131
  return await self._execute_many(sql, params, connection=connection, **kwargs)
129
132
 
130
133
  sql, params = statement.compile(placeholder_style=target_style)
131
134
 
132
- # Process parameters through type coercion
133
135
  params = self._process_parameters(params)
134
136
 
135
137
  return await self._execute(sql, params, statement, connection=connection, **kwargs)
136
138
 
137
139
  async def _execute(
138
140
  self, sql: str, parameters: Any, statement: SQL, connection: Optional[AiosqliteConnection] = None, **kwargs: Any
139
- ) -> Union[SelectResultDict, DMLResultDict]:
141
+ ) -> SQLResult[RowT]:
140
142
  conn = self._connection(connection)
141
- # Convert parameters to the format expected by the SQL
142
- # Note: SQL was already rendered with appropriate placeholder style in _execute_statement
143
- if ":param_" in sql or (parameters and isinstance(parameters, dict)):
144
- # SQL has named placeholders, ensure params are dict
145
- converted_params = self._convert_parameters_to_driver_format(
146
- sql, parameters, target_style=ParameterStyle.NAMED_COLON
147
- )
148
- else:
149
- # SQL has positional placeholders, ensure params are list/tuple
150
- converted_params = self._convert_parameters_to_driver_format(
151
- sql, parameters, target_style=ParameterStyle.QMARK
152
- )
153
- async with self._get_cursor(conn) as cursor:
154
- # Aiosqlite handles both dict and tuple parameters
155
- await cursor.execute(sql, converted_params or ())
156
- if self.returns_rows(statement.expression):
157
- fetched_data = await cursor.fetchall()
158
- column_names = [desc[0] for desc in cursor.description or []]
159
- # Convert to list of dicts or tuples as expected by TypedDict
160
- data_list: list[Any] = list(fetched_data) if fetched_data else []
161
- result: SelectResultDict = {
162
- "data": data_list,
163
- "column_names": column_names,
164
- "rows_affected": len(data_list),
165
- }
166
- return result
167
- dml_result: DMLResultDict = {"rows_affected": cursor.rowcount, "status_message": "OK"}
168
- return dml_result
143
+
144
+ async with managed_transaction_async(conn, auto_commit=True) as txn_conn:
145
+ normalized_params = normalize_parameter_sequence(parameters)
146
+
147
+ # Extract the actual parameters from the normalized list
148
+ if normalized_params and len(normalized_params) == 1:
149
+ actual_params = normalized_params[0]
150
+ else:
151
+ actual_params = normalized_params
152
+
153
+ # AIOSQLite expects tuple or dict - handle parameter conversion
154
+ if ":param_" in sql or (isinstance(actual_params, dict)):
155
+ # SQL has named placeholders, ensure params are dict
156
+ converted_params = self._convert_parameters_to_driver_format(
157
+ sql, actual_params, target_style=ParameterStyle.NAMED_COLON
158
+ )
159
+ else:
160
+ # SQL has positional placeholders, ensure params are list/tuple
161
+ converted_params = self._convert_parameters_to_driver_format(
162
+ sql, actual_params, target_style=ParameterStyle.QMARK
163
+ )
164
+
165
+ async with self._get_cursor(txn_conn) as cursor:
166
+ # Aiosqlite handles both dict and tuple parameters
167
+ await cursor.execute(sql, converted_params or ())
168
+ if self.returns_rows(statement.expression):
169
+ fetched_data = await cursor.fetchall()
170
+ column_names = [desc[0] for desc in cursor.description or []]
171
+ data_list: list[Any] = list(fetched_data) if fetched_data else []
172
+ return SQLResult(
173
+ statement=statement,
174
+ data=data_list,
175
+ column_names=column_names,
176
+ rows_affected=len(data_list),
177
+ operation_type="SELECT",
178
+ )
179
+
180
+ return SQLResult(
181
+ statement=statement,
182
+ data=[],
183
+ rows_affected=cursor.rowcount,
184
+ operation_type=self._determine_operation_type(statement),
185
+ metadata={"status_message": "OK"},
186
+ )
169
187
 
170
188
  async def _execute_many(
171
189
  self, sql: str, param_list: Any, connection: Optional[AiosqliteConnection] = None, **kwargs: Any
172
- ) -> DMLResultDict:
173
- conn = self._connection(connection)
174
- logger.debug("Executing SQL (executemany): %s", sql)
175
- if param_list:
176
- logger.debug("Query parameters (batch): %s", param_list)
177
-
178
- # Convert parameter list to proper format for executemany
179
- params_list: list[tuple[Any, ...]] = []
180
- if param_list and isinstance(param_list, Sequence):
181
- for param_set in param_list:
182
- param_set = cast("Any", param_set)
183
- if isinstance(param_set, (list, tuple)):
184
- params_list.append(tuple(param_set))
185
- elif param_set is None:
186
- params_list.append(())
187
-
188
- async with self._get_cursor(conn) as cursor:
189
- await cursor.executemany(sql, params_list)
190
- result: DMLResultDict = {"rows_affected": cursor.rowcount, "status_message": "OK"}
191
- return result
190
+ ) -> SQLResult[RowT]:
191
+ # Use provided connection or driver's default connection
192
+ conn = connection if connection is not None else self._connection(None)
193
+
194
+ async with managed_transaction_async(conn, auto_commit=True) as txn_conn:
195
+ # Normalize parameter list using consolidated utility
196
+ normalized_param_list = normalize_parameter_sequence(param_list)
197
+
198
+ params_list: list[tuple[Any, ...]] = []
199
+ if normalized_param_list and isinstance(normalized_param_list, Sequence):
200
+ for param_set in normalized_param_list:
201
+ if isinstance(param_set, (list, tuple)):
202
+ params_list.append(tuple(param_set))
203
+ elif param_set is None:
204
+ params_list.append(())
205
+
206
+ async with self._get_cursor(txn_conn) as cursor:
207
+ await cursor.executemany(sql, params_list)
208
+ return SQLResult(
209
+ statement=SQL(sql, _dialect=self.dialect),
210
+ data=[],
211
+ rows_affected=cursor.rowcount,
212
+ operation_type="EXECUTE",
213
+ metadata={"status_message": "OK"},
214
+ )
192
215
 
193
216
  async def _execute_script(
194
217
  self, script: str, connection: Optional[AiosqliteConnection] = None, **kwargs: Any
195
- ) -> ScriptResultDict:
196
- conn = self._connection(connection)
197
- async with self._get_cursor(conn) as cursor:
198
- await cursor.executescript(script)
199
- result: ScriptResultDict = {
200
- "statements_executed": -1, # AIOSQLite doesn't provide this info
201
- "status_message": "SCRIPT EXECUTED",
202
- }
203
- return result
218
+ ) -> SQLResult[RowT]:
219
+ # Use provided connection or driver's default connection
220
+ conn = connection if connection is not None else self._connection(None)
221
+
222
+ async with managed_transaction_async(conn, auto_commit=True) as txn_conn:
223
+ async with self._get_cursor(txn_conn) as cursor:
224
+ await cursor.executescript(script)
225
+ return SQLResult(
226
+ statement=SQL(script, _dialect=self.dialect).as_script(),
227
+ data=[],
228
+ rows_affected=0,
229
+ operation_type="SCRIPT",
230
+ metadata={"status_message": "SCRIPT EXECUTED"},
231
+ total_statements=-1, # AIOSQLite doesn't provide this info
232
+ successful_statements=-1,
233
+ )
204
234
 
205
235
  async def _bulk_load_file(self, file_path: Path, table_name: str, format: str, mode: str, **options: Any) -> int:
206
236
  """Database-specific bulk load implementation using storage backend."""
@@ -234,66 +264,6 @@ class AiosqliteDriver(
234
264
  finally:
235
265
  await conn.close()
236
266
 
237
- async def _wrap_select_result(
238
- self, statement: SQL, result: SelectResultDict, schema_type: "Optional[type[ModelDTOT]]" = None, **kwargs: Any
239
- ) -> Union[SQLResult[ModelDTOT], SQLResult[RowT]]:
240
- fetched_data = result["data"]
241
- column_names = result["column_names"]
242
- rows_affected = result["rows_affected"]
243
-
244
- rows_as_dicts: list[dict[str, Any]] = [dict(row) for row in fetched_data]
245
-
246
- if self.returns_rows(statement.expression):
247
- converted_data_seq = self.to_schema(data=rows_as_dicts, schema_type=schema_type)
248
- return SQLResult[ModelDTOT](
249
- statement=statement,
250
- data=list(converted_data_seq),
251
- column_names=column_names,
252
- rows_affected=rows_affected,
253
- operation_type="SELECT",
254
- )
255
- return SQLResult[RowT](
256
- statement=statement,
257
- data=rows_as_dicts,
258
- column_names=column_names,
259
- rows_affected=rows_affected,
260
- operation_type="SELECT",
261
- )
262
-
263
- async def _wrap_execute_result(
264
- self, statement: SQL, result: Union[DMLResultDict, ScriptResultDict], **kwargs: Any
265
- ) -> SQLResult[RowT]:
266
- operation_type = "UNKNOWN"
267
- if statement.expression:
268
- operation_type = str(statement.expression.key).upper()
269
-
270
- if "statements_executed" in result:
271
- script_result = cast("ScriptResultDict", result)
272
- return SQLResult[RowT](
273
- statement=statement,
274
- data=[],
275
- rows_affected=0,
276
- operation_type="SCRIPT",
277
- total_statements=script_result.get("statements_executed", -1),
278
- metadata={"status_message": script_result.get("status_message", "")},
279
- )
280
-
281
- if "rows_affected" in result:
282
- dml_result = cast("DMLResultDict", result)
283
- rows_affected = dml_result["rows_affected"]
284
- status_message = dml_result["status_message"]
285
- return SQLResult[RowT](
286
- statement=statement,
287
- data=[],
288
- rows_affected=rows_affected,
289
- operation_type=operation_type,
290
- metadata={"status_message": status_message},
291
- )
292
-
293
- # This shouldn't happen with TypedDict approach
294
- msg = f"Unexpected result type: {type(result)}"
295
- raise ValueError(msg)
296
-
297
267
  def _connection(self, connection: Optional[AiosqliteConnection] = None) -> AiosqliteConnection:
298
268
  """Get the connection to use for the operation."""
299
269
  return connection or self.connection